From: Georgi Gerganov Date: Fri, 14 Jul 2023 08:03:55 +0000 (+0300) Subject: ggml : fix mul_mat src1 indexing when src1 is not contiguous (#386) X-Git-Tag: upstream/0.0.1642~1320 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=e55e3473688f7786b8a826014146f9d4d632bfa1;p=pkg%2Fggml%2Fsources%2Fggml ggml : fix mul_mat src1 indexing when src1 is not contiguous (#386) --- diff --git a/src/ggml.c b/src/ggml.c index c137ae65..256b8265 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -10684,6 +10684,8 @@ static void ggml_compute_forward_mul_mat( const enum ggml_type type = src0->type; + const bool src1_cont = ggml_is_contiguous(src1); + ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot; enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float; @@ -10747,7 +10749,7 @@ static void ggml_compute_forward_mul_mat( float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); if (type != GGML_TYPE_F32) { - float * const wdata = params->wdata; + float * const wdata = params->wdata; ggml_to_float_t const to_float = type_traits[type].to_float; size_t id = 0; @@ -10805,7 +10807,7 @@ static void ggml_compute_forward_mul_mat( // src1 rows const int64_t nr1 = ne11*ne12*ne13; - void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; + const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type]; for (int64_t ir1 = 0; ir1 < nr1; ++ir1) { @@ -10828,7 +10830,15 @@ static void ggml_compute_forward_mul_mat( const int64_t i3 = i13; const char * src0_row = (const char *) src0->data + ( 0 + i02*nb02 + i03*nb03 ); - const char * src1_col = (const char *) wdata + (i11 + i12*ne11 + i13*ne12*ne11)*row_size; + + // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides + // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using + // the original src1 data pointer, so we should index using the indices directly + // TODO: this is a bit of a hack, we should probably have a better way to handle this + const char * src1_col = (const char *) wdata + + (src1_cont + ? (i11 + i12*ne11 + i13*ne12*ne11)*row_size + : (i11*nb11 + i12*nb12 + i13*nb13)); float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));