]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml : unroll ggml_vec_dot_f16 in ggml_compute_forward_flash_attn_f16
authorGeorgi Gerganov <redacted>
Sat, 7 Jan 2023 15:32:23 +0000 (17:32 +0200)
committerGeorgi Gerganov <redacted>
Sat, 7 Jan 2023 17:19:40 +0000 (19:19 +0200)
ggml.c

diff --git a/ggml.c b/ggml.c
index 7db762feb99f968be6d73a6e996600fc77e0e73b..e627164a80bd2c48570cd98f7af00fc3cca8b894 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -79,9 +79,11 @@ typedef void* thread_ret_t;
 #define static_assert(cond, msg) _Static_assert(cond, msg)
 #endif
 
+/*#define GGML_PERF*/
 #define GGML_DEBUG 0
 #define GGML_GELU_FP16
 #define GGML_SOFT_MAX_UNROLL 4
+#define GGML_VEC_DOT_UNROLL 4
 
 #if UINTPTR_MAX == 0xFFFFFFFF
     #define GGML_MEM_ALIGN 4
@@ -909,6 +911,61 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
     *s = sumf;
 }
 
+// compute GGML_VEC_DOT_UNROLL dot products at once
+// xs - x row stride in bytes
+inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
+    ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
+
+    const ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL] = { xv };
+
+    for (int i = 1; i < GGML_VEC_DOT_UNROLL; ++i) {
+        x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
+    }
+
+#if defined(GGML_SIMD)
+    const int np = (n & ~(GGML_F16_STEP - 1));
+
+    GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
+
+    GGML_F16_VEC ax[GGML_F16_ARR];
+    GGML_F16_VEC ay[GGML_F16_ARR];
+
+    for (int i = 0; i < np; i += GGML_F16_STEP) {
+        for (int j = 0; j < GGML_F16_ARR; j++) {
+            ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
+
+            for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
+                ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j);
+
+                sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]);
+            }
+        }
+    }
+
+    // reduce sum0..sum3 to sum0
+    for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
+        GGML_F16_VEC_REDUCE(sumf[k], sum[k]);
+    }
+
+    // leftovers
+    for (int i = np; i < n; ++i) {
+        for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
+            sumf[j] += GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]);
+        }
+    }
+#else
+    for (int i = 0; i < n; ++i) {
+        for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
+            sumf[j] += GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]);
+        }
+    }
+#endif
+
+    for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
+        s[i] = sumf[i];
+    }
+}
+
 inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) {
 #if defined(GGML_SIMD)
     const int np = (n & ~(GGML_F32_STEP - 1));
@@ -4720,9 +4777,9 @@ static void ggml_compute_forward_mul_mat_f16_f32(
 
             float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
 
-            for (int ic = 0; ic < ne11; ++ic) {
-                assert(ne00 % 32 == 0);
+            assert(ne00 % 32 == 0);
 
+            for (int ic = 0; ic < ne11; ++ic) {
                 ggml_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00);
             }
         }
@@ -6092,6 +6149,8 @@ static void ggml_compute_forward_flash_attn_f16(
             S[i] = -INFINITY;
         }
 
+        // looks like unrolling here does not help
+#if 1
         for (int ic = 0; ic < nek1; ++ic) {
             // k indices
             const int ik3 = iq3;
@@ -6106,6 +6165,24 @@ static void ggml_compute_forward_flash_attn_f16(
                     (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
                     (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
         }
+#else
+        GGML_ASSERT(nek1 % GGML_VEC_DOT_UNROLL == 0);
+
+        for (int ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
+            // k indices
+            const int ik3 = iq3;
+            const int ik2 = iq2;
+            const int ik1 = ic;
+
+            // S indices
+            const int i1 = ik1;
+
+            ggml_vec_dot_f16_unroll(neq0, nbk1,
+                    S + i1,
+                                    ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
+                    (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
+        }
+#endif
 
         // scale
         ggml_vec_scale_f32(nek1, S, scale);
@@ -6173,15 +6250,17 @@ static void ggml_compute_forward_flash_attn_f16(
             S16[i] = GGML_FP32_TO_FP16(S[i]);
         }
 
-        for (int ic = 0; ic < nev1; ++ic) {
+        GGML_ASSERT(nev1 % GGML_VEC_DOT_UNROLL == 0);
+
+        for (int ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
             // dst indices
             const int i1 = iq1;
             const int i2 = iq2;
             const int i3 = iq3;
 
-            ggml_vec_dot_f16(nek1,
-                    (float *)       ((char *) dst->data + (ic*nb0 + i1*nb1  + i2*nb2  + i3*nb3)),
-                    (ggml_fp16_t *) ((char *) v->data   + (         ic*nbv1 + i2*nbv2 + i3*nbv3)),
+            ggml_vec_dot_f16_unroll(nek1, nbv1,
+                    (float *) ((char *) dst->data + (ic*nb0 + i1*nb1  + i2*nb2  + i3*nb3)),
+                              ((char *) v->data   + (         ic*nbv1 + i2*nbv2 + i3*nbv3)),
                     S16);
         }
     }