]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
vulkan: Pad N dimension of B matrix for coopmat2 perf, to avoid bounds checking ...
authorJeff Bolz <redacted>
Mon, 17 Mar 2025 09:41:59 +0000 (04:41 -0500)
committerGeorgi Gerganov <redacted>
Thu, 27 Mar 2025 09:06:03 +0000 (11:06 +0200)
* vulkan: Pad N dimension of B matrix for coopmat2 perf, to avoid bounds checking

ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp

index e46007a52f56e74c4963b2421b42dab671c028d9..a837b0dda4cbf60b4941d745f9da6772d1309115 100644 (file)
@@ -29,6 +29,7 @@
 
 #include "ggml-vulkan-shaders.hpp"
 
+#define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1))
 #define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
 
 #define VK_VENDOR_ID_AMD 0x1002
@@ -368,6 +369,7 @@ struct vk_mat_mat_push_constants {
     uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
     uint32_t k_split;
     uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
+    uint32_t padded_N;
 };
 struct vk_mat_vec_push_constants {
     uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
@@ -380,6 +382,7 @@ struct vk_mat_mat_id_push_constants {
     uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
     uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
     uint32_t nei0; uint32_t nei1; uint32_t nbi1; uint32_t ne11;
+    uint32_t padded_N;
 };
 struct vk_mat_vec_id_push_constants {
     uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
@@ -3882,18 +3885,19 @@ static void ggml_vk_matmul(
         vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer,
         uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
         uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
-        uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3) {
+        uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3,
+        uint32_t padded_n) {
         VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ")");
     ggml_vk_sync_buffers(subctx);
     if (split_k == 1) {
-        const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3 };
+        const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
         ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, sizeof(vk_mat_mat_push_constants), &pc, { m, n, batch });
         return;
     }
 
     GGML_ASSERT(batch_stride_d == m * n);
 
-    const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3 };
+    const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3, padded_n };
     // Make sure enough workgroups get assigned for split k to work
     ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, sizeof(vk_mat_mat_push_constants), &pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
     ggml_vk_sync_buffers(subctx);
@@ -3937,14 +3941,15 @@ static void ggml_vk_matmul_id(
         vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids,
         uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
         uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
-        uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11) {
+        uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11,
+        uint32_t padded_n) {
     VK_LOG_DEBUG("ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), " <<
         "m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " <<
         "batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " <<
         "n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")");
     ggml_vk_sync_buffers(subctx);
     const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d,
-                                              nei0, nei1, nbi1, ne11 };
+                                              nei0, nei1, nbi1, ne11, padded_n };
     ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, sizeof(vk_mat_mat_id_push_constants), &pc, { m, nei1, n_as });
 }
 
@@ -4106,15 +4111,17 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
     // Not implemented
     GGML_ASSERT(y_non_contig || !qy_needs_dequant);  // NOLINT
 
-    const int x_ne = ne01 * ne00;
-    const int y_ne = ne11 * ne10;
-    const int d_ne = ne11 * ne01;
-
     const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
     const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
 
     vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
 
+    // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
+    uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
+    const int x_ne = ne01 * ne00;
+    const int y_ne = padded_n * ne10;
+    const int d_ne = ne11 * ne01;
+
     const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline);
 
     const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
@@ -4237,7 +4244,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
         { d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k },
         ne01, ne11, ne10,
         ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21,
-        split_k, ne12*ne13, ne02, ne12, r2, r3
+        split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n
     );  // NOLINT
 }
 
@@ -4688,15 +4695,17 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
     // Not implemented
     GGML_ASSERT(y_non_contig || !qy_needs_dequant);  // NOLINT
 
-    const uint64_t x_ne = ne01 * ne00;
-    const uint64_t y_ne = ne11 * ne10;
-    const uint64_t d_ne = ne21 * ne20;
-
     const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
     const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
 
     vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
 
+    // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
+    uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
+    const uint64_t x_ne = ne01 * ne00;
+    const uint64_t y_ne = padded_n * ne10;
+    const uint64_t d_ne = ne21 * ne20;
+
     const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
     const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
     const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
@@ -4815,7 +4824,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
         { d_D, d_buf_offset, d_sz * ne22 * ne23 }, { d_ids, ids_buf_offset, ids_sz },
         ne01, ne21, ne10, ne10, ne10, ne01,
         stride_batch_x, stride_batch_y, ne20*ne21,
-        n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11
+        n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11, padded_n
     );  // NOLINT
 }
 
@@ -6775,7 +6784,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
             ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k),
             m, n, k,
             k, k, m, k*m, k*n, m*n,
-            split_k, batch, batch, batch, 1, 1
+            split_k, batch, batch, batch, 1, 1, n
         );
     }
     ggml_vk_ctx_end(subctx);
@@ -7120,7 +7129,7 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
             ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k),
             m, n, k,
             k, k, m, k*m, k*n, m*n,
-            split_k, batch, batch, batch, 1, 1
+            split_k, batch, batch, batch, 1, 1, n
         );
     }
     ggml_vk_ctx_end(subctx);
index 66dd2c860d82dce85a83b57abebee0cbb2c0fbb3..5b7a4efe2ca8e1655aaf1932cf344bcb9ac3e905 100644 (file)
@@ -48,6 +48,8 @@ layout (push_constant) uniform parameter
     uint broadcast2;
     uint broadcast3;
 #endif
+    // N dimension for the B matrix can be >= p.N
+    uint padded_N;
 } p;
 
 
@@ -202,18 +204,19 @@ void main() {
 #endif
 
     // Use end_k rather than p.K as the dimension because that's what
-    // we need to bound check against when using split_k
+    // we need to bound check against when using split_k.
+    // Bounds check B against padded_N, but bounds check D against N.
     tensorLayoutA = setTensorLayoutDimensionNV(tensorLayoutA, p.M, end_k);
-    tensorLayoutB = setTensorLayoutDimensionNV(tensorLayoutB, p.N, end_k);
+    tensorLayoutB = setTensorLayoutDimensionNV(tensorLayoutB, p.padded_N, end_k);
     tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.N, p.M);
     tensorLayoutAClamp = setTensorLayoutDimensionNV(tensorLayoutAClamp, p.M, end_k);
-    tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.N, end_k);
+    tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.padded_N, end_k);
 
     tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
 
 #if !defined(MUL_MAT_ID)
     // Detect a fast path where all loads are entirely in bounds and no clamping is required
-    if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.N && (start_k % BK) == 0 && (end_k % BK) == 0 &&
+    if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.padded_N && (start_k % BK) == 0 && (end_k % BK) == 0 &&
 #if QUANT_K == 1
         (stride_a % 8) == 0 &&
 #endif
@@ -263,7 +266,7 @@ void main() {
 #ifdef MUL_MAT_ID
             bool unclampedB = true;
 #else
-            bool unclampedB = (ic + 1) * BN <= p.N && block_k + BK <= end_k && (block_k % 8) == 0;
+            bool unclampedB = (ic + 1) * BN <= p.padded_N && block_k + BK <= end_k && (block_k % 8) == 0;
 #endif
             if (unclampedA && unclampedB) {
                 coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA);