]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : broadcast mul_mat + conv batch support (#2199)
authorGeorgi Gerganov <redacted>
Wed, 12 Jul 2023 17:51:29 +0000 (20:51 +0300)
committerGitHub <redacted>
Wed, 12 Jul 2023 17:51:29 +0000 (20:51 +0300)
* ggml : broadcast mul_mat + conv batch support

* ggml : apply mul_mat broadcast fix by @jploski

ggml.c

diff --git a/ggml.c b/ggml.c
index 3d10dd00d71f79215b59d8b7d2aeaafe6a335a2a..c137ae658df7f7bf97e74f85cdca55c0ab8cf7b5 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -4168,10 +4168,9 @@ static inline bool ggml_is_matrix(const struct ggml_tensor * tensor) {
 static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
     static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
 
-    return
-        (t0->ne[0] == t1->ne[0])  &&
-        (t0->ne[2] == t1->ne[2])  &&
-        (t0->ne[3] == t1->ne[3]);
+    return (t0->ne[0]           == t1->ne[0])  &&
+           (t1->ne[2]%t0->ne[2] == 0)          && // verify t0 is broadcastable
+           (t1->ne[3]%t0->ne[3] == 0);
 }
 
 static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
@@ -6036,8 +6035,8 @@ struct ggml_tensor * ggml_mul_mat(
         is_node = true;
     }
 
-    const int64_t ne[4] = { a->ne[1], b->ne[1], a->ne[2], b->ne[3] };
-    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne);
+    const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(a->n_dims, b->n_dims), ne);
 
     result->op   = GGML_OP_MUL_MAT;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -7173,7 +7172,6 @@ struct ggml_tensor* ggml_conv_2d(
     int                  d0,
     int                  d1) {
 
-    GGML_ASSERT(b->ne[3] == 1);
     GGML_ASSERT(a->ne[2] == b->ne[2]);
     bool is_node = false;
 
@@ -7185,7 +7183,7 @@ struct ggml_tensor* ggml_conv_2d(
     const int64_t ne[4] = {
         ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0),
         ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1),
-        a->ne[3], 1,
+        a->ne[3], b->ne[3],
     };
     struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
 
@@ -10641,7 +10639,6 @@ static void ggml_compute_forward_rms_norm_back(
     }
 }
 
-
 // ggml_compute_forward_mul_mat
 
 #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
@@ -10685,17 +10682,17 @@ static void ggml_compute_forward_mul_mat(
     const int ith = params->ith;
     const int nth = params->nth;
 
-    GGML_ASSERT(ne02 == ne12);
-    GGML_ASSERT(ne03 == ne13);
-    GGML_ASSERT(ne2  == ne12);
-    GGML_ASSERT(ne3  == ne13);
-
     const enum ggml_type type = src0->type;
 
     ggml_vec_dot_t    const vec_dot               = type_traits[type].vec_dot;
     enum ggml_type    const vec_dot_type          = type_traits[type].vec_dot_type;
     ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
 
+    GGML_ASSERT(ne0 == ne01);
+    GGML_ASSERT(ne1 == ne11);
+    GGML_ASSERT(ne2 == ne12);
+    GGML_ASSERT(ne3 == ne13);
+
     // we don't support permuted src0 or src1
     GGML_ASSERT(nb00 == GGML_TYPE_SIZE[type]);
     GGML_ASSERT(nb10 == sizeof(float));
@@ -10706,16 +10703,16 @@ static void ggml_compute_forward_mul_mat(
     GGML_ASSERT(nb1 <= nb2);
     GGML_ASSERT(nb2 <= nb3);
 
-    GGML_ASSERT(ne0 == ne01);
-    GGML_ASSERT(ne1 == ne11);
-    GGML_ASSERT(ne2 == ne02);
-    GGML_ASSERT(ne3 == ne03);
-
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
 
 #if defined(GGML_USE_CLBLAST)
     if (ggml_cl_can_mul_mat(src0, src1, dst)) {
+        // TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension
+        //       ref: https://github.com/ggerganov/ggml/pull/224
+        GGML_ASSERT(ne02 == ne12);
+        GGML_ASSERT(ne03 == ne13);
+
         if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
             ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
         }
@@ -10725,6 +10722,11 @@ static void ggml_compute_forward_mul_mat(
 
 #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
+        // TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension
+        //       ref: https://github.com/ggerganov/ggml/pull/224
+        GGML_ASSERT(ne02 == ne12);
+        GGML_ASSERT(ne03 == ne13);
+
         if (params->ith != 0) {
             return;
         }
@@ -10794,41 +10796,44 @@ static void ggml_compute_forward_mul_mat(
         return;
     }
 
-    // parallelize by src0 rows using ggml_vec_dot_q
+    // parallelize by src0 rows
+    const int64_t dr = (ne01 + nth - 1)/nth;
 
-    // total rows in src0
-    const int nr = ne01*ne02*ne03;
+    const int64_t ir10 = dr*ith;
+    const int64_t ir11 = MIN(ir10 + dr, ne01);
 
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
+    // src1 rows
+    const int64_t nr1 = ne11*ne12*ne13;
 
     void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
-    const size_t row_size = ne00*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
-
-    for (int ir = ir0; ir < ir1; ++ir) {
-        // src0 indices
-        const int i03 = ir/(ne02*ne01);
-        const int i02 = (ir - i03*ne02*ne01)/ne01;
-        const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
-        const int i13 = i03;
-        const int i12 = i02;
-
-        const int i0 = i01;
-        const int i2 = i02;
-        const int i3 = i03;
-
-        void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
-        char * src1_col =          ((char *)      wdata + (      (0 + i12*ne11 + i13*ne12*ne11)*row_size));
-
-        float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
-
-        for (int64_t ic = 0; ic < ne11; ++ic) {
-            vec_dot(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size));
+    const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
+
+    for (int64_t ir1 = 0; ir1 < nr1; ++ir1) {
+        const int64_t i13 = (ir1/(ne12*ne11));
+        const int64_t i12 = (ir1 - i13*ne12*ne11)/ne11;
+        const int64_t i11 = (ir1 - i13*ne12*ne11 - i12*ne11);
+
+        const int64_t ir0 = (ir1/ne11)%(ne02*ne03);
+        const int64_t i03 = (ir0/(ne02));
+        // Hack for "Falcon multi-query-attention key stutter" / alternative to ggml_repeat2.
+        // See https://github.com/ggerganov/llama.cpp/issues/1602#issuecomment-1606087470:
+        // GG: this is likely the correct way to broadcast, though need some more thought
+        //     therefore leaving the comments to remind us for now
+        const int64_t i02 = (i12 / (ne12 / ne02));
+        // Original from PR/224 (and also essential/correct for non-broadcast matmuls in Falcon)
+        // const int64_t i02 = (ir0 - i03*ne02);
+
+        const int64_t i1 = i11;
+        const int64_t i2 = i12;
+        const int64_t i3 = i13;
+
+        const char * src0_row = (const char *) src0->data + (  0 + i02*nb02 + i03*nb03     );
+        const char * src1_col = (const char *)      wdata + (i11 + i12*ne11 + i13*ne12*ne11)*row_size;
+
+        float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
+
+        for (int64_t ir = ir10; ir < ir11; ++ir) {
+            vec_dot(ne00, &dst_col[ir], src0_row + ir*nb01, src1_col);
         }
     }
 
@@ -13013,16 +13018,18 @@ static void ggml_compute_forward_conv_2d_sk_p0_f16_f32(
         {
             ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
 
-            for (int i12 = 0; i12 < ne12; i12++) {
-                const float * const src = (float *)((char *) src1->data + i12*nb12);
-                ggml_fp16_t * dst_data = wdata;
-
-                for (int i1 = 0; i1 < ne1; i1++) {
-                    for (int i0 = 0; i0 < ne0; i0++) {
-                        for (int ik1 = 0; ik1 < nk1; ik1++) {
-                            for (int ik0 = 0; ik0 < nk0; ik0++) {
-                                dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] =
-                                    GGML_FP32_TO_FP16(src[(i1*nk1 + ik1)*ne10 + (i0*nk0 + ik0)]);
+            for (int i13 = 0; i13 < ne13; i13++) {
+                for (int i12 = 0; i12 < ne12; i12++) {
+                    const float * const src = (float *)((char *) src1->data + i13*nb13 + i12*nb12);
+                    ggml_fp16_t * dst_data = wdata + i13*(ne1*ne0*ew0);
+
+                    for (int i1 = 0; i1 < ne1; i1++) {
+                        for (int i0 = 0; i0 < ne0; i0++) {
+                            for (int ik1 = 0; ik1 < nk1; ik1++) {
+                                for (int ik0 = 0; ik0 < nk0; ik0++) {
+                                    dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] =
+                                        GGML_FP32_TO_FP16(src[(i1*nk1 + ik1)*ne10 + (i0*nk0 + ik0)]);
+                                }
                             }
                         }
                     }
@@ -13049,14 +13056,16 @@ static void ggml_compute_forward_conv_2d_sk_p0_f16_f32(
 
     ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
 
-    for (int i2 = ip0; i2 < ip1; i2++) {
-        float * dst_data = (float *)((char *) dst->data + i2*nb2);
-
-        for (int i1 = 0; i1 < ne1; ++i1) {
-            for (int i0 = 0; i0 < ne0; ++i0) {
-                ggml_vec_dot_f16(ew0, dst_data + i1*ne0 + i0,
-                        (ggml_fp16_t *) ((char *) src0->data + i2*nb03),
-                        (ggml_fp16_t *)                wdata + (i1*ne0 + i0)*ew0);
+    for (int i3 = 0; i3 < ne3; i3++) {
+        for (int i2 = ip0; i2 < ip1; i2++) {
+            float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2);
+
+            for (int i1 = 0; i1 < ne1; ++i1) {
+                for (int i0 = 0; i0 < ne0; ++i0) {
+                    ggml_vec_dot_f16(ew0, dst_data + i1*ne0 + i0,
+                            (ggml_fp16_t *) ((char *) src0->data + i2*nb03),
+                            (ggml_fp16_t *)                wdata + i3*nb3 + (i1*ne0 + i0)*ew0);
+                }
             }
         }
     }
@@ -13105,10 +13114,9 @@ static void ggml_compute_forward_conv_2d(
 
     if (s0 == src0->ne[0] && s1 == src0->ne[1]) {
         ggml_compute_forward_conv_2d_sk_p0(params, src0, src1, dst);
-    }
-    else {
+    } else {
         GGML_ASSERT(false); // only stride equal to kernel size is supported
-    };
+    }
 }
 
 // ggml_compute_forward_pool_1d_sk_p0
@@ -16558,8 +16566,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
                 {
                     n_tasks = n_threads;
 
-                    GGML_ASSERT(node->src[1]->ne[3] == 1);
-
                     const int64_t ne00 = node->src[0]->ne[0]; // W
                     const int64_t ne01 = node->src[0]->ne[1]; // H
                     const int64_t ne02 = node->src[0]->ne[2]; // C