]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
vulkan: optimize mul_mat_id loading row ids into shared memory (llama/15427)
authorJeff Bolz <redacted>
Sat, 23 Aug 2025 06:31:54 +0000 (01:31 -0500)
committerGeorgi Gerganov <redacted>
Fri, 5 Sep 2025 09:54:01 +0000 (12:54 +0300)
- Spread the work across the whole workgroup. Using more threads seems to
far outweigh the synchronization overhead.
- Specialize the code for when the division is by a power of two.

src/ggml-vulkan/ggml-vulkan.cpp
src/ggml-vulkan/vulkan-shaders/mul_mm.comp
src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp

index fb18a55cdad2c64c6309dc44e16cf47719746f7c..2c5678f4884cf21a1c927ad1a62c3f2f1b4d11cb 100644 (file)
@@ -2168,9 +2168,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
         s_mmq_wg_denoms_k = { 32,  64,  1 };
 
         // spec constants and tile sizes for quant matmul_id
-        l_warptile_mmqid = { 256, 128, 128, 16, 0 };
-        m_warptile_mmqid = { 256, 128, 64, 16, 0 };
-        s_warptile_mmqid = { 256, 128, 64, 16, 0 };
+        l_warptile_mmqid = { 256, 128, 128, 16, 0, device->subgroup_size };
+        m_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size };
+        s_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size };
         l_mmqid_wg_denoms = { 128, 128, 1 };
         m_mmqid_wg_denoms = { 128, 64, 1 };
         s_mmqid_wg_denoms = { 128, 64, 1 };
index a61a464c7bef88593ff7b711190349f5d265b7b1..d57cc6bdec5df0e290c37989eeae96b60dda99a3 100644 (file)
@@ -103,16 +103,74 @@ layout (constant_id = 10) const uint WARP = 32;
 shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE];
 shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
 
+#define NUM_WARPS (BLOCK_SIZE / WARP)
+
 #ifdef MUL_MAT_ID
 shared u16vec2 row_ids[4096];
 uint _ne1;
 #ifdef COOPMAT
-shared uint _ne1_sh;
+shared uvec4 ballots_sh[NUM_WARPS];
+void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
+    _ne1 = 0;
+    uint num_elements = p.nei1 * p.nei0;
+    uint nei0shift = findLSB(p.nei0);
+
+    uint ids[16];
+    uint iter = 0;
+
+    for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
+        // prefetch up to 16 elements
+        if (iter == 0) {
+            [[unroll]] for (uint k = 0; k < 16; ++k) {
+                uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
+                bool in_range = i < num_elements;
+                uint ii1;
+                if (nei0_is_pow2) {
+                    ii1 = i >> nei0shift;
+                } else {
+                    ii1 = i / p.nei0;
+                }
+                uint ii0 = i - ii1 * p.nei0;
+                ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
+            }
+        }
+        uint i = j + gl_LocalInvocationIndex;
+        bool in_range = i < num_elements;
+        uint ii1;
+        if (nei0_is_pow2) {
+            ii1 = i >> nei0shift;
+        } else {
+            ii1 = i / p.nei0;
+        }
+        uint ii0 = i - ii1 * p.nei0;
+        uint id = ids[iter++];
+        uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
+
+        ballots_sh[gl_SubgroupID] = ballot;
+        barrier();
+
+        uint subgroup_base = 0;
+        uint total = 0;
+        for (uint k = 0; k < gl_NumSubgroups; ++k) {
+            if (k == gl_SubgroupID) {
+                subgroup_base = total;
+            }
+            total += subgroupBallotBitCount(ballots_sh[k]);
+        }
+        barrier();
+
+        uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
+        if (in_range && id == expert_idx) {
+            row_ids[_ne1 + idx] = u16vec2(ii0, ii1);
+        }
+        _ne1 += total;
+        iter &= 15;
+    }
+    barrier();
+}
 #endif
 #endif // MUL_MAT_ID
 
-#define NUM_WARPS (BLOCK_SIZE / WARP)
-
 #ifdef COOPMAT
 shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
 #endif
@@ -178,44 +236,11 @@ void main() {
 
 #ifdef MUL_MAT_ID
 #ifdef COOPMAT
-    // Spread the search across all elements in the first subgroup
-    if (gl_SubgroupID == 0) {
-        _ne1 = 0;
-        uint num_elements = p.nei1 * p.nei0;
-
-        uint ids[16];
-        uint iter = 0;
-
-        for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
-            // prefetch up to 16 elements
-            if (iter == 0) {
-                [[unroll]] for (uint k = 0; k < 16; ++k) {
-                    uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
-                    bool in_range = i < num_elements;
-                    uint ii1 = i / p.nei0;
-                    uint ii0 = i % p.nei0;
-                    ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
-                }
-            }
-            uint i = j + gl_SubgroupInvocationID;
-            bool in_range = i < num_elements;
-            uint ii1 = i / p.nei0;
-            uint ii0 = i % p.nei0;
-            uint id = ids[iter++];
-            uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
-            uint idx = subgroupBallotExclusiveBitCount(ballot);
-            if (in_range && id == expert_idx) {
-                row_ids[_ne1 + idx] = u16vec2(ii0, ii1);
-            }
-            _ne1 += subgroupBallotBitCount(ballot);
-            iter &= 15;
-        }
-        _ne1_sh = _ne1;
+    if (bitCount(p.nei0) == 1) {
+        load_row_ids(expert_idx, true);
+    } else {
+        load_row_ids(expert_idx, false);
     }
-
-    barrier();
-
-    _ne1 = _ne1_sh;
 #else
     _ne1 = 0;
     for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
index 29e4b5c9ce2d4141b66819a86ad9c96197b490b3..4d16eb0791ddc6ab206340599f7c57ddcb4fcb59 100644 (file)
@@ -19,6 +19,7 @@
 #endif
 
 #include "types.comp"
+#include "utils.comp"
 
 layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 
@@ -99,7 +100,8 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB {
 };
 
 uint _ne1;
-shared uint _ne1_sh;
+layout (constant_id = 5) const uint subgroup_size = 32;
+shared uvec4 ballots_sh[BLOCK_SIZE / subgroup_size];
 
 B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2])
 {
@@ -128,6 +130,64 @@ D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem
     return elem;
 }
 
+void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
+    _ne1 = 0;
+    uint num_elements = p.nei1 * p.nei0;
+    uint nei0shift = findLSB(p.nei0);
+
+    uint ids[16];
+    uint iter = 0;
+
+    for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
+        // prefetch up to 16 elements
+        if (iter == 0) {
+            [[unroll]] for (uint k = 0; k < 16; ++k) {
+                uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
+                bool in_range = i < num_elements;
+                uint ii1;
+                if (nei0_is_pow2) {
+                    ii1 = i >> nei0shift;
+                } else {
+                    ii1 = i / p.nei0;
+                }
+                uint ii0 = i - ii1 * p.nei0;
+                ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
+            }
+        }
+        uint i = j + gl_LocalInvocationIndex;
+        bool in_range = i < num_elements;
+        uint ii1;
+        if (nei0_is_pow2) {
+            ii1 = i >> nei0shift;
+        } else {
+            ii1 = i / p.nei0;
+        }
+        uint ii0 = i - ii1 * p.nei0;
+        uint id = ids[iter++];
+        uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
+
+        ballots_sh[gl_SubgroupID] = ballot;
+        barrier();
+
+        uint subgroup_base = 0;
+        uint total = 0;
+        for (uint k = 0; k < gl_NumSubgroups; ++k) {
+            if (k == gl_SubgroupID) {
+                subgroup_base = total;
+            }
+            total += subgroupBallotBitCount(ballots_sh[k]);
+        }
+        barrier();
+
+        uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
+        if (in_range && id == expert_idx) {
+            row_ids[_ne1 + idx] = u16vec4(fastmod(ii0, p.ne11), ii1, ii0, 0);
+        }
+        _ne1 += total;
+        iter &= 15;
+    }
+    barrier();
+}
 #endif
 
 void main() {
@@ -157,45 +217,12 @@ void main() {
     const uint ic = gl_WorkGroupID.y;
 
 #ifdef MUL_MAT_ID
-    // Spread the search across all elements in the first subgroup
-    if (gl_SubgroupID == 0) {
-        _ne1 = 0;
-        uint num_elements = p.nei1 * p.nei0;
-
-        uint ids[16];
-        uint iter = 0;
-
-        for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
-            // prefetch up to 16 elements
-            if (iter == 0) {
-                [[unroll]] for (uint k = 0; k < 16; ++k) {
-                    uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
-                    bool in_range = i < num_elements;
-                    uint ii1 = i / p.nei0;
-                    uint ii0 = i % p.nei0;
-                    ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
-                }
-            }
-            uint i = j + gl_SubgroupInvocationID;
-            bool in_range = i < num_elements;
-            uint ii1 = i / p.nei0;
-            uint ii0 = i % p.nei0;
-            uint id = ids[iter++];
-            uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
-            uint idx = subgroupBallotExclusiveBitCount(ballot);
-            if (in_range && id == expert_idx) {
-                row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0);
-            }
-            _ne1 += subgroupBallotBitCount(ballot);
-            iter &= 15;
-        }
-        _ne1_sh = _ne1;
+    if (bitCount(p.nei0) == 1) {
+        load_row_ids(expert_idx, true);
+    } else {
+        load_row_ids(expert_idx, false);
     }
 
-    barrier();
-
-    _ne1 = _ne1_sh;
-
     // Workgroup has no work
     if (ic * BN >= _ne1) return;
 #endif