]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
metal: SSM_SCAN performance (llama/14743)
authorGabe Goodhart <redacted>
Fri, 25 Jul 2025 16:47:39 +0000 (10:47 -0600)
committerGeorgi Gerganov <redacted>
Mon, 28 Jul 2025 05:43:21 +0000 (08:43 +0300)
* feat: Add s_off as a parameter in the args struct

This may not be necessary, but it more closely mirrors the CUDA kernel

Branch: GraniteFourPerf

Signed-off-by: Gabe Goodhart <redacted>
* perf: Parallelize mamba2 SSM_SCAN metal kernel over d_state

This is a first attempt at optimizing the metal kernel. The changes here
are:

- Launch the kernel with a thread group of size d_state
- Use simd groups and shared memory to do the summation for the y
  computation

When tested with G4 tiny preview, this shows roughly a 3x speedup on
prefill and 15% speedup on decode.

Signed-off-by: Gabe Goodhart <redacted>
* fix: Update logic to correctly do the multi-layer parallel sum

Signed-off-by: Gabe Goodhart <redacted>
* fix: Correctly size the shared memory bufer and assert expected size relationships

Branch: GraniteFourPerf

Signed-off-by: Gabe Goodhart <redacted>
* refactor: Compute block offsets once rather than once per token

Branch: GraniteFourPerf

Signed-off-by: Gabe Goodhart <redacted>
* feat: Use local variable for state recursion

Branch: GraniteFourPerf

Signed-off-by: Gabe Goodhart <redacted>
* feat: Use a secondary simd_sum instead of a for loop

Branch: GraniteFourPerf

Signed-off-by: Gabe Goodhart <redacted>
* feat: Add assertion and comment about relationship between simd size and num simd groups

Branch: GraniteFourPerf

Signed-off-by: Gabe Goodhart <redacted>
* feat: Parallelize of d_state for mamba-1

Branch: GraniteFourPerf

Signed-off-by: Gabe Goodhart <redacted>
* feat: Parallel sum in SSM_CONV

Branch: GraniteFourPerf

Signed-off-by: Gabe Goodhart <redacted>
* Revert "feat: Parallel sum in SSM_CONV"

After discussion with @compilade, the size of the parallelism here is
not worth the cost in complexity or overhead of the parallel for.

https://github.com/ggml-org/llama.cpp/pull/14743#discussion_r2223395357

This reverts commit 16bc059660c1c59e566628201c0ca2c20c9f4bc3.

Signed-off-by: Gabe Goodhart <redacted>
* refactor: Simplify shared memory sizing

Branch: GraniteFourPerf

Signed-off-by: Gabe Goodhart <redacted>
Co-Authored-By: Georgi Gerganov <redacted>
---------

Signed-off-by: Gabe Goodhart <redacted>
Co-authored-by: Georgi Gerganov <redacted>
src/ggml-metal/ggml-metal-impl.h
src/ggml-metal/ggml-metal.m
src/ggml-metal/ggml-metal.metal

index b7b3fc49af35db9e9d57ed9d27838d197f3cfb82..8424464d8cadca4926c992a6142d185574b5ea37 100644 (file)
@@ -528,6 +528,7 @@ typedef struct {
     int64_t  n_group;
     int64_t  n_seq_tokens;
     int64_t  n_seqs;
+    int64_t  s_off;
     uint64_t nb01;
     uint64_t nb02;
     uint64_t nb03;
index 1a9999325fe27ac2ab15614d2953cc2689517d4f..337f7985badf328ed928282509725c94ce2e5caa 100644 (file)
@@ -3141,6 +3141,7 @@ static int ggml_metal_encode_node(
                     /*.n_group      =*/ n_group,
                     /*.n_seq_tokens =*/ n_seq_tokens,
                     /*.n_seqs       =*/ n_seqs,
+                    /*.s_off        =*/ ggml_nelements(src1) * sizeof(float),
                     /*.nb01         =*/ nb01,
                     /*.nb02         =*/ nb02,
                     /*.nb03         =*/ nb03,
@@ -3169,12 +3170,22 @@ static int ggml_metal_encode_node(
                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:7];
                 [encoder setBytes:&args    length:sizeof(args) atIndex:8];
 
+                // One shared memory bucket for each simd group in the threadgroup
+                // NOTE: Metal kernels require the buffer size to be multiple of 16 bytes
+                //  https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
+                if (d_state >= 32) {
+                    GGML_ASSERT((int64_t)(d_state / 32) <= 32);
+                    const int64_t shmem_size = 32;
+                    GGML_ASSERT(d_state <= (int64_t)pipeline.maxTotalThreadsPerThreadgroup);
+                    [encoder setThreadgroupMemoryLength:(shmem_size)*sizeof(float) atIndex:0];
+                }
+
                 if (ne30 == 1) {
                     // Mamba-2
-                    [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                    [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
                 } else {
                     GGML_ASSERT(d_inner == 1);
-                    [encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                    [encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
                 }
             } break;
         case GGML_OP_RWKV_WKV6:
index f62b9ad548e695ce00a8c8aa0681f59e59efc37f..99a453090f6b0ab58636a2ba53378555342d69ca 100644 (file)
@@ -1823,10 +1823,16 @@ kernel void kernel_ssm_scan_f32(
         device const void * src5,
         device const void * src6,
         device      float * dst,
+        threadgroup float * shared [[threadgroup(0)]],
         constant ggml_metal_kargs_ssm_scan & args,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        uint3  tpitg[[thread_position_in_threadgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgptg[[simdgroups_per_threadgroup]],
+        uint3   tgpg[[threadgroups_per_grid]]) {
+
+    const int64_t i0 = tpitg.x;
     const int64_t i1 = 0;
     const int64_t ir = tgpig.x; // current head
     const int64_t i3 = tgpig.y; // current seq
@@ -1841,41 +1847,88 @@ kernel void kernel_ssm_scan_f32(
     const int64_t ng  = args.n_group;
     const int64_t n_t = args.n_seq_tokens;
 
-    const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
+    const int64_t s_off = args.s_off;
 
     device const int32_t * ids = (device const int32_t *) src6;
 
-    device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
-    device       float * s  = (device       float *) ((device       char *) dst  + ir*args.nb02 +      i3*args.nb03 + s_off);
+    device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
+    device       float * s_buff  = (device       float *) ((device       char *) dst  + ir*args.nb02 +      i3*args.nb03 + s_off);
+    const int64_t i = i0 + i1*nc;
+    float s0 = s0_buff[i];
+    float s  = s_buff[i];
+
+        device const float * A        = (device const float *) ((device const char *) src3 + ir*args.nb31);
+        device const float * x_block  = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
+        device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
+        device const float * B_block  = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43);
+        device const float * C_block  = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53);
+        device       float * y_block  = (device       float *) ((device       char *) dst  + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
 
     for (int64_t i2 = 0; i2 < n_t; ++i2) {
-        device const float * x  = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
-        device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
-        device const float * A  = (device const float *) ((device const char *) src3 + ir*args.nb31); // {d_state, nh}
-        device const float * B  = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
-        device const float * C  = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
-        device       float * y  = (device       float *) ((device       char *) dst  + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
+        device const float * x  = (device const float *) ((device const char *) x_block + i2*args.nb12);    // {dim, nh, nt, ns}
+        device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21);   // {nh, nt, ns}
+        device const float * B  = (device const float *) ((device const char *) B_block + i2*args.nb42);    // {d_state, ng, nt, ns}
+        device const float * C  = (device const float *) ((device const char *) C_block + i2*args.nb52);    // {d_state, ng, nt, ns}
+        device       float * y  = (device       float *) ((device       char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
 
         const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
         const float x_dt = x[0] * dt_soft_plus;
-        float sumf = 0.0f;
 
-        for (int64_t i0 = 0; i0 < nc; ++i0) {
-            const int64_t i = i0 + i1*nc;
-            const float state = (s0[i] * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
-            sumf += state * C[i0];
-            s[i] = state;
-        }
+        const float state = (s0 * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
+        s = state;
+
+        // Parallel sum: This relies on the fact that this kernel will be
+        // dispatched with each threadgroup having (d_state, 1, 1) threads which
+        // are subdivided into SIMD groups of size `sgptg`. The goal is to
+        // compute y = sum({state * C[i] for i in range(d_state)}).
+        // To parallelize this effectively, we first use simd_sum over each SIMD
+        // group to compute the sum of each SIMD group, then place the result in
+        // the SIMD group's indexed bucket in the shared memory. We then sum
+        // over the individual group sums to compute the final sum.
+
+        // Computed for each thread
+        float sumf = state * C[i0];
 
-        y[0] = sumf;
+        // Sum the threads in the simd group => simd sum
+        sumf = simd_sum(sumf);
+
+        if (sgptg > 1) {
+
+            // Once per simd group, place the group sum into the shared buffer
+            if (tiisg == 0) {
+                shared[sgitg] = sumf;
+            }
+
+            // Wait for all threads in the threadgroup to reach this point. This
+            // ensures that all elements of the shared buffer are populated with the
+            // sum of the individual simd groups.
+            threadgroup_barrier(mem_flags::mem_threadgroup);
+
+            // For simd group 0 at indices < num simd groups, extract the shared
+            // simd sum
+            sumf = 0.0f;
+            if (sgitg == 0) {
+                if (tiisg < sgptg) {
+                    sumf = shared[tiisg];
+                }
+                sumf = simd_sum(sumf);
+                if (tiisg == 0) {
+                    y[0] = sumf;
+                }
+            }
+        } else if (tiisg == 0) {
+            y[0] = sumf;
+        }
 
         // recurse
         s0 = s;
     }
+
+    // Assign the final state to the output buffer
+    s_buff[i] = s;
 }
 
 // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
-// TODO: optimize (e.g. by parallelizing over d_state)
 kernel void kernel_ssm_scan_f32_group(
         device const void * src0,
         device const void * src1,
@@ -1885,10 +1938,16 @@ kernel void kernel_ssm_scan_f32_group(
         device const void * src5,
         device const void * src6,
         device      float * dst,
+        threadgroup float * shared [[threadgroup(0)]],
         constant ggml_metal_kargs_ssm_scan & args,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        uint3  tpitg[[thread_position_in_threadgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgptg[[simdgroups_per_threadgroup]],
+        uint3   tgpg[[threadgroups_per_grid]]) {
+
+    const int64_t i0 = tpitg.x;
     const int64_t i1 = tgpig.x;
     const int64_t ir = tgpig.y; // current head
     const int64_t i3 = tgpig.z; // current seq
@@ -1903,38 +1962,81 @@ kernel void kernel_ssm_scan_f32_group(
     const int64_t ng  = args.n_group;
     const int64_t n_t = args.n_seq_tokens;
 
-    const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
+    const int64_t s_off = args.s_off;
 
     device const int32_t * ids = (device const int32_t *) src6;
 
-    device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
-    device       float * s  = (device       float *) ((device       char *) dst  + ir*args.nb02 +      i3*args.nb03 + s_off);
+    device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
+    device       float * s_buff  = (device       float *) ((device       char *) dst  + ir*args.nb02 +      i3*args.nb03 + s_off);
+    const int64_t i = i0 + i1*nc;
+    float s0 = s0_buff[i];
+    float s  = s_buff[i];
+
+    device const float * A        = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
+    device const float * x_block  = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
+    device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
+    device const float * B_block  = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43);
+    device const float * C_block  = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53);
+    device       float * y_block  = (device       float *) ((device       char *) dst  + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
 
     for (int64_t i2 = 0; i2 < n_t; ++i2) {
-        device const float * x  = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
-        device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
-        device const float * A  = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
-        device const float * B  = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
-        device const float * C  = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
-        device       float * y  = (device       float *) ((device       char *) dst  + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
+        device const float * x  = (device const float *) ((device const char *) x_block  + i2*args.nb12);    // {dim, nh, nt, ns}
+        device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21);    // {nh, nt, ns}
+        device const float * B  = (device const float *) ((device const char *) B_block  + i2*args.nb42);    // {d_state, ng, nt, ns}
+        device const float * C  = (device const float *) ((device const char *) C_block  + i2*args.nb52);    // {d_state, ng, nt, ns}
+        device       float * y  = (device       float *) ((device       char *) y_block  + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
 
         const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
         const float x_dt = x[0] * dt_soft_plus;
         const float dA = exp(dt_soft_plus * A[0]);
-        float sumf = 0.0f;
 
-        for (int64_t i0 = 0; i0 < nc; ++i0) {
-            const int64_t i = i0 + i1*nc;
-            const float state = (s0[i] * dA) + (B[i0] * x_dt);
-            sumf += state * C[i0];
-            s[i] = state;
+        const float state = (s0 * dA) + (B[i0] * x_dt);
+        s = state;
+
+        // Parallel sum: This relies on the fact that this kernel will be
+        // dispatched with each threadgroup having (d_state, 1, 1) threads which
+        // are subdivided into SIMD groups of size `sgptg`. The goal is to
+        // compute y = sum({state * C[i] for i in range(d_state)}).
+        // To parallelize this effectively, we first use simd_sum over each SIMD
+        // group to compute the sum of each SIMD group, then place the result in
+        // the SIMD group's indexed bucket in the shared memory. We then sum
+        // over the individual group sums to compute the final sum.
+
+        // Computed for each thread
+        float sumf = state * C[i0];
+
+        // Sum the threads in the simd group => simd sum
+        sumf = simd_sum(sumf);
+
+        // Once per simd group, place the group sum into the shared buffer
+        if (tiisg == 0) {
+            shared[sgitg] = sumf;
         }
 
-        y[0] = sumf;
+        // Wait for all threads in the threadgroup to reach this point. This
+        // ensures that all elements of the shared buffer are populated with the
+        // sum of the individual simd groups.
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        // For simd group 0 at indices < num simd groups, extract the shared
+        // simd sum
+        sumf = 0.0f;
+        if (sgitg == 0) {
+            if (tiisg < sgptg) {
+                sumf = shared[tiisg];
+            }
+            sumf = simd_sum(sumf);
+            if (tiisg == 0) {
+                y[0] = sumf;
+            }
+        }
 
         // recurse
         s0 = s;
     }
+
+    // Assign the final state to the output buffer
+    s_buff[i] = s;
 }
 
 kernel void kernel_rwkv_wkv6_f32(