From: Georgi Gerganov Date: Tue, 25 Jul 2023 15:28:22 +0000 (+0300) Subject: ggml : sync llama.cpp (#415) X-Git-Tag: upstream/0.0.1642~1290 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=73536d9d35637ba87f9de4f383cc73872c925c44;p=pkg%2Fggml%2Fsources%2Fggml ggml : sync llama.cpp (#415) - faster graph build - inference speed-ups across GPU backends - activation functions relax constraints ggml-ci --- diff --git a/include/ggml/ggml.h b/include/ggml/ggml.h index de44fba9..c309f136 100644 --- a/include/ggml/ggml.h +++ b/include/ggml/ggml.h @@ -442,7 +442,7 @@ extern "C" { void * extra; // extra things e.g. for ggml-cuda.cu - char padding[8]; + char padding[4]; }; static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor); @@ -463,6 +463,11 @@ extern "C" { void * abort_callback_data; }; + // next prime after GGML_MAX_NODES + // #define GGML_GRAPH_HASHTABLE_SIZE 4099 + // next prime after GGML_MAX_NODES * 2 (nodes + leafs) + #define GGML_GRAPH_HASHTABLE_SIZE 8273 + // computation graph struct ggml_cgraph { int n_nodes; @@ -472,6 +477,8 @@ extern "C" { struct ggml_tensor * grads[GGML_MAX_NODES]; struct ggml_tensor * leafs[GGML_MAX_NODES]; + void * visited_hash_table[GGML_GRAPH_HASHTABLE_SIZE]; + // performance int perf_runs; int64_t perf_cycles; @@ -866,14 +873,17 @@ extern "C" { GGML_API struct ggml_tensor * ggml_rms_norm( struct ggml_context * ctx, - struct ggml_tensor * a); + struct ggml_tensor * a, + float eps); GGML_API struct ggml_tensor * ggml_rms_norm_inplace( struct ggml_context * ctx, - struct ggml_tensor * a); + struct ggml_tensor * a, + float eps); // a - x // b - dy + // TODO: update with configurable eps GGML_API struct ggml_tensor * ggml_rms_norm_back( struct ggml_context * ctx, struct ggml_tensor * a, diff --git a/src/ggml-cuda.cu b/src/ggml-cuda.cu index 0ab06ec9..d31fc79c 100644 --- a/src/ggml-cuda.cu +++ b/src/ggml-cuda.cu @@ -332,12 +332,10 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) { } } -static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols) { +static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) { const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = threadIdx.x; - const float eps = 1e-6f; - float tmp = 0.0f; // partial sum for thread in warp for (int col = tid; col < ncols; col += WARP_SIZE) { @@ -1073,10 +1071,12 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx, uint16_t aux[4]; const uint8_t * sc = (const uint8_t *)aux; + uint16_t q16[8]; + const uint8_t * q4 = (const uint8_t *)q16; + for (int i = ix; i < num_blocks_per_row; i += 2) { const uint8_t * ql1 = x[i].qs + q_offset; - const uint8_t * ql2 = ql1 + 64; const uint8_t * qh = x[i].qh + l0; const float * y1 = yy + i*QK_K + y_offset; const float * y2 = y1 + 128; @@ -1092,15 +1092,25 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx, float4 sum = {0.f, 0.f, 0.f, 0.f}; float smin = 0; + const uint16_t * q1 = (const uint16_t *)ql1; + const uint16_t * q2 = q1 + 32; + q16[0] = q1[0] & 0x0f0f; + q16[1] = q1[8] & 0x0f0f; + q16[2] = (q1[0] >> 4) & 0x0f0f; + q16[3] = (q1[8] >> 4) & 0x0f0f; + q16[4] = q2[0] & 0x0f0f; + q16[5] = q2[8] & 0x0f0f; + q16[6] = (q2[0] >> 4) & 0x0f0f; + q16[7] = (q2[8] >> 4) & 0x0f0f; for (int l = 0; l < n; ++l) { - sum.x += y1[l+ 0] * ((ql1[l+ 0] & 0xF) + (qh[l+ 0] & (hm1 << 0) ? 16 : 0)) - + y1[l+16] * ((ql1[l+16] & 0xF) + (qh[l+16] & (hm1 << 0) ? 16 : 0)); - sum.y += y1[l+32] * ((ql1[l+ 0] >> 4) + (qh[l+ 0] & (hm1 << 1) ? 16 : 0)) - + y1[l+48] * ((ql1[l+16] >> 4) + (qh[l+16] & (hm1 << 1) ? 16 : 0)); - sum.z += y2[l+ 0] * ((ql2[l+ 0] & 0xF) + (qh[l+ 0] & (hm2 << 0) ? 16 : 0)) - + y2[l+16] * ((ql2[l+16] & 0xF) + (qh[l+16] & (hm2 << 0) ? 16 : 0)); - sum.w += y2[l+32] * ((ql2[l+ 0] >> 4) + (qh[l+ 0] & (hm2 << 1) ? 16 : 0)) - + y2[l+48] * ((ql2[l+16] >> 4) + (qh[l+16] & (hm2 << 1) ? 16 : 0)); + sum.x += y1[l+ 0] * (q4[l +0] + (qh[l+ 0] & (hm1 << 0) ? 16 : 0)) + + y1[l+16] * (q4[l +2] + (qh[l+16] & (hm1 << 0) ? 16 : 0)); + sum.y += y1[l+32] * (q4[l +4] + (qh[l+ 0] & (hm1 << 1) ? 16 : 0)) + + y1[l+48] * (q4[l +6] + (qh[l+16] & (hm1 << 1) ? 16 : 0)); + sum.z += y2[l+ 0] * (q4[l +8] + (qh[l+ 0] & (hm2 << 0) ? 16 : 0)) + + y2[l+16] * (q4[l+10] + (qh[l+16] & (hm2 << 0) ? 16 : 0)); + sum.w += y2[l+32] * (q4[l+12] + (qh[l+ 0] & (hm2 << 1) ? 16 : 0)) + + y2[l+48] * (q4[l+14] + (qh[l+16] & (hm2 << 1) ? 16 : 0)); smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3] + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7]; } @@ -1554,15 +1564,25 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1( #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics const block_q4_K * bq4_K = (const block_q4_K *) vbq; - const int bq8_offset = QR4_K * (iqs / QI8_1); // 0, 2, 4, 6 - float sumf_d = 0.0f; float sumf_m = 0.0f; +#ifndef GGML_QKK_64 + + // iqs is in 0...15. bq8_offset = 2 * (iqs/4) -> bq8_offset = 0, 2, 4, 6 + const int bq8_offset = QR4_K * (iqs / (QI8_1/2)); + const float d = bq4_K->d; const float dmin = bq4_K->dmin; - const int v = *((int *) &bq4_K->qs[sizeof(int) * iqs]); + // iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12 + // iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44 + // iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76 + // iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108 + + const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * (iqs%4)); + const int v1 = q4[0]; + const int v2 = q4[4]; const uint16_t * scales = (const uint16_t *)bq4_K->scales; uint16_t aux[2]; @@ -1580,16 +1600,59 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1( for (int i = 0; i < QR4_K; ++i) { const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; - const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]); const float d8i = bq8i->d; + const int * q8 = (const int *)bq8i->qs + (iqs%4); + const int ui1 = q8[0]; + const int ui2 = q8[4]; - const int vi = (v >> (4*i)) & 0x0F0F0F0F; + const int vi1 = (v1 >> (4*i)) & 0x0F0F0F0F; + const int vi2 = (v2 >> (4*i)) & 0x0F0F0F0F; - sumf_d += d8i * (__dp4a(vi, ui, 0) * sc[i]); // SIMD dot product - sumf_m += d8i * (__dp4a(0x01010101, ui, 0) * m[i]); // multiply constant part of q4_K with sum of q8_1 values + const int dot1 = __dp4a(vi2, ui2, __dp4a(vi1, ui1, 0)); // SIMD dot product + const int dot2 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0)); + + sumf_d += d8i * (dot1 * sc[i]); + sumf_m += d8i * (dot2 * m[i]); // multiply constant part of q4_K with sum of q8_1 values } return d*sumf_d - dmin*sumf_m; + +#else + + uint16_t aux16[2]; + const uint8_t * s = (const uint8_t *)aux16; + + const uint16_t * a = (const uint16_t *)bq4_K->scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + + const float dall = bq4_K->d[0]; + const float dmin = bq4_K->d[1]; + + const float d8_1 = bq8_1[0].d; + const float d8_2 = bq8_1[1].d; + + const int ui1 = *((const int *)bq8_1[0].qs + iqs); + const int ui2 = *((const int *)bq8_1[0].qs + iqs + 4); + const int ui3 = *((const int *)bq8_1[1].qs + iqs); + const int ui4 = *((const int *)bq8_1[1].qs + iqs + 4); + + const int * q4 = (const int *)bq4_K->qs + iqs; + const int v1 = q4[0]; + const int v2 = q4[4]; + + const int dot1 = __dp4a(ui2, v2 & 0x0f0f0f0f, __dp4a(ui1, v1 & 0x0f0f0f0f, 0)); + const int dot2 = __dp4a(ui4, (v2 >> 4) & 0x0f0f0f0f, __dp4a(ui3, (v1 >> 4) & 0x0f0f0f0f, 0)); + const int dot3 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0)); + const int dot4 = __dp4a(0x01010101, ui4, __dp4a(0x01010101, ui3, 0)); + + sumf_d += d8_1 * (dot1 * s[0]) + d8_2 * (dot2 * s[1]); + sumf_m += d8_1 * (dot3 * s[2]) + d8_2 * (dot4 * s[3]); + + return dall * sumf_d - dmin * sumf_m; + +#endif + #else return 0.0f; // only to satisfy the compiler #endif // __CUDA_ARCH__ >= MIN_CC_DP4A @@ -1601,7 +1664,11 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1( #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics const block_q5_K * bq5_K = (const block_q5_K *) vbq; - const int bq8_offset = QR5_K * (iqs / QI8_1); +#ifndef GGML_QKK_64 + + const int bq8_offset = QR5_K * (iqs / (QI8_1/2)); + const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * (iqs%4)); + const int * qh = (const int *)(bq5_K->qh + 4 * (iqs%4)); float sumf_d = 0.0f; float sumf_m = 0.0f; @@ -1609,31 +1676,87 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1( const float d = bq5_K->d; const float dmin = bq5_K->dmin; - const int vl = *((int *) &bq5_K->qs[sizeof(int) * iqs]); + const int vl1 = ql[0]; + const int vl2 = ql[4]; - const int vh = (*((int *) &bq5_K->qh[sizeof(int) * (iqs % (QI5_K/4))])) >> bq8_offset; + const int vh1 = qh[0] >> bq8_offset; + const int vh2 = qh[4] >> bq8_offset; - for (int i = 0; i < QR5_K; ++i) { - const int isc = bq8_offset + i; + const uint16_t * scales = (const uint16_t *)bq5_K->scales; + uint16_t aux[2]; + const int j = bq8_offset/2; + if (j < 2) { + aux[0] = scales[j+0] & 0x3f3f; + aux[1] = scales[j+2] & 0x3f3f; + } else { + aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2); + aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2); + } + const uint8_t * sc = (const uint8_t *)aux; + const uint8_t * m = sc + 2; - uint8_t sc, m; - get_scale_min_k4(isc, bq5_K->scales, sc, m); + for (int i = 0; i < QR5_K; ++i) { const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; - const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]); const float d8i = bq8i->d; + const int * q8 = (const int *)bq8i->qs + (iqs%4); + const int ui1 = q8[0]; + const int ui2 = q8[4]; - const int vil = (vl >> (4*i)) & 0x0F0F0F0F; + const int vil1 = (vl1 >> (4*i)) & 0x0F0F0F0F; + const int vil2 = (vl2 >> (4*i)) & 0x0F0F0F0F; + + const int vih1 = ((vh1 >> i) << 4) & 0x10101010; + const int vih2 = ((vh2 >> i) << 4) & 0x10101010; + + const int vi1 = vil1 | vih1; + const int vi2 = vil2 | vih2; - const int vih = ((vh >> i) << 4) & 0x10101010; + const int dot1 = __dp4a(vi2, ui2, __dp4a(vi1, ui1, 0)); // SIMD dot product + const int dot2 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0)); - const int vi = vil | vih; + sumf_d += d8i * (dot1 * sc[i]); + sumf_m += d8i * (dot2 * m[i]); - sumf_d += d8i * (__dp4a(vi, ui, 0) * sc); // SIMD dot product - sumf_m += d8i * (__dp4a(0x01010101, ui, 0) * m); // multiply constant part of q5_K with sum of q8_1 values } return d*sumf_d - dmin*sumf_m; + +#else + + const int8_t * s = bq5_K->scales; + + const float d = bq5_K->d; + + const float d8_1 = bq8_1[0].d; + const float d8_2 = bq8_1[1].d; + + const int ui1 = *((const int *)bq8_1[0].qs + iqs); + const int ui2 = *((const int *)bq8_1[0].qs + iqs + 4); + const int ui3 = *((const int *)bq8_1[1].qs + iqs); + const int ui4 = *((const int *)bq8_1[1].qs + iqs + 4); + + const int * ql = (const int *)bq5_K->qs + iqs; + const int vl1 = ql[0]; + const int vl2 = ql[4]; + + const int step = 4 * iqs; // 0, 4, 8, 12 + const int im = step/8; // = 0 for iqs = 0, 1, = 1 for iqs = 2, 3 + const int in = step%8; // 0, 4, 0, 4 + const int vh = (*((const int *)(bq5_K->qh + in))) >> im; + + const int v1 = (((vh << 4) & 0x10101010) ^ 0x10101010) | ((vl1 >> 0) & 0x0f0f0f0f); + const int v2 = (((vh << 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 0) & 0x0f0f0f0f); + const int v3 = (((vh >> 0) & 0x10101010) ^ 0x10101010) | ((vl1 >> 4) & 0x0f0f0f0f); + const int v4 = (((vh >> 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 4) & 0x0f0f0f0f); + + const float sumf_d = d8_1 * (__dp4a(ui1, v1, 0) * s[0] + __dp4a(ui2, v2, 0) * s[1]) + + d8_2 * (__dp4a(ui3, v3, 0) * s[2] + __dp4a(ui4, v4, 0) * s[3]); + + return d * sumf_d; + +#endif + #else return 0.0f; // only to satisfy the compiler #endif // __CUDA_ARCH__ >= MIN_CC_DP4A @@ -2074,10 +2197,10 @@ static void norm_f32_cuda(const float * x, float * dst, const int ncols, const i norm_f32<<>>(x, dst, ncols); } -static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { GGML_ASSERT(ncols % WARP_SIZE == 0); const dim3 block_dims(WARP_SIZE, 1, 1); - rms_norm_f32<<>>(x, dst, ncols); + rms_norm_f32<<>>(x, dst, ncols, eps); } static void quantize_row_q8_1_cuda(const float * x, void * vy, const int ndata, const int k, cudaStream_t stream) { @@ -2306,7 +2429,10 @@ static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float * const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q + // Note: we use QI4_K/2 instead of QI4_K to make the dot product template require 4 groups of quants to be processed per + // kernel call instead of 2. This results in a better perfmance because the cost of computing the k-quant scales + // is better amortized. + mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -2315,7 +2441,10 @@ static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float * const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q + // Note: we use QI5_K/2 instead of QI5_K to make the dot product template require 4 groups of quants to be processed per + // kernel call instead of 2. This results in a better perfmance because the cost of computing the k-quant scales + // is better amortized. + mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -2822,8 +2951,11 @@ inline void ggml_cuda_op_rms_norm( const int64_t ne00 = src0->ne[0]; const int64_t i01_diff = i01_high - i01_low; + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + // compute - rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main); + rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, eps, cudaStream_main); (void) src1; (void) dst; diff --git a/src/ggml-metal.h b/src/ggml-metal.h index 928f1705..16f1a0ca 100644 --- a/src/ggml-metal.h +++ b/src/ggml-metal.h @@ -61,6 +61,13 @@ void ggml_metal_set_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * // get data from the device into host memory void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t); +// try to find operations that can be run concurrently in the graph +// you should run it again if the topology of your graph changes +void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf); + +// if the graph has been optimized for concurrently dispatch +bool ggml_metal_if_optimized(struct ggml_metal_context * ctx); + // same as ggml_graph_compute but uses Metal // creates gf->n_threads command buffers in parallel void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf); diff --git a/src/ggml-metal.m b/src/ggml-metal.m index 1fd6e857..74a6bff4 100644 --- a/src/ggml-metal.m +++ b/src/ggml-metal.m @@ -36,6 +36,9 @@ struct ggml_metal_context { int n_buffers; struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS]; + int concur_list[GGML_MAX_NODES]; + int concur_list_len; + // custom kernels #define GGML_METAL_DECL_KERNEL(name) \ id function_##name; \ @@ -98,6 +101,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { ctx->device = MTLCreateSystemDefaultDevice(); ctx->queue = [ctx->device newCommandQueue]; ctx->n_buffers = 0; + ctx->concur_list_len = 0; // determine if we can use MPS if (MPSSupportsMTLDevice(ctx->device)) { @@ -217,6 +221,13 @@ void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) { ctx->n_cb = n_cb; } +bool ggml_metal_if_optimized(struct ggml_metal_context * ctx) { + if (ctx->concur_list_len) { + return true; + } + return false; +} + // finds the Metal buffer that contains the tensor data on the GPU device // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the // Metal buffer based on the host memory pointer @@ -355,11 +366,98 @@ void ggml_metal_get_tensor( memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), ggml_nbytes(t)); } +void ggml_metal_graph_find_concurrency( + struct ggml_metal_context * ctx, + struct ggml_cgraph * gf) { + int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time + int nodes_unused[GGML_MAX_NODES]; + + for (int i = 0; i < GGML_MAX_NODES; i++) {ctx->concur_list[i] = 0;} + for (int i = 0; i < gf->n_nodes; i++) {nodes_unused[i] = 1;} + ctx->concur_list_len = 0; + + int n_left = gf->n_nodes; + int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list + int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos + + while (n_left > 0) { + // number of nodes at a layer (that can be issued concurrently) + int concurrency = 0; + for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) { + if (nodes_unused[i]) { + // if the requirements for gf->nodes[i] are satisfied + int exe_flag=1; + // scan all srcs + for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) { + struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind]; + if (src_cur) { + // if is leaf nodes it's satisfied. + if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {continue;} + + // otherwise this src should be the output from previous nodes. + int is_found = 0; + // scan 2*search_depth back because we inserted barrier. + for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) { + if (gf->nodes[ctx->concur_list[j]] == src_cur) {is_found = 1; break;} + } + if (is_found == 0) {exe_flag = 0; break;} + } + } + if (exe_flag) { + // check if nodes[i]'s data will be overwritten by a node before nodes[i]. + // if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3] + int64_t data_start = (int64_t) gf->nodes[i]->data; + int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]); + for (int j = n_start; j < i; j++) { + if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \ + && gf->nodes[j]->op != GGML_OP_VIEW \ + && gf->nodes[j]->op != GGML_OP_TRANSPOSE \ + && gf->nodes[j]->op != GGML_OP_PERMUTE) { + if (((int64_t)gf->nodes[j]->data) >= data_start + length || \ + ((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) { + continue; + } else { + exe_flag = 0; + } + } + } + } + if (exe_flag) { + ctx->concur_list[level_pos + concurrency] = i; + nodes_unused[i] = 0; + concurrency++; + ctx->concur_list_len++; + } + } + } + n_left -= concurrency; + // adding a barrier different layer + ctx->concur_list[level_pos + concurrency] = -1; + ctx->concur_list_len++; + // jump all sorted nodes at nodes_bak + while (!nodes_unused[n_start]) {n_start++;} + level_pos += concurrency + 1; + } + + if (ctx->concur_list_len > GGML_MAX_NODES) { + fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__); + } +} + void ggml_metal_graph_compute( struct ggml_metal_context * ctx, struct ggml_cgraph * gf) { metal_printf("%s: evaluating graph\n", __func__); + // if there is ctx->concur_list, dispatch concurrently + // else fallback to serial dispatch + MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor; + + const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_NODES; + + const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes; + edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial; + // create multiple command buffers and enqueue them // then, we encode the graph into the command buffers in parallel @@ -378,7 +476,7 @@ void ggml_metal_graph_compute( dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT); for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { - const int n_nodes_per_cb = (gf->n_nodes + n_cb - 1) / n_cb; + const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb; dispatch_async(queue, ^{ size_t offs_src0 = 0; @@ -389,10 +487,21 @@ void ggml_metal_graph_compute( id encoder = nil; - const int node_start = (cb_idx + 0) * n_nodes_per_cb; - const int node_end = (cb_idx == n_cb - 1) ? gf->n_nodes : (cb_idx + 1) * n_nodes_per_cb; + const int node_start = (cb_idx + 0) * n_nodes_per_cb; + const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb; + + for (int ind = node_start; ind < node_end; ++ind) { + const int i = has_concur ? ctx->concur_list[ind] : ind; + + if (i == -1) { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; + continue; + } + [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers]; + continue; + } - for (int i = node_start; i < node_end; ++i) { metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op)); struct ggml_tensor * src0 = gf->nodes[i]->src[0]; @@ -463,7 +572,7 @@ void ggml_metal_graph_compute( case GGML_OP_ADD: { if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; } if (ggml_nelements(src1) == ne10) { @@ -484,7 +593,7 @@ void ggml_metal_graph_compute( case GGML_OP_MUL: { if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; } if (ggml_nelements(src1) == ne10) { @@ -505,7 +614,7 @@ void ggml_metal_graph_compute( case GGML_OP_SCALE: { if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; } const float scale = *(const float *) src1->data; @@ -524,7 +633,7 @@ void ggml_metal_graph_compute( case GGML_UNARY_OP_SILU: { if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; } [encoder setComputePipelineState:ctx->pipeline_silu]; @@ -538,7 +647,7 @@ void ggml_metal_graph_compute( case GGML_UNARY_OP_RELU: { if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; } [encoder setComputePipelineState:ctx->pipeline_relu]; @@ -552,7 +661,7 @@ void ggml_metal_graph_compute( case GGML_UNARY_OP_GELU: { if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; } [encoder setComputePipelineState:ctx->pipeline_gelu]; @@ -572,7 +681,7 @@ void ggml_metal_graph_compute( case GGML_OP_SOFT_MAX: { if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; } const int nth = 32; @@ -590,7 +699,7 @@ void ggml_metal_graph_compute( case GGML_OP_DIAG_MASK_INF: { if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; } const int n_past = ((int32_t *)(dst->op_params))[0]; @@ -653,7 +762,7 @@ void ggml_metal_graph_compute( } } else { if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; } int nth0 = 32; @@ -780,7 +889,7 @@ void ggml_metal_graph_compute( case GGML_OP_GET_ROWS: { if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; } switch (src0->type) { @@ -809,10 +918,11 @@ void ggml_metal_graph_compute( case GGML_OP_RMS_NORM: { if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; } - const float eps = 1e-6f; + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); const int nth = 512; @@ -831,7 +941,7 @@ void ggml_metal_graph_compute( case GGML_OP_NORM: { if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; } const float eps = 1e-5f; @@ -853,7 +963,7 @@ void ggml_metal_graph_compute( case GGML_OP_ALIBI: { if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; } GGML_ASSERT((src0t == GGML_TYPE_F32)); @@ -896,7 +1006,7 @@ void ggml_metal_graph_compute( case GGML_OP_ROPE: { if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; } const int n_past = ((int32_t *) dst->op_params)[0]; @@ -940,7 +1050,7 @@ void ggml_metal_graph_compute( case GGML_OP_CONT: { if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; } const int nth = 32; diff --git a/src/ggml-metal.metal b/src/ggml-metal.metal index 987376d5..696b33ce 100644 --- a/src/ggml-metal.metal +++ b/src/ggml-metal.metal @@ -387,87 +387,90 @@ kernel void kernel_rms_norm( } } -// function for calculate inner product between a q4_0 block and 32 floats (yl), sumy is SUM(yl[i]) -float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl) { +// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q4 quants begin (0 or QK4_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) { float d = qb_curr->d; - float4 acc = 0.f; - device uint16_t * qs = ((device uint16_t *)qb_curr + 1); - for (int i = 0; i < 16; i+=2) { - acc[0] += yl[i] * (qs[i / 2] & 0x000F); - acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0); - acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00); - acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000); + float2 acc = 0.f; + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2); + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) + + yl[i + 1] * (qs[i / 2] & 0x0F00); + acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) + + yl[i + 9] * (qs[i / 2] & 0xF000); } - return d * (sumy * -8.f + acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f); + return d * (sumy * -8.f + acc[0] + acc[1]); } -// function for calculate inner product between a q4_1 block and 32 floats (yl), sumy is SUM(yl[i]) -float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl) { +// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q4 quants begin (0 or QK4_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) { float d = qb_curr->d; float m = qb_curr->m; - float4 acc = 0.f; - device uint16_t * qs = ((device uint16_t *)qb_curr + 2); - for (int i = 0; i < 16; i+=2) { - acc[0] += yl[i] * (qs[i / 2] & 0x000F); - acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0); - acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00); - acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000); + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); + float2 acc = 0.f; + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) + + yl[i + 1] * (qs[i / 2] & 0x0F00); + acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) + + yl[i + 9] * (qs[i / 2] & 0xF000); } - return d * (acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f) + sumy * m; + return d * (acc[0] + acc[1]) + sumy * m; } // putting them in the kernel cause a significant performance penalty #define N_DST 4 // each SIMD group works on 4 rows #define N_SIMDGROUP 2 // number of SIMD groups in a thread group #define N_SIMDWIDTH 32 // assuming SIMD group size is 32 -template +//Note: This is a template, but strictly speaking it only applies to +// quantizations where the block size is 32. It also does not +// giard against the number of rows not being divisible by +// N_DST, so this is another explicit assumption of the implementation. +template void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst, int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01, uint2 tgpig, uint tiisg, uint sgitg) { const int nb = ne00/QK4_0; const int r0 = tgpig.x; const int r1 = tgpig.y; - device const block_q_type * x = (device const block_q_type *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb; + const int first_row = (r0 * nsg + sgitg) * nr; + device const block_q_type * x = (device const block_q_type *) src0 + first_row * nb; device const float * y = (device const float *) src1 + r1*ne10; - float4 y_curr[8]; // src1 vector cache - float sumf[N_DST]={0.f}, all_sum; - thread float * yl=(thread float *)y_curr; + float yl[16]; // src1 vector cache + float sumf[nr]={0.f}; - // each thread in a SIMD group deals with 1 block. - for (int column = 0; column < nb / N_SIMDWIDTH; column++) { - float sumy = 0; - for (int i = 0; i < QK4_0 / 4; i++) { - y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0)) + i); - sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; - } + const int ix = tiisg/2; + const int il = 8*(tiisg%2); - for (int row = 0; row < N_DST; row++) { - sumf[row] += block_q_n_dot_y(x+(tiisg + row * nb + column * N_SIMDWIDTH), sumy, yl); - } - } + device const float * yb = y + ix * QK4_0 + il; - // from now loads two rows every time and 16 blocks per row - int ir = tiisg / (N_SIMDWIDTH / 2); - int ib = tiisg % (N_SIMDWIDTH / 2); - for (int ind = 0; ind < (nb % N_SIMDWIDTH + N_SIMDWIDTH / 2 - 1)/(N_SIMDWIDTH / 2); ind++) { - int nb_start = (nb / N_SIMDWIDTH) * N_SIMDWIDTH + ind * (N_SIMDWIDTH / 2); //where the left blocks start + // each thread in a SIMD group deals with half a block. + for (int ib = ix; ib < nb; ib += nw/2) { float sumy = 0; - for (int i = 0; i < QK4_0 / 4; i++) { - y_curr[i] = *((device float4 *)(y + (nb_start + ib) * QK4_0) + i); - sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; + for (int i = 0; i < 8; i += 2) { + sumy += yb[i] + yb[i+1]; + yl[i+0] = yb[i+ 0]; + yl[i+1] = yb[i+ 1]/256.f; + sumy += yb[i+16] + yb[i+17]; + yl[i+8] = yb[i+16]/16.f; + yl[i+9] = yb[i+17]/4096.f; } - for (int row = 0; row < N_DST; row+=2) { - if (nb_start + ib < nb) { - sumf[row + ir] += block_q_n_dot_y(x + (nb_start + ib + (row + ir) * nb), sumy, yl); - } + for (int row = 0; row < nr; row++) { + sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il); } + + yb += QK4_0 * 16; } - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) { - dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum; + for (int row = 0; row < nr; ++row) { + const float tot = simd_sum(sumf[row]); + if (tiisg == 0 && first_row + row < ne01) { + dst[r1*ne0 + first_row + row] = tot; } } } @@ -483,7 +486,7 @@ kernel void kernel_mul_mat_q4_0_f32( uint2 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg); + mul_vec_q_n_f32(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg); } kernel void kernel_mul_mat_q4_1_f32( @@ -497,7 +500,7 @@ kernel void kernel_mul_mat_q4_1_f32( uint2 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg); + mul_vec_q_n_f32(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg); } kernel void kernel_mul_mat_f16_f32( diff --git a/src/ggml.c b/src/ggml.c index 960b8057..35c56151 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -4229,6 +4229,15 @@ bool ggml_is_contiguous(const struct ggml_tensor * tensor) { tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; } +static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return + tensor->nb[0] == GGML_TYPE_SIZE[tensor->type] && + tensor->nb[2] == tensor->nb[1]*tensor->ne[1] && + tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; +} + bool ggml_is_permuted(const struct ggml_tensor * tensor) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); @@ -5781,6 +5790,7 @@ struct ggml_tensor * ggml_norm_inplace( static struct ggml_tensor * ggml_rms_norm_impl( struct ggml_context * ctx, struct ggml_tensor * a, + float eps, bool inplace) { bool is_node = false; @@ -5790,7 +5800,7 @@ static struct ggml_tensor * ggml_rms_norm_impl( struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - // TODO: maybe store epsilon here? + ggml_set_op_params(result, &eps, sizeof(eps)); result->op = GGML_OP_RMS_NORM; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; @@ -5801,14 +5811,16 @@ static struct ggml_tensor * ggml_rms_norm_impl( struct ggml_tensor * ggml_rms_norm( struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_rms_norm_impl(ctx, a, false); + struct ggml_tensor * a, + float eps) { + return ggml_rms_norm_impl(ctx, a, eps, false); } struct ggml_tensor * ggml_rms_norm_inplace( struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_rms_norm_impl(ctx, a, true); + struct ggml_tensor * a, + float eps) { + return ggml_rms_norm_impl(ctx, a, eps, true); } struct ggml_tensor * ggml_rms_norm_back( @@ -7018,14 +7030,16 @@ struct ggml_tensor * ggml_flash_attn( } //struct ggml_tensor * result = ggml_dup_tensor(ctx, q); - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, q->ne); + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, q->n_dims, q->ne); + + int32_t t = masked ? 1 : 0; + ggml_set_op_params(result, &t, sizeof(t)); result->op = GGML_OP_FLASH_ATTN; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = q; result->src[1] = k; result->src[2] = v; - result->src[3] = ggml_new_i32(ctx, masked ? 1 : 0); return result; } @@ -7049,7 +7063,7 @@ struct ggml_tensor * ggml_flash_ff( } //struct ggml_tensor * result = ggml_dup_tensor(ctx, a); - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, a->ne); + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, a->n_dims, a->ne); result->op = GGML_OP_FLASH_FF; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; @@ -7115,13 +7129,15 @@ struct ggml_tensor * ggml_flash_attn_back( struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + int32_t masked_i = masked ? 1 : 0; + ggml_set_op_params(result, &masked_i, sizeof(masked_i)); + result->op = GGML_OP_FLASH_ATTN_BACK; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = q; result->src[1] = k; result->src[2] = v; result->src[3] = d; - result->src[4] = ggml_new_i32(ctx, masked ? 1 : 0); return result; } @@ -9811,8 +9827,8 @@ static void ggml_compute_forward_gelu_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0)); + GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst)); GGML_ASSERT(ggml_are_same_shape(src0, dst)); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { @@ -9870,8 +9886,8 @@ static void ggml_compute_forward_gelu_quick_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0)); + GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst)); GGML_ASSERT(ggml_are_same_shape(src0, dst)); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { @@ -9929,8 +9945,8 @@ static void ggml_compute_forward_silu_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0)); + GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst)); GGML_ASSERT(ggml_are_same_shape(src0, dst)); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { @@ -9989,9 +10005,9 @@ static void ggml_compute_forward_silu_back_f32( const struct ggml_tensor * src0, const struct ggml_tensor * grad, struct ggml_tensor * dst) { - GGML_ASSERT(ggml_is_contiguous(grad)); - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_is_contiguous_except_dim_1(grad)); + GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0)); + GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst)); GGML_ASSERT(ggml_are_same_shape(src0, dst)); GGML_ASSERT(ggml_are_same_shape(src0, grad)); @@ -10131,7 +10147,8 @@ static void ggml_compute_forward_rms_norm_f32( GGML_TENSOR_UNARY_OP_LOCALS; - const float eps = 1e-6f; // TODO: make this a parameter + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); // TODO: optimize for (int64_t i03 = 0; i03 < ne03; i03++) { @@ -14760,7 +14777,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_FLASH_ATTN: { - const int32_t t = ggml_get_i32_1d(tensor->src[3], 0); + const int32_t t = ggml_get_op_params_i32(tensor, 0); GGML_ASSERT(t == 0 || t == 1); const bool masked = t != 0; ggml_compute_forward_flash_attn(params, tensor->src[0], tensor->src[1], tensor->src[2], masked, tensor); @@ -14771,7 +14788,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_FLASH_ATTN_BACK: { - int32_t t = ggml_get_i32_1d(tensor->src[4], 0); + int32_t t = ggml_get_op_params_i32(tensor, 0); GGML_ASSERT(t == 0 || t == 1); bool masked = t != 0; ggml_compute_forward_flash_attn_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], masked, tensor); @@ -15389,7 +15406,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { struct ggml_tensor * flash_grad = NULL; if (src0->grad || src1->grad || tensor->src[2]->grad) { - int32_t t = ggml_get_i32_1d(tensor->src[3], 0); + int32_t t = ggml_get_op_params_i32(tensor, 0); GGML_ASSERT(t == 0 || t == 1); bool masked = t != 0; flash_grad = @@ -15661,6 +15678,34 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor } } +static_assert(GGML_GRAPH_HASHTABLE_SIZE > GGML_MAX_NODES * 2, "GGML_GRAPH_HT_SIZE is too small"); + +static size_t hash(void * p) { + return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE; +} + +static bool hash_insert(void * hash_table[], void * p) { + size_t h = hash(p); + + // linear probing + size_t i = h; + while (hash_table[i] != NULL && hash_table[i] != p) { + i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE; + if (i == h) { + // hash table is full + GGML_ASSERT(false); + } + } + + if (hash_table[i] == p) { + return true; + } + + // insert + hash_table[i] = p; + return false; +} + static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) { if (node->grad == NULL) { // this usually happens when we generate intermediate nodes from constants in the backward pass @@ -15671,16 +15716,8 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * } // check if already visited - for (int i = 0; i < cgraph->n_nodes; i++) { - if (cgraph->nodes[i] == node) { - return; - } - } - - for (int i = 0; i < cgraph->n_leafs; i++) { - if (cgraph->leafs[i] == node) { - return; - } + if (hash_insert(cgraph->visited_hash_table, node)) { + return; } for (int i = 0; i < GGML_MAX_SRC; ++i) { @@ -15743,6 +15780,7 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) { /*.nodes =*/ { NULL }, /*.grads =*/ { NULL }, /*.leafs =*/ { NULL }, + /*.hash_table =*/ { NULL }, /*.perf_runs =*/ 0, /*.perf_cycles =*/ 0, /*.perf_time_us =*/ 0, @@ -15784,7 +15822,7 @@ struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cg if (node->is_param) { GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node); - ggml_build_forward_impl(&result, node->grad, true); + ggml_build_forward_expand(&result, node->grad); } } diff --git a/tests/test-grad0.c b/tests/test-grad0.c index ef20bce5..6d312216 100644 --- a/tests/test-grad0.c +++ b/tests/test-grad0.c @@ -850,7 +850,7 @@ int main(int argc, const char ** argv) { ggml_set_param(ctx0, x[i]); } - struct ggml_tensor * f = ggml_sum(ctx0, ggml_rms_norm(ctx0, x[0])); + struct ggml_tensor * f = ggml_sum(ctx0, ggml_rms_norm(ctx0, x[0], 1e-6f)); check_gradient("rms_norm", ctx0, x, f, ndims, nargs, 1e-4f, 1.0f, INFINITY); }