]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : fix mul_mat src1 indexing when src1 is not contiguous (#386)
authorGeorgi Gerganov <redacted>
Fri, 14 Jul 2023 08:03:55 +0000 (11:03 +0300)
committerGitHub <redacted>
Fri, 14 Jul 2023 08:03:55 +0000 (11:03 +0300)
src/ggml.c

index c137ae658df7f7bf97e74f85cdca55c0ab8cf7b5..256b826556fddc9f8574edc9a22029300615fda3 100644 (file)
@@ -10684,6 +10684,8 @@ static void ggml_compute_forward_mul_mat(
 
     const enum ggml_type type = src0->type;
 
+    const bool src1_cont = ggml_is_contiguous(src1);
+
     ggml_vec_dot_t    const vec_dot               = type_traits[type].vec_dot;
     enum ggml_type    const vec_dot_type          = type_traits[type].vec_dot_type;
     ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
@@ -10747,7 +10749,7 @@ static void ggml_compute_forward_mul_mat(
                 float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
 
                 if (type != GGML_TYPE_F32) {
-                    float * const wdata = params->wdata;
+                            float * const wdata    = params->wdata;
                     ggml_to_float_t const to_float = type_traits[type].to_float;
 
                     size_t id = 0;
@@ -10805,7 +10807,7 @@ static void ggml_compute_forward_mul_mat(
     // src1 rows
     const int64_t nr1 = ne11*ne12*ne13;
 
-    void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
+    const void * wdata    = (src1->type == vec_dot_type) ? src1->data : params->wdata;
     const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
 
     for (int64_t ir1 = 0; ir1 < nr1; ++ir1) {
@@ -10828,7 +10830,15 @@ static void ggml_compute_forward_mul_mat(
         const int64_t i3 = i13;
 
         const char * src0_row = (const char *) src0->data + (  0 + i02*nb02 + i03*nb03     );
-        const char * src1_col = (const char *)      wdata + (i11 + i12*ne11 + i13*ne12*ne11)*row_size;
+
+        // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
+        //       if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
+        //       the original src1 data pointer, so we should index using the indices directly
+        // TODO: this is a bit of a hack, we should probably have a better way to handle this
+        const char * src1_col = (const char *) wdata +
+            (src1_cont
+             ? (i11      + i12*ne11 + i13*ne12*ne11)*row_size
+             : (i11*nb11 + i12*nb12 + i13*nb13));
 
         float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));