]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
vulkan: support larger argsort (llama/17313)
authorJeff Bolz <redacted>
Wed, 19 Nov 2025 16:25:50 +0000 (10:25 -0600)
committerGeorgi Gerganov <redacted>
Fri, 12 Dec 2025 15:53:04 +0000 (17:53 +0200)
* vulkan: support larger argsort

This is an extension of the original bitonic sorting shader that puts the
temporary values in global memory and when more than 1024 threads are needed
it runs multiple workgroups and synchronizes through a pipelinebarrier.

To improve the memory access pattern, a copy of the float value is kept with
the index value. I've applied this same change to the original shared memory
version of the shader, which is still used when ncols <= 1024.

* Reduce the number of shader variants. Use smaller workgroups when doing a single pass, for a modest perf boost

* reduce loop overhead

* run multiple cols per invocation, to reduce barrier overhead

ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp
ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp [new file with mode: 0644]
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

index 2ef9ab14aa8f84de2345ab5ad1397b83d1789d83..691af7bc263db3b7dbd9b1b9e5c67d054fcbf18f 100644 (file)
@@ -406,8 +406,8 @@ enum shader_reduction_mode {
     SHADER_REDUCTION_MODE_COUNT,
 };
 
+// argsort pipelines for up to 1<<10 invocations per workgroup
 static constexpr uint32_t num_argsort_pipelines = 11;
-static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
 static constexpr uint32_t num_topk_moe_pipelines = 10;
 
 static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE,  GGML_OP_ARGSORT,
@@ -526,6 +526,7 @@ struct vk_device_struct {
     bool multi_add;
     bool shader_int64;
     bool buffer_device_address;
+    bool vulkan_memory_model;
 
     bool add_rms_fusion;
     uint32_t partials_binding_alignment;
@@ -539,6 +540,9 @@ struct vk_device_struct {
     uint32_t subgroup_max_size;
     bool subgroup_require_full_support;
 
+    // floor(log2(maxComputeWorkGroupInvocations))
+    uint32_t max_workgroup_size_log2 {};
+
     bool coopmat_support;
     bool coopmat_acc_f32_support {};
     bool coopmat_acc_f16_support {};
@@ -684,6 +688,7 @@ struct vk_device_struct {
     vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
     vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
     vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
+    vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
     vk_pipeline pipeline_sum_rows_f32;
     vk_pipeline pipeline_argmax_f32;
     vk_pipeline pipeline_count_equal_i32;
@@ -1174,8 +1179,14 @@ struct vk_op_soft_max_push_constants {
 
 struct vk_op_argsort_push_constants {
     uint32_t ncols;
+    uint32_t ncols_padded;
+    uint32_t ncols_padded_log2;
     uint32_t nrows;
-    int32_t order;
+    uint32_t order;
+    uint32_t outer_start;
+    uint32_t outer_end;
+    uint32_t inner_start;
+    uint32_t inner_end;
 };
 
 struct vk_op_im2col_push_constants {
@@ -3895,7 +3906,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
     }
 
     for (uint32_t i = 0; i < num_argsort_pipelines; ++i) {
-        ggml_vk_create_pipeline2(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1u<<i, 1, 1}, {1u<<i, i}, 1, true);
+        uint32_t BLOCK_SIZE = 1u << std::min(i, device->max_workgroup_size_log2);
+        if (i <= device->max_workgroup_size_log2 &&
+            2 * sizeof(int) * BLOCK_SIZE <= device->properties.limits.maxComputeSharedMemorySize) {
+            const uint32_t NCOLS_PADDED_LOG2 = i;
+            ggml_vk_create_pipeline2(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED_LOG2}, 1, true);
+        }
+        const uint32_t WG_UNROLL_FACTOR = BLOCK_SIZE > 1 ? 2 : 1;
+        BLOCK_SIZE /= WG_UNROLL_FACTOR;
+        ggml_vk_create_pipeline2(device, device->pipeline_argsort_large_f32[i], "argsort_large_f32_"+std::to_string(i), argsort_large_f32_len, argsort_large_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE * WG_UNROLL_FACTOR, 1, 1}, {BLOCK_SIZE, WG_UNROLL_FACTOR}, 1, true);
     }
 
     ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
@@ -4296,6 +4315,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
 
         device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated;
 
+        device->max_workgroup_size_log2 = uint32_t(log2f(float(device->properties.limits.maxComputeWorkGroupInvocations)));
+
         std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
 
         // Try to find a non-graphics compute queue and transfer-focused queues
@@ -4435,6 +4456,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
 
         device->shader_int64 = device_features2.features.shaderInt64;
         device->buffer_device_address = vk12_features.bufferDeviceAddress;
+        device->vulkan_memory_model = vk12_features.vulkanMemoryModel;
 
         if (device->subgroup_size_control) {
             device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
@@ -8359,19 +8381,6 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
             }
             return nullptr;
         }
-    case GGML_OP_ARGSORT:
-        if (ctx->num_additional_fused_ops) {
-            uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
-            GGML_ASSERT(idx < num_topk_moe_pipelines);
-            topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
-            return ctx->device->pipeline_topk_moe[idx][mode];
-        }
-
-        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
-            uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
-            return ctx->device->pipeline_argsort_f32[idx];
-        }
-        return nullptr;
     case GGML_OP_SUM:
     case GGML_OP_SUM_ROWS:
     case GGML_OP_MEAN:
@@ -8763,8 +8772,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
         elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
         break;
     case GGML_OP_ARGSORT:
-        elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 };
-        elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
+        GGML_ASSERT(0);
         break;
     case GGML_OP_IM2COL:
         {
@@ -9891,16 +9899,89 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
 }
 
 static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
-    int32_t * op_params = (int32_t *)dst->op_params;
+    const uint32_t * op_params = (const uint32_t *)dst->op_params;
 
     uint32_t ncols = src0->ne[0];
     uint32_t nrows = ggml_nrows(src0);
 
-    ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGSORT, {
-        ncols,
-        nrows,
-        op_params[0],
-    });
+    uint32_t ncols_pad_log2 = (uint32_t)ceilf(log2f(float(ncols)));
+    uint32_t ncolsp2 = 1 << ncols_pad_log2;
+
+    vk_op_argsort_push_constants pc { ncols, ncolsp2, ncols_pad_log2, nrows, op_params[0], 0, 0, 0, 0, };
+
+    // Pick the largest workgroup size <= ncolsp2
+    uint32_t pipeline_idx = std::min(ncols_pad_log2, num_argsort_pipelines - 1);
+
+    // Use the "small" argsort shader if the whole sort can be done by a single workgroup.
+    bool use_small = ncols_pad_log2 <= ctx->device->max_workgroup_size_log2 &&
+                     ctx->device->pipeline_argsort_f32[pipeline_idx] != nullptr;
+
+    vk_pipeline pipeline = use_small ? ctx->device->pipeline_argsort_f32[pipeline_idx]
+                                     : ctx->device->pipeline_argsort_large_f32[pipeline_idx];
+
+    vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0);
+    vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
+    vk_subbuffer subbuf1 = dst_buf;
+
+    // Reserve space for ivec2 per element, with rows padded to a power of two
+    if (!use_small) {
+        const size_t x_sz = size_t{ncolsp2} * nrows * 2 * sizeof(int);
+
+        if (ctx->prealloc_size_x < x_sz) {
+            ctx->prealloc_size_x = x_sz;
+            ggml_vk_preallocate_buffers(ctx, subctx);
+        }
+        if (ctx->prealloc_x_need_sync) {
+            ggml_vk_sync_buffers(ctx, subctx);
+        }
+        subbuf1 = { ctx->prealloc_x, 0, ctx->prealloc_x->size };
+    }
+
+    std::array<uint32_t, 3> elements;
+
+    elements[0] = ncolsp2;
+    elements[1] = std::min((uint32_t)ggml_nrows(src0), ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
+    elements[2] = 1;
+
+    // First dispatch initializes tmp_idx and does the first N passes where
+    // there is only communication between threads in the same workgroup.
+    {
+        vk_op_argsort_push_constants pc2 = pc;
+        pc2.outer_start = 0;
+        pc2.outer_end = std::min(ncols_pad_log2, ctx->device->max_workgroup_size_log2);
+        pc2.inner_start = 0;
+        pc2.inner_end = 100;
+        ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
+        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc2, elements);
+    }
+    if (!use_small) {
+        ggml_vk_sync_buffers(ctx, subctx);
+        // Loop over outer/inner passes, synchronizing between each pass.
+        for (uint32_t outer = ctx->device->max_workgroup_size_log2; outer < ncols_pad_log2; ++outer) {
+            for (uint32_t inner = 0; inner < outer + 1; ++inner) {
+                vk_op_argsort_push_constants pc2 = pc;
+                pc2.outer_start = outer;
+                pc2.outer_end = outer + 1;
+                pc2.inner_start = inner;
+                pc2.inner_end = inner + 1;
+                // When the inner idx is large enough, there's only communication
+                // within a workgroup. So the remaining inner iterations can all
+                // run in the same dispatch.
+                if (outer - inner < pipeline_idx) {
+                    pc2.inner_end = 100;
+                    inner = outer;
+                    pipeline = ctx->device->pipeline_argsort_large_f32[pipeline_idx];
+                } else {
+                    // Smaller workgroup empirically seems to perform better
+                    pipeline = ctx->device->pipeline_argsort_large_f32[pipeline_idx - 2];
+                }
+                ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
+                ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc2, elements);
+                ggml_vk_sync_buffers(ctx, subctx);
+            }
+        }
+        ctx->prealloc_x_need_sync = true;
+    }
 }
 
 static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
@@ -13721,7 +13802,19 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
         case GGML_OP_LOG:
             return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
         case GGML_OP_ARGSORT:
-            return op->ne[0] <= max_argsort_cols;
+            {
+                if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
+                    return false;
+                }
+                ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
+                auto device = ggml_vk_get_device(ctx->device);
+                // pipeline_argsort_large_f32 requires vulkan memory model.
+                if (device->vulkan_memory_model) {
+                    return true;
+                } else {
+                    return op->ne[0] <= (1 << device->max_workgroup_size_log2);
+                }
+            }
         case GGML_OP_UPSCALE:
         case GGML_OP_ACC:
         case GGML_OP_CONCAT:
index c4e68bc02370ac862a69aed68b277a7c60ab3126..0fc2b9b725350623763717450ae22ce31ff1840f 100644 (file)
@@ -4,28 +4,27 @@
 #include "types.glsl"
 
 layout(constant_id = 0) const int BLOCK_SIZE = 1024;
-layout(constant_id = 1) const int BLOCK_SIZE_LOG2 = 10;
+layout(constant_id = 1) const int NCOLS_PADDED_LOG2 = 10;
 #define ASC 0
 
 layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 
 layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1)          buffer D {int data_d[];};
+layout (binding = 2) writeonly buffer D {int data_d[];};
 
 layout (push_constant) uniform parameter {
     uint ncols;
+    uint ncols_padded;
+    uint ncols_padded_log2;
     uint nrows;
     uint order;
+    uint outer_start;
+    uint outer_end;
+    uint inner_start;
+    uint inner_end;
 } p;
 
-shared int dst_row[BLOCK_SIZE];
-shared A_TYPE a_sh[BLOCK_SIZE];
-
-void swap(uint idx0, uint idx1) {
-    int tmp = dst_row[idx0];
-    dst_row[idx0] = dst_row[idx1];
-    dst_row[idx1] = tmp;
-}
+shared ivec2 dst_row[BLOCK_SIZE];
 
 void argsort(bool needs_bounds_check, const uint row) {
     // bitonic sort
@@ -34,11 +33,10 @@ void argsort(bool needs_bounds_check, const uint row) {
     const uint row_offset = row * p.ncols;
 
     // initialize indices
-    dst_row[col] = col;
-    a_sh[col] = data_a[row_offset + col];
+    dst_row[col] = ivec2(col, floatBitsToInt(data_a[row_offset + col]));
     barrier();
 
-    uint num_outer_loop_iters = BLOCK_SIZE_LOG2;
+    uint num_outer_loop_iters = NCOLS_PADDED_LOG2;
     [[unroll]] for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) {
         uint num_inner_loop_iters = outer_idx + 1;
         [[unroll]] for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) {
@@ -47,14 +45,15 @@ void argsort(bool needs_bounds_check, const uint row) {
             int idx_0 = (col & k) == 0 ? col : ixj;
             int idx_1 = (col & k) == 0 ? ixj : col;
 
-            int sh_idx_0 = dst_row[idx_0];
-            int sh_idx_1 = dst_row[idx_1];
-            bool idx_0_oob = needs_bounds_check ? sh_idx_0 >= p.ncols : false;
-            bool idx_1_oob = needs_bounds_check ? sh_idx_1 >= p.ncols : false;
+            ivec2 sh_idx_0 = dst_row[idx_0];
+            ivec2 sh_idx_1 = dst_row[idx_1];
+            bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.ncols : false;
+            bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.ncols : false;
 
             if ((idx_0_oob ||
-                (!idx_1_oob && a_sh[sh_idx_0] > a_sh[sh_idx_1])) && (ixj > col)) {
-                swap(idx_0, idx_1);
+                (!idx_1_oob && intBitsToFloat(sh_idx_0.y) > intBitsToFloat(sh_idx_1.y))) && (ixj > col)) {
+                dst_row[idx_0] = sh_idx_1;
+                dst_row[idx_1] = sh_idx_0;
             }
 
             barrier();
@@ -63,9 +62,9 @@ void argsort(bool needs_bounds_check, const uint row) {
 
     if (col < p.ncols) {
         if (p.order == ASC) {
-            data_d[row_offset + col] = dst_row[col];
+            data_d[row_offset + col] = dst_row[col].x;
         } else {
-            data_d[row_offset + p.ncols - col - 1] = dst_row[col];
+            data_d[row_offset + p.ncols - col - 1] = dst_row[col].x;
         }
     }
 }
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp b/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp
new file mode 100644 (file)
index 0000000..920bac6
--- /dev/null
@@ -0,0 +1,114 @@
+#version 450
+#extension GL_EXT_control_flow_attributes : enable
+#extension GL_KHR_memory_scope_semantics : enable
+#pragma use_vulkan_memory_model
+
+#include "types.glsl"
+
+layout(constant_id = 0) const int BLOCK_SIZE = 1024;
+layout(constant_id = 1) const int WG_UNROLL_FACTOR = 2;
+#define ASC 0
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+layout (binding = 1) workgroupcoherent buffer B {ivec2 tmp_idx[];};
+layout (binding = 2) workgroupcoherent buffer D {int data_d[];};
+
+layout (push_constant) uniform parameter {
+    uint ncols;
+    uint ncols_padded;
+    uint ncols_padded_log2;
+    uint nrows;
+    uint order;
+    uint outer_start;
+    uint outer_end;
+    uint inner_start;
+    uint inner_end;
+} p;
+
+void argsort(bool needs_bounds_check, const uint row) {
+    // bitonic sort
+    int col = int(gl_GlobalInvocationID.x);
+    col = (col % BLOCK_SIZE) + (col / BLOCK_SIZE) * BLOCK_SIZE * WG_UNROLL_FACTOR;
+
+    const uint row_offset = row * p.ncols;
+    uint idx_offset = row * p.ncols_padded;
+
+    bool need_barrier = false;
+
+    // initialize indices
+    if (p.outer_start == 0 && p.inner_start == 0) {
+        [[unroll]] for (int u = 0; u < WG_UNROLL_FACTOR; ++u) {
+            uint c = u*BLOCK_SIZE + col;
+            if (c < p.ncols_padded) {
+                ivec2 v = ivec2(c, floatBitsToInt(data_a[row_offset + c]));
+                tmp_idx[idx_offset + c] = v;
+            }
+        }
+        need_barrier = true;
+    }
+
+    [[unroll]] for (uint outer_idx = p.outer_start, k = (2 << outer_idx); outer_idx < p.outer_end; k *= 2, outer_idx++) {
+        uint inner_end = min(p.inner_end, outer_idx + 1);
+        for (uint j = k >> (p.inner_start + 1), inner_idx = p.inner_start; inner_idx < inner_end; j /= 2, inner_idx++) {
+            if (need_barrier) {
+                controlBarrier(gl_ScopeWorkgroup, gl_ScopeWorkgroup, gl_StorageSemanticsBuffer, gl_SemanticsAcquireRelease);
+            }
+            need_barrier = true;
+            [[unroll]] for (int u = 0; u < WG_UNROLL_FACTOR; ++u) {
+                int c = u*BLOCK_SIZE + col;
+                const int ixj = int(c ^ j);
+
+                if (ixj < c) {
+                    continue;
+                }
+
+                int idx_0 = (c & k) == 0 ? c : ixj;
+                int idx_1 = (c & k) == 0 ? ixj : c;
+
+                ivec2 sh_idx_0 = tmp_idx[idx_offset + idx_0];
+                ivec2 sh_idx_1 = tmp_idx[idx_offset + idx_1];
+                bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.ncols : false;
+                bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.ncols : false;
+
+                if ((idx_0_oob ||
+                    (!idx_1_oob && intBitsToFloat(sh_idx_0.y) > intBitsToFloat(sh_idx_1.y)))) {
+                    tmp_idx[idx_offset + idx_0] = sh_idx_1;
+                    tmp_idx[idx_offset + idx_1] = sh_idx_0;
+                }
+            }
+        }
+    }
+
+    if (p.outer_end == p.ncols_padded_log2 &&
+        p.inner_end >= p.ncols_padded_log2 + 1) {
+        controlBarrier(gl_ScopeWorkgroup, gl_ScopeWorkgroup, gl_StorageSemanticsBuffer, gl_SemanticsAcquireRelease);
+        [[unroll]] for (int u = 0; u < WG_UNROLL_FACTOR; ++u) {
+            uint c = u*BLOCK_SIZE + col;
+            if (c < p.ncols) {
+                if (p.order == ASC) {
+                    data_d[row_offset + c] = tmp_idx[idx_offset + c].x;
+                } else {
+                    data_d[row_offset + p.ncols - c - 1] = tmp_idx[idx_offset + c].x;
+                }
+            }
+        }
+    }
+}
+
+void main() {
+    if (p.ncols == p.ncols_padded) {
+        uint row = gl_WorkGroupID.y;
+        while (row < p.nrows) {
+            argsort(false, row);
+            row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
+        }
+    } else {
+        uint row = gl_WorkGroupID.y;
+        while (row < p.nrows) {
+            argsort(true, row);
+            row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
+        }
+    }
+}
index 71d91889c7a978bd68ce67b4d5f6da2172d14ff9..a6b8b9505261b34eb79265937ec951a60b9e245c 100644 (file)
@@ -892,6 +892,7 @@ void process_shaders() {
     string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
 
     string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
+    string_to_spv("argsort_large_f32", "argsort_large.comp", {{"A_TYPE", "float"}});
 
     string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
     string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));