]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
vulkan: Remove splitting for mul_mat_id (llama/15568)
authorJeff Bolz <redacted>
Tue, 26 Aug 2025 04:42:44 +0000 (23:42 -0500)
committerGeorgi Gerganov <redacted>
Fri, 5 Sep 2025 09:54:04 +0000 (12:54 +0300)
row_ids only needs to hold the BN rows for the current tile.

src/ggml-vulkan/ggml-vulkan.cpp
src/ggml-vulkan/vulkan-shaders/mul_mm.comp
src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp
tests/test-backend-ops.cpp

index 30e53175042acfa3a0ff361a8afb2bc1c95c7e20..04ad664e61c0731317919d4218477cf15f16411c 100644 (file)
@@ -2090,10 +2090,11 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
     const uint32_t warps = warptile[0] / warptile[10];
 
     const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
-    const uint32_t mmid_row_ids = mul_mat_id ? (4096 * sizeof(uint32_t) + 4/*_ne1*/) : 0;
+    const uint32_t mmid_row_ids = mul_mat_id ? (warptile[2] * 2 * sizeof(uint16_t)) : 0;
     const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
+    const uint32_t ballots_sh = mul_mat_id ? (warps * 4 * sizeof(uint32_t)) : 0;
 
-    const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size;
+    const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size + ballots_sh;
     const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
 
     VK_LOG_DEBUG("ggml_vk_matmul_shmem_support(warptile=(" << warptile[0] << "," << warptile[1] << "," << warptile[2] << "), "
@@ -6288,7 +6289,6 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
 
     const uint64_t nei0 = ids->ne[0];
     const uint64_t nei1 = ids->ne[1];
-    GGML_ASSERT(nei0 * nei1 <= 4096);
 
     const uint32_t nbi1 = ids->nb[1];
     const uint32_t nbi2 = ids->nb[2];
@@ -6728,37 +6728,7 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
     if (src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
         ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
     } else {
-        // Split based on number of ids, to fit in shared memory
-        const uint32_t nei0 = (uint32_t)src2->ne[0];
-        const uint32_t nei1 = (uint32_t)src2->ne[1];
-
-        GGML_ASSERT(nei0 <= 4096);
-        const uint32_t split_size = std::min(nei1, 4096u / nei0);
-
-        if (split_size == nei1) {
-            ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
-        } else {
-            ggml_tensor src1_copy = *src1;
-            ggml_tensor src2_copy = *src2;
-            ggml_tensor dst_copy = *dst;
-
-            for (uint32_t token_start = 0; token_start < nei1; token_start += split_size) {
-                const uint32_t n_tokens = std::min(split_size, nei1 - token_start);
-
-                src1_copy.view_offs = src1->view_offs + token_start * src1_copy.nb[2];
-                src2_copy.view_offs = src2->view_offs + token_start * src2_copy.nb[1];
-                dst_copy.view_offs = dst->view_offs + token_start * dst_copy.nb[2];
-
-                src1_copy.ne[2] = n_tokens;
-                src2_copy.ne[1] = n_tokens;
-                dst_copy.ne[2] = n_tokens;
-
-                ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, &src1_copy, &src2_copy, &dst_copy, dryrun);
-                // invalidate cached prealloc_y, can't cache based on the copy of the ggml_tensor
-                ctx->prealloc_y_last_pipeline_used = {};
-                ctx->prealloc_y_last_tensor_used = nullptr;
-            }
-        }
+        ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
     }
 }
 
index 40c0d9b0c5731f75ddd38cd57f0fb56008973ad7..5ecf68a64383b84a56daf936d9048382f2f287ad 100644 (file)
@@ -109,13 +109,13 @@ shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
 #define NUM_WARPS (BLOCK_SIZE / WARP)
 
 #ifdef MUL_MAT_ID
-shared u16vec2 row_ids[4096];
+shared u16vec2 row_ids[BN];
 uint _ne1;
 
 #ifdef MUL_MAT_ID_USE_SUBGROUPS
 shared uvec4 ballots_sh[NUM_WARPS];
 
-void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
+void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
     _ne1 = 0;
     uint num_elements = p.nei1 * p.nei0;
     uint nei0shift = findLSB(p.nei0);
@@ -165,11 +165,14 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
         barrier();
 
         uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
-        if (in_range && id == expert_idx) {
-            row_ids[_ne1 + idx] = u16vec2(ii0, ii1);
+        if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
+            row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1);
         }
         _ne1 += total;
         iter &= 15;
+        if (_ne1 >= (ic + 1) * BN) {
+            break;
+        }
     }
     barrier();
 }
@@ -242,16 +245,18 @@ void main() {
 #ifdef MUL_MAT_ID
 #ifdef MUL_MAT_ID_USE_SUBGROUPS
     if (bitCount(p.nei0) == 1) {
-        load_row_ids(expert_idx, true);
+        load_row_ids(expert_idx, true, ic);
     } else {
-        load_row_ids(expert_idx, false);
+        load_row_ids(expert_idx, false, ic);
     }
 #else
     _ne1 = 0;
-    for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
-        for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
+    for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) {
+        for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) {
             if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
-                row_ids[_ne1] = u16vec2(ii0, ii1);
+                if (_ne1 >= ic * BN) {
+                    row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1);
+                }
                 _ne1++;
             }
         }
@@ -797,7 +802,7 @@ void main() {
         [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
 #if LOAD_VEC_B == 8
 #ifdef MUL_MAT_ID
-            const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
+            const u16vec2 row_idx = row_ids[loadc_b + l];
             const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
 #else
             const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
@@ -813,7 +818,7 @@ void main() {
             buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w);
 #elif LOAD_VEC_B == 4
 #ifdef MUL_MAT_ID
-            const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
+            const u16vec2 row_idx = row_ids[loadc_b + l];
             const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
 #else
             const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
@@ -832,7 +837,7 @@ void main() {
 #else
             const uint row_i = ic * BN + loadc_b + l;
             if (row_i < _ne1 && block + loadr_b < end_k) {
-                const u16vec2 row_idx = row_ids[row_i];
+                const u16vec2 row_idx = row_ids[loadc_b + l];
                 buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
             } else {
                 buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
@@ -903,7 +908,7 @@ void main() {
                 const uint row_i = dc + cm_col * TN + col + store_c;
                 if (row_i >= _ne1) break;
 
-                const u16vec2 row_idx = row_ids[row_i];
+                const u16vec2 row_idx = row_ids[row_i - ic * BN];
 
                 if (dr + cm_row * TM + store_r < p.M) {
                     data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
@@ -953,7 +958,7 @@ void main() {
                 const uint row_i = dc_warp + cc;
                 if (row_i >= _ne1) break;
 
-                const u16vec2 row_idx = row_ids[row_i];
+                const u16vec2 row_idx = row_ids[row_i - ic * BN];
 #endif // MUL_MAT_ID
                 [[unroll]] for (uint cr = 0; cr < TM; cr++) {
 #ifdef MUL_MAT_ID
index 4d16eb0791ddc6ab206340599f7c57ddcb4fcb59..f5aebf6e93f948386000b0b2fdde42dae2821110 100644 (file)
@@ -93,7 +93,7 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
 #ifdef MUL_MAT_ID
 layout (binding = 3) readonly buffer IDS {int data_ids[];};
 
-shared u16vec4 row_ids[4096];
+shared u16vec4 row_ids[BN];
 
 layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB {
    B_TYPE b[];
@@ -111,7 +111,7 @@ B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const i
         return B_TYPE(0.0);
     }
 
-    const u16vec4 row_idx = row_ids[row_i];
+    const u16vec4 row_idx = row_ids[row_i & (BN - 1)];
     B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]];
 
     return ret;
@@ -123,14 +123,14 @@ D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem
     uint dc = ic * BN + c;
 
     if (dr < p.M && dc < _ne1) {
-        uint row_i = dc;
+        uint row_i = c;
         const u16vec4 row_idx = row_ids[row_i];
         data_d[row_idx.y * p.batch_stride_d + row_idx.z * p.stride_d + dr] = elem;
     }
     return elem;
 }
 
-void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
+void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
     _ne1 = 0;
     uint num_elements = p.nei1 * p.nei0;
     uint nei0shift = findLSB(p.nei0);
@@ -180,11 +180,14 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
         barrier();
 
         uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
-        if (in_range && id == expert_idx) {
-            row_ids[_ne1 + idx] = u16vec4(fastmod(ii0, p.ne11), ii1, ii0, 0);
+        if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
+            row_ids[_ne1 + idx - ic * BN] = u16vec4(fastmod(ii0, p.ne11), ii1, ii0, 0);
         }
         _ne1 += total;
         iter &= 15;
+        if (_ne1 >= (ic + 1) * BN) {
+            break;
+        }
     }
     barrier();
 }
@@ -218,9 +221,9 @@ void main() {
 
 #ifdef MUL_MAT_ID
     if (bitCount(p.nei0) == 1) {
-        load_row_ids(expert_idx, true);
+        load_row_ids(expert_idx, true, ic);
     } else {
-        load_row_ids(expert_idx, false);
+        load_row_ids(expert_idx, false, ic);
     }
 
     // Workgroup has no work
index ef6f452195ba202d6304a3af44af51f28bc39cc0..765521ffebb4e4744a8a73e61880d2ed69aaeb76 100644 (file)
@@ -6017,6 +6017,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     // test large experts*tokens
     for (bool b : {false, true}) {
         test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, b, 32, 1024, 16));
+        test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 2, 2, b, 32, 8192, 64));
     }
 
     test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 1, 1, false, 8, 16, 1));