]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
vulkan: fix SSM_CONV PP scaling with large ubatch sizes (llama/20379)
authorProgenyAlpha <redacted>
Thu, 12 Mar 2026 09:03:18 +0000 (05:03 -0400)
committerGeorgi Gerganov <redacted>
Mon, 16 Mar 2026 11:10:15 +0000 (13:10 +0200)
* vulkan: optimize SSM_CONV workgroup dispatch for large ubatch

Tile tokens into 2D workgroups (32x16) to reduce workgroup launch
overhead at large ubatch sizes. Add vec4 fast path for nc=4 (common
d_conv size). Fixes PP performance degradation with ubatch > 512.

Ref: ggml-org/llama.cpp#18725

Co-Authored-By: Claude Opus 4.6 <redacted>
* vulkan: remove unused shared memory declaration in SSM_CONV

Co-Authored-By: Claude Opus 4.6 <redacted>
---------

Co-authored-by: Progeny Alpha <redacted>
Co-authored-by: Claude Opus 4.6 <redacted>
ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp

index ce3c85e7589a65a31e2431cbec4b090c17892f5d..2a2f7f4f11c86baf517e68ff6971c3b9452add25 100644 (file)
@@ -4576,7 +4576,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
         ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
     }
 
-    ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 1, 1}, {32}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
 
index d62696bcfaecfc8cf3df1ce54bd6522a43b8f0b2..6802b1fc955dd307a32832c780e01a00b1b9d4f5 100644 (file)
@@ -5,8 +5,9 @@
 #include "types.glsl"
 
 layout(constant_id = 0) const uint BLOCK_SIZE = 32;
+layout(constant_id = 1) const uint TOKENS_PER_WG = 16;
 
-layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z = 1) in;
 
 layout(binding = 0) readonly buffer Src0 { float src0[]; };
 layout(binding = 1) readonly buffer Src1 { float src1[]; };
@@ -20,25 +21,30 @@ layout(push_constant) uniform PushConstants {
 };
 
 void main() {
-    const uint global_thread_id = gl_GlobalInvocationID.x;
-    const uint i2 = gl_WorkGroupID.y;
+    const uint i1 = gl_GlobalInvocationID.x;
+    const uint i2 = gl_WorkGroupID.y * TOKENS_PER_WG + gl_LocalInvocationID.y;
     const uint i3 = gl_WorkGroupID.z;
 
-    if (global_thread_id >= nr || i2 >= n_t || i3 >= n_s) {
+    if (i1 >= nr || i2 >= n_t || i3 >= n_s) {
         return;
     }
 
-    const uint i1 = global_thread_id;
     const uint src0_base = i3 * (nb02 / 4) + i2 + i1 * (nb01 / 4);
     const uint src1_base = i1 * (nb11 / 4);
-    const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1;
 
     float sum = 0.0;
-    [[unroll]] for (uint i0 = 0; i0 < nc; i0++) {
-        const uint src0_idx = src0_base + i0;
-        const uint src1_idx = src1_base + i0;
-        sum += src0[src0_idx] * src1[src1_idx];
+
+    if (nc == 4) {
+        sum = dot(
+            vec4(src0[src0_base], src0[src0_base + 1], src0[src0_base + 2], src0[src0_base + 3]),
+            vec4(src1[src1_base], src1[src1_base + 1], src1[src1_base + 2], src1[src1_base + 3])
+        );
+    } else {
+        [[unroll]] for (uint i0 = 0; i0 < nc; i0++) {
+            sum += src0[src0_base + i0] * src1[src1_base + i0];
+        }
     }
 
+    const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1;
     dst[dst_idx] = sum;
 }