]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml : improve vec_dot_f16 unrolling in flash_attn_f16
authorGeorgi Gerganov <redacted>
Sun, 8 Jan 2023 09:41:18 +0000 (11:41 +0200)
committerGeorgi Gerganov <redacted>
Sun, 8 Jan 2023 09:41:18 +0000 (11:41 +0200)
examples/command/command.cpp
ggml.c

index 74a14f9cc867d3542dd95d16b8bda0c8d7f41b52..4558a67dae9b575e0ab9f1e05d124866d4c3305e 100644 (file)
@@ -781,7 +781,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
                 std::string prompt;
                 std::string command;
 
-                for (int i = 0; i < words.size(); ++i) {
+                for (int i = 0; i < (int) words.size(); ++i) {
                     if (i < k_prompt_length) {
                         prompt += words[i] + " ";
                     } else {
diff --git a/ggml.c b/ggml.c
index eefdcdd7f65a93ec6d209c9e9cfa379f1a1a45c2..ccbd6c74a922853224f4b2c834ab5d2985f0f545 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -84,7 +84,7 @@ typedef void* thread_ret_t;
 #define GGML_GELU_FP16
 
 #define GGML_SOFT_MAX_UNROLL 4
-#define GGML_VEC_DOT_UNROLL  4
+#define GGML_VEC_DOT_UNROLL  2
 
 #ifdef GGML_USE_ACCELERATE
 // uncomment to use vDSP for soft max computation
@@ -923,9 +923,9 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
 inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
     ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
 
-    const ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL] = { xv };
+    ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL];
 
-    for (int i = 1; i < GGML_VEC_DOT_UNROLL; ++i) {
+    for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
         x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
     }
 
@@ -6158,40 +6158,37 @@ static void ggml_compute_forward_flash_attn_f16(
             S[i] = -INFINITY;
         }
 
-        // looks like unrolling here does not help
-#if 1
-        for (int ic = 0; ic < nek1; ++ic) {
-            // k indices
-            const int ik3 = iq3;
-            const int ik2 = iq2;
-            const int ik1 = ic;
-
-            // S indices
-            const int i1 = ik1;
-
-            ggml_vec_dot_f16(neq0,
-                    S + i1,
-                    (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
-                    (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
-        }
-#else
-        GGML_ASSERT(nek1 % GGML_VEC_DOT_UNROLL == 0);
-
-        for (int ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
-            // k indices
-            const int ik3 = iq3;
-            const int ik2 = iq2;
-            const int ik1 = ic;
+        if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) {
+            for (int ic = 0; ic < nek1; ++ic) {
+                // k indices
+                const int ik3 = iq3;
+                const int ik2 = iq2;
+                const int ik1 = ic;
 
-            // S indices
-            const int i1 = ik1;
+                // S indices
+                const int i1 = ik1;
 
-            ggml_vec_dot_f16_unroll(neq0, nbk1,
-                    S + i1,
-                                    ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
-                    (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
+                ggml_vec_dot_f16(neq0,
+                        S + i1,
+                        (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
+                        (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
+            }
+        } else {
+            for (int ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
+                // k indices
+                const int ik3 = iq3;
+                const int ik2 = iq2;
+                const int ik1 = ic;
+
+                // S indices
+                const int i1 = ik1;
+
+                ggml_vec_dot_f16_unroll(neq0, nbk1,
+                        S + i1,
+                        ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
+                        (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
+            }
         }
-#endif
 
         // scale
         ggml_vec_scale_f32(nek1, S, scale);
@@ -6261,18 +6258,30 @@ static void ggml_compute_forward_flash_attn_f16(
             S16[i] = GGML_FP32_TO_FP16(S[i]);
         }
 
-        GGML_ASSERT(nev1 % GGML_VEC_DOT_UNROLL == 0);
+        if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) {
+            for (int ic = 0; ic < nev1; ++ic) {
+                // dst indices
+                const int i1 = iq1;
+                const int i2 = iq2;
+                const int i3 = iq3;
 
-        for (int ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
-            // dst indices
-            const int i1 = iq1;
-            const int i2 = iq2;
-            const int i3 = iq3;
+                ggml_vec_dot_f16(nek1,
+                        (float *)       ((char *) dst->data + (ic*nb0 + i1*nb1  + i2*nb2  + i3*nb3)),
+                        (ggml_fp16_t *) ((char *) v->data   + (         ic*nbv1 + i2*nbv2 + i3*nbv3)),
+                        S16);
+            }
+        } else {
+            for (int ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
+                // dst indices
+                const int i1 = iq1;
+                const int i2 = iq2;
+                const int i3 = iq3;
 
-            ggml_vec_dot_f16_unroll(nek1, nbv1,
-                    (float *) ((char *) dst->data + (ic*nb0 + i1*nb1  + i2*nb2  + i3*nb3)),
-                              ((char *) v->data   + (         ic*nbv1 + i2*nbv2 + i3*nbv3)),
-                    S16);
+                ggml_vec_dot_f16_unroll(nek1, nbv1,
+                        (float *) ((char *) dst->data + (ic*nb0 + i1*nb1  + i2*nb2  + i3*nb3)),
+                        ((char *) v->data   + (         ic*nbv1 + i2*nbv2 + i3*nbv3)),
+                        S16);
+            }
         }
     }
 }