]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vulkan: Optimize GGML_OP_CUMSUM (#18417)
authorJeff Bolz <redacted>
Fri, 2 Jan 2026 21:32:30 +0000 (15:32 -0600)
committerGitHub <redacted>
Fri, 2 Jan 2026 21:32:30 +0000 (15:32 -0600)
* vulkan: Optimize GGML_OP_CUMSUM

There are two paths: The preexisting one that does a whole row per workgroup
in a single shader, and one that splits each row into multiple blocks and does
two passes. The first pass computes partials within a block, the second adds
the block partials to compute the final result. The multipass shader is used
when there are a small number of large rows.

In the whole-row shader, handle multiple elements per invocation.

* use 2 ELEM_PER_THREAD for AMD/Intel

* address feedback

ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp
ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp [new file with mode: 0644]
ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp [new file with mode: 0644]
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

index 27772e11f1ede8be3e42ba6fdf846bfc279485e9..16254457bb13cbc99b0de3a87e60e3cfdf1829ef 100644 (file)
@@ -765,6 +765,9 @@ struct vk_device_struct {
     vk_pipeline pipeline_topk_f32[num_topk_pipelines];
     vk_pipeline pipeline_sum_rows_f32;
     vk_pipeline pipeline_cumsum_f32;
+    vk_pipeline pipeline_cumsum_small_f32;
+    vk_pipeline pipeline_cumsum_multipass1_f32;
+    vk_pipeline pipeline_cumsum_multipass2_f32;
     vk_pipeline pipeline_argmax_f32;
     vk_pipeline pipeline_count_equal_i32;
     std::map<vk_solve_tri_pipeline_state, vk_pipeline> pipeline_solve_tri_f32;
@@ -4178,7 +4181,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
 
     ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
 
-    ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 128, device->subgroup_size }, 1, true, true, device->subgroup_size);
+    const uint32_t cumsum_elem_per_thread = (device->vendor_id == VK_VENDOR_ID_AMD || device->vendor_id == VK_VENDOR_ID_INTEL) ? 2 : 4;
+    ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32,       "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 256, device->subgroup_size, cumsum_elem_per_thread }, 1, true, true, device->subgroup_size);
+    ggml_vk_create_pipeline(device, device->pipeline_cumsum_small_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 128, device->subgroup_size, 1 }, 1, true, true, device->subgroup_size);
+    ggml_vk_create_pipeline(device, device->pipeline_cumsum_multipass1_f32, "cumsum_multipass1_f32", cumsum_multipass1_f32_len, cumsum_multipass1_f32_data, "main", 3, sizeof(vk_op_sum_rows_push_constants), {256, 1, 1}, { 256, device->subgroup_size }, 1, true, true, device->subgroup_size);
+    ggml_vk_create_pipeline(device, device->pipeline_cumsum_multipass2_f32, "cumsum_multipass2_f32", cumsum_multipass2_f32_len, cumsum_multipass2_f32_data, "main", 3, sizeof(vk_op_sum_rows_push_constants), {256, 1, 1}, { 256, device->subgroup_size }, 1, true, true, device->subgroup_size);
 
     ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
 
@@ -8804,7 +8811,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
         return nullptr;
     case GGML_OP_CUMSUM:
         if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return ctx->device->pipeline_cumsum_f32;
+            if (src0->ne[0] <= 512) {
+                return ctx->device->pipeline_cumsum_small_f32;
+            } else {
+                return ctx->device->pipeline_cumsum_f32;
+            }
         }
         return nullptr;
     case GGML_OP_SOLVE_TRI:
@@ -10708,8 +10719,50 @@ static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, cons
 }
 
 static void ggml_vk_cumsum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
-    vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
-    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CUMSUM, p);
+    vk_op_sum_rows_push_constants pc = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
+    // Use the single pass shader when the rows are small or there are enough rows to fill the GPU.
+    // For fewer, larger rows, use the multipass shader to spread each row across SMs.
+    if (dst->ne[0] <= 4096 || ggml_nrows(dst) >= ctx->device->shader_core_count) {
+        ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CUMSUM, pc);
+        return;
+    }
+
+    // First pass computes partial sums within a block, and stores the last partial
+    // to the temp buffer. Second pass sums the block partials from the temp buffer
+    // and adds that to the result of the first pass.
+    vk_pipeline pipeline1 = ctx->device->pipeline_cumsum_multipass1_f32;
+    vk_pipeline pipeline2 = ctx->device->pipeline_cumsum_multipass2_f32;
+    GGML_ASSERT(pipeline1 != nullptr && pipeline2 != nullptr);
+
+    ggml_pipeline_request_descriptor_sets(ctx, pipeline1, 1);
+    ggml_pipeline_request_descriptor_sets(ctx, pipeline2, 1);
+
+    std::array<uint32_t, 3> elements;
+
+    elements[0] = dst->ne[0];
+    elements[1] = (uint32_t)ggml_nrows(dst);
+    elements[2] = 1;
+
+    size_t temp_size = sizeof(float) * elements[0] * ggml_nrows(dst);
+
+    if (ctx->prealloc_size_split_k < temp_size) {
+        ctx->prealloc_size_split_k = temp_size;
+        ggml_vk_preallocate_buffers(ctx, subctx);
+    }
+
+    vk_subbuffer src_buf = ggml_vk_tensor_subbuffer(ctx, src0);
+    vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
+    vk_subbuffer temp_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);
+
+    if (ctx->prealloc_split_k_need_sync) {
+        ggml_vk_sync_buffers(ctx, subctx);
+    }
+
+    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline1, {src_buf, dst_buf, temp_buf}, pc, elements);
+    ggml_vk_sync_buffers(ctx, subctx);
+    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline2, {src_buf, dst_buf, temp_buf}, pc, elements);
+
+    ctx->prealloc_split_k_need_sync = true;
 }
 
 static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
index a4c8fc354e97b7188419fd34e3f8c570892dc043..75e3c3b0eb44ae683b706179c12e37c09dce10c6 100644 (file)
@@ -14,6 +14,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
 
 layout (constant_id = 0) const uint BLOCK_SIZE = 128;
 layout (constant_id = 1) const uint SUBGROUP_SIZE = 32;
+layout (constant_id = 2) const uint ELEM_PER_THREAD = 4;
 
 #define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
 
@@ -38,32 +39,45 @@ void main() {
         last_sum = 0;
     }
 
-    uint col = tid;
-    uint num_iter = CEIL_DIV(p.n_cols, BLOCK_SIZE);
+    uint col = tid * ELEM_PER_THREAD;
+    uint num_iter = CEIL_DIV(p.n_cols, BLOCK_SIZE * ELEM_PER_THREAD);
     for (int i = 0; i < num_iter; ++i) {
-        FLOAT_TYPE v = 0;
-        if (col < p.n_cols) {
-            v = FLOAT_TYPE(data_a[src_idx + col]);
+        FLOAT_TYPE v[ELEM_PER_THREAD];
+        FLOAT_TYPE thread_sum = 0;
+        [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
+            if (col + j < p.n_cols) {
+                thread_sum += FLOAT_TYPE(data_a[src_idx + col + j]);
+            }
+            v[j] = thread_sum;
         }
-        v = subgroupInclusiveAdd(v);
 
+        thread_sum = subgroupExclusiveAdd(thread_sum);
+        [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
+            v[j] += thread_sum;
+        }
         // Store the largest partial sum for each subgroup, then add the partials for all
         // lower subgroups and the final partial sum from the previous iteration.
         if (gl_SubgroupInvocationID == SUBGROUP_SIZE - 1) {
-            partial[subgroup_id] = v;
+            partial[subgroup_id] = v[ELEM_PER_THREAD - 1];
         }
         barrier();
-        for (int j = 0; j < subgroup_id; ++j) {
-            v += partial[j];
+        for (int s = 0; s < subgroup_id; ++s) {
+            [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
+                v[j] += partial[s];
+            }
+        }
+        [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
+            v[j] += last_sum;
         }
-        v += last_sum;
         barrier();
         if (tid == BLOCK_SIZE - 1) {
-            last_sum = v;
+            last_sum = v[ELEM_PER_THREAD - 1];
         }
-        if (col < p.n_cols) {
-            data_d[dst_idx + col] = D_TYPE(v);
+        [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
+            if (col + j < p.n_cols) {
+                data_d[dst_idx + col + j] = D_TYPE(v[j]);
+            }
         }
-        col += BLOCK_SIZE;
+        col += BLOCK_SIZE * ELEM_PER_THREAD;
     }
 }
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp
new file mode 100644 (file)
index 0000000..6d39f92
--- /dev/null
@@ -0,0 +1,60 @@
+#version 450
+
+#include "types.glsl"
+#include "sum_rows.glsl"
+
+#extension GL_EXT_control_flow_attributes : enable
+#extension GL_KHR_shader_subgroup_arithmetic : enable
+#extension GL_KHR_shader_subgroup_basic : enable
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+layout (binding = 2) writeonly buffer T {D_TYPE data_t[];};
+
+layout (constant_id = 0) const uint BLOCK_SIZE = 128;
+layout (constant_id = 1) const uint SUBGROUP_SIZE = 32;
+
+#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
+
+shared FLOAT_TYPE partial[BLOCK_SIZE / SUBGROUP_SIZE];
+
+void main() {
+    const uint row = gl_WorkGroupID.y;
+    const uint tid = gl_LocalInvocationID.x;
+    const uint col = gl_GlobalInvocationID.x;
+
+    const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L);
+    const uint i03_offset = i03 * p.ne01*p.ne02;
+    const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L);
+    const uint i01 = row - i03_offset - i02*p.ne01;
+
+    const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03;
+    const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13;
+
+    uint subgroup_id = tid / SUBGROUP_SIZE;
+
+    FLOAT_TYPE v = 0;
+    if (col < p.n_cols) {
+        v = FLOAT_TYPE(data_a[src_idx + col]);
+    }
+    v = subgroupInclusiveAdd(v);
+
+    // Store the largest partial sum for each subgroup, then add the partials for all
+    // lower subgroups and the final partial sum from the previous iteration.
+    if (gl_SubgroupInvocationID == SUBGROUP_SIZE - 1) {
+        partial[subgroup_id] = v;
+    }
+    barrier();
+    for (int j = 0; j < subgroup_id; ++j) {
+        v += partial[j];
+    }
+    barrier();
+    if (tid == BLOCK_SIZE - 1) {
+        data_t[gl_WorkGroupID.x + gl_NumWorkGroups.x * row] = v;
+    }
+    if (col < p.n_cols) {
+        data_d[dst_idx + col] = D_TYPE(v);
+    }
+}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp
new file mode 100644 (file)
index 0000000..e401893
--- /dev/null
@@ -0,0 +1,66 @@
+#version 450
+
+#include "types.glsl"
+#include "sum_rows.glsl"
+
+#extension GL_EXT_control_flow_attributes : enable
+#extension GL_KHR_shader_subgroup_arithmetic : enable
+#extension GL_KHR_shader_subgroup_basic : enable
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+layout (binding = 1) buffer D {D_TYPE data_d[];};
+layout (binding = 2) readonly buffer T {D_TYPE data_t[];};
+
+layout (constant_id = 0) const uint BLOCK_SIZE = 128;
+layout (constant_id = 1) const uint SUBGROUP_SIZE = 32;
+
+#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
+
+shared FLOAT_TYPE temp[BLOCK_SIZE / SUBGROUP_SIZE];
+
+void main() {
+    const uint row = gl_WorkGroupID.y;
+    const uint tid = gl_LocalInvocationID.x;
+
+    const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L);
+    const uint i03_offset = i03 * p.ne01*p.ne02;
+    const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L);
+    const uint i01 = row - i03_offset - i02*p.ne01;
+
+    const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03;
+    const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13;
+
+    const uint col = gl_GlobalInvocationID.x;
+
+    float v = 0;
+    // prefetch value we're adding to
+    if (col < p.n_cols) {
+        v = data_d[dst_idx + col];
+    }
+
+    // compute the sum of all previous blocks
+    uint c = tid;
+    float sum = 0;
+    while (c < gl_WorkGroupID.x) {
+        sum += data_t[c + gl_NumWorkGroups.x * row];
+        c += BLOCK_SIZE;
+    }
+
+    sum = subgroupAdd(sum);
+    if (gl_SubgroupInvocationID == 0) {
+        temp[gl_SubgroupID] = sum;
+    }
+    barrier();
+    sum = 0;
+    [[unroll]] for (uint s = 0; s < BLOCK_SIZE / SUBGROUP_SIZE; ++s) {
+        sum += temp[s];
+    }
+
+    // Add the sum to what the first pass computed
+    if (col < p.n_cols) {
+        data_d[dst_idx + col] = v + sum;
+    }
+}
+
index 1bcd8365beab7058791137fd61f9a9f2eaf077f1..5b61ff9ca26689ffd1d8be5d6e3edae77cc1caf4 100644 (file)
@@ -944,6 +944,8 @@ void process_shaders() {
     string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
     string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
     string_to_spv("cumsum_f32", "cumsum.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
+    string_to_spv("cumsum_multipass1_f32", "cumsum_multipass1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
+    string_to_spv("cumsum_multipass2_f32", "cumsum_multipass2.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
 
     string_to_spv("count_experts", "count_experts.comp", merge_maps(base_dict, {{"A_TYPE", "uint"}, {"D_TYPE", "uint"}}));