]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vulkan: Use spec constants for conv2d s/d/p and kernel W/H (#16978)
authorJeff Bolz <redacted>
Sat, 8 Nov 2025 19:24:29 +0000 (13:24 -0600)
committerGitHub <redacted>
Sat, 8 Nov 2025 19:24:29 +0000 (13:24 -0600)
* vulkan: Use spec constants for conv2d s/d/p and kernel W/H

Also add some additional unroll hints, which seems to help.

* lock around map lookup

ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp

index 9c2aeb57f0003f082dccd2fcbcc292ac3c7105d3..6da7bbd2f611df43e63cecbf97a459b715af757a 100644 (file)
@@ -351,6 +351,12 @@ enum vk_conv_shapes {
     CONV_SHAPE_COUNT,
 };
 
+uint32_t conv_shapes_wg_denoms[][3] = {
+    { 128, 128, 1 },
+    {  64,  32, 1 },
+    {  32, 256, 1 },
+};
+
 enum dmmv_wg_sizes {
     DMMV_WG_SIZE_SUBGROUP,
     DMMV_WG_SIZE_LARGE,
@@ -379,6 +385,18 @@ struct vk_fa_pipeline_state {
     }
 };
 
+struct vk_conv2d_pipeline_state {
+    vk_conv2d_pipeline_state(uint32_t s0, uint32_t s1, uint32_t p0, uint32_t p1, uint32_t d0, uint32_t d1, uint32_t KW, uint32_t KH)
+        : s0(s0), s1(s1), p0(p0), p1(p1), d0(d0), d1(d1), KW(KW), KH(KH) {}
+
+    uint32_t s0, s1, p0, p1, d0, d1, KW, KH;
+
+    bool operator<(const vk_conv2d_pipeline_state &b) const {
+        return std::tie(s0, s1, p0, p1, d0, d1, KW, KH) <
+               std::tie(b.s0, b.s1, b.p0, b.p1, b.d0, b.d1, b.KW, b.KH);
+    }
+};
+
 enum shader_reduction_mode {
     SHADER_REDUCTION_MODE_SHMEM,
     SHADER_REDUCTION_MODE_HYBRID,
@@ -675,10 +693,10 @@ struct vk_device_struct {
     vk_pipeline pipeline_ssm_conv_f32;
     vk_pipeline pipeline_opt_step_adamw_f32;
     vk_pipeline pipeline_opt_step_sgd_f32;
-    vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
-    vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
-    vk_pipeline pipeline_conv_transpose_2d_f32[CONV_SHAPE_COUNT];
-    vk_pipeline pipeline_conv_transpose_2d_f16_f32[CONV_SHAPE_COUNT];
+    std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv2d_f32[CONV_SHAPE_COUNT];
+    std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
+    std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv_transpose_2d_f32[CONV_SHAPE_COUNT];
+    std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv_transpose_2d_f16_f32[CONV_SHAPE_COUNT];
     vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32;
     vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32;
 
@@ -1258,17 +1276,13 @@ struct vk_op_conv2d_push_constants {
     uint32_t nb2;
     uint32_t nb3;
 
-    // init_fastdiv_values constants for dividing by KW, KW*KH, OW, OW*OH
-    uint32_t KWmp;   uint32_t KWL;
-    uint32_t KWKHmp; uint32_t KWKHL;
+    // init_fastdiv_values constants for dividing by OW, OW*OH
     uint32_t OWmp;   uint32_t OWL;
     uint32_t OWOHmp; uint32_t OWOHL;
 };
 
 template <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) {
-    // Compute magic values to divide by KW, KW*KH, OW, OW*OH
-    init_fastdiv_values(p.KW,       p.KWmp,    p.KWL);
-    init_fastdiv_values(p.KW*p.KH,  p.KWKHmp,  p.KWKHL);
+    // Compute magic values to divide by OW, OW*OH
     init_fastdiv_values(p.OW,       p.OWmp,    p.OWL);
     init_fastdiv_values(p.OW*p.OH,  p.OWOHmp,  p.OWOHL);
 }
@@ -1304,23 +1318,15 @@ struct vk_op_conv_transpose_2d_push_constants {
     uint32_t nb2;
     uint32_t nb3;
 
-    // init_fastdiv_values constants for dividing by KW, KW*KH, OW, OW*OH, s0, s1
-    uint32_t KWmp;   uint32_t KWL;
-    uint32_t KWKHmp; uint32_t KWKHL;
+    // init_fastdiv_values constants for dividing by OW, OW*OH
     uint32_t OWmp;   uint32_t OWL;
     uint32_t OWOHmp; uint32_t OWOHL;
-    uint32_t s0mp; uint32_t s0L;
-    uint32_t s1mp; uint32_t s1L;
 };
 
 template <> void init_pushconst_fastdiv(vk_op_conv_transpose_2d_push_constants &p) {
-    // Compute magic values to divide by KW, KW*KH, OW, OW*OH, s0, s1
-    init_fastdiv_values(p.KW,       p.KWmp,    p.KWL);
-    init_fastdiv_values(p.KW*p.KH,  p.KWKHmp,  p.KWKHL);
+    // Compute magic values to divide by OW, OW*OH
     init_fastdiv_values(p.OW,       p.OWmp,    p.OWL);
     init_fastdiv_values(p.OW*p.OH,  p.OWOHmp,  p.OWOHL);
-    init_fastdiv_values(p.s0,       p.s0mp,    p.s0L);
-    init_fastdiv_values(p.s1,       p.s1mp,    p.s1L);
 }
 
 struct vk_op_conv2d_dw_push_constants {
@@ -3858,22 +3864,22 @@ static void ggml_vk_load_shaders(vk_device& device) {
         switch (s) {
         default:
         case CONV_SHAPE_128x128:
-            conv2d_BS_K = 128;
-            conv2d_BS_NPQ = 128;
+            conv2d_BS_K = conv_shapes_wg_denoms[CONV_SHAPE_128x128][0];
+            conv2d_BS_NPQ = conv_shapes_wg_denoms[CONV_SHAPE_128x128][1];
             conv2d_BS_CRS = 16;
             if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != vk_device_architecture::AMD_GCN) {
                 conv2d_UNROLL = false;
             }
             break;
         case CONV_SHAPE_64x32:
-            conv2d_BS_K = 64;
-            conv2d_BS_NPQ = 32;
+            conv2d_BS_K = conv_shapes_wg_denoms[CONV_SHAPE_64x32][0];
+            conv2d_BS_NPQ = conv_shapes_wg_denoms[CONV_SHAPE_64x32][1];
             conv2d_BS_CRS = 32;
             conv2d_TS_K   = 4;
             break;
         case CONV_SHAPE_32x256:
-            conv2d_BS_K = 32;
-            conv2d_BS_NPQ = 256;
+            conv2d_BS_K = conv_shapes_wg_denoms[CONV_SHAPE_32x256][0];
+            conv2d_BS_NPQ = conv_shapes_wg_denoms[CONV_SHAPE_32x256][1];
             conv2d_BS_CRS = 16;
             break;
         }
@@ -3907,10 +3913,22 @@ static void ggml_vk_load_shaders(vk_device& device) {
         std::vector<uint32_t> spec_constants = { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD };
 
 #define CREATE_CONV(name, type_suffix, spv_suffix) \
-        ggml_vk_create_pipeline( \
-            device, device->pipeline_##name##type_suffix[s], #name #type_suffix, \
-            name##type_suffix##spv_suffix##_len, name##type_suffix##spv_suffix##_data, "main", 3, \
-            sizeof(vk_op_##name##_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
+        for (auto &c : device->pipeline_##name##type_suffix[s]) { \
+            const vk_conv2d_pipeline_state &state = c.first;  \
+            std::vector<uint32_t> spec_constants_cpy = spec_constants; \
+            spec_constants_cpy.push_back(state.s0); \
+            spec_constants_cpy.push_back(state.s1); \
+            spec_constants_cpy.push_back(state.p0); \
+            spec_constants_cpy.push_back(state.p1); \
+            spec_constants_cpy.push_back(state.d0); \
+            spec_constants_cpy.push_back(state.d1); \
+            spec_constants_cpy.push_back(state.KW); \
+            spec_constants_cpy.push_back(state.KH); \
+            ggml_vk_create_pipeline( \
+                device, c.second, #name #type_suffix, \
+                name##type_suffix##spv_suffix##_len, name##type_suffix##spv_suffix##_data, "main", 3, \
+                sizeof(vk_op_##name##_push_constants), wg_denoms, spec_constants_cpy, 1, true, use_collectives);    \
+        }
 #define CREATE_CONVS(spv_suffix) \
         CREATE_CONV(conv2d, _f32, spv_suffix) \
         CREATE_CONV(conv2d, _f16_f32, spv_suffix) \
@@ -8536,7 +8554,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
 
             uint32_t tiles[CONV_SHAPE_COUNT];
             for (uint32_t i = 0; i < CONV_SHAPE_COUNT; ++i) {
-                tiles[i] = CEIL_DIV(elements[0], ctx->device->pipeline_conv2d_f32[i]->wg_denoms[0]) * CEIL_DIV(elements[1], ctx->device->pipeline_conv2d_f32[i]->wg_denoms[1]);
+                tiles[i] = CEIL_DIV(elements[0], conv_shapes_wg_denoms[i][0]) * CEIL_DIV(elements[1], conv_shapes_wg_denoms[i][1]);
             }
 
             // We can't query number of shader cores on Intel, use 32 as a placeholder
@@ -8551,19 +8569,45 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
                 shape = CONV_SHAPE_64x32;
             }
 
+            uint32_t KW = static_cast<uint32_t>(src0->ne[0]);
+            uint32_t KH = static_cast<uint32_t>(src0->ne[1]);
+            uint32_t s0 = static_cast<uint32_t>(dst->op_params[0]);
+            uint32_t s1 = op == GGML_OP_CONV_2D ? static_cast<uint32_t>(dst->op_params[1]) : static_cast<uint32_t>(dst->op_params[0]);
+            uint32_t p0 = op == GGML_OP_CONV_2D ? static_cast<uint32_t>(dst->op_params[2]) : 0;
+            uint32_t p1 = op == GGML_OP_CONV_2D ? static_cast<uint32_t>(dst->op_params[3]) : 0;
+            uint32_t d0 = op == GGML_OP_CONV_2D ? static_cast<uint32_t>(dst->op_params[4]) : 1;
+            uint32_t d1 = op == GGML_OP_CONV_2D ? static_cast<uint32_t>(dst->op_params[5]) : 1;
+
+            vk_conv2d_pipeline_state conv2d_pipeline_state(s0, s1, p0, p1, d0, d1, KW, KH);
+
+            std::map<vk_conv2d_pipeline_state, vk_pipeline> *pipelines = nullptr;
             if (op == GGML_OP_CONV_2D) {
                 if (src0->type == GGML_TYPE_F32) {
-                    return ctx->device->pipeline_conv2d_f32[shape];
+                    pipelines = &ctx->device->pipeline_conv2d_f32[shape];
                 } else if (src0->type == GGML_TYPE_F16) {
-                    return ctx->device->pipeline_conv2d_f16_f32[shape];
+                    pipelines = &ctx->device->pipeline_conv2d_f16_f32[shape];
                 }
             } else if (op == GGML_OP_CONV_TRANSPOSE_2D) {
                 if (src0->type == GGML_TYPE_F32) {
-                    return ctx->device->pipeline_conv_transpose_2d_f32[shape];
+                    pipelines = &ctx->device->pipeline_conv_transpose_2d_f32[shape];
                 } else if (src0->type == GGML_TYPE_F16) {
-                    return ctx->device->pipeline_conv_transpose_2d_f16_f32[shape];
+                    pipelines = &ctx->device->pipeline_conv_transpose_2d_f16_f32[shape];
+                }
+            }
+
+            vk_pipeline pipeline = nullptr;
+
+            {
+                std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
+                auto it = pipelines->find(conv2d_pipeline_state);
+                if (it != pipelines->end()) {
+                    pipeline = it->second;
+                } else {
+                    (*pipelines)[conv2d_pipeline_state] = pipeline = std::make_shared<vk_pipeline_struct>();
                 }
             }
+
+            return pipeline;
         }
         return nullptr;
     case GGML_OP_CONV_2D_DW:
index 0367e80bbfa737783b903a3479f8ddb58eeeb21d..e9bdbf7db5e9abe0d866c6ab7ca77beed90d56b0 100644 (file)
@@ -62,14 +62,8 @@ layout(push_constant) uniform parameter {
     uint32_t nb3;
 
     // fastdiv helper values
-    uint32_t KWmp;   uint32_t KWL;
-    uint32_t KWKHmp; uint32_t KWKHL;
     uint32_t OWmp;   uint32_t OWL;
     uint32_t OWOHmp; uint32_t OWOHL;
-#ifdef TRANSPOSE
-    uint32_t s0mp; uint32_t s0L;
-    uint32_t s1mp; uint32_t s1L;
-#endif
 }
 
 p;
@@ -84,6 +78,15 @@ layout(constant_id = 4) const uint TS_K            = 8;
 layout(constant_id = 5) const uint use_collectives = 1;
 layout(constant_id = 6) const uint SHMEM_PAD       = 4;
 
+layout(constant_id = 7)  const uint s0             = 1;
+layout(constant_id = 8)  const uint s1             = 1;
+layout(constant_id = 9)  const uint p0             = 0;
+layout(constant_id = 10) const uint p1             = 0;
+layout(constant_id = 11) const uint d0             = 1;
+layout(constant_id = 12) const uint d1             = 1;
+layout(constant_id = 13) const uint KW             = 1;
+layout(constant_id = 14) const uint KH             = 1;
+
 uint32_t       tid     = gl_LocalInvocationID.x;
 const uint32_t WG_SIZE = gl_WorkGroupSize.x;
 
@@ -92,7 +95,7 @@ uint splitWork(uint work_size, uint block_size) {
 }
 
 uint32_t K   = p.Cout;
-uint32_t CRS = p.Cin * p.KH * p.KW;
+uint32_t CRS = p.Cin * KH * KW;
 uint32_t NPQ = p.N * p.OH * p.OW;
 
 uint32_t n_elems_out = K * NPQ;
@@ -187,7 +190,7 @@ void main() {
     }
 #endif
     /* Advance block in CRS dim */
-    for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) {
+    [[dont_unroll]] for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) {
         uint32_t CRS_idx_a;
         uint32_t Cin_idx_a;
         uint32_t KH_idx_a;
@@ -200,10 +203,10 @@ void main() {
         uint32_t cached_KW_idx;
         if (use_collectives == 1) {
             cached_CRS_idx                = B_idx_CRS * BS_CRS + gl_SubgroupInvocationID;
-            cached_Cin_idx                = fastdiv(cached_CRS_idx, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
-            uint32_t cached_CRS_remainder = (cached_CRS_idx - cached_Cin_idx * p.KW * p.KH);
-            cached_KH_idx                 = fastdiv(cached_CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
-            cached_KW_idx                 = cached_CRS_remainder - cached_KH_idx * p.KW;
+            cached_Cin_idx                = cached_CRS_idx / (KW * KH);
+            uint32_t cached_CRS_remainder = cached_CRS_idx % (KW * KH);
+            cached_KH_idx                 = cached_CRS_remainder / KW;
+            cached_KW_idx                 = cached_CRS_remainder KW;
 
             CRS_idx_a = subgroupShuffle(cached_CRS_idx, Ac);
             Cin_idx_a = subgroupShuffle(cached_Cin_idx, Ac);
@@ -211,21 +214,21 @@ void main() {
             KW_idx_a  = subgroupShuffle(cached_KW_idx, Ac);
         } else {
             CRS_idx_a              = B_idx_CRS * BS_CRS + Ac;  // Global CRS_idx_a (column index of A)
-            Cin_idx_a              = fastdiv(CRS_idx_a, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
-            uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
-            KH_idx_a               = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
-            KW_idx_a               = CRS_remainder - KH_idx_a * p.KW;
+            Cin_idx_a              = CRS_idx_a / (KW * KH);
+            uint32_t CRS_remainder = CRS_idx_a % (KW * KH);
+            KH_idx_a               = CRS_remainder / KW;
+            KW_idx_a               = CRS_remainder KW;
         }
 #else
         CRS_idx_a     = B_idx_CRS * BS_CRS + Ac;  // Global CRS_idx_a (column index of A)
-        Cin_idx_a     = fastdiv(CRS_idx_a, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); / (p.KW * p.KH);
-        CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
-        KH_idx_a      = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
-        KW_idx_a      = CRS_remainder - KH_idx_a * p.KW;
+        Cin_idx_a     = CRS_idx_a / (KW * KH);
+        CRS_remainder = CRS_idx_a % (KW * KH);
+        KH_idx_a      = CRS_remainder / KW;
+        KW_idx_a      = CRS_remainder KW;
 #endif
 
         /* Load kernel to A_block: (BS_K x BS_CRS)*/
-        for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) {
+        UNROLL for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) {
             uint32_t B_ly    = r_offset + Ar;
             uint32_t B_lx    = Ac;
             uint32_t K_idx   = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/
@@ -262,27 +265,27 @@ void main() {
                 KW_idx_b  = subgroupShuffle(cached_KW_idx, r_offset + Br);
             } else {
                 CRS_idx_b              = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
-                Cin_idx_b              = fastdiv(CRS_idx_b, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
-                uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
-                KH_idx_b               = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
-                KW_idx_b               = CRS_remainder - KH_idx_b * p.KW;
+                Cin_idx_b              = CRS_idx_b / (KW * KH);
+                uint32_t CRS_remainder = CRS_idx_b % (KW * KH);
+                KH_idx_b               = CRS_remainder / KW;
+                KW_idx_b               = CRS_remainder KW;
             }
 #else
             CRS_idx_b              = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
-            Cin_idx_b              = fastdiv(CRS_idx_b, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
-            uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
-            KH_idx_b               = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
-            KW_idx_b               = CRS_remainder - KH_idx_b * p.KW;
+            Cin_idx_b              = CRS_idx_b / (KW * KH);
+            uint32_t CRS_remainder = CRS_idx_b % (KW * KH);
+            KH_idx_b               = CRS_remainder / KW;
+            KW_idx_b               = CRS_remainder KW;
 #endif
 
 #ifdef TRANSPOSE
-            uint32_t H_idx_x_s1 = OH_idx - KH_idx_b * p.d1 + p.p1;
-            uint32_t W_idx_x_s0 = OW_idx - KW_idx_b * p.d0 + p.p0;
-            uint32_t H_idx = fastdiv(H_idx_x_s1, p.s1mp, p.s1L);
-            uint32_t W_idx = fastdiv(W_idx_x_s0, p.s0mp, p.s0L);
+            uint32_t H_idx_x_s1 = OH_idx - KH_idx_b * d1 + p1;
+            uint32_t W_idx_x_s0 = OW_idx - KW_idx_b * d0 + p0;
+            uint32_t H_idx = H_idx_x_s1 / s1;
+            uint32_t W_idx = W_idx_x_s0 / s0;
 #else
-            uint32_t H_idx = OH_idx * p.s1 + KH_idx_b * p.d1 - p.p1;
-            uint32_t W_idx = OW_idx * p.s0 + KW_idx_b * p.d0 - p.p0;
+            uint32_t H_idx = OH_idx * s1 + KH_idx_b * d1 - p1;
+            uint32_t W_idx = OW_idx * s0 + KW_idx_b * d0 - p0;
 #endif
             uint32_t src_idx =
                 min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1);
@@ -290,7 +293,7 @@ void main() {
             if (CRS_idx_b >= CRS || NPQ_idx >= NPQ
                 || H_idx >= p.H || W_idx >= p.W // Lower bound checks aren't necessary. (idx >= 0x80000000 for such case)
 #ifdef TRANSPOSE
-                || (H_idx_x_s1 - H_idx * p.s1 != 0) || (W_idx_x_s0 - W_idx * p.s0 != 0)
+                || (H_idx_x_s1 - H_idx * s1 != 0) || (W_idx_x_s0 - W_idx * s0 != 0)
 #endif
                 ) {
                 val = 0.0;