]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
metal : more precise Q*K in FA vec kernel (llama/10247)
authorGeorgi Gerganov <redacted>
Mon, 11 Nov 2024 06:39:13 +0000 (08:39 +0200)
committerGeorgi Gerganov <redacted>
Fri, 15 Nov 2024 13:21:04 +0000 (15:21 +0200)
ggml/src/ggml-metal.metal

index 413661c8a5d4280ef7166cdc94bf037ffb1cd9f3..e8b71a9f88321db9ab33606e1073e83322b635b4 100644 (file)
@@ -2942,6 +2942,7 @@ kernel void kernel_flash_attn_ext(
                 half smax = -INFINITY;
 
                 // load the mask in shared memory
+                #pragma unroll(Q)
                 for (short j = 0; j < Q; ++j) {
                     device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*nb31);
 
@@ -2968,7 +2969,7 @@ kernel void kernel_flash_attn_ext(
                         // we can read directly from global memory
                         device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
 
-#pragma unroll
+                        #pragma unroll(D8)
                         for (short i = 0; i < D8; ++i) {
                             k8x8_t mk;
                             simdgroup_load(mk, pk + i*8, nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10
@@ -2989,7 +2990,7 @@ kernel void kernel_flash_attn_ext(
 
                                 simdgroup_barrier(mem_flags::mem_threadgroup);
 
-#pragma unroll
+                                #pragma unroll(4)
                                 for (short k = 0; k < 4; ++k) {
                                     k8x8_t mk;
 
@@ -3067,7 +3068,7 @@ kernel void kernel_flash_attn_ext(
                 s8x8_t mm;
                 simdgroup_load(mm, ss + 2*C, TS, 0, false);
 
-#pragma unroll
+                #pragma unroll(D8)
                 for (short i = 0; i < D8; ++i) {
                     simdgroup_multiply(lo[i], mm, lo[i]);
                 }
@@ -3082,7 +3083,8 @@ kernel void kernel_flash_attn_ext(
                     if (is_same<vd4x4_t, v4x4_t>::value) {
                         // we can read directly from global memory
                         device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
-#pragma unroll
+
+                        #pragma unroll(D8)
                         for (short i = 0; i < D8; ++i) {
                             v8x8_t mv;
                             simdgroup_load(mv, pv + i*8, nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20
@@ -3103,7 +3105,7 @@ kernel void kernel_flash_attn_ext(
 
                                 simdgroup_barrier(mem_flags::mem_threadgroup);
 
-#pragma unroll
+                                #pragma unroll(4)
                                 for (short k = 0; k < 4; ++k) {
                                     v8x8_t mv;
 
@@ -3196,6 +3198,7 @@ kernel void kernel_flash_attn_ext(
                 simdgroup_load(ms0, ss + 2*C,         TS, 0, false);
                 simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false);
 
+                #pragma unroll(D8)
                 for (short i = 0; i < D8; ++i) {
                     o8x8_t t;
 
@@ -3413,6 +3416,7 @@ kernel void kernel_flash_attn_ext_vec(
         // load the queries from shared memory into local memory
         q4x4_t mq[D16/NL];
 
+        #pragma unroll(D16/NL)
         for (short ii = 0; ii < D16; ii += NL) {
             mq[ii/NL] = sq4x4[ii + tx];
         }
@@ -3454,17 +3458,23 @@ kernel void kernel_flash_attn_ext_vec(
 
                     device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
 
-#pragma unroll
+                    #pragma unroll(D16/NL)
                     for (short ii = 0; ii < D16; ii += NL) {
                         const short i = ii + tx;
 
                         k4x4_t mk;
                         deq_k(pk + i/nl_k, i%nl_k, mk);
 
-                        mqka[0] += dot(mq[ii/NL][0], mk[0]);
-                        mqka[1] += dot(mq[ii/NL][1], mk[1]);
-                        mqka[2] += dot(mq[ii/NL][2], mk[2]);
-                        mqka[3] += dot(mq[ii/NL][3], mk[3]);
+                        // note: this is less precise than the version below
+                        //mqka[0] += dot(mq[ii/NL][0], mk[0]);
+                        //mqka[1] += dot(mq[ii/NL][1], mk[1]);
+                        //mqka[2] += dot(mq[ii/NL][2], mk[2]);
+                        //mqka[3] += dot(mq[ii/NL][3], mk[3]);
+
+                        mqka[0] += dot((float4) mq[ii/NL][0], (float4) mk[0]);
+                        mqka[1] += dot((float4) mq[ii/NL][1], (float4) mk[1]);
+                        mqka[2] += dot((float4) mq[ii/NL][2], (float4) mk[2]);
+                        mqka[3] += dot((float4) mq[ii/NL][3], (float4) mk[3]);
                     }
 
                     qk_t mqk = mqka[0] + mqka[1] + mqka[2] + mqka[3];
@@ -3513,7 +3523,7 @@ kernel void kernel_flash_attn_ext_vec(
                 ss[tiisg] = vs;
 
                 // O = diag(ms)*O
-#pragma unroll
+                #pragma unroll(D16/NL)
                 for (short ii = 0; ii < D16; ii += NL) {
                     lo[ii/NL] *= ms;
                 }
@@ -3523,13 +3533,12 @@ kernel void kernel_flash_attn_ext_vec(
 
             // O = O + (Q*K^T)*V
             {
-#pragma unroll
                 for (short cc = 0; cc < C/4; ++cc) {
                     device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
 
                     const s4x4_t ms(ss[4*cc + ty]);
 
-#pragma unroll
+                    #pragma unroll(D16/NL)
                     for (short ii = 0; ii < D16; ii += NL) {
                         const short i = ii + tx;