]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
metal : gemma2 flash attention support (llama/9159)
authorslaren <redacted>
Mon, 26 Aug 2024 09:08:59 +0000 (11:08 +0200)
committerGeorgi Gerganov <redacted>
Tue, 27 Aug 2024 19:01:14 +0000 (22:01 +0300)
src/ggml-metal.m
src/ggml-metal.metal
tests/test-backend-ops.cpp

index f6c36267c51fb38ea03ca5d56868c9e6535aaf0c..670d53861bbe26e994073ff1fca0ed2ab0c15ac6 100644 (file)
@@ -817,15 +817,6 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
             if (op->src[0]->ne[0] == 256) {
                 return false;
             }
-            {
-                float logit_softcap;
-
-                memcpy(&logit_softcap, ((const float *) op->op_params) + 2, sizeof(logit_softcap));
-
-                if (logit_softcap != 0.0f) {
-                    return false;
-                }
-            }
             return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
         case GGML_OP_MUL_MAT:
         case GGML_OP_MUL_MAT_ID:
@@ -2693,9 +2684,14 @@ static enum ggml_status ggml_metal_graph_compute(
 
                         float scale;
                         float max_bias;
+                        float logit_softcap;
+                        memcpy(&scale,         ((int32_t *) dst->op_params) + 0, sizeof(scale));
+                        memcpy(&max_bias,      ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
+                        memcpy(&logit_softcap, ((int32_t *) dst->op_params) + 2, sizeof(logit_softcap));
 
-                        memcpy(&scale,    ((int32_t *) dst->op_params) + 0, sizeof(scale));
-                        memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
+                        if (logit_softcap != 0.0f) {
+                            scale /= logit_softcap;
+                        }
 
                         const uint32_t n_head      = src0->ne[2];
                         const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
@@ -2746,30 +2742,31 @@ static enum ggml_status ggml_metal_graph_compute(
                         } else {
                             [encoder setBuffer:id_src0     offset:offs_src0           atIndex:3];
                         }
-                        [encoder setBuffer:id_dst      offset:offs_dst            atIndex:4];
-                        [encoder setBytes:&ne01        length:sizeof( int64_t)    atIndex:5];
-                        [encoder setBytes:&ne02        length:sizeof( int64_t)    atIndex:6];
-                        [encoder setBytes:&ne03        length:sizeof( int64_t)    atIndex:7];
-                        [encoder setBytes:&nb01        length:sizeof(uint64_t)    atIndex:8];
-                        [encoder setBytes:&nb02        length:sizeof(uint64_t)    atIndex:9];
-                        [encoder setBytes:&nb03        length:sizeof(uint64_t)    atIndex:10];
-                        [encoder setBytes:&ne11        length:sizeof( int64_t)    atIndex:11];
-                        [encoder setBytes:&ne12        length:sizeof( int64_t)    atIndex:12];
-                        [encoder setBytes:&ne13        length:sizeof( int64_t)    atIndex:13];
-                        [encoder setBytes:&nb11        length:sizeof(uint64_t)    atIndex:14];
-                        [encoder setBytes:&nb12        length:sizeof(uint64_t)    atIndex:15];
-                        [encoder setBytes:&nb13        length:sizeof(uint64_t)    atIndex:16];
-                        [encoder setBytes:&nb21        length:sizeof(uint64_t)    atIndex:17];
-                        [encoder setBytes:&nb22        length:sizeof(uint64_t)    atIndex:18];
-                        [encoder setBytes:&nb23        length:sizeof(uint64_t)    atIndex:19];
-                        [encoder setBytes:&nb31        length:sizeof(uint64_t)    atIndex:20];
-                        [encoder setBytes:&ne1         length:sizeof( int64_t)    atIndex:21];
-                        [encoder setBytes:&ne2         length:sizeof( int64_t)    atIndex:22];
-                        [encoder setBytes:&scale       length:sizeof(   float)    atIndex:23];
-                        [encoder setBytes:&max_bias    length:sizeof(   float)    atIndex:24];
-                        [encoder setBytes:&m0          length:sizeof(m0)          atIndex:25];
-                        [encoder setBytes:&m1          length:sizeof(m1)          atIndex:26];
-                        [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
+                        [encoder setBuffer:id_dst        offset:offs_dst              atIndex:4];
+                        [encoder setBytes:&ne01          length:sizeof( int64_t)      atIndex:5];
+                        [encoder setBytes:&ne02          length:sizeof( int64_t)      atIndex:6];
+                        [encoder setBytes:&ne03          length:sizeof( int64_t)      atIndex:7];
+                        [encoder setBytes:&nb01          length:sizeof(uint64_t)      atIndex:8];
+                        [encoder setBytes:&nb02          length:sizeof(uint64_t)      atIndex:9];
+                        [encoder setBytes:&nb03          length:sizeof(uint64_t)      atIndex:10];
+                        [encoder setBytes:&ne11          length:sizeof( int64_t)      atIndex:11];
+                        [encoder setBytes:&ne12          length:sizeof( int64_t)      atIndex:12];
+                        [encoder setBytes:&ne13          length:sizeof( int64_t)      atIndex:13];
+                        [encoder setBytes:&nb11          length:sizeof(uint64_t)      atIndex:14];
+                        [encoder setBytes:&nb12          length:sizeof(uint64_t)      atIndex:15];
+                        [encoder setBytes:&nb13          length:sizeof(uint64_t)      atIndex:16];
+                        [encoder setBytes:&nb21          length:sizeof(uint64_t)      atIndex:17];
+                        [encoder setBytes:&nb22          length:sizeof(uint64_t)      atIndex:18];
+                        [encoder setBytes:&nb23          length:sizeof(uint64_t)      atIndex:19];
+                        [encoder setBytes:&nb31          length:sizeof(uint64_t)      atIndex:20];
+                        [encoder setBytes:&ne1           length:sizeof( int64_t)      atIndex:21];
+                        [encoder setBytes:&ne2           length:sizeof( int64_t)      atIndex:22];
+                        [encoder setBytes:&scale         length:sizeof(   float)      atIndex:23];
+                        [encoder setBytes:&max_bias      length:sizeof(   float)      atIndex:24];
+                        [encoder setBytes:&m0            length:sizeof(m0)            atIndex:25];
+                        [encoder setBytes:&m1            length:sizeof(m1)            atIndex:26];
+                        [encoder setBytes:&n_head_log2   length:sizeof(n_head_log2)   atIndex:27];
+                        [encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:28];
 
                         if (!use_vec_kernel) {
                             // half8x8 kernel
index 17432085c03ad7e6b9e93554665b277aa9c9c133..2de9a592990033495d9687717e0c685663f92877 100644 (file)
@@ -2056,6 +2056,7 @@ typedef void (flash_attn_ext_f16_t)(
         constant     float & m0,
         constant     float & m1,
         constant  uint32_t & n_head_log2,
+        constant     float & logit_softcap,
         threadgroup   half * shared,
         uint3  tgpig[[threadgroup_position_in_grid]],
         uint3  tpitg[[thread_position_in_threadgroup]],
@@ -2094,6 +2095,7 @@ kernel void kernel_flash_attn_ext_f16(
         constant     float & m0,
         constant     float & m1,
         constant  uint32_t & n_head_log2,
+        constant     float & logit_softcap,
         threadgroup   half * shared [[threadgroup(0)]],
         uint3  tgpig[[threadgroup_position_in_grid]],
         uint3  tpitg[[thread_position_in_threadgroup]],
@@ -2222,14 +2224,19 @@ kernel void kernel_flash_attn_ext_f16(
                     const short tx = tiisg%4;
                     const short ty = tiisg/4;
 
+                    // mqk = mqk*scale
+                    ss[8*cc + ty*TF + 2*tx + 0] *= scale;
+                    ss[8*cc + ty*TF + 2*tx + 1] *= scale;
+
+                    if (logit_softcap != 0.0f) {
+                        ss[8*cc + ty*TF + 2*tx + 0] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 0]);
+                        ss[8*cc + ty*TF + 2*tx + 1] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 1]);
+                    }
+
                     if (mask != q) {
-                        // mqk = mqk*scale + mask*slope
-                        ss[8*cc + ty*TF + 2*tx + 0] = scale*ss[8*cc + ty*TF + 2*tx + 0] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
-                        ss[8*cc + ty*TF + 2*tx + 1] = scale*ss[8*cc + ty*TF + 2*tx + 1] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
-                    } else {
-                        // mqk = mqk*scale
-                        ss[8*cc + ty*TF + 2*tx + 0] *= scale;
-                        ss[8*cc + ty*TF + 2*tx + 1] *= scale;
+                        // mqk = mqk + mask*slope
+                        ss[8*cc + ty*TF + 2*tx + 0] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
+                        ss[8*cc + ty*TF + 2*tx + 1] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
                     }
                 }
             }
@@ -2425,6 +2432,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
         constant     float & m0,
         constant     float & m1,
         constant  uint32_t & n_head_log2,
+        constant     float & logit_softcap,
         threadgroup   half * shared [[threadgroup(0)]],
         uint3  tgpig[[threadgroup_position_in_grid]],
         uint3  tpitg[[thread_position_in_threadgroup]],
@@ -2559,7 +2567,13 @@ kernel void kernel_flash_attn_ext_vec_f16(
 
                     // mqk = mqk*scale + mask*slope
                     if (tiisg == 0) {
-                        mqk = mqk*scale + ((mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f);
+                        mqk *= scale;
+
+                        if (logit_softcap != 0.0f) {
+                            mqk = logit_softcap*precise::tanh(mqk);
+                        }
+
+                        mqk += (mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f;
 
                         ss4[cc] = mqk;
                     }
index fce5419795d4e350607a57c05cd0bb0193f55384..eb1de59ac48a003b67be509fcf42308920e3a122 100644 (file)
@@ -2564,7 +2564,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
     }
 
     GGML_ABORT("fatal error");
-    return false;
 }
 
 static void usage(char ** argv) {