]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
vulkan: support solve_tri with larger N/K values (llama/17781)
authorJeff Bolz <redacted>
Sat, 6 Dec 2025 07:56:45 +0000 (01:56 -0600)
committerGeorgi Gerganov <redacted>
Fri, 12 Dec 2025 15:53:20 +0000 (17:53 +0200)
Split N into chunks to fit into shared memory.
If K > 128, use a larger workgroup with enough invocations.
Add perf tests matching qwen3next.

ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp

index b40b356ce215101d0d67aab549df6a85eb583b35..6a52f1342e600b94773ae8d7b05409fde36e23a6 100644 (file)
@@ -4033,10 +4033,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
 
     for (auto &s : device->pipeline_solve_tri_f32) {
         const vk_solve_tri_pipeline_state &state = s.first;
+
+        // Max number of rows to load at a time, limited by shared memory
+        const uint32_t batch_N = device->properties.limits.maxComputeSharedMemorySize / ((state.N + state.K) * sizeof(float));
+        // Need at least K invocations, and prefer a minimum of 128 to spread out loading shared memory
+        const uint32_t block_size = std::max(128u, 1u << (uint32_t)ceilf(log2f(float(state.K))));
+
         ggml_vk_create_pipeline(
             device, s.second, "solve_tri_f32",
             solve_tri_f32_len, solve_tri_f32_data, "main", 3,
-            sizeof(vk_op_binary_push_constants), {1, 1, 1}, { 0, state.N, state.K }, 1, true);
+            sizeof(vk_op_binary_push_constants), {1, 1, 1}, { 0, state.N, state.K, batch_N, block_size }, 1, true);
     }
 
 #define IM2COL(bda) \
@@ -14025,10 +14031,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                 const uint32_t N = op->src[0]->ne[0];
                 const uint32_t K = op->src[1]->ne[0];
                 // K dimension limited to workgroup size
-                if (K > 128) {
+                if (K > 1u << device->max_workgroup_size_log2) {
                     return false;
                 }
-                if (N * N * sizeof(float) + N * K * sizeof(float) > device->properties.limits.maxComputeSharedMemorySize) {
+                const uint32_t batch_N = device->properties.limits.maxComputeSharedMemorySize / ((N + K) * sizeof(float));
+
+                if (batch_N == 0) {
                     return false;
                 }
                 return true;
index 253a9e7efee523a1fea7921634acc1dcfeda69af..3b65145032ccfd5f5f3c5d3e261c33e09a22140a 100644 (file)
@@ -5,8 +5,9 @@
 
 layout (constant_id = 1) const uint N = 64;
 layout (constant_id = 2) const uint K = 32;
+layout (constant_id = 3) const uint BATCH_N = 32;
 
-layout(local_size_x = 128, local_size_y = 1, local_size_z = 1) in;
+layout(local_size_x_id = 4, local_size_y = 1, local_size_z = 1) in;
 
 uint a_base, b_base, x_base;
 
@@ -22,8 +23,8 @@ void store_x(uint r, uint c, FLOAT_TYPE v) {
     data_d[x_base + r * p.nb21 + c * p.nb20] = D_TYPE(v);
 }
 
-shared FLOAT_TYPE shA[N * N];
-shared FLOAT_TYPE shB[N * K];
+shared FLOAT_TYPE shA[BATCH_N * N];
+shared FLOAT_TYPE shB[BATCH_N * K];
 
 void main() {
     const uint batch = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
@@ -39,34 +40,42 @@ void main() {
     b_base = get_boffset() + i2 * p.nb12 + i3 * p.nb13;
     x_base = get_doffset() + i2 * p.nb22 + i3 * p.nb23;
 
-    // Load the A matrix into shA
-    [[unroll]] for (uint i = 0; i < N * N; i += gl_WorkGroupSize.x) {
-        uint idx = i + tid;
-        if (((N * N) % gl_WorkGroupSize.x == 0) || idx < N * N) {
-            shA[idx] = get_a(idx / N, idx % N);
+    FLOAT_TYPE X[N];
+
+    // Loop over batches of rows
+    [[unroll]] for (uint row_base = 0; row_base < N; row_base += BATCH_N) {
+        const uint cur_N = min(BATCH_N, N - row_base);
+
+        // Load the A matrix batch into shA
+        [[unroll]] for (uint i = 0; i < cur_N * N; i += gl_WorkGroupSize.x) {
+            uint idx = i + tid;
+            if (((cur_N * N) % gl_WorkGroupSize.x == 0) || idx < cur_N * N) {
+                shA[idx] = get_a(row_base + idx / N, idx % N);
+            }
         }
-    }
-    // Load the B matrix into shB
-    [[unroll]] for (uint i = 0; i < N * K; i += gl_WorkGroupSize.x) {
-        uint idx = i + tid;
-        if (((N * K) % gl_WorkGroupSize.x == 0) || idx < N * K) {
-            shB[idx] = get_b(idx / K, idx % K);
+        // Load the B matrix batch into shB
+        [[unroll]] for (uint i = 0; i < cur_N * K; i += gl_WorkGroupSize.x) {
+            uint idx = i + tid;
+            if (((cur_N * K) % gl_WorkGroupSize.x == 0) || idx < cur_N * K) {
+                shB[idx] = get_b(row_base + idx / K, idx % K);
+            }
         }
-    }
-    barrier();
+        barrier();
 
-    FLOAT_TYPE X[N];
-    // Each thread solves one column
-    if (tid < K) {
-        [[unroll]] for (int r = 0; r < N; ++r) {
-            FLOAT_TYPE b = shB[r * K + tid];
-            // Compute x[r,c] = (b[r,c] - sum(a[r,c]*x[c])) / a[r,r]
-            [[unroll]] for (int c = 0; c < r; ++c) {
-                b -= shA[r * N + c] * X[c];
+        // Each thread solves one column
+        if (tid < K) {
+            [[unroll]] for (uint row_offset = 0; row_offset < cur_N; ++row_offset) {
+                uint r = row_base + row_offset;
+                FLOAT_TYPE b = shB[row_offset * K + tid];
+                // Compute x[r,c] = (b[r,c] - sum(a[r,c]*x[c])) / a[r,r]
+                [[unroll]] for (int c = 0; c < r; ++c) {
+                    b -= shA[row_offset * N + c] * X[c];
+                }
+                FLOAT_TYPE x = b / shA[row_offset * N + r];
+                X[r] = x;
+                store_x(r, tid, x);
             }
-            FLOAT_TYPE x = b / shA[r * N + r];
-            X[r] = x;
-            store_x(r, tid, x);
         }
+        barrier();
     }
 }