]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml : fix bug in new soft max computation
authorGeorgi Gerganov <redacted>
Sat, 7 Jan 2023 19:00:07 +0000 (21:00 +0200)
committerGeorgi Gerganov <redacted>
Sat, 7 Jan 2023 19:00:07 +0000 (21:00 +0200)
ggml.c

diff --git a/ggml.c b/ggml.c
index f4c96eb42229d494905fb538fd9d0e3e3618a3f9..eefdcdd7f65a93ec6d209c9e9cfa379f1a1a45c2 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -82,8 +82,15 @@ typedef void* thread_ret_t;
 /*#define GGML_PERF*/
 #define GGML_DEBUG 0
 #define GGML_GELU_FP16
+
 #define GGML_SOFT_MAX_UNROLL 4
-#define GGML_VEC_DOT_UNROLL 4
+#define GGML_VEC_DOT_UNROLL  4
+
+#ifdef GGML_USE_ACCELERATE
+// uncomment to use vDSP for soft max computation
+// note: not sure if it is actually faster
+//#define GGML_SOFT_MAX_ACCELERATE
+#endif
 
 #if UINTPTR_MAX == 0xFFFFFFFF
     #define GGML_MEM_ALIGN 4
@@ -5975,7 +5982,12 @@ static void ggml_compute_forward_flash_attn_f32(
 
             float sum = 0.0f;
             {
-#ifndef GGML_USE_ACCELERATE
+#ifdef GGML_SOFT_MAX_ACCELERATE
+                max = -max;
+                vDSP_vsadd(S, 1, &max, S, 1, Mup);
+                vvexpf(S, S, &Mup);
+                ggml_vec_sum_f32(Mup, &sum, S);
+#else
                 uint16_t   scvt[GGML_SOFT_MAX_UNROLL];
                 ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
 
@@ -5998,9 +6010,6 @@ static void ggml_compute_forward_flash_attn_f32(
                 for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
                     sum += sump[i];
                 }
-#else
-                vvexpf(S, S, &Mup);
-                ggml_vec_sum_f32(Mup, &sum, S);
 #endif
             }
 
@@ -6202,7 +6211,12 @@ static void ggml_compute_forward_flash_attn_f16(
 
             float sum = 0.0f;
             {
-#ifndef GGML_USE_ACCELERATE
+#ifdef GGML_SOFT_MAX_ACCELERATE
+                max = -max;
+                vDSP_vsadd(S, 1, &max, S, 1, Mup);
+                vvexpf(S, S, &Mup);
+                ggml_vec_sum_f32(Mup, &sum, S);
+#else
                 uint16_t   scvt[GGML_SOFT_MAX_UNROLL];
                 ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
 
@@ -6225,9 +6239,6 @@ static void ggml_compute_forward_flash_attn_f16(
                 for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
                     sum += sump[i];
                 }
-#else
-                vvexpf(S, S, &Mup);
-                ggml_vec_sum_f32(Mup, &sum, S);
 #endif
             }
 
@@ -6244,7 +6255,7 @@ static void ggml_compute_forward_flash_attn_f16(
 #endif
         }
 
-        ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M);
+        ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup);
 
         for (int i = 0; i < M; i++) {
             S16[i] = GGML_FP32_TO_FP16(S[i]);