]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : use F32 prec for K*Q in vec FA (#9595)
authorGeorgi Gerganov <redacted>
Mon, 23 Sep 2024 08:27:47 +0000 (11:27 +0300)
committerGitHub <redacted>
Mon, 23 Sep 2024 08:27:47 +0000 (11:27 +0300)
ggml-ci

ggml/src/ggml-metal.metal

index f323ab5f447d5497259405f9e3eb5cb827f4aedd..2b200032394b1f041f647ec519ffc3a7fc9e7f98 100644 (file)
@@ -2631,11 +2631,11 @@ kernel void kernel_flash_attn_ext_vec_f16(
         const short iv3 = iq3 / rv3;
 
         // load the queries from shared memory into local memory
-        half4 mq[D4];
+        float4 mq[D4];
 
         for (short ii = 0; ii < D4; ii += NW) {
             short i = ii + tiisg;
-            mq[i] = sq4[i];
+            mq[i] = (float4) sq4[i];
         }
 
         // pointer to the mask
@@ -2661,11 +2661,11 @@ kernel void kernel_flash_attn_ext_vec_f16(
                     for (short ii = 0; ii < D4; ii += NW) {
                         const short i = ii + tiisg;
 
-                        half4x4 mk;
-                        mk[0] = pk4[i + 0*(nb11/8)];
-                        mk[1] = pk4[i + 1*(nb11/8)];
-                        mk[2] = pk4[i + 2*(nb11/8)];
-                        mk[3] = pk4[i + 3*(nb11/8)];
+                        float4x4 mk;
+                        mk[0] = (float4) pk4[i + 0*(nb11/8)];
+                        mk[1] = (float4) pk4[i + 1*(nb11/8)];
+                        mk[2] = (float4) pk4[i + 2*(nb11/8)];
+                        mk[3] = (float4) pk4[i + 3*(nb11/8)];
 
                         mqk += (float4) (mq[i] * mk);
                     }