]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
finetune : speed-up ggml_compute_forward_out_prod_f32 via BLAS (#4079)
authorgwjr <redacted>
Fri, 17 Nov 2023 14:48:19 +0000 (14:48 +0000)
committerGitHub <redacted>
Fri, 17 Nov 2023 14:48:19 +0000 (16:48 +0200)
* Remove logically superfluous assertions and order by dimension

* Use cblas_sgemm() to implement ggml_compute_forward_out_prod()

* Remove ggml_compute_forward_out_prod_use_blas(), fix compiling errors on cmake/zig, remove trailing whitespace

* Add openBLAS support for sgemm() in compute_forward_out_prod()

ggml.c

diff --git a/ggml.c b/ggml.c
index ada1067da56d45ba5111311ad390671b940c2815..c7086ba844c6006584cf38e8857eb3c3ab519516 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -9611,10 +9611,12 @@ static void ggml_compute_forward_out_prod_f32(
     const int ith = params->ith;
     const int nth = params->nth;
 
+    GGML_ASSERT(ne0  == ne00);
+    GGML_ASSERT(ne1  == ne10);
+    GGML_ASSERT(ne2  == ne02);
     GGML_ASSERT(ne02 == ne12);
-    GGML_ASSERT(ne03 == ne13);
-    GGML_ASSERT(ne2  == ne12);
     GGML_ASSERT(ne3  == ne13);
+    GGML_ASSERT(ne03 == ne13);
 
     // we don't support permuted src0 or src1
     GGML_ASSERT(nb00 == sizeof(float));
@@ -9625,18 +9627,25 @@ static void ggml_compute_forward_out_prod_f32(
     // GGML_ASSERT(nb1 <= nb2);
     // GGML_ASSERT(nb2 <= nb3);
 
-    GGML_ASSERT(ne0 == ne00);
-    GGML_ASSERT(ne1 == ne10);
-    GGML_ASSERT(ne2 == ne02);
-    GGML_ASSERT(ne3 == ne03);
-
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
 
     // TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod
-    // TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
+    // TODO: #if defined(GGML_USE_CLBLAST)
+
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+    bool use_blas = ggml_is_matrix(src0) &&
+        ggml_is_matrix(src1) &&
+        ggml_is_contiguous(src0) &&
+        (ggml_is_contiguous(src1) || ggml_is_transposed(src1));
+#endif
 
     if (params->type == GGML_TASK_INIT) {
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) // gemm beta will zero dst
+        if (use_blas) {
+            return;
+        }
+#endif
         ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
         return;
     }
@@ -9645,6 +9654,50 @@ static void ggml_compute_forward_out_prod_f32(
         return;
     }
 
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+    if (use_blas) {
+        if (params->ith != 0) { // All threads other than the first do no work.
+            return;
+        }
+        // Arguments to ggml_compute_forward_out_prod (expressed as major,minor)
+        // src0: (k,n)
+        // src1: (k,m)
+        // dst:  (m,n)
+        //
+        // Arguments to sgemm (see https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/sgemm.f)
+        // Also expressed as (major,minor)
+        // a: (m,k): so src1 transposed
+        // b: (k,n): so src0
+        // c: (m,n)
+        //
+        // However, if ggml_is_transposed(src1) is true, then
+        // src1->data already contains a transposed version, so sgemm mustn't
+        // transpose it further.
+
+        int n = src0->ne[0];
+        int k = src0->ne[1];
+        int m = src1->ne[0];
+
+        int transposeA, lda;
+
+        if (!ggml_is_transposed(src1)) {
+            transposeA = CblasTrans;
+            lda = m;
+        } else {
+            transposeA = CblasNoTrans;
+            lda = k;
+        }
+
+        float * a = (float *) ((char *) src1->data);
+        float * b = (float *) ((char *) src0->data);
+        float * c = (float *) ((char *) dst->data);
+
+        cblas_sgemm(CblasRowMajor, transposeA, CblasNoTrans, m, n, k, 1.0, a, lda, b, n, 0.0, c, n);
+
+        return;
+    }
+#endif
+
     // dst[:,:,:,:] = 0
     // for i2,i3:
     //   for i1: