]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vulkan: Optimize SSM_SCAN (#16645)
authorJeff Bolz <redacted>
Sat, 25 Oct 2025 05:04:12 +0000 (00:04 -0500)
committerGitHub <redacted>
Sat, 25 Oct 2025 05:04:12 +0000 (07:04 +0200)
ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

index 21bd0522555643e5806e7f6d3eb2a34fcf350487..5e6b751ae372d42f3bff17e193892fad727b2b09 100644 (file)
@@ -3623,8 +3623,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
 
     ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
 
-    ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1);
+    if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
+        ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
+        ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
+    } else {
+        ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
+        ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
+    }
 
     ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 1, 1}, {32}, 1);
 
index 12bd1745790523eb99d5ba2ed09f4e363faf1cb5..8f67be97995189ceccd96cbb61a00c3972ab41a6 100644 (file)
@@ -1,6 +1,9 @@
 #version 450
 
 #extension GL_EXT_control_flow_attributes : require
+#if USE_SUBGROUP_ADD
+#extension GL_KHR_shader_subgroup_arithmetic : enable
+#endif
 
 #include "types.glsl"
 
@@ -84,35 +87,47 @@ void main() {
         }
 
         barrier();
-        for (uint w = D_STATE; w > SUBGROUP_SIZE; w >>= 1) {
-            [[unroll]] for (uint j = 0; j < ((w >> 1) * SPLIT_H + D_STATE - 1) / D_STATE; j++) {
-                const uint k = (tid % (w >> 1)) +
-                              (D_STATE * (tid / (w >> 1))) +
-                              j * D_STATE * (D_STATE / (w >> 1));
-                if (k < SPLIT_H * D_STATE && (k + (w >> 1)) < SPLIT_H * D_STATE) {
-                    stateC[k] += stateC[k + (w >> 1)];
+        [[unroll]]
+        for (uint w = D_STATE / 2; w >= SUBGROUP_SIZE; w >>= 1) {
+            [[unroll]] for (uint j = 0; j < (w * SPLIT_H + D_STATE - 1) / D_STATE; j++) {
+                const uint k = (tid % w) + (D_STATE * (tid / w)) + j * D_STATE * (D_STATE / w);
+                if (k < SPLIT_H * D_STATE && (k + w) < SPLIT_H * D_STATE) {
+                    stateC[k] += stateC[k + w];
                 }
             }
             barrier();
         }
 
-        [[unroll]] for (uint j = 0; j <= SPLIT_H / (D_STATE / SUBGROUP_SIZE); j++) {
+        [[unroll]] for (uint j = 0; j < max(1, SPLIT_H / (D_STATE / SUBGROUP_SIZE)); j++) {
             const uint idx = (tid % SUBGROUP_SIZE) +
                             D_STATE * (tid / SUBGROUP_SIZE) +
                             j * D_STATE * (D_STATE / SUBGROUP_SIZE);
+            const uint max_idx = SUBGROUP_SIZE - 1 +
+                            D_STATE * ((D_STATE - 1) / SUBGROUP_SIZE) +
+                            j * D_STATE * (D_STATE / SUBGROUP_SIZE);
 
-            uint lane = tid % SUBGROUP_SIZE;
-
-            [[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) {
-                if (idx + offset < SPLIT_H * D_STATE) {
-                    stateC[idx] += stateC[idx + offset];
+            if (idx < SPLIT_H * D_STATE ||
+                max_idx < SPLIT_H * D_STATE) {
+                float sc;
+#if USE_SUBGROUP_ADD
+                sc = stateC[idx];
+                sc = subgroupAdd(sc);
+#else
+                [[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) {
+                    if (idx + offset < SPLIT_H * D_STATE) {
+                        stateC[idx] += stateC[idx + offset];
+                    }
+                    barrier();
                 }
-                barrier();
-            }
+                if (tid % SUBGROUP_SIZE == 0) {
+                    sc = stateC[idx];
+                }
+#endif
 
-            if (idx < SPLIT_H * D_STATE && tid % SUBGROUP_SIZE == 0) {
-                const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE);
-                d[y_base_idx + i * stride_y + k] = stateC[idx];
+                if (tid % SUBGROUP_SIZE == 0) {
+                    const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE);
+                    d[y_base_idx + i * stride_y + k] = sc;
+                }
             }
         }
 
index 49bf6c764f726efd1d86dffc20395795f579a34e..0f25ba3453093306a1965da6b7ab139601b728e8 100644 (file)
@@ -916,7 +916,8 @@ void process_shaders() {
     string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}});
     string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}});
 
-    string_to_spv("ssm_scan_f32", "ssm_scan.comp", {{"A_TYPE", "float"}});
+    string_to_spv("ssm_scan_f32",          "ssm_scan.comp", {{"A_TYPE", "float"}});
+    string_to_spv("ssm_scan_subgroup_f32", "ssm_scan.comp", {{"A_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}});
 
     string_to_spv("ssm_conv_f32", "ssm_conv.comp", {{"A_TYPE", "float"}});