]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
llamafile : improve sgemm.cpp (llama/6796)
authorJustine Tunney <redacted>
Mon, 22 Apr 2024 19:00:36 +0000 (15:00 -0400)
committerGeorgi Gerganov <redacted>
Mon, 13 May 2024 08:02:26 +0000 (11:02 +0300)
* llamafile : improve sgemm.cpp

- Re-enable by default
- Fix issue described in #6716
- Make code more abstract, elegant, and maintainable
- Faster handling of weirdly shaped `m` an `n` edge cases

* Address review comments

* Help clang produce fma instructions

* Address review comments

ggml.c

diff --git a/ggml.c b/ggml.c
index a745104c655cf9b8ab9439a310b6bc74a4077ff2..b9e2150f16e5eb1a08adb1e4180b7f0e95c0c412 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -10887,7 +10887,7 @@ static void ggml_compute_forward_mul_mat(
 #endif
 
 #if GGML_USE_LLAMAFILE
-    if (nb10 == ggml_type_size(src1->type)) {
+    if (src1_cont) {
         for (int64_t i13 = 0; i13 < ne13; i13++)
             for (int64_t i12 = 0; i12 < ne12; i12++)
                 if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
@@ -10940,15 +10940,13 @@ UseGgmlGemm1:;
     const size_t row_size = ggml_row_size(vec_dot_type, ne10);
 
 #if GGML_USE_LLAMAFILE
-    if (nb10 == ggml_type_size(src1->type) || src1->type != vec_dot_type) {
+    if (src1->type != vec_dot_type) {
         for (int64_t i13 = 0; i13 < ne13; i13++)
             for (int64_t i12 = 0; i12 < ne12; i12++)
                 if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
                                      (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
                                      nb01/ggml_type_size(src0->type),
-                                     (const char *)wdata + ggml_row_size(vec_dot_type,
-                                         nb12/ggml_type_size(src1->type)*i12 +
-                                         nb13/ggml_type_size(src1->type)*i13),
+                                     (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size,
                                      row_size/ggml_type_size(vec_dot_type),
                                      (char *)dst->data + i12*nb2 + i13*nb3,
                                      nb1/ggml_type_size(dst->type),