]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
vulkan: optimizations for deepseek prompt processing (llama/14555)
authorJeff Bolz <redacted>
Sat, 12 Jul 2025 09:51:58 +0000 (04:51 -0500)
committerGeorgi Gerganov <redacted>
Sat, 12 Jul 2025 13:05:00 +0000 (16:05 +0300)
* vulkan: allow unclamped loads in coopmat2 mul_mat_id shader

* vulkan: increase coopmat2 mul_mat_id tile size

* vulkan: optimize mat_mul_id row_ids search to batch loads, and port to coopmat1 path

* vulkan: use smaller FA row size when head size is large. applies to both scalar and CM2 paths (CM1 isn't used due to shared memory limits)

src/ggml-vulkan/ggml-vulkan.cpp
src/ggml-vulkan/vulkan-shaders/mul_mm.comp
src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp

index c36e1a6d3bfc26f91028d7b4b5fbbe1f01945a7f..cdddf868fbe29cd53e099b01ecd174446cfd52f6 100644 (file)
@@ -1735,7 +1735,14 @@ static FaHeadSizes fa_get_head_sizes(uint32_t hsk, uint32_t hsv) {
 // number of rows/cols for flash attention shader
 static constexpr uint32_t flash_attention_num_small_rows = 32;
 static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
-static constexpr uint32_t scalar_flash_attention_num_large_rows = 8;
+
+static uint32_t get_fa_scalar_num_large_rows(uint32_t hsv) {
+    if (hsv >= 512) {
+        return 2;
+    } else {
+        return 8;
+    }
+}
 
 // The FA coopmat1 shader assumes 16x16x16 matrix multiply support.
 // 128 threads split into four subgroups, each subgroup does 1/4
@@ -1760,7 +1767,7 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
         if (small_rows) {
             return {scalar_flash_attention_num_small_rows, 64};
         } else {
-            return {scalar_flash_attention_num_large_rows, 32};
+            return {get_fa_scalar_num_large_rows(hsv), 32};
         }
     }
 
@@ -1779,7 +1786,11 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
 
     // small cols to reduce register count
     if (ggml_is_quantized(type) || hsk >= 256) {
-        return {64, 32};
+        if (hsk >= 512) {
+            return {32, 32};
+        } else {
+            return {64, 32};
+        }
     }
     return {64, 64};
 }
@@ -1821,7 +1832,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
     const uint32_t warps = warptile[0] / warptile[10];
 
     const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
-    const uint32_t mmid_row_ids = mul_mat_id ? 4096 * sizeof(uint32_t) : 0;
+    const uint32_t mmid_row_ids = mul_mat_id ? (4096 * sizeof(uint32_t) + 4/*_ne1*/) : 0;
     const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
 
     const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size;
@@ -1946,10 +1957,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
         s_mmq_wg_denoms_k = { 32,  32, 1 };
 
         // spec constants and tile sizes for quant matmul_id
-        l_warptile_mmqid = { 256, 128, 64, 16, 0 };
+        l_warptile_mmqid = { 256, 128, 128, 16, 0 };
         m_warptile_mmqid = { 256, 128, 64, 16, 0 };
         s_warptile_mmqid = { 256, 128, 64, 16, 0 };
-        l_mmqid_wg_denoms = { 128, 64, 1 };
+        l_mmqid_wg_denoms = { 128, 128, 1 };
         m_mmqid_wg_denoms = { 128, 64, 1 };
         s_mmqid_wg_denoms = { 128, 64, 1 };
 
@@ -6048,7 +6059,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
     // Needs to be kept up to date on shader changes
     GGML_UNUSED(hsv);
     const uint32_t wg_size = scalar_flash_attention_workgroup_size;
-    const uint32_t Br = scalar_flash_attention_num_large_rows;
+    const uint32_t Br = get_fa_scalar_num_large_rows(hsv);
     const uint32_t Bc = scalar_flash_attention_Bc;
 
     const uint32_t tmpsh = wg_size * sizeof(float);
@@ -6173,7 +6184,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
     case FA_SCALAR:
     case FA_COOPMAT1:
         // We may switch from coopmat1 to scalar, so use the scalar limit for both
-        max_gqa = scalar_flash_attention_num_large_rows;
+        max_gqa = get_fa_scalar_num_large_rows(HSV);
         break;
     case FA_COOPMAT2:
         max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
index 888ce79f6ec113ac5bf47fe6fce154c086ab21ad..f481549911b92bb02ce4727ffce565c71345bf9e 100644 (file)
@@ -18,6 +18,7 @@
 #extension GL_KHR_cooperative_matrix : enable
 #extension GL_KHR_memory_scope_semantics : enable
 #extension GL_KHR_shader_subgroup_basic : enable
+#extension GL_KHR_shader_subgroup_ballot : enable
 #endif
 
 #ifdef MUL_MAT_ID
@@ -104,6 +105,10 @@ shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
 
 #ifdef MUL_MAT_ID
 shared u16vec2 row_ids[4096];
+uint _ne1;
+#ifdef COOPMAT
+shared uint _ne1_sh;
+#endif
 #endif // MUL_MAT_ID
 
 #define NUM_WARPS (BLOCK_SIZE / WARP)
@@ -172,7 +177,47 @@ void main() {
     const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK;
 
 #ifdef MUL_MAT_ID
-    uint _ne1 = 0;
+#ifdef COOPMAT
+    // Spread the search across all elements in the first subgroup
+    if (gl_SubgroupID == 0) {
+        _ne1 = 0;
+        uint num_elements = p.nei1 * p.nei0;
+
+        uint ids[16];
+        uint iter = 0;
+
+        for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
+            // prefetch up to 16 elements
+            if (iter == 0) {
+                [[unroll]] for (uint k = 0; k < 16; ++k) {
+                    uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
+                    bool in_range = i < num_elements;
+                    uint ii1 = i / p.nei0;
+                    uint ii0 = i % p.nei0;
+                    ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
+                }
+            }
+            uint i = j + gl_SubgroupInvocationID;
+            bool in_range = i < num_elements;
+            uint ii1 = i / p.nei0;
+            uint ii0 = i % p.nei0;
+            uint id = ids[iter++];
+            uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
+            uint idx = subgroupBallotExclusiveBitCount(ballot);
+            if (in_range && id == expert_idx) {
+                row_ids[_ne1 + idx] = u16vec2(ii0, ii1);
+            }
+            _ne1 += subgroupBallotBitCount(ballot);
+            iter &= 15;
+        }
+        _ne1_sh = _ne1;
+    }
+
+    barrier();
+
+    _ne1 = _ne1_sh;
+#else
+    _ne1 = 0;
     for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
         for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
             if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
@@ -183,6 +228,7 @@ void main() {
     }
 
     barrier();
+#endif
 
     // Workgroup has no work
     if (ic * BN >= _ne1) return;
index 9184657573281455975fc70eaae438e601327822..29e4b5c9ce2d4141b66819a86ad9c96197b490b3 100644 (file)
@@ -162,17 +162,32 @@ void main() {
         _ne1 = 0;
         uint num_elements = p.nei1 * p.nei0;
 
-        for (uint i = gl_SubgroupInvocationID; subgroupAny(i < num_elements); i += gl_SubgroupSize) {
+        uint ids[16];
+        uint iter = 0;
+
+        for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
+            // prefetch up to 16 elements
+            if (iter == 0) {
+                [[unroll]] for (uint k = 0; k < 16; ++k) {
+                    uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
+                    bool in_range = i < num_elements;
+                    uint ii1 = i / p.nei0;
+                    uint ii0 = i % p.nei0;
+                    ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
+                }
+            }
+            uint i = j + gl_SubgroupInvocationID;
             bool in_range = i < num_elements;
-            uint ii0 = i % p.nei0;
             uint ii1 = i / p.nei0;
-            uint id = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
+            uint ii0 = i % p.nei0;
+            uint id = ids[iter++];
             uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
             uint idx = subgroupBallotExclusiveBitCount(ballot);
             if (in_range && id == expert_idx) {
                 row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0);
             }
             _ne1 += subgroupBallotBitCount(ballot);
+            iter &= 15;
         }
         _ne1_sh = _ne1;
     }
@@ -414,17 +429,31 @@ void main() {
                 fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
             }
 
-            coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
-            coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
+            if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
+                coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
+                coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
 
-            coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
+                coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
 #ifdef MUL_MAT_ID
-            coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
+                coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
 #else
-            coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
+                coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
 #endif
 
-            sum = coopMatMulAdd(mat_a, mat_b, sum);
+                sum = coopMatMulAdd(mat_a, mat_b, sum);
+            } else {
+                coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
+                coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
+
+                coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
+#ifdef MUL_MAT_ID
+                coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
+#else
+                coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
+#endif
+
+                sum = coopMatMulAdd(mat_a, mat_b, sum);
+            }
         }
 
         // Convert from ACC_TYPE to D_TYPE