]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : fix F32 accumulation in FA vec kernel (#10232)
authorGeorgi Gerganov <redacted>
Sat, 9 Nov 2024 09:52:45 +0000 (11:52 +0200)
committerGitHub <redacted>
Sat, 9 Nov 2024 09:52:45 +0000 (11:52 +0200)
ggml/src/ggml-metal.metal

index 7e151741466b6aa11e1dd0305087643d8c52b514..1f233ba7f8eaab1bac165ff791fe2309daba0468 100644 (file)
@@ -3450,7 +3450,7 @@ kernel void kernel_flash_attn_ext_vec(
             {
                 // each simdgroup processes 1 query and 4 keys
                 for (short cc = 0; cc < C/4; ++cc) {
-                    qk_t mqk = 0.0;
+                    qk_t mqka[4] = { 0.0, 0.0, 0.0, 0.0 };
 
                     device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
 
@@ -3461,13 +3461,14 @@ kernel void kernel_flash_attn_ext_vec(
                         k4x4_t mk;
                         deq_k(pk + i/nl_k, i%nl_k, mk);
 
-                        mqk +=
-                            dot(mq[ii/NL][0], mk[0]) +
-                            dot(mq[ii/NL][1], mk[1]) +
-                            dot(mq[ii/NL][2], mk[2]) +
-                            dot(mq[ii/NL][3], mk[3]);
+                        mqka[0] += dot(mq[ii/NL][0], mk[0]);
+                        mqka[1] += dot(mq[ii/NL][1], mk[1]);
+                        mqka[2] += dot(mq[ii/NL][2], mk[2]);
+                        mqka[3] += dot(mq[ii/NL][3], mk[3]);
                     }
 
+                    qk_t mqk = mqka[0] + mqka[1] + mqka[2] + mqka[3];
+
                     // simdgroup reduce
                     // [ 0 ..  7] -> [ 0]
                     // [ 8 .. 15] -> [ 8]