]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
vulkan: handle quantize_q8_1 overflowing the max workgroup count (llama/18515)
authorJeff Bolz <redacted>
Mon, 5 Jan 2026 10:30:14 +0000 (04:30 -0600)
committerGeorgi Gerganov <redacted>
Wed, 14 Jan 2026 07:11:59 +0000 (09:11 +0200)
* vulkan: handle quantize_q8_1 overflowing the max workgroup count

* vulkan: Fix small tile size matmul on lavapipe

* fix mul_mat_id failures

ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp

index 16254457bb13cbc99b0de3a87e60e3cfdf1829ef..502a4deebc950f1e094d86bac0f0de5719948c3f 100644 (file)
@@ -2898,39 +2898,41 @@ static void ggml_vk_load_shaders(vk_device& device) {
         const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1;
         const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1;
 
-        l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
-        m_warptile = { 128,  64,  64, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
-        s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
+        const uint32_t s_warptile_wm = device->subgroup_size == 8 ? 8 : 32;
 
-        l_warptile_mmq = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
-        m_warptile_mmq = { 128,  64,  64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
-        s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
+        l_warptile = { 128,             128, 128, 16, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
+        m_warptile = { 128,              64,  64, 16, subgroup_size_8,     32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
+        s_warptile = { subgroup_size_32, 32,  32, 16, s_warptile_wm,       32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
+
+        l_warptile_mmq = { 128,             128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
+        m_warptile_mmq = { 128,              64,  64, 32, subgroup_size_8,     32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
+        s_warptile_mmq = { subgroup_size_32, 32,  32, 32, s_warptile_wm,       32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
 
         // Integer MMQ has a smaller shared memory profile, but heavier register use
-        l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
-        m_warptile_mmq_int = { 128,  64,  64, 32, subgroup_size_8,     32, 2, 2, 2, 1, subgroup_size_8 };
-        s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32,       32, 2, 2, 1, 1, subgroup_size_8 };
+        l_warptile_mmq_int = { 128,             128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
+        m_warptile_mmq_int = { 128,              64,  64, 32, subgroup_size_8,     32, 2, 2, 2, 1, subgroup_size_8 };
+        s_warptile_mmq_int = { subgroup_size_32, 32,  32, 32, s_warptile_wm,       32, 2, 2, 1, 1, subgroup_size_8 };
 
         // K-quants use even more registers, mitigate by setting WMITER to 1
-        l_warptile_mmq_int_k = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 1, 4, 4, 1, subgroup_size_8 };
-        m_warptile_mmq_int_k = { 128,  64,  64, 32, subgroup_size_8,     32, 1, 2, 2, 1, subgroup_size_8 };
-        s_warptile_mmq_int_k = { subgroup_size_32, 32, 32, 32, 32,       32, 1, 2, 1, 1, subgroup_size_8 };
+        l_warptile_mmq_int_k = { 128,               128, 128, 32, subgroup_size_8 * 2, 64, 1, 4, 4, 1, subgroup_size_8 };
+        m_warptile_mmq_int_k = { 128,                64,  64, 32, subgroup_size_8,     32, 1, 2, 2, 1, subgroup_size_8 };
+        s_warptile_mmq_int_k = { subgroup_size_32,   32,  32, 32, s_warptile_wm,       32, 1, 2, 1, 1, subgroup_size_8 };
 
-        l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 };
-        m_warptile_id = { 128,  64,  64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 };
-        s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 };
+        l_warptile_id = { 128,                      128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 };
+        m_warptile_id = { 128,                       64,  64, 16, mul_mat_subgroup_size_16,     32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 };
+        s_warptile_id = { mul_mat_subgroup_size_16,  32,  32, 16, s_warptile_wm,                32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 };
 
-        l_warptile_mmqid = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_8 };
-        m_warptile_mmqid = { 128,  64,  64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 };
-        s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 };
+        l_warptile_mmqid = { 128,                       128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_8 };
+        m_warptile_mmqid = { 128,                        64,  64, 32, mul_mat_subgroup_size_8,     32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 };
+        s_warptile_mmqid = { mul_mat_subgroup_size_32,   32,  32, 32, s_warptile_wm,               32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 };
 
-        l_warptile_mmqid_int = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, 4, 4, 1, mul_mat_subgroup_size_8 };
-        m_warptile_mmqid_int = { 128,  64,  64, 32, mul_mat_subgroup_size_8,     32, 2, 2, 2, 1, mul_mat_subgroup_size_8 };
-        s_warptile_mmqid_int = { mul_mat_subgroup_size_32, 32, 32, 32, 32,       32, 2, 2, 1, 1, mul_mat_subgroup_size_8 };
+        l_warptile_mmqid_int = { 128,                       128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, 4, 4, 1, mul_mat_subgroup_size_8 };
+        m_warptile_mmqid_int = { 128,                        64,  64, 32, mul_mat_subgroup_size_8,     32, 2, 2, 2, 1, mul_mat_subgroup_size_8 };
+        s_warptile_mmqid_int = { mul_mat_subgroup_size_32,   32,  32, 32, s_warptile_wm,               32, 2, 2, 1, 1, mul_mat_subgroup_size_8 };
 
-        l_warptile_mmqid_int_k = { 128, 128, 128, 32, mul_mat_subgroup_size_16 * 2, 64, 1, 4, 4, 1, mul_mat_subgroup_size_16 };
-        m_warptile_mmqid_int_k = { 128,  64,  64, 32, mul_mat_subgroup_size_16,     32, 1, 2, 2, 1, mul_mat_subgroup_size_16 };
-        s_warptile_mmqid_int_k = { mul_mat_subgroup_size_32, 32, 32, 32, 32,       32, 1, 2, 1, 1, mul_mat_subgroup_size_16 };
+        l_warptile_mmqid_int_k = { 128,                     128, 128, 32, mul_mat_subgroup_size_16 * 2, 64, 1, 4, 4, 1, mul_mat_subgroup_size_16 };
+        m_warptile_mmqid_int_k = { 128,                      64,  64, 32, mul_mat_subgroup_size_16,     32, 1, 2, 2, 1, mul_mat_subgroup_size_16 };
+        s_warptile_mmqid_int_k = { mul_mat_subgroup_size_32, 32,  32, 32, s_warptile_wm,                32, 1, 2, 1, 1, mul_mat_subgroup_size_16 };
 
         // chip specific tuning
         if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) {
@@ -6773,7 +6775,12 @@ static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& sub
 
     vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
 
-    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array<uint32_t, 1>{ne}, { ne, 1, 1 });
+    const uint32_t num_blocks = CEIL_DIV(ne, pipeline->wg_denoms[0]);
+    // clamp the number of elements to the max workgroup count. The shader will iterate over the total number of blocks.
+    const uint64_t max_elements = std::min<uint64_t>(uint64_t{ctx->device->properties.limits.maxComputeWorkGroupCount[0]} * pipeline->wg_denoms[0], std::numeric_limits<uint32_t>::max());
+    const uint32_t elements = std::min(ne, static_cast<uint32_t>(max_elements));
+
+    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array<uint32_t, 2>{ ne, num_blocks }, { elements, 1, 1 });
     ggml_vk_sync_buffers(ctx, subctx);
 }
 
index 20e45d0253e3680d1628aa1c8154839460ead4c6..7ea29a07e3748efc74ed7eea55bb10cb4e397c3f 100644 (file)
@@ -15,6 +15,7 @@
 layout (push_constant) uniform parameter
 {
     uint ne;
+    uint num_blocks;
 } p;
 
 #include "types.glsl"
@@ -33,8 +34,7 @@ layout (binding = 1) writeonly buffer D {block_q8_1_x4 data_b[];};
 shared float shmem[GROUP_SIZE];
 #endif
 
-void quantize() {
-    const uint wgid = gl_WorkGroupID.x;
+void quantize(const uint wgid) {
     const uint tid = INVOCATION_ID;
 
     // Each thread handles a vec4, so 8 threads handle a block
@@ -45,11 +45,7 @@ void quantize() {
     const uint ib = wgid * blocks_per_group + block_in_wg;
     const uint iqs = tid % 8;
 
-#ifndef QBLOCK_X4
-    if (ib >= gl_NumWorkGroups.x * blocks_per_group) {
-        return;
-    }
-#else
+#ifdef QBLOCK_X4
     const uint ibx4_outer = ib / 4;
     const uint ibx4_inner = ib % 4;
 
@@ -123,5 +119,9 @@ void quantize() {
 }
 
 void main() {
-    quantize();
+    uint wgid = gl_WorkGroupID.x;
+    while (wgid < p.num_blocks) {
+        quantize(wgid);
+        wgid += gl_NumWorkGroups.x;
+    }
 }