]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : sync llama.cpp (#415)
authorGeorgi Gerganov <redacted>
Tue, 25 Jul 2023 15:28:22 +0000 (18:28 +0300)
committerGitHub <redacted>
Tue, 25 Jul 2023 15:28:22 +0000 (18:28 +0300)
- faster graph build
- inference speed-ups across GPU backends
- activation functions relax constraints

ggml-ci

include/ggml/ggml.h
src/ggml-cuda.cu
src/ggml-metal.h
src/ggml-metal.m
src/ggml-metal.metal
src/ggml.c
tests/test-grad0.c

index de44fba9e0961886ba2d35046e60dc7934f05961..c309f1361c6f6c5494369b19fb611754f46b7dda 100644 (file)
@@ -442,7 +442,7 @@ extern "C" {
 
         void * extra; // extra things e.g. for ggml-cuda.cu
 
-        char padding[8];
+        char padding[4];
     };
 
     static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
@@ -463,6 +463,11 @@ extern "C" {
         void * abort_callback_data;
     };
 
+    // next prime after GGML_MAX_NODES
+    // #define GGML_GRAPH_HASHTABLE_SIZE 4099
+    // next prime after GGML_MAX_NODES * 2 (nodes + leafs)
+    #define GGML_GRAPH_HASHTABLE_SIZE 8273
+
     // computation graph
     struct ggml_cgraph {
         int n_nodes;
@@ -472,6 +477,8 @@ extern "C" {
         struct ggml_tensor * grads[GGML_MAX_NODES];
         struct ggml_tensor * leafs[GGML_MAX_NODES];
 
+        void * visited_hash_table[GGML_GRAPH_HASHTABLE_SIZE];
+
         // performance
         int     perf_runs;
         int64_t perf_cycles;
@@ -866,14 +873,17 @@ extern "C" {
 
     GGML_API struct ggml_tensor * ggml_rms_norm(
             struct ggml_context * ctx,
-            struct ggml_tensor  * a);
+            struct ggml_tensor  * a,
+            float                 eps);
 
     GGML_API struct ggml_tensor * ggml_rms_norm_inplace(
             struct ggml_context * ctx,
-            struct ggml_tensor  * a);
+            struct ggml_tensor  * a,
+            float                 eps);
 
     // a - x
     // b - dy
+    // TODO: update with configurable eps
     GGML_API struct ggml_tensor * ggml_rms_norm_back(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
index 0ab06ec9a4b6d6ebceab943763143af6b75b417a..d31fc79c10961de17704b317b7489f7a0c5e1c9e 100644 (file)
@@ -332,12 +332,10 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
     }
 }
 
-static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols) {
+static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
     const int row = blockIdx.x*blockDim.y + threadIdx.y;
     const int tid = threadIdx.x;
 
-    const float eps = 1e-6f;
-
     float tmp = 0.0f; // partial sum for thread in warp
 
     for (int col = tid; col < ncols; col += WARP_SIZE) {
@@ -1073,10 +1071,12 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx,
     uint16_t aux[4];
     const uint8_t * sc = (const uint8_t *)aux;
 
+    uint16_t q16[8];
+    const uint8_t * q4 = (const uint8_t *)q16;
+
     for (int i = ix; i < num_blocks_per_row; i += 2) {
 
         const uint8_t * ql1 = x[i].qs + q_offset;
-        const uint8_t * ql2 = ql1 + 64;
         const uint8_t * qh  = x[i].qh + l0;
         const float   * y1  = yy + i*QK_K + y_offset;
         const float   * y2  = y1 + 128;
@@ -1092,15 +1092,25 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx,
 
         float4 sum = {0.f, 0.f, 0.f, 0.f};
         float smin = 0;
+        const uint16_t * q1 = (const uint16_t *)ql1;
+        const uint16_t * q2 = q1 + 32;
+        q16[0] = q1[0] & 0x0f0f;
+        q16[1] = q1[8] & 0x0f0f;
+        q16[2] = (q1[0] >> 4) & 0x0f0f;
+        q16[3] = (q1[8] >> 4) & 0x0f0f;
+        q16[4] = q2[0] & 0x0f0f;
+        q16[5] = q2[8] & 0x0f0f;
+        q16[6] = (q2[0] >> 4) & 0x0f0f;
+        q16[7] = (q2[8] >> 4) & 0x0f0f;
         for (int l = 0; l < n; ++l) {
-            sum.x += y1[l+ 0] * ((ql1[l+ 0] & 0xF) + (qh[l+ 0] & (hm1 << 0) ? 16 : 0))
-                   + y1[l+16] * ((ql1[l+16] & 0xF) + (qh[l+16] & (hm1 << 0) ? 16 : 0));
-            sum.y += y1[l+32] * ((ql1[l+ 0] >>  4) + (qh[l+ 0] & (hm1 << 1) ? 16 : 0))
-                   + y1[l+48] * ((ql1[l+16] >>  4) + (qh[l+16] & (hm1 << 1) ? 16 : 0));
-            sum.z += y2[l+ 0] * ((ql2[l+ 0] & 0xF) + (qh[l+ 0] & (hm2 << 0) ? 16 : 0))
-                   + y2[l+16] * ((ql2[l+16] & 0xF) + (qh[l+16] & (hm2 << 0) ? 16 : 0));
-            sum.w += y2[l+32] * ((ql2[l+ 0] >>  4) + (qh[l+ 0] & (hm2 << 1) ? 16 : 0))
-                   + y2[l+48] * ((ql2[l+16] >>  4) + (qh[l+16] & (hm2 << 1) ? 16 : 0));
+            sum.x += y1[l+ 0] * (q4[l +0] + (qh[l+ 0] & (hm1 << 0) ? 16 : 0))
+                   + y1[l+16] * (q4[l +2] + (qh[l+16] & (hm1 << 0) ? 16 : 0));
+            sum.y += y1[l+32] * (q4[l +4] + (qh[l+ 0] & (hm1 << 1) ? 16 : 0))
+                   + y1[l+48] * (q4[l +6] + (qh[l+16] & (hm1 << 1) ? 16 : 0));
+            sum.z += y2[l+ 0] * (q4[l +8] + (qh[l+ 0] & (hm2 << 0) ? 16 : 0))
+                   + y2[l+16] * (q4[l+10] + (qh[l+16] & (hm2 << 0) ? 16 : 0));
+            sum.w += y2[l+32] * (q4[l+12] + (qh[l+ 0] & (hm2 << 1) ? 16 : 0))
+                   + y2[l+48] * (q4[l+14] + (qh[l+16] & (hm2 << 1) ? 16 : 0));
             smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3]
                   + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
         }
@@ -1554,15 +1564,25 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
 #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
     const block_q4_K * bq4_K = (const block_q4_K *) vbq;
 
-    const int bq8_offset = QR4_K * (iqs / QI8_1); // 0, 2, 4, 6
-
     float sumf_d = 0.0f;
     float sumf_m = 0.0f;
 
+#ifndef GGML_QKK_64
+
+    // iqs is in 0...15. bq8_offset = 2 * (iqs/4) -> bq8_offset = 0, 2, 4, 6
+    const int bq8_offset = QR4_K * (iqs / (QI8_1/2));
+
     const float    d = bq4_K->d;
     const float dmin = bq4_K->dmin;
 
-    const int v = *((int *) &bq4_K->qs[sizeof(int) * iqs]);
+    // iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12
+    // iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44
+    // iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76
+    // iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108
+
+    const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * (iqs%4));
+    const int v1 = q4[0];
+    const int v2 = q4[4];
 
     const uint16_t * scales = (const uint16_t *)bq4_K->scales;
     uint16_t aux[2];
@@ -1580,16 +1600,59 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
     for (int i = 0; i < QR4_K; ++i) {
 
         const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
-        const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]);
         const float d8i = bq8i->d;
+        const int * q8 = (const int *)bq8i->qs + (iqs%4);
+        const int ui1 = q8[0];
+        const int ui2 = q8[4];
 
-        const int vi = (v >> (4*i)) & 0x0F0F0F0F;
+        const int vi1 = (v1 >> (4*i)) & 0x0F0F0F0F;
+        const int vi2 = (v2 >> (4*i)) & 0x0F0F0F0F;
 
-        sumf_d += d8i * (__dp4a(vi,         ui, 0) * sc[i]); // SIMD dot product
-        sumf_m += d8i * (__dp4a(0x01010101, ui, 0) * m[i]);  // multiply constant part of q4_K with sum of q8_1 values
+        const int dot1 = __dp4a(vi2, ui2, __dp4a(vi1, ui1, 0)); // SIMD dot product
+        const int dot2 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0));
+
+        sumf_d += d8i * (dot1 * sc[i]);
+        sumf_m += d8i * (dot2 * m[i]);  // multiply constant part of q4_K with sum of q8_1 values
     }
 
     return d*sumf_d - dmin*sumf_m;
+
+#else
+
+    uint16_t aux16[2];
+    const uint8_t * s = (const uint8_t *)aux16;
+
+    const uint16_t * a = (const uint16_t *)bq4_K->scales;
+    aux16[0] = a[0] & 0x0f0f;
+    aux16[1] = (a[0] >> 4) & 0x0f0f;
+
+    const float dall = bq4_K->d[0];
+    const float dmin = bq4_K->d[1];
+
+    const float d8_1 = bq8_1[0].d;
+    const float d8_2 = bq8_1[1].d;
+
+    const int ui1 = *((const int *)bq8_1[0].qs + iqs);
+    const int ui2 = *((const int *)bq8_1[0].qs + iqs + 4);
+    const int ui3 = *((const int *)bq8_1[1].qs + iqs);
+    const int ui4 = *((const int *)bq8_1[1].qs + iqs + 4);
+
+    const int * q4 = (const int *)bq4_K->qs + iqs;
+    const int v1 = q4[0];
+    const int v2 = q4[4];
+
+    const int dot1 = __dp4a(ui2, v2 & 0x0f0f0f0f, __dp4a(ui1, v1 & 0x0f0f0f0f, 0));
+    const int dot2 = __dp4a(ui4, (v2 >> 4) & 0x0f0f0f0f, __dp4a(ui3, (v1 >> 4) & 0x0f0f0f0f, 0));
+    const int dot3 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0));
+    const int dot4 = __dp4a(0x01010101, ui4, __dp4a(0x01010101, ui3, 0));
+
+    sumf_d += d8_1 * (dot1 * s[0]) + d8_2 * (dot2 * s[1]);
+    sumf_m += d8_1 * (dot3 * s[2]) + d8_2 * (dot4 * s[3]);
+
+    return dall * sumf_d - dmin * sumf_m;
+
+#endif
+
 #else
     return 0.0f; // only to satisfy the compiler
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
@@ -1601,7 +1664,11 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
 #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
     const block_q5_K * bq5_K = (const block_q5_K *) vbq;
 
-    const int bq8_offset = QR5_K * (iqs / QI8_1);
+#ifndef GGML_QKK_64
+
+    const int bq8_offset = QR5_K * (iqs / (QI8_1/2));
+    const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * (iqs%4));
+    const int * qh = (const int *)(bq5_K->qh + 4 * (iqs%4));
 
     float sumf_d = 0.0f;
     float sumf_m = 0.0f;
@@ -1609,31 +1676,87 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
     const float    d = bq5_K->d;
     const float dmin = bq5_K->dmin;
 
-    const int vl = *((int *) &bq5_K->qs[sizeof(int) * iqs]);
+    const int vl1 = ql[0];
+    const int vl2 = ql[4];
 
-    const int vh = (*((int *) &bq5_K->qh[sizeof(int) * (iqs % (QI5_K/4))])) >> bq8_offset;
+    const int vh1 = qh[0] >> bq8_offset;
+    const int vh2 = qh[4] >> bq8_offset;
 
-    for (int i = 0; i < QR5_K; ++i) {
-        const int isc = bq8_offset + i;
+    const uint16_t * scales = (const uint16_t *)bq5_K->scales;
+    uint16_t aux[2];
+    const int j = bq8_offset/2;
+    if (j < 2) {
+        aux[0] = scales[j+0] & 0x3f3f;
+        aux[1] = scales[j+2] & 0x3f3f;
+    } else {
+        aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
+        aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
+    }
+    const uint8_t * sc = (const uint8_t *)aux;
+    const uint8_t * m  = sc + 2;
 
-        uint8_t sc, m;
-        get_scale_min_k4(isc, bq5_K->scales, sc, m);
+    for (int i = 0; i < QR5_K; ++i) {
 
         const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
-        const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]);
         const float d8i = bq8i->d;
+        const int * q8 = (const int *)bq8i->qs + (iqs%4);
+        const int ui1 = q8[0];
+        const int ui2 = q8[4];
 
-        const int vil = (vl >> (4*i)) & 0x0F0F0F0F;
+        const int vil1 = (vl1 >> (4*i)) & 0x0F0F0F0F;
+        const int vil2 = (vl2 >> (4*i)) & 0x0F0F0F0F;
+
+        const int vih1 = ((vh1 >> i) << 4) & 0x10101010;
+        const int vih2 = ((vh2 >> i) << 4) & 0x10101010;
+
+        const int vi1 = vil1 | vih1;
+        const int vi2 = vil2 | vih2;
 
-        const int vih = ((vh >> i) << 4) & 0x10101010;
+        const int dot1 = __dp4a(vi2, ui2, __dp4a(vi1, ui1, 0)); // SIMD dot product
+        const int dot2 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0));
 
-        const int vi = vil | vih;
+        sumf_d += d8i * (dot1 * sc[i]);
+        sumf_m += d8i * (dot2 * m[i]);
 
-        sumf_d += d8i * (__dp4a(vi,         ui, 0) * sc); // SIMD dot product
-        sumf_m += d8i * (__dp4a(0x01010101, ui, 0) * m);  // multiply constant part of q5_K with sum of q8_1 values
     }
 
     return d*sumf_d - dmin*sumf_m;
+
+#else
+
+    const int8_t * s = bq5_K->scales;
+
+    const float d = bq5_K->d;
+
+    const float d8_1 = bq8_1[0].d;
+    const float d8_2 = bq8_1[1].d;
+
+    const int ui1 = *((const int *)bq8_1[0].qs + iqs);
+    const int ui2 = *((const int *)bq8_1[0].qs + iqs + 4);
+    const int ui3 = *((const int *)bq8_1[1].qs + iqs);
+    const int ui4 = *((const int *)bq8_1[1].qs + iqs + 4);
+
+    const int * ql = (const int *)bq5_K->qs + iqs;
+    const int vl1 = ql[0];
+    const int vl2 = ql[4];
+
+    const int step = 4 * iqs; // 0, 4, 8, 12
+    const int im = step/8; // = 0 for iqs = 0, 1, = 1 for iqs = 2, 3
+    const int in = step%8; // 0, 4, 0, 4
+    const int vh = (*((const int *)(bq5_K->qh + in))) >> im;
+
+    const int v1 = (((vh << 4) & 0x10101010) ^ 0x10101010) | ((vl1 >> 0) & 0x0f0f0f0f);
+    const int v2 = (((vh << 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 0) & 0x0f0f0f0f);
+    const int v3 = (((vh >> 0) & 0x10101010) ^ 0x10101010) | ((vl1 >> 4) & 0x0f0f0f0f);
+    const int v4 = (((vh >> 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 4) & 0x0f0f0f0f);
+
+    const float sumf_d = d8_1 * (__dp4a(ui1, v1, 0) * s[0] + __dp4a(ui2, v2, 0) * s[1])
+                       + d8_2 * (__dp4a(ui3, v3, 0) * s[2] + __dp4a(ui4, v4, 0) * s[3]);
+
+    return d * sumf_d;
+
+#endif
+
 #else
     return 0.0f; // only to satisfy the compiler
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
@@ -2074,10 +2197,10 @@ static void norm_f32_cuda(const float * x, float * dst, const int ncols, const i
     norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
 }
 
-static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
     GGML_ASSERT(ncols % WARP_SIZE == 0);
     const dim3 block_dims(WARP_SIZE, 1, 1);
-    rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
+    rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
 }
 
 static void quantize_row_q8_1_cuda(const float * x, void * vy, const int ndata, const int k, cudaStream_t stream) {
@@ -2306,7 +2429,10 @@ static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float *
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(1, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    mul_mat_vec_q<QK_K, QI4_K, block_q4_K, vec_dot_q4_K_q8_1>
+    // Note: we use QI4_K/2 instead of QI4_K to make the dot product template require 4 groups of quants to be processed per
+    //       kernel call instead of 2. This results in a better perfmance because the cost of computing the k-quant scales
+    //       is better amortized.
+    mul_mat_vec_q<QK_K, QI4_K/2, block_q4_K, vec_dot_q4_K_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
 }
 
@@ -2315,7 +2441,10 @@ static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float *
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(1, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    mul_mat_vec_q<QK_K, QI5_K, block_q5_K, vec_dot_q5_K_q8_1>
+    // Note: we use QI5_K/2 instead of QI5_K to make the dot product template require 4 groups of quants to be processed per
+    //       kernel call instead of 2. This results in a better perfmance because the cost of computing the k-quant scales
+    //       is better amortized.
+    mul_mat_vec_q<QK_K, QI5_K/2, block_q5_K, vec_dot_q5_K_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
 }
 
@@ -2822,8 +2951,11 @@ inline void ggml_cuda_op_rms_norm(
     const int64_t ne00 = src0->ne[0];
     const int64_t i01_diff = i01_high - i01_low;
 
+    float eps;
+    memcpy(&eps, dst->op_params, sizeof(float));
+
     // compute
-    rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
+    rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, eps, cudaStream_main);
 
     (void) src1;
     (void) dst;
index 928f1705c381cf710468968b3f5a66e19e5a0c47..16f1a0caacfac483cf2c33e0715ca8d1c61ea7ce 100644 (file)
@@ -61,6 +61,13 @@ void ggml_metal_set_tensor(struct ggml_metal_context * ctx, struct ggml_tensor *
 // get data from the device into host memory
 void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);
 
+// try to find operations that can be run concurrently in the graph
+// you should run it again if the topology of your graph changes
+void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
+
+// if the graph has been optimized for concurrently dispatch
+bool ggml_metal_if_optimized(struct ggml_metal_context * ctx);
+
 // same as ggml_graph_compute but uses Metal
 // creates gf->n_threads command buffers in parallel
 void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
index 1fd6e857ffe6138dc2a5be702bebcd57eb5c6fc5..74a6bff40411784f2b13ca4c1a7bf607bfc400c4 100644 (file)
@@ -36,6 +36,9 @@ struct ggml_metal_context {
     int n_buffers;
     struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
 
+    int concur_list[GGML_MAX_NODES];
+    int concur_list_len;
+
     // custom kernels
 #define GGML_METAL_DECL_KERNEL(name) \
     id<MTLFunction>             function_##name; \
@@ -98,6 +101,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
     ctx->device = MTLCreateSystemDefaultDevice();
     ctx->queue  = [ctx->device newCommandQueue];
     ctx->n_buffers = 0;
+    ctx->concur_list_len = 0;
 
     // determine if we can use MPS
     if (MPSSupportsMTLDevice(ctx->device)) {
@@ -217,6 +221,13 @@ void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
     ctx->n_cb = n_cb;
 }
 
+bool ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
+    if (ctx->concur_list_len) {
+        return true;
+    }
+    return false;
+}
+
 // finds the Metal buffer that contains the tensor data on the GPU device
 // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
 // Metal buffer based on the host memory pointer
@@ -355,11 +366,98 @@ void ggml_metal_get_tensor(
     memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), ggml_nbytes(t));
 }
 
+void ggml_metal_graph_find_concurrency(
+        struct ggml_metal_context * ctx,
+        struct ggml_cgraph * gf) {
+    int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
+    int nodes_unused[GGML_MAX_NODES];
+
+    for (int i = 0; i < GGML_MAX_NODES; i++) {ctx->concur_list[i] = 0;}
+    for (int i = 0; i < gf->n_nodes; i++) {nodes_unused[i] = 1;}
+    ctx->concur_list_len = 0;
+
+    int n_left = gf->n_nodes;
+    int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
+    int level_pos = 0;  // at ctx->concur_list, the last layer (level) ends at level_pos
+
+    while (n_left > 0) {
+        // number of nodes at a layer (that can be issued concurrently)
+        int concurrency = 0;
+        for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) {
+            if (nodes_unused[i]) {
+                // if the requirements for gf->nodes[i] are satisfied
+                int exe_flag=1;
+                // scan all srcs
+                for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) {
+                    struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind];
+                    if (src_cur) {
+                        // if is leaf nodes it's satisfied.
+                        if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {continue;}
+
+                        // otherwise this src should be the output from previous nodes.
+                        int is_found = 0;
+                        // scan 2*search_depth back because we inserted barrier.
+                        for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
+                            if (gf->nodes[ctx->concur_list[j]] == src_cur) {is_found = 1; break;}
+                        }
+                        if (is_found == 0) {exe_flag = 0; break;}
+                    }
+                }
+                if (exe_flag) {
+                    // check if nodes[i]'s data will be overwritten by a node before nodes[i].
+                    // if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
+                    int64_t data_start = (int64_t) gf->nodes[i]->data;
+                    int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]);
+                    for (int j = n_start; j < i; j++) {
+                        if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \
+                                            && gf->nodes[j]->op != GGML_OP_VIEW \
+                                            && gf->nodes[j]->op != GGML_OP_TRANSPOSE \
+                                            && gf->nodes[j]->op != GGML_OP_PERMUTE) {
+                            if (((int64_t)gf->nodes[j]->data) >= data_start + length || \
+                                ((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) {
+                                continue;
+                            } else {
+                                exe_flag = 0;
+                            }
+                        }
+                    }
+                }
+                if (exe_flag) {
+                    ctx->concur_list[level_pos + concurrency] = i;
+                    nodes_unused[i] = 0;
+                    concurrency++;
+                    ctx->concur_list_len++;
+                }
+            }
+        }
+        n_left -= concurrency;
+        // adding a barrier different layer
+        ctx->concur_list[level_pos + concurrency] = -1;
+        ctx->concur_list_len++;
+        // jump all sorted nodes at nodes_bak
+        while (!nodes_unused[n_start]) {n_start++;}
+        level_pos += concurrency + 1;
+    }
+
+    if (ctx->concur_list_len > GGML_MAX_NODES) {
+        fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__);
+    }
+}
+
 void ggml_metal_graph_compute(
         struct ggml_metal_context * ctx,
                struct ggml_cgraph * gf) {
     metal_printf("%s: evaluating graph\n", __func__);
 
+    // if there is ctx->concur_list, dispatch concurrently
+    // else fallback to serial dispatch
+    MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
+
+    const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_NODES;
+
+    const int n_nodes  = has_concur ? ctx->concur_list_len      : gf->n_nodes;
+    edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;
+
     // create multiple command buffers and enqueue them
     // then, we encode the graph into the command buffers in parallel
 
@@ -378,7 +476,7 @@ void ggml_metal_graph_compute(
     dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
 
     for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
-        const int n_nodes_per_cb = (gf->n_nodes + n_cb - 1) / n_cb;
+        const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
 
         dispatch_async(queue, ^{
             size_t offs_src0 = 0;
@@ -389,10 +487,21 @@ void ggml_metal_graph_compute(
 
             id<MTLComputeCommandEncoder> encoder = nil;
 
-            const int node_start =                                      (cb_idx + 0) * n_nodes_per_cb;
-            const int node_end   = (cb_idx == n_cb - 1) ? gf->n_nodes : (cb_idx + 1) * n_nodes_per_cb;
+            const int node_start =                                  (cb_idx + 0) * n_nodes_per_cb;
+            const int node_end   = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
+
+            for (int ind = node_start; ind < node_end; ++ind) {
+                const int i = has_concur ? ctx->concur_list[ind] : ind;
+
+                if (i == -1) {
+                    if (encoder == nil) {
+                        encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
+                        continue;
+                    }
+                    [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
+                    continue;
+                }
 
-            for (int i = node_start; i < node_end; ++i) {
                 metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
 
                 struct ggml_tensor * src0 = gf->nodes[i]->src[0];
@@ -463,7 +572,7 @@ void ggml_metal_graph_compute(
                     case GGML_OP_ADD:
                         {
                             if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoder];
+                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
                             }
 
                             if (ggml_nelements(src1) == ne10) {
@@ -484,7 +593,7 @@ void ggml_metal_graph_compute(
                     case GGML_OP_MUL:
                         {
                             if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoder];
+                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
                             }
 
                             if (ggml_nelements(src1) == ne10) {
@@ -505,7 +614,7 @@ void ggml_metal_graph_compute(
                     case GGML_OP_SCALE:
                         {
                             if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoder];
+                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
                             }
 
                             const float scale = *(const float *) src1->data;
@@ -524,7 +633,7 @@ void ggml_metal_graph_compute(
                             case GGML_UNARY_OP_SILU:
                                 {
                                     if (encoder == nil) {
-                                        encoder = [command_buffer computeCommandEncoder];
+                                        encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
                                     }
 
                                     [encoder setComputePipelineState:ctx->pipeline_silu];
@@ -538,7 +647,7 @@ void ggml_metal_graph_compute(
                             case GGML_UNARY_OP_RELU:
                                 {
                                     if (encoder == nil) {
-                                        encoder = [command_buffer computeCommandEncoder];
+                                        encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
                                     }
 
                                     [encoder setComputePipelineState:ctx->pipeline_relu];
@@ -552,7 +661,7 @@ void ggml_metal_graph_compute(
                             case GGML_UNARY_OP_GELU:
                                 {
                                     if (encoder == nil) {
-                                        encoder = [command_buffer computeCommandEncoder];
+                                        encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
                                     }
 
                                     [encoder setComputePipelineState:ctx->pipeline_gelu];
@@ -572,7 +681,7 @@ void ggml_metal_graph_compute(
                     case GGML_OP_SOFT_MAX:
                         {
                             if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoder];
+                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
                             }
 
                             const int nth = 32;
@@ -590,7 +699,7 @@ void ggml_metal_graph_compute(
                     case GGML_OP_DIAG_MASK_INF:
                         {
                             if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoder];
+                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
                             }
 
                             const int n_past = ((int32_t *)(dst->op_params))[0];
@@ -653,7 +762,7 @@ void ggml_metal_graph_compute(
                                 }
                             } else {
                                 if (encoder == nil) {
-                                    encoder = [command_buffer computeCommandEncoder];
+                                    encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
                                 }
 
                                 int nth0 = 32;
@@ -780,7 +889,7 @@ void ggml_metal_graph_compute(
                     case GGML_OP_GET_ROWS:
                         {
                             if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoder];
+                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
                             }
 
                             switch (src0->type) {
@@ -809,10 +918,11 @@ void ggml_metal_graph_compute(
                     case GGML_OP_RMS_NORM:
                         {
                             if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoder];
+                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
                             }
 
-                            const float eps = 1e-6f;
+                            float eps;
+                            memcpy(&eps, dst->op_params, sizeof(float));
 
                             const int nth = 512;
 
@@ -831,7 +941,7 @@ void ggml_metal_graph_compute(
                     case GGML_OP_NORM:
                         {
                             if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoder];
+                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
                             }
 
                             const float eps = 1e-5f;
@@ -853,7 +963,7 @@ void ggml_metal_graph_compute(
                     case GGML_OP_ALIBI:
                         {
                             if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoder];
+                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
                             }
 
                             GGML_ASSERT((src0t == GGML_TYPE_F32));
@@ -896,7 +1006,7 @@ void ggml_metal_graph_compute(
                     case GGML_OP_ROPE:
                         {
                             if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoder];
+                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
                             }
 
                             const int n_past = ((int32_t *) dst->op_params)[0];
@@ -940,7 +1050,7 @@ void ggml_metal_graph_compute(
                     case GGML_OP_CONT:
                         {
                             if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoder];
+                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
                             }
 
                             const int nth = 32;
index 987376d560879b97af68f85b066b6e5e4fb5ad6d..696b33ce75cf4fa8d92850ef5e762c32190d6410 100644 (file)
@@ -387,87 +387,90 @@ kernel void kernel_rms_norm(
     }
 }
 
-// function for calculate inner product between a q4_0 block and 32 floats (yl), sumy is SUM(yl[i])
-float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl) {
+// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
+// il indicates where the q4 quants begin (0 or QK4_0/4)
+// we assume that the yl's have been multiplied with the appropriate scale factor
+// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
+inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
     float d = qb_curr->d;
-    float4 acc = 0.f;
-    device uint16_t * qs = ((device uint16_t *)qb_curr + 1);
-    for (int i = 0; i < 16; i+=2) {
-        acc[0] += yl[i]      * (qs[i / 2] & 0x000F);
-        acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
-        acc[2] += yl[i +  1] * (qs[i / 2] & 0x0F00);
-        acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
+    float2 acc = 0.f;
+    device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
+    for (int i = 0; i < 8; i+=2) {
+        acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
+                + yl[i + 1] * (qs[i / 2] & 0x0F00);
+        acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
+                + yl[i + 9] * (qs[i / 2] & 0xF000);
     }
-    return d * (sumy * -8.f + acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f);
+    return d * (sumy * -8.f + acc[0] + acc[1]);
 }
 
-// function for calculate inner product between a q4_1 block and 32 floats (yl), sumy is SUM(yl[i])
-float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl) {
+// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
+// il indicates where the q4 quants begin (0 or QK4_0/4)
+// we assume that the yl's have been multiplied with the appropriate scale factor
+// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
+inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
     float d = qb_curr->d;
     float m = qb_curr->m;
-    float4 acc = 0.f;
-    device uint16_t * qs = ((device uint16_t *)qb_curr + 2);
-    for (int i = 0; i < 16; i+=2) {
-        acc[0] += yl[i]      * (qs[i / 2] & 0x000F);
-        acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
-        acc[2] += yl[i +  1] * (qs[i / 2] & 0x0F00);
-        acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
+    device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
+    float2 acc = 0.f;
+    for (int i = 0; i < 8; i+=2) {
+        acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
+                + yl[i + 1] * (qs[i / 2] & 0x0F00);
+        acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
+                + yl[i + 9] * (qs[i / 2] & 0xF000);
     }
-    return d * (acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f) + sumy * m;
+    return d * (acc[0] + acc[1]) + sumy * m;
 }
 
 // putting them in the kernel cause a significant performance penalty
 #define N_DST 4 // each SIMD group works on 4 rows
 #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
 #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
-template<typename block_q_type>
+//Note: This is a template, but strictly speaking it only applies to
+//      quantizations where the block size is 32. It also does not
+//      giard against the number of rows not being divisible by
+//      N_DST, so this is another explicit assumption of the implementation.
+template<typename block_q_type, int nr, int nsg, int nw>
 void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
                     int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01,
                     uint2 tgpig, uint tiisg, uint sgitg) {
     const int nb = ne00/QK4_0;
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
-    device const block_q_type * x = (device const block_q_type *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
+    const int first_row = (r0 * nsg + sgitg) * nr;
+    device const block_q_type * x = (device const block_q_type *) src0 + first_row * nb;
     device const float      * y = (device const float      *) src1 + r1*ne10;
-    float4 y_curr[8];       // src1 vector cache
-    float sumf[N_DST]={0.f}, all_sum;
-    thread float * yl=(thread float *)y_curr;
+    float yl[16];       // src1 vector cache
+    float sumf[nr]={0.f};
 
-    // each thread in a SIMD group deals with 1 block.
-    for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
-        float sumy = 0;
-        for (int i = 0; i < QK4_0 / 4; i++) {
-            y_curr[i] = *((device float4  *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0)) + i);
-            sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
-        }
+    const int ix = tiisg/2;
+    const int il = 8*(tiisg%2);
 
-        for (int row = 0; row < N_DST; row++) {
-            sumf[row] += block_q_n_dot_y(x+(tiisg + row * nb + column * N_SIMDWIDTH), sumy, yl);
-        }
-    }
+    device const float * yb = y + ix * QK4_0 + il;
 
-    // from now loads two rows every time and 16 blocks per row
-    int ir = tiisg / (N_SIMDWIDTH / 2);
-    int ib = tiisg % (N_SIMDWIDTH / 2);
-    for (int ind = 0; ind < (nb % N_SIMDWIDTH + N_SIMDWIDTH / 2 - 1)/(N_SIMDWIDTH / 2); ind++) {
-        int nb_start = (nb / N_SIMDWIDTH) * N_SIMDWIDTH + ind * (N_SIMDWIDTH / 2); //where the left blocks start
+    // each thread in a SIMD group deals with half a block.
+    for (int ib = ix; ib < nb; ib += nw/2) {
         float sumy = 0;
-        for (int i = 0; i < QK4_0 / 4; i++) {
-            y_curr[i] = *((device float4 *)(y + (nb_start + ib) * QK4_0) + i);
-            sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
+        for (int i = 0; i < 8; i += 2) {
+            sumy += yb[i] + yb[i+1];
+            yl[i+0] = yb[i+ 0];
+            yl[i+1] = yb[i+ 1]/256.f;
+            sumy += yb[i+16] + yb[i+17];
+            yl[i+8] = yb[i+16]/16.f;
+            yl[i+9] = yb[i+17]/4096.f;
         }
 
-        for (int row = 0; row < N_DST; row+=2) {
-            if (nb_start + ib < nb) {
-                sumf[row + ir] += block_q_n_dot_y(x + (nb_start + ib + (row + ir) * nb), sumy, yl);
-            }
+        for (int row = 0; row < nr; row++) {
+            sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il);
         }
+
+        yb += QK4_0 * 16;
     }
 
-    for (int row = 0; row < N_DST; ++row) {
-        all_sum = simd_sum(sumf[row]);
-        if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
-            dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
+    for (int row = 0; row < nr; ++row) {
+        const float tot = simd_sum(sumf[row]);
+        if (tiisg == 0 && first_row + row < ne01) {
+            dst[r1*ne0 + first_row + row] = tot;
         }
     }
 }
@@ -483,7 +486,7 @@ kernel void kernel_mul_mat_q4_0_f32(
         uint2 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32<block_q4_0>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
+    mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
 }
 
 kernel void kernel_mul_mat_q4_1_f32(
@@ -497,7 +500,7 @@ kernel void kernel_mul_mat_q4_1_f32(
         uint2 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
-     mul_vec_q_n_f32<block_q4_1>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
+     mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
 }
 
 kernel void kernel_mul_mat_f16_f32(
index 960b8057709a987e8aa83395fbe51a5c597fe21f..35c56151b8f7c7bbbe9ca5f4b7e4372a78cc173d 100644 (file)
@@ -4229,6 +4229,15 @@ bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
         tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
 }
 
+static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * tensor) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return
+        tensor->nb[0] == GGML_TYPE_SIZE[tensor->type] &&
+        tensor->nb[2] == tensor->nb[1]*tensor->ne[1] &&
+        tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
+}
+
 bool ggml_is_permuted(const struct ggml_tensor * tensor) {
     static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
 
@@ -5781,6 +5790,7 @@ struct ggml_tensor * ggml_norm_inplace(
 static struct ggml_tensor * ggml_rms_norm_impl(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
+        float eps,
         bool inplace) {
     bool is_node = false;
 
@@ -5790,7 +5800,7 @@ static struct ggml_tensor * ggml_rms_norm_impl(
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    // TODO: maybe store epsilon here?
+    ggml_set_op_params(result, &eps, sizeof(eps));
 
     result->op   = GGML_OP_RMS_NORM;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -5801,14 +5811,16 @@ static struct ggml_tensor * ggml_rms_norm_impl(
 
 struct ggml_tensor * ggml_rms_norm(
         struct ggml_context * ctx,
-        struct ggml_tensor  * a) {
-    return ggml_rms_norm_impl(ctx, a, false);
+        struct ggml_tensor  * a,
+        float  eps) {
+    return ggml_rms_norm_impl(ctx, a, eps, false);
 }
 
 struct ggml_tensor * ggml_rms_norm_inplace(
         struct ggml_context * ctx,
-        struct ggml_tensor  * a) {
-    return ggml_rms_norm_impl(ctx, a, true);
+        struct ggml_tensor  * a,
+        float eps) {
+    return ggml_rms_norm_impl(ctx, a, eps, true);
 }
 
 struct ggml_tensor * ggml_rms_norm_back(
@@ -7018,14 +7030,16 @@ struct ggml_tensor * ggml_flash_attn(
     }
 
     //struct ggml_tensor * result = ggml_dup_tensor(ctx, q);
-    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, q->ne);
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, q->n_dims, q->ne);
+
+    int32_t t = masked ? 1 : 0;
+    ggml_set_op_params(result, &t, sizeof(t));
 
     result->op   = GGML_OP_FLASH_ATTN;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = q;
     result->src[1] = k;
     result->src[2] = v;
-    result->src[3] = ggml_new_i32(ctx, masked ? 1 : 0);
 
     return result;
 }
@@ -7049,7 +7063,7 @@ struct ggml_tensor * ggml_flash_ff(
     }
 
     //struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
-    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, a->ne);
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, a->n_dims, a->ne);
 
     result->op   = GGML_OP_FLASH_FF;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -7115,13 +7129,15 @@ struct ggml_tensor * ggml_flash_attn_back(
 
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
 
+    int32_t masked_i = masked ? 1 : 0;
+    ggml_set_op_params(result, &masked_i, sizeof(masked_i));
+
     result->op   = GGML_OP_FLASH_ATTN_BACK;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = q;
     result->src[1] = k;
     result->src[2] = v;
     result->src[3] = d;
-    result->src[4] = ggml_new_i32(ctx, masked ? 1 : 0);
 
     return result;
 }
@@ -9811,8 +9827,8 @@ static void ggml_compute_forward_gelu_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
-    GGML_ASSERT(ggml_is_contiguous(src0));
-    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
+    GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
     GGML_ASSERT(ggml_are_same_shape(src0, dst));
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
@@ -9870,8 +9886,8 @@ static void ggml_compute_forward_gelu_quick_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
-    GGML_ASSERT(ggml_is_contiguous(src0));
-    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
+    GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
     GGML_ASSERT(ggml_are_same_shape(src0, dst));
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
@@ -9929,8 +9945,8 @@ static void ggml_compute_forward_silu_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
-    GGML_ASSERT(ggml_is_contiguous(src0));
-    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
+    GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
     GGML_ASSERT(ggml_are_same_shape(src0, dst));
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
@@ -9989,9 +10005,9 @@ static void ggml_compute_forward_silu_back_f32(
         const struct ggml_tensor * src0,
         const struct ggml_tensor * grad,
         struct ggml_tensor * dst) {
-    GGML_ASSERT(ggml_is_contiguous(grad));
-    GGML_ASSERT(ggml_is_contiguous(src0));
-    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_is_contiguous_except_dim_1(grad));
+    GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
+    GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
     GGML_ASSERT(ggml_are_same_shape(src0, dst));
     GGML_ASSERT(ggml_are_same_shape(src0, grad));
 
@@ -10131,7 +10147,8 @@ static void ggml_compute_forward_rms_norm_f32(
 
     GGML_TENSOR_UNARY_OP_LOCALS;
 
-    const float eps = 1e-6f; // TODO: make this a parameter
+    float eps;
+    memcpy(&eps, dst->op_params, sizeof(float));
 
     // TODO: optimize
     for (int64_t i03 = 0; i03 < ne03; i03++) {
@@ -14760,7 +14777,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             } break;
         case GGML_OP_FLASH_ATTN:
             {
-                const int32_t t = ggml_get_i32_1d(tensor->src[3], 0);
+                const int32_t t = ggml_get_op_params_i32(tensor, 0);
                 GGML_ASSERT(t == 0 || t == 1);
                 const bool masked = t != 0;
                 ggml_compute_forward_flash_attn(params, tensor->src[0], tensor->src[1], tensor->src[2], masked, tensor);
@@ -14771,7 +14788,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             } break;
         case GGML_OP_FLASH_ATTN_BACK:
             {
-                int32_t t = ggml_get_i32_1d(tensor->src[4], 0);
+                int32_t t = ggml_get_op_params_i32(tensor, 0);
                 GGML_ASSERT(t == 0 || t == 1);
                 bool masked = t != 0;
                 ggml_compute_forward_flash_attn_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], masked, tensor);
@@ -15389,7 +15406,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 struct ggml_tensor * flash_grad = NULL;
                 if (src0->grad || src1->grad || tensor->src[2]->grad) {
-                    int32_t t = ggml_get_i32_1d(tensor->src[3], 0);
+                    int32_t t = ggml_get_op_params_i32(tensor, 0);
                     GGML_ASSERT(t == 0 || t == 1);
                     bool masked = t != 0;
                     flash_grad =
@@ -15661,6 +15678,34 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
     }
 }
 
+static_assert(GGML_GRAPH_HASHTABLE_SIZE > GGML_MAX_NODES * 2, "GGML_GRAPH_HT_SIZE is too small");
+
+static size_t hash(void * p) {
+    return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
+}
+
+static bool hash_insert(void * hash_table[], void * p) {
+    size_t h = hash(p);
+
+    // linear probing
+    size_t i = h;
+    while (hash_table[i] != NULL && hash_table[i] != p) {
+        i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
+        if (i == h) {
+            // hash table is full
+            GGML_ASSERT(false);
+        }
+    }
+
+    if (hash_table[i] == p) {
+        return true;
+    }
+
+    // insert
+    hash_table[i] = p;
+    return false;
+}
+
 static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
     if (node->grad == NULL) {
         // this usually happens when we generate intermediate nodes from constants in the backward pass
@@ -15671,16 +15716,8 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
     }
 
     // check if already visited
-    for (int i = 0; i < cgraph->n_nodes; i++) {
-        if (cgraph->nodes[i] == node) {
-            return;
-        }
-    }
-
-    for (int i = 0; i < cgraph->n_leafs; i++) {
-        if (cgraph->leafs[i] == node) {
-            return;
-        }
+    if (hash_insert(cgraph->visited_hash_table, node)) {
+        return;
     }
 
     for (int i = 0; i < GGML_MAX_SRC; ++i) {
@@ -15743,6 +15780,7 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
         /*.nodes        =*/ { NULL },
         /*.grads        =*/ { NULL },
         /*.leafs        =*/ { NULL },
+        /*.hash_table   =*/ { NULL },
         /*.perf_runs    =*/ 0,
         /*.perf_cycles  =*/ 0,
         /*.perf_time_us =*/ 0,
@@ -15784,7 +15822,7 @@ struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cg
 
         if (node->is_param) {
             GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
-            ggml_build_forward_impl(&result, node->grad, true);
+            ggml_build_forward_expand(&result, node->grad);
         }
     }
 
index ef20bce516662e645395475e3dc6fdf01756b4ab..6d312216d58af7c89287696d00127966105bf7d7 100644 (file)
@@ -850,7 +850,7 @@ int main(int argc, const char ** argv) {
                     ggml_set_param(ctx0, x[i]);
                 }
 
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_rms_norm(ctx0, x[0]));
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_rms_norm(ctx0, x[0], 1e-6f));
 
                 check_gradient("rms_norm", ctx0, x, f, ndims, nargs, 1e-4f, 1.0f, INFINITY);
             }