]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
metal : separate scale and mask from QKT in FA kernel (llama/9189)
authorGeorgi Gerganov <redacted>
Mon, 26 Aug 2024 15:31:02 +0000 (18:31 +0300)
committerGeorgi Gerganov <redacted>
Tue, 27 Aug 2024 19:01:14 +0000 (22:01 +0300)
* metal : separate scale and mask from QKT in FA kernel

* metal : ne01 check no longer necessary

* metal : keep data in local memory

src/ggml-metal.metal

index c2b6b740039a2c709b32fa9fddcc6fbc146ef12f..f323ab5f447d5497259405f9e3eb5cb827f4aedd 100644 (file)
@@ -2341,24 +2341,6 @@ kernel void kernel_flash_attn_ext_f16(
                     }
 
                     simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
-
-                    const short tx = tiisg%4;
-                    const short ty = tiisg/4;
-
-                    // mqk = mqk*scale
-                    ss[8*cc + ty*TF + 2*tx + 0] *= scale;
-                    ss[8*cc + ty*TF + 2*tx + 1] *= scale;
-
-                    if (logit_softcap != 0.0f) {
-                        ss[8*cc + ty*TF + 2*tx + 0] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 0]);
-                        ss[8*cc + ty*TF + 2*tx + 1] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 1]);
-                    }
-
-                    if (mask != q) {
-                        // mqk = mqk + mask*slope
-                        ss[8*cc + ty*TF + 2*tx + 0] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
-                        ss[8*cc + ty*TF + 2*tx + 1] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
-                    }
                 }
             }
 
@@ -2370,10 +2352,19 @@ kernel void kernel_flash_attn_ext_f16(
                 float ms[Q];
 
                 for (short j = 0; j < Q; ++j) {
-                    const short p = tiisg;
-
                     const float m = M[j];
-                    const float s = ss[j*TF + p];
+
+                    // scale and apply the logitcap / mask
+                    float s = ss[j*TF + tiisg]*scale;
+
+                    if (logit_softcap != 0.0f) {
+                        s = logit_softcap*precise::tanh(s);
+                    }
+
+                    if (mask != q) {
+                        // mqk = mqk + mask*slope
+                        s += slope*mp[ic + j*nb31/sizeof(half) + tiisg];
+                    }
 
                     smax = simd_max(max(smax, s));
                     M[j] = simd_max(max(M[j], s));
@@ -2384,7 +2375,7 @@ kernel void kernel_flash_attn_ext_f16(
                     S[j] = S[j]*ms[j] + simd_sum(vs);
 
                     // the P matrix from the paper (Q rows, C columns)
-                    ss[j*TF + p] = vs;
+                    ss[j*TF + tiisg] = vs;
                 }
 
                 // create a QxQ diagonal matrix for rescaling the output