]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vulkan: optimize ssm_scan (#18630)
authorJeff Bolz <redacted>
Thu, 8 Jan 2026 14:16:54 +0000 (08:16 -0600)
committerGitHub <redacted>
Thu, 8 Jan 2026 14:16:54 +0000 (15:16 +0100)
* vulkan: optimize ssm_scan

* fix warp vs subgroup naming

ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp

index 4d3c085f67af687c4a894d6ef2bc26bf9cb58160..7e17f4945da689b70bce46af4d573ed15a839415 100644 (file)
@@ -570,6 +570,7 @@ struct vk_device_struct {
     bool uma;
     bool prefer_host_memory;
     bool float_controls_rte_fp16;
+    bool subgroup_basic;
     bool subgroup_arithmetic;
     bool subgroup_shuffle;
     bool subgroup_ballot;
@@ -4301,8 +4302,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
 
     if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
-        ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
-        ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
+        ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size}, 1, true, true);
+        ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size}, 1, true, true);
     } else {
         ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
         ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
@@ -4638,6 +4639,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
         }
         device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16;
 
+        device->subgroup_basic = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
+                                 (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eBasic);
         device->subgroup_arithmetic = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
                                       (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic);
 #ifdef __APPLE__
@@ -9870,8 +9873,9 @@ static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx,
 
     std::array<uint32_t, 3> elements;
 
-    const int splitH = 16;
-    const uint32_t num_workgroups_x = CEIL_DIV(n_head * head_dim, splitH);
+    const uint32_t d_state = src0->ne[0];
+    uint32_t num_subgroups = d_state / ctx->device->subgroup_size;
+    const uint32_t num_workgroups_x = CEIL_DIV(n_head * head_dim, num_subgroups);
     const uint32_t num_workgroups_y = n_seq;
     elements = { num_workgroups_x, num_workgroups_y, 1 };
 
@@ -14777,11 +14781,13 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                     return false;
                 }
 
-                const uint32_t SPLIT_H = 16;
+                size_t shmem_size = d_state * sizeof(float);
 
-                size_t stateC_size = SPLIT_H * d_state * sizeof(float);
+                if (shmem_size > device->properties.limits.maxComputeSharedMemorySize) {
+                    return false;
+                }
 
-                if (stateC_size > device->properties.limits.maxComputeSharedMemorySize) {
+                if (!device->subgroup_basic) {
                     return false;
                 }
 
index 8f67be97995189ceccd96cbb61a00c3972ab41a6..c7416206dbdaed3ea3bb185d02d9a398a8a3ce70 100644 (file)
@@ -1,6 +1,7 @@
 #version 450
 
 #extension GL_EXT_control_flow_attributes : require
+#extension GL_KHR_shader_subgroup_basic : enable
 #if USE_SUBGROUP_ADD
 #extension GL_KHR_shader_subgroup_arithmetic : enable
 #endif
@@ -9,7 +10,8 @@
 
 layout(constant_id = 0) const uint D_STATE = 128;
 layout(constant_id = 1) const uint SUBGROUP_SIZE = 32;
-layout(constant_id = 2) const uint SPLIT_H = 16;
+
+const uint32_t c_factor = D_STATE / SUBGROUP_SIZE;
 
 layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 
@@ -41,22 +43,28 @@ float softplus(float x) {
     }
 }
 
-shared float stateC[SPLIT_H * D_STATE];
+#if !USE_SUBGROUP_ADD
+shared float temp[D_STATE];
+#endif
 
 void main() {
-    const uint tid = gl_LocalInvocationID.x;
-    const uint head_idx = (gl_WorkGroupID.x * SPLIT_H) / d_head;
-    const uint head_off = ((gl_WorkGroupID.x * SPLIT_H) % d_head) * 4;
-    const uint seq_idx = gl_WorkGroupID.y;
+    const uint subgroup = gl_SubgroupID;
+    const uint lane     = gl_SubgroupInvocationID;
+    const uint tid      = gl_SubgroupID * SUBGROUP_SIZE + lane;
+    const uint subgroup_idx = gl_WorkGroupID.x  * c_factor + subgroup;
+
+    const uint head_idx =  subgroup_idx / d_head;
+    const uint head_off = (subgroup_idx % d_head) * 4;
+    const uint seq_idx  = gl_WorkGroupID.y;
 
     const uint group_off = (head_idx / (n_head / n_group)) * D_STATE * 4;
     const uint s0_base_idx = (uint(ids[seq_idx]) * nb03 + head_idx * nb02 + head_off * D_STATE) / 4;
-    const uint x_base_idx = (seq_idx * nb13 + gl_WorkGroupID.x * SPLIT_H * 4) / 4;
+    const uint x_base_idx = (seq_idx * nb13 + subgroup_idx * 4) / 4;
     const uint dt_base_idx = (seq_idx * nb22 + head_idx * 4) / 4;
     const uint A_base_idx = (head_idx * nb31) / 4;
     const uint B_base_idx = (seq_idx * nb43 + group_off) / 4;
     const uint C_base_idx = (seq_idx * nb53 + group_off) / 4;
-    const uint y_base_idx = seq_idx * n_tok * n_head * d_head + gl_WorkGroupID.x * SPLIT_H;
+    const uint y_base_idx = seq_idx * n_tok * n_head * d_head + subgroup_idx;
     const uint s_base_idx = (s_off + seq_idx * nb03 + head_idx * nb02 + head_off * D_STATE) / 4;
 
     const uint stride_x = nb12 / 4;
@@ -65,76 +73,52 @@ void main() {
     const uint stride_C = nb52 / 4;
     const uint stride_y = n_head * d_head;
 
-    float state[SPLIT_H];
-    [[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
-        state[j] = s0[s0_base_idx + j * D_STATE + tid];
-    }
+    float state[c_factor];
 
-    for (uint i = 0; i < n_tok; i++) {
-        const float dt_soft_plus = softplus(dt[dt_base_idx + i * stride_dt]);
+    [[unroll]] for (uint j = 0; j < c_factor; j++) {
+        state[j] = s0[s0_base_idx + SUBGROUP_SIZE * j + lane];
+    }
 
-        const float dA = exp(dt_soft_plus * A[A_base_idx]);
+    float a = A[A_base_idx];
 
-        const float B_val = B[B_base_idx + i * stride_B + tid];
-        const float C_val = C[C_base_idx + i * stride_C + tid];
+    for (uint i = 0; i < n_tok; i++) {
+        float dt_soft_plus = softplus(dt[dt_base_idx + i * stride_dt]);
 
-        [[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
-            const float x_dt = x[x_base_idx + i * stride_x + j] * dt_soft_plus;
+        float state_sum = 0.0f;
 
+        const float dA   = exp(dt_soft_plus * a);
+        const float x_dt = x[x_base_idx + i * stride_x] * dt_soft_plus;
+        [[unroll]] for (uint j = 0; j < c_factor; j++) {
+            float B_val = B[B_base_idx + i * stride_B + SUBGROUP_SIZE * j + lane];
+            float C_val = C[C_base_idx + i * stride_C + SUBGROUP_SIZE * j + lane];
             state[j] = (state[j] * dA) + (B_val * x_dt);
-
-            stateC[j * D_STATE + tid] = state[j] * C_val;
+            state_sum += state[j] * C_val;
         }
 
+#if USE_SUBGROUP_ADD
+        state_sum = subgroupAdd(state_sum);
+#else
+        temp[tid] = state_sum;
         barrier();
-        [[unroll]]
-        for (uint w = D_STATE / 2; w >= SUBGROUP_SIZE; w >>= 1) {
-            [[unroll]] for (uint j = 0; j < (w * SPLIT_H + D_STATE - 1) / D_STATE; j++) {
-                const uint k = (tid % w) + (D_STATE * (tid / w)) + j * D_STATE * (D_STATE / w);
-                if (k < SPLIT_H * D_STATE && (k + w) < SPLIT_H * D_STATE) {
-                    stateC[k] += stateC[k + w];
-                }
+        [[unroll]] for (uint s = SUBGROUP_SIZE / 2; s > 0; s >>= 1) {
+            if (lane < s) {
+                temp[tid] += temp[tid + s];
             }
             barrier();
         }
-
-        [[unroll]] for (uint j = 0; j < max(1, SPLIT_H / (D_STATE / SUBGROUP_SIZE)); j++) {
-            const uint idx = (tid % SUBGROUP_SIZE) +
-                            D_STATE * (tid / SUBGROUP_SIZE) +
-                            j * D_STATE * (D_STATE / SUBGROUP_SIZE);
-            const uint max_idx = SUBGROUP_SIZE - 1 +
-                            D_STATE * ((D_STATE - 1) / SUBGROUP_SIZE) +
-                            j * D_STATE * (D_STATE / SUBGROUP_SIZE);
-
-            if (idx < SPLIT_H * D_STATE ||
-                max_idx < SPLIT_H * D_STATE) {
-                float sc;
-#if USE_SUBGROUP_ADD
-                sc = stateC[idx];
-                sc = subgroupAdd(sc);
-#else
-                [[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) {
-                    if (idx + offset < SPLIT_H * D_STATE) {
-                        stateC[idx] += stateC[idx + offset];
-                    }
-                    barrier();
-                }
-                if (tid % SUBGROUP_SIZE == 0) {
-                    sc = stateC[idx];
-                }
+        // get the value from lane 0
+        state_sum = temp[subgroup * SUBGROUP_SIZE];
+        barrier();
 #endif
 
-                if (tid % SUBGROUP_SIZE == 0) {
-                    const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE);
-                    d[y_base_idx + i * stride_y + k] = sc;
-                }
-            }
+        if (lane == 0) {
+            d[y_base_idx + i * stride_y] = state_sum;
         }
-
-        barrier();
     }
 
-    [[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
-        d[s_base_idx + j * D_STATE + tid] = state[j];
+    // write back the state
+    [[unroll]]
+    for (int j = 0; j < c_factor; j++) {
+        d[s_base_idx + SUBGROUP_SIZE * j + lane] = state[j];
     }
 }