]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vulkan: Use VK_EXT_shader_64bit_indexing to handle large mat_mul(_id) (#18678)
authorJeff Bolz <redacted>
Mon, 12 Jan 2026 11:32:13 +0000 (05:32 -0600)
committerGitHub <redacted>
Mon, 12 Jan 2026 11:32:13 +0000 (12:32 +0100)
This fixes incoherent output in Llama-4-Maverick-17B-128E-PAB-Q8_0, which
has a mul_mat_id with an A matrix that's Q8_0 8192 x 5120 x 128.

This should work when the number of blocks in the A matrix is less than 2^32
(for mul_mat_vec or mul_mm_cm2), or for mul_mm I think the limit is like
2^32*LOAD_VEC_A elements.

- Divide batch_stride by QUANT_K earlier, so the block index calculation works in 32b.
- Each vk_pipeline_struct has a linked list of pipelines that will allow it to handle
variants. So far this change just adds a single use case for this, compiling with the
e64BitIndexingEXT flag.
- Use the 64b indexing variant when the A matrix is larger than maxStorageBufferRange.

64-bit indexing has some cost - around 3-5% in MoE models, so it's worth the effort
to avoid enabling it unconditionally.

20 files changed:
ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
tests/test-backend-ops.cpp

index ba5252b814bfb0ede31bdfa272230de632269216..4b337cb931d51ad8d61a9da8e4e31bc02554540f 100644 (file)
@@ -119,6 +119,8 @@ struct ggml_backend_vk_context;
 // Max number of adds that can be fused without exceeding MAX_PARAMETER_COUNT.
 #define MAX_FUSED_ADDS (MAX_PARAMETER_COUNT - 3)
 
+typedef std::shared_ptr<struct vk_pipeline_struct> vk_pipeline;
+
 struct vk_pipeline_struct {
     std::string name;
     vk::ShaderModule shader_module;
@@ -136,9 +138,15 @@ struct vk_pipeline_struct {
     std::atomic<bool> compiled {};
     // number of registers used, extracted from pipeline executable properties
     uint32_t register_count {};
+
+#if defined(VK_EXT_shader_64bit_indexing)
+    bool is_64b_indexing {};
+#endif
+    // linked list of pipelines for multiple compilation variants.
+    // currently only used to compile a 64-bit indexing variant.
+    vk_pipeline next;
 };
 
-typedef std::shared_ptr<vk_pipeline_struct> vk_pipeline;
 typedef std::weak_ptr<vk_pipeline_struct> vk_pipeline_ref;
 
 static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline);
@@ -584,6 +592,8 @@ struct vk_device_struct {
     bool add_rms_fusion;
     uint32_t partials_binding_alignment;
 
+    bool shader_64b_indexing;
+
     bool integer_dot_product;
     // 0: default, 1: force mmvq, -1: disable mmvq
     int32_t mmvq_mode;
@@ -2080,6 +2090,19 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
         compute_pipeline_create_info.setPNext(&rci);
     }
 
+#if defined(VK_EXT_shader_64bit_indexing)
+    vk::PipelineCreateFlags2CreateInfo pipelineFlags2CreateInfo;
+    if (pipeline->is_64b_indexing)
+    {
+        pipelineFlags2CreateInfo.flags = vk::PipelineCreateFlagBits2::e64BitIndexingEXT;
+        if (device->pipeline_executable_properties_support) {
+            pipelineFlags2CreateInfo.flags |= vk::PipelineCreateFlagBits2::eCaptureStatisticsKHR;
+        }
+        pipelineFlags2CreateInfo.setPNext(compute_pipeline_create_info.pNext);
+        compute_pipeline_create_info.setPNext(&pipelineFlags2CreateInfo);
+    }
+#endif
+
     try {
         pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;
     } catch (const vk::SystemError& e) {
@@ -3066,7 +3089,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
     }
 
     std::vector<std::future<void>> compiles;
-    auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const char *name, size_t spv_size, const void* spv_data, const char *entrypoint,
+    auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& base_pipeline, const char *name, size_t spv_size, const void* spv_data, const char *entrypoint,
                                               uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants,
                                               uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {
 
@@ -3074,35 +3097,49 @@ static void ggml_vk_load_shaders(vk_device& device) {
             required_subgroup_size = get_subgroup_size(name, device->architecture);
         }
 
-        if (!pipeline) {
-            pipeline = std::make_shared<vk_pipeline_struct>();
-        }
-        if (!pipeline->initialized) {
-            pipeline->name = name;
-            pipeline->parameter_count = parameter_count;
-            pipeline->push_constant_size = push_constant_size;
-            pipeline->wg_denoms = wg_denoms;
-            pipeline->align = align;
-            pipeline->initialized = true;
-        }
+        vk_pipeline *ptr = &base_pipeline;
 
-        if (!pipeline->needed || pipeline->compiled) {
-            return;
+        int num_pipelines = 1;
+#if defined(VK_EXT_shader_64bit_indexing)
+        if (device->shader_64b_indexing) {
+            num_pipelines = 2;
         }
-        // TODO: We're no longer benefitting from the async compiles (shaders are
-        // compiled individually, as needed) and this complexity can be removed.
-        {
-            // wait until fewer than N compiles are in progress
-            uint32_t N = std::max(1u, std::thread::hardware_concurrency());
-            std::unique_lock<std::mutex> guard(compile_count_mutex);
-            while (compile_count >= N) {
-                compile_count_cond.wait(guard);
+#endif
+        for (int i = 0; i < num_pipelines; ++i, ptr = &(*ptr)->next) {
+            vk_pipeline &pipeline = *ptr;
+            if (!pipeline) {
+                pipeline = std::make_shared<vk_pipeline_struct>();
+            }
+            if (!pipeline->initialized) {
+                pipeline->name = name;
+                pipeline->parameter_count = parameter_count;
+                pipeline->push_constant_size = push_constant_size;
+                pipeline->wg_denoms = wg_denoms;
+                pipeline->align = align;
+                pipeline->initialized = true;
+#if defined(VK_EXT_shader_64bit_indexing)
+                pipeline->is_64b_indexing = (i == 1);
+#endif
+            }
+
+            if (!pipeline->needed || pipeline->compiled) {
+                continue;
+            }
+            // TODO: We're no longer benefitting from the async compiles (shaders are
+            // compiled individually, as needed) and this complexity can be removed.
+            {
+                // wait until fewer than N compiles are in progress
+                uint32_t N = std::max(1u, std::thread::hardware_concurrency());
+                std::unique_lock<std::mutex> guard(compile_count_mutex);
+                while (compile_count >= N) {
+                    compile_count_cond.wait(guard);
+                }
+                compile_count++;
             }
-            compile_count++;
-        }
 
-        compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint,
-                                      parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
+            compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint,
+                                          parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
+        }
     };
 
     auto const &ggml_vk_create_pipeline2 = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const char *entrypoint,
@@ -4480,6 +4517,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
         bool pipeline_executable_properties_support = false;
         device->coopmat_support = false;
         device->integer_dot_product = false;
+        device->shader_64b_indexing = false;
         bool bfloat16_support = false;
 
         for (const auto& properties : ext_props) {
@@ -4527,6 +4565,10 @@ static vk_device ggml_vk_get_device(size_t idx) {
                 device->memory_priority = true;
             } else if (strcmp("VK_EXT_external_memory_host", properties.extensionName) == 0) {
                 device->external_memory_host = true;
+#if defined(VK_EXT_shader_64bit_indexing)
+            } else if (strcmp("VK_EXT_shader_64bit_indexing", properties.extensionName) == 0) {
+                device->shader_64b_indexing = true;
+#endif
             }
         }
 
@@ -4817,6 +4859,16 @@ static vk_device ggml_vk_get_device(size_t idx) {
             device_extensions.push_back("VK_EXT_external_memory_host");
         }
 
+#if defined(VK_EXT_shader_64bit_indexing)
+        VkPhysicalDeviceShader64BitIndexingFeaturesEXT shader_64bit_indexing_features {};
+        shader_64bit_indexing_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_64_BIT_INDEXING_FEATURES_EXT;
+        if (device->shader_64b_indexing) {
+            last_struct->pNext = (VkBaseOutStructure *)&shader_64bit_indexing_features;
+            last_struct = (VkBaseOutStructure *)&shader_64bit_indexing_features;
+            device_extensions.push_back("VK_EXT_shader_64bit_indexing");
+        }
+#endif
+
         vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
 
         device->pipeline_executable_properties_support = pipeline_executable_properties_support;
@@ -6902,6 +6954,20 @@ static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& sub
     ggml_vk_sync_buffers(ctx, subctx);
 }
 
+static vk_pipeline ggml_vk_get_64b_indexing_pipeline(ggml_backend_vk_context * ctx, vk_pipeline &pipeline) {
+    GGML_UNUSED(ctx);
+#if defined(VK_EXT_shader_64bit_indexing)
+    vk_pipeline *ptr = &pipeline;
+    while (*ptr) {
+        if ((*ptr)->is_64b_indexing) {
+            return *ptr;
+        }
+        ptr = &(*ptr)->next;
+    }
+#endif
+    return pipeline;
+}
+
 static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool disable_split_k) {
     VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << ggml_type_name(src0->type) << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
     std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << ggml_type_name(src1->type) << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
@@ -6985,6 +7051,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
 
     vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type));
 
+    if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
+        pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline);
+    }
+
     // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
     uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) : ne11;
     const uint64_t x_ne = ggml_nelements(src0);
@@ -7294,6 +7364,10 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
         to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
     }
 
+    if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
+        dmmv = ggml_vk_get_64b_indexing_pipeline(ctx, dmmv);
+    }
+
     const bool qx_needs_dequant = x_non_contig;
     const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig);
 
@@ -7489,9 +7563,15 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
         gqa_ratio = 1;
     }
 
+    vk_pipeline pipeline = ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1];
+
+    if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
+        pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline);
+    }
+
     {
         // Request descriptor sets
-        ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], 1);
+        ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
     }
 
     vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops], true);
@@ -7533,7 +7613,7 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
         workgroups_z /= gqa_ratio;
     }
 
-    ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1],
+    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
         {
             d_Qx,
             d_Qy,
@@ -7583,9 +7663,14 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
     const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t);
     const uint32_t channel_stride_y = nb12 / sizeof(float);
 
+    vk_pipeline pipeline = ctx->device->pipeline_mul_mat_vec_nc_f16_f32;
+    if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
+        pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline);
+    }
+
     {
         // Request descriptor sets
-        ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, 1);
+        ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
     }
 
     vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops], true);
@@ -7622,7 +7707,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
 
     init_pushconst_tensor_offsets(ctx, pc, src0, src1, nullptr, nullptr, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
 
-    ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32,
+    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
         {
             d_Qx,
             d_Qy,
@@ -7641,8 +7726,9 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
     // Handle huge A matrix by splitting the M dimensions. This works well for convolution use cases
     // where the M dimension is very large.
     // Split_k doesn't work with M splitting.
+    // This only supports batchsize == 1.
     const size_t nbytes = ggml_nbytes(src0);
-    const bool needs_split = nbytes > ctx->device->properties.limits.maxStorageBufferRange;
+    const bool needs_split = dst->ne[2] == 1 && dst->ne[3] == 1 && nbytes > ctx->device->properties.limits.maxStorageBufferRange;
     if (needs_split) {
         // Choose the number of rows that can fit (and divide by two, to allow for any additional offsets)
         const uint32_t M_split = ctx->device->properties.limits.maxStorageBufferRange / (2 * src0->nb[1]);
@@ -7784,6 +7870,9 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
 
     vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type);
 
+    if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
+        pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline);
+    }
     // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
     uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
     const uint64_t x_ne = ggml_nelements(src0);
@@ -8045,6 +8134,10 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
     const bool qx_needs_dequant = x_non_contig;
     const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig);
 
+    if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
+        dmmv = ggml_vk_get_64b_indexing_pipeline(ctx, dmmv);
+    }
+
     // Not implemented
     GGML_ASSERT(y_non_contig || !qy_needs_dequant);  // NOLINT
     GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr);  // NOLINT
index b3c96576debafc5e529b453ecc25b69823013cc6..2271be4021bcfed3b982cb6283892bd88893cf55 100644 (file)
@@ -87,7 +87,6 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
     const uint tid = gl_LocalInvocationID.x;
 
     get_offsets(a_offset, b_offset, d_offset);
-    a_offset /= QUANT_K;
 
     y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
 
index cfc8b0c7f4b84f461d40f9abdac11b776c9146d2..dfb78659362b223603524aab4f3ea34fd19c6ec8 100644 (file)
@@ -65,9 +65,9 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
 
     a_offset =
 #ifdef MUL_MAT_ID
-            expert_id * p.batch_stride_a;
+            expert_id * (p.batch_stride_a / QUANT_K);
 #else
-            batch_idx_a * p.batch_stride_a;
+            batch_idx_a * (p.batch_stride_a / QUANT_K);
 #endif
     b_offset =
 #ifdef MUL_MAT_ID
index e5cc7ff8629a699febc084867c70e3bd98ed1fce..3ea24a76cec006fa80bdde5c0ff57925d671d068 100644 (file)
@@ -11,7 +11,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32,
                                const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
     // Compute starting index in matrix B for this superblock
     const uint y_idx = i * QUANT_K + 32 * ib32;
-    uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
+    uint ibi = a_offset + first_row * num_blocks_per_row + i;
 
     // Precompute indices for quantization lookup tables
     const uint qh_base = 2 * ib32;
index c5f5e9cbb2b613fe7f62276b0b061b4186a7c714..fd953c8faddcd57c80e6b548fc5a2303e36bb7f0 100644 (file)
@@ -17,7 +17,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32,
             const vec4 b_val_1 = vec4(data_b_v4[base_b_idx + 2 * l + 1]);
 
             // index for data_a
-            uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
+            uint ibi = a_offset + first_row * num_blocks_per_row + i;
 
             [[unroll]] for (uint n = 0; n < num_rows; ++n) {
                 const float d = float(data_a[ibi].d);
index e424af12c5a6fbb3af69ad6e0048440d15e8e5e1..b4f6d1d6b648dc95126d2f40a8396d92a6c9d264 100644 (file)
@@ -12,7 +12,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
     const uint nibble_shift = 4 * (itid & 1);
     const uint ib32 = itid / 2; // 0..7
 
-    uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
+    uint ibi = a_offset + first_row * num_blocks_per_row + i;
     [[unroll]] for (uint n = 0; n < num_rows; ++n) {
         const float d = float(data_a[ibi].d);
         const uint scale = (data_a[ibi].scales[ib32] >> nibble_shift) & 0xF;
index 7ec2e04f58e11bbcdb5d21a6ad44d2776d789f9f..d8dafe5f709d891ad76cba57ff388555bdbedeef 100644 (file)
@@ -11,7 +11,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
     const uint y_idx = i * QUANT_K + 16 * itid;
     const uint nibble_shift = 4 * (itid & 1);
     const uint ib32 = itid / 2; // 0..7
-    uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
+    uint ibi = a_offset + first_row * num_blocks_per_row + i;
     // Precompute db multiplication factors
     float db_vals[NUM_ROWS];
     [[unroll]] for (uint n = 0; n < num_rows; ++n) {
@@ -22,7 +22,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
         db_vals[n] = d * (0.125f + float(scale) * 0.25f);
         ibi += num_blocks_per_row;
     }
-    ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
+    ibi = a_offset + first_row * num_blocks_per_row + i;
     [[unroll]] for (uint n = 0; n < num_rows; ++n) {
         // Preload grid and sign data for all l values
         vec4 grid0_vals[2], grid1_vals[2];
index 71bd72d17e3893272ce78f0b1a0f9a396c548966..f75dcf8331d93b4e1c1b7a073f462441b2d9d5ee 100644 (file)
@@ -11,7 +11,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
     const uint y_idx = i * QUANT_K + 16 * itid;
     const uint ib32 = itid / 2; // 0..7
 
-    uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
+    uint ibi = a_offset + first_row * num_blocks_per_row + i;
     [[unroll]] for (uint n = 0; n < num_rows; ++n) {
         const float d = float(data_a[ibi].d);
         const uint signscale = pack32(u16vec2(
index a4b9ab1f94f1025dfcebcf70d8d16d187daba3ab..5cdf2a89d0fd2c6888ebdf68c64bde0a00cf018d 100644 (file)
@@ -10,7 +10,7 @@ FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
 void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
     const uint y_idx = i * QUANT_K + 32 * ib32;
 
-    uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
+    uint ibi = a_offset + first_row * num_blocks_per_row + i;
     [[unroll]] for (uint n = 0; n < num_rows; ++n) {
         const float d = float(data_a[ibi].d);
         const uint scale = (data_a[ibi].scales[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
index 40849c691f297ea63a55adce3bb8b6997273bd99..a88898109ab6ee4da4da4b141db8e304e2e35ae9 100644 (file)
@@ -11,7 +11,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
     const uint y_idx = i * QUANT_K + 16 * itid;
     const uint ib32 = itid / 2; // 0..7
 
-    uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
+    uint ibi = a_offset + first_row * num_blocks_per_row + i;
     [[unroll]] for (uint n = 0; n < num_rows; ++n) {
         const float d = float(data_a[ibi].d);
         const uint signscale = pack32(u16vec2(
index 14093c0de5a4593ead242a14015d3400efa11725..619de054cb8be7849f3ae83c1078be84e40d3612 100644 (file)
@@ -15,7 +15,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
     const uint y_idx = i * QUANT_K + y_offset;
 
     [[unroll]] for (uint n = 0; n < num_rows; ++n) {
-        const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
+        const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row;
         csel ^= 1;
 
         if (!all_threads) { // when we don't have enough blocks to use all threads
index 528f224d86bc6045314b88c2e1efb702a76f81c8..93e48b790122a922f9c0d7d76022b459d2cf2ccb 100644 (file)
@@ -14,7 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, co
     const uint y_idx = i * QUANT_K + y_offset;
 
     [[unroll]] for (uint n = 0; n < num_rows; ++n) {
-        const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
+        const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row;
         csel ^= 1;
 
         if (!all_threads) { // when we don't have enough blocks to use all threads
index 49d91ad59101ef54601cc0133f4e56b37247e343..6af5a81587d4dc5511fb4067d6daaabc6d468aa3 100644 (file)
@@ -13,7 +13,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
     const uint y2_idx = y1_idx + 128;
 
     [[unroll]] for (uint n = 0; n < num_rows; ++n) {
-        const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
+        const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row;
         const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm);
 
         const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im    ];
index 0d61b4966ec4a424fff402c2cca852faa3419b4a..3695b47b98d727829cb70b84d5b99faec019522e 100644 (file)
@@ -13,7 +13,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
     const uint y2_idx = y1_idx + 128;
 
     [[unroll]] for (uint n = 0; n < num_rows; ++n) {
-        const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
+        const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row;
         const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm);
 
         const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im    ];
index d7a7f6426ee957ec8162ced57efe924cf06aa1ed..3e89d91cbb0476f82bf04899a6d90863f9cc737b 100644 (file)
@@ -15,7 +15,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
     const uint y_idx = i * QUANT_K + y_offset;
 
     [[unroll]] for (uint n = 0; n < num_rows; ++n) {
-        const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
+        const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row;
         csel ^= 1;
 
         if (!all_threads) { // when we don't have enough blocks to use all threads
index ff5f43979d2e0789002a59489399d2854b0e5dc9..6fe3e2dc0437beae23a0fe63e955a31366e6b3df 100644 (file)
@@ -79,7 +79,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
     const uint tid = gl_LocalInvocationID.x;
 
     get_offsets(a_offset, b_offset, d_offset);
-    a_offset /= QUANT_K_Q8_1;
+    a_offset *= QUANT_K / QUANT_K_Q8_1;
     b_offset /= QUANT_K_Q8_1;
 
     FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
index c0c00d28fca678f750c08e65c022cc9fe9a48742..775e9a70f6d52213edd2092044772a2747d7182b 100644 (file)
@@ -234,13 +234,13 @@ void main() {
     const uint end_k = min(p.K, (ik + 1) * p.k_split);
 #endif
 
-    uint pos_a = (
+    uint pos_a =
 #ifdef MUL_MAT_ID
-        expert_idx * p.batch_stride_a +
+        expert_idx * (p.batch_stride_a / LOAD_VEC_A) +
 #else
-        batch_idx_a * p.batch_stride_a +
+        batch_idx_a * (p.batch_stride_a / LOAD_VEC_A) +
 #endif
-        ir * BM * p.stride_a + start_k) / LOAD_VEC_A;
+        (ir * BM * p.stride_a + start_k) / LOAD_VEC_A;
 #ifdef MUL_MAT_ID
     uint pos_b = 0;
 #else
index d0d1d8ef7231b772e4d7f40f7af2669fcfa9fdc5..b6614d2fc5999d5cf90bdb4756c904cc5aa5193a 100644 (file)
@@ -250,10 +250,10 @@ void main() {
 #endif
 
 #ifdef MUL_MAT_ID
-    uint pos_a = (expert_idx * p.batch_stride_a) / QUANT_K;
+    uint pos_a = expert_idx * (p.batch_stride_a / QUANT_K);
     uint pos_b = 0;
 #else
-    uint pos_a = (batch_idx_a * p.batch_stride_a) / QUANT_K;
+    uint pos_a = batch_idx_a * (p.batch_stride_a / QUANT_K);
     uint pos_b = batch_idx * p.batch_stride_b;
     uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
 #endif
index cd36e270ab3f94a6d50cd46fa44094d8656defde..335d7f6a68273095416e1ac9a6aa235ee57046ed 100644 (file)
@@ -189,13 +189,13 @@ void main() {
     const uint end_k = min(p.K, (ik + 1) * p.k_split);
 #endif
 
-    uint pos_a_ib = (
+    uint pos_a_ib =
 #ifdef MUL_MAT_ID
-        expert_idx * p.batch_stride_a +
+        expert_idx * (p.batch_stride_a / BK) +
 #else
-        batch_idx_a * p.batch_stride_a +
+        batch_idx_a * (p.batch_stride_a / BK) +
 #endif
-        ir * BM * p.stride_a + start_k) / BK;
+        (ir * BM * p.stride_a + start_k) / BK;
 #ifdef MUL_MAT_ID
     uint pos_b_ib = 0;
 #else
index 56d277e167037afeebc9f11de023419f57142680..19ef58404efd2dd68444c867987eb84b861d598e 100644 (file)
@@ -7560,6 +7560,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 1700000, 96, 2592, {1, 1}, {1, 1}));
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 1700000,  3, 2592, {1, 1}, {1, 1}));
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 1700000,  1, 2592, {1, 1}, {1, 1}));
+
+    test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_Q8_0, GGML_TYPE_F32, 128, 128, false, 8192, 2, 5120)); // Llama-4-Maverick-17B-128E-PAB-Q8_0
+    test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_Q8_0, GGML_TYPE_F32, 128, 128, false, 8192, 1, 5120)); // Llama-4-Maverick-17B-128E-PAB-Q8_0
+    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 8192, 1, 5120, {128, 1}, {1, 1}));
+    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 8192, 512, 5120, {128, 1}, {1, 1}));
 #endif
 
     for (ggml_type type_a : all_types) {