From: Justine Tunney Date: Mon, 22 Apr 2024 19:00:36 +0000 (-0400) Subject: llamafile : improve sgemm.cpp (llama/6796) X-Git-Tag: upstream/0.0.1642~752 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=0f2335ef1d88a83584364c56b4d6d5db28744506;p=pkg%2Fggml%2Fsources%2Fggml llamafile : improve sgemm.cpp (llama/6796) * llamafile : improve sgemm.cpp - Re-enable by default - Fix issue described in #6716 - Make code more abstract, elegant, and maintainable - Faster handling of weirdly shaped `m` an `n` edge cases * Address review comments * Help clang produce fma instructions * Address review comments --- diff --git a/src/ggml.c b/src/ggml.c index a745104c..b9e2150f 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -10887,7 +10887,7 @@ static void ggml_compute_forward_mul_mat( #endif #if GGML_USE_LLAMAFILE - if (nb10 == ggml_type_size(src1->type)) { + if (src1_cont) { for (int64_t i13 = 0; i13 < ne13; i13++) for (int64_t i12 = 0; i12 < ne12; i12++) if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type), @@ -10940,15 +10940,13 @@ UseGgmlGemm1:; const size_t row_size = ggml_row_size(vec_dot_type, ne10); #if GGML_USE_LLAMAFILE - if (nb10 == ggml_type_size(src1->type) || src1->type != vec_dot_type) { + if (src1->type != vec_dot_type) { for (int64_t i13 = 0; i13 < ne13; i13++) for (int64_t i12 = 0; i12 < ne12; i12++) if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type), (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type), - (const char *)wdata + ggml_row_size(vec_dot_type, - nb12/ggml_type_size(src1->type)*i12 + - nb13/ggml_type_size(src1->type)*i13), + (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, row_size/ggml_type_size(vec_dot_type), (char *)dst->data + i12*nb2 + i13*nb3, nb1/ggml_type_size(dst->type),