]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
vulkan: support softmax/FA batch and broadcast (llama/14449)
authorJeff Bolz <redacted>
Tue, 1 Jul 2025 08:32:56 +0000 (03:32 -0500)
committerGeorgi Gerganov <redacted>
Sat, 12 Jul 2025 13:05:00 +0000 (16:05 +0300)
src/ggml-vulkan/ggml-vulkan.cpp
src/ggml-vulkan/vulkan-shaders/flash_attn.comp
src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp
src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp
src/ggml-vulkan/vulkan-shaders/soft_max.comp

index b8e25ba20980a5cbaadb387b1e74e10febbe5e39..25f70127a62c562d74eaa4687fc74e69456c3d4c 100644 (file)
@@ -633,6 +633,7 @@ struct vk_flash_attn_push_constants {
     uint32_t nev2;
     uint32_t nev3;
     uint32_t nem1;
+    uint32_t nem2;
 
     uint32_t nb01;
     uint32_t nb02;
@@ -643,7 +644,6 @@ struct vk_flash_attn_push_constants {
     uint32_t nb21;
     uint32_t nb22;
     uint32_t nb23;
-    uint32_t nb31;
 
     float scale;
     float max_bias;
@@ -658,6 +658,7 @@ struct vk_flash_attn_push_constants {
     uint32_t split_kv;
     uint32_t k_num;
 };
+static_assert(sizeof(vk_flash_attn_push_constants) <= 128, "sizeof(vk_flash_attn_push_constants) must be <= 128");
 
 struct vk_op_push_constants {
     uint32_t KX;
@@ -756,6 +757,14 @@ struct vk_op_rope_push_constants {
 struct vk_op_soft_max_push_constants {
     uint32_t KX;
     uint32_t KY;
+    uint32_t ne00;
+    uint32_t ne01;
+    uint32_t ne02;
+    uint32_t ne12;
+    uint32_t ne13;
+    uint32_t nb11;
+    uint32_t nb12;
+    uint32_t nb13;
     float scale;
     float max_bias;
     float m0;
@@ -6040,7 +6049,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
     GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)
 
     const uint32_t nem1 = mask ? mask->ne[1] : 0;
-    const uint32_t nbm1 = mask ? mask->nb[1] : 0;
+    const uint32_t nem2 = mask ? mask->ne[2] : 0;
 
     const uint32_t D = neq0;
     uint32_t N = neq1;
@@ -6203,7 +6212,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
     // Try to use split_k when KV is large enough to be worth the overhead
     if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
         // Try to run two workgroups per SM.
-        split_k = ctx->device->shader_core_count * 2 / workgroups_y;
+        split_k = ctx->device->shader_core_count * 2 / (workgroups_y * workgroups_z);
         if (split_k > 1) {
             // Try to evenly split KV into split_k chunks, but it needs to be a multiple
             // of "align", so recompute split_k based on that.
@@ -6213,9 +6222,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
         }
     }
 
-    // Reserve space for split_k temporaries. For each split, we need to store the O matrix (D x ne1)
-    // and the per-row m and L values (ne1 rows).
-    const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k : 0;
+    // Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
+    // and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
+    const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0;
     if (split_k_size > ctx->device->max_memory_allocation_size) {
         GGML_ABORT("Requested preallocation size is too large");
     }
@@ -6307,11 +6316,10 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
                                               (uint32_t)neq2, (uint32_t)neq3,
                                               (uint32_t)nek2, (uint32_t)nek3,
                                               (uint32_t)nev2, (uint32_t)nev3,
-                                              nem1,
+                                              nem1, nem2,
                                               q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
                                               k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
                                               v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
-                                              nbm1,
                                               scale, max_bias, logit_softcap,
                                               mask != nullptr, n_head_log2, m0, m1,
                                               gqa_ratio, split_kv, split_k };
@@ -6334,13 +6342,13 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
                                     pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
 
         ggml_vk_sync_buffers(subctx);
-        const std::array<uint32_t, 3> pc2 = { D, (uint32_t)ne1, split_k };
+        const std::array<uint32_t, 4> pc2 = { D, (uint32_t)ne1, (uint32_t)ne3, split_k };
         ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
                                     {
                                         vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
                                         vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
                                     },
-                                    pc2, { (uint32_t)ne1, 1, 1 });
+                                    pc2, { (uint32_t)ne1, 1, (uint32_t)ne3 });
     } else {
         ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
                                     {
@@ -7666,7 +7674,13 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
     const uint32_t nrows_x = (uint32_t)ggml_nrows(src0);
     const uint32_t nrows_y = (uint32_t)src0->ne[1];
 
-    const uint32_t n_head_kv   = nrows_x/nrows_y;
+    const uint32_t ne12 = src1 ? (uint32_t)(src1->ne[2]) : 0u;
+    const uint32_t ne13 = src1 ? (uint32_t)(src1->ne[3]) : 0u;
+    const uint32_t nb11 = src1 ? (uint32_t)(src1->nb[1] / src1->nb[0]) : 0u;
+    const uint32_t nb12 = src1 ? (uint32_t)(src1->nb[2] / src1->nb[0]) : 0u;
+    const uint32_t nb13 = src1 ? (uint32_t)(src1->nb[3] / src1->nb[0]) : 0u;
+
+    const uint32_t n_head_kv   = src0->ne[2];
     const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
 
     const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
@@ -7675,6 +7689,9 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
     ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
         ncols,
         src1 != nullptr ? nrows_y : (uint32_t)0,
+        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
+        ne12, ne13,
+        nb11, nb12, nb13,
         scale, max_bias,
         m0, m1,
         n_head_log2,
@@ -10248,11 +10265,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                 if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
                     return false;
                 }
-                // TODO: support broadcast
-                // ref: https://github.com/ggml-org/llama.cpp/pull/14435
-                if (op->src[0]->ne[3] != 1) {
-                    return false;
-                }
                 // It's straightforward to support different K/V dequant, but would
                 // significantly increase the number of pipelines
                 if (op->src[1]->type != op->src[2]->type) {
@@ -10413,13 +10425,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
         case GGML_OP_DIAG_MASK_INF:
             return true;
         case GGML_OP_SOFT_MAX:
-            // TODO: support batching
-            if (op->src[0]->ne[3] != 1) {
-                return false;
-            }
-            // TODO: support broadcast
-            // ref: https://github.com/ggml-org/llama.cpp/pull/14435
-            return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
         case GGML_OP_SOFT_MAX_BACK:
         case GGML_OP_ARGSORT:
         case GGML_OP_SUM:
index ce230a8f7d91038ed84f2ba4841052f110e5fbf7..6f80101d1c4326f5f4632e738d2e3da9d3677c4c 100644 (file)
@@ -99,6 +99,10 @@ void main() {
     uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
     uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
 #endif
+    uint32_t m_offset = 0;
+    if (p.nem2 != 1) {
+        m_offset = (iq3 % p.nem2) * p.nem1 * KV;
+    }
 
     [[dont_unroll]]
     for (uint32_t j = start_j; j < end_j; ++j) {
@@ -150,7 +154,7 @@ void main() {
                 uint32_t c = (idx + tid) % Bc;
                 uint32_t r = (idx + tid) / Bc;
                 if (idx + tid < Bc * Br) {
-                    masksh[c][r] = float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]);
+                    masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
                 }
             }
             barrier();
@@ -277,7 +281,7 @@ void main() {
     // If there is split_k, then the split_k resolve shader does the final
     // division by L. Store the intermediate O value and per-row m and L values.
     if (p.k_num > 1) {
-        uint32_t o_offset = D * p.ne1 * split_k_index;
+        uint32_t o_offset = D * p.ne1 * (split_k_index + iq3 * p.k_num);
 
         [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
             if (r < N) {
@@ -289,7 +293,7 @@ void main() {
             }
         }
 
-        o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
+        o_offset = D * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
         [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
             if (r < N) {
                 perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
@@ -311,7 +315,7 @@ void main() {
         }
     }
 
-    uint32_t o_offset = iq3*p.ne2*p.ne1;
+    uint32_t o_offset = iq3*p.ne2*p.ne1*D;
 
     if (p.gqa_ratio > 1) {
         [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
index 61d90e2d8ed219448c2d9cd710fe6165b399b4b0..67b935cf57e50d5e50288e0e16339e92002afb4b 100644 (file)
@@ -24,6 +24,7 @@ layout (push_constant) uniform parameter {
     uint32_t nev2;
     uint32_t nev3;
     uint32_t nem1;
+    uint32_t nem2;
 
     uint32_t nb01;
     uint32_t nb02;
@@ -34,7 +35,6 @@ layout (push_constant) uniform parameter {
     uint32_t nb21;
     uint32_t nb22;
     uint32_t nb23;
-    uint32_t nb31;
 
     float scale;
     float max_bias;
index da478be24fb6e7298ddff54438bf00313075a9c2..26fe50c7a81e58080d324c9d5e92512def9253b0 100644 (file)
@@ -123,6 +123,10 @@ void main() {
     uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
     uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
 #endif
+    uint32_t m_offset = 0;
+    if (p.nem2 != 1) {
+        m_offset = (iq3 % p.nem2) * p.nem1 * KV;
+    }
 
     [[dont_unroll]]
     for (uint32_t j = start_j; j < end_j; ++j) {
@@ -181,7 +185,7 @@ void main() {
                 uint32_t c = (idx + tid) % Bc;
                 uint32_t r = (idx + tid) / Bc;
                 if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
-                    sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]));
+                    sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
                 }
             }
             barrier();
@@ -300,7 +304,7 @@ void main() {
     // If there is split_k, then the split_k resolve shader does the final
     // division by L. Store the intermediate O value and per-row m and L values.
     if (p.k_num > 1) {
-        uint32_t o_offset = D * p.ne1 * split_k_index;
+        uint32_t o_offset = D * p.ne1 * (split_k_index + iq3 * p.k_num);
 
         [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
             if (tile_row(r) < N) {
@@ -312,7 +316,7 @@ void main() {
             }
         }
 
-        o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
+        o_offset = D * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
         [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
             if (tile_row(r) < N) {
                 perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
@@ -334,7 +338,7 @@ void main() {
         }
     }
 
-    uint32_t o_offset = iq3*p.ne2*p.ne1;
+    uint32_t o_offset = iq3*p.ne2*p.ne1*D;
 
     if (p.gqa_ratio > 1) {
         [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
index 6acf67a03a46351fbfd1d28540f951f02ef5648e..cf47bd53e28bb68fe0a4e12afa1ae35d24e378fe 100644 (file)
@@ -130,6 +130,11 @@ void main() {
         coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
     }
 
+    uint32_t m_offset = 0;
+    if (p.nem2 != 1) {
+        m_offset = (iq3 % p.nem2) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
+    }
+
     [[dont_unroll]]
     for (uint32_t j = start_j; j < end_j; ++j) {
 
@@ -155,7 +160,7 @@ void main() {
 
             coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
 
-            coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
+            coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
 
             S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
         }
@@ -229,10 +234,10 @@ void main() {
     if (p.k_num > 1) {
         coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
 
-        uint32_t o_offset = D * p.ne1 * split_k_index;
+        uint32_t o_offset = D * p.ne1 * (split_k_index + iq3 * p.k_num);
         coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
 
-        o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
+        o_offset = D * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
         coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
         coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
         return;
@@ -250,7 +255,7 @@ void main() {
 
     O = Ldiag*O;
 
-    uint32_t o_offset = iq3*p.ne2*p.ne1;
+    uint32_t o_offset = iq3*p.ne2*p.ne1*D;
 
     coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
     if (p.gqa_ratio > 1) {
index a7e3956854c442dc56b53779932c55c6d4eba610..599cef072e931e3a503a3512901f484357ebad1b 100644 (file)
@@ -12,6 +12,7 @@ layout (binding = 1) writeonly buffer D {float data_d[];};
 layout (push_constant) uniform parameter {
     uint D;
     uint N;
+    uint ne3;
     uint k_num;
 } p;
 
@@ -19,13 +20,14 @@ void main() {
     // Each workgroup handles a row
     const uint n = gl_WorkGroupID.x;
     const uint tid = gl_LocalInvocationID.x;
+    const uint iq3 = gl_WorkGroupID.z;
 
     uint D = p.D;
     uint N = p.N;
     uint k_num = p.k_num;
 
-    uint l_offset = D * N * k_num + n;
-    uint m_offset = D * N * k_num + N + n;
+    uint l_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + n;
+    uint m_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + N + n;
     uint lm_stride = N * 2;
 
     // Compute the max m value for the row
@@ -49,11 +51,11 @@ void main() {
     for (uint d = tid; d < D; d += BLOCK_SIZE) {
         float O = 0.0;
         [[unroll]] for (uint k = 0; k < k_num; ++k) {
-            uint o_offset = D * N * k + D * n + d;
+            uint o_offset = D * N * (k + iq3 * k_num) + D * n + d;
             float m = data_a[m_offset + k * lm_stride];
             O += exp(m - m_max) * data_a[o_offset];
         }
         O *= L;
-        data_d[D * n + d] = O;
+        data_d[iq3 * D * N + D * n + d] = O;
     }
 }
index 51fc2dc7ed406cf7a87401ed3913f8ae2241a95f..5bcd3b1e3ddc67989f95b8971041ddef2647cced 100644 (file)
@@ -6,6 +6,14 @@ layout (push_constant) uniform parameter
 {
     uint KX;
     uint KY;
+    uint ne00;
+    uint ne01;
+    uint ne02;
+    uint ne12;
+    uint ne13;
+    uint nb11;
+    uint nb12;
+    uint nb13;
     float scale;
     float max_bias;
     float m0;
@@ -31,7 +39,15 @@ shared FLOAT_TYPE vals[BLOCK_SIZE];
 void soft_max(uint num_iters) {
     const uint tid = gl_LocalInvocationID.x;
     const uint rowx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
-    const uint rowy = (p.KY > 0) ? (rowx % p.KY) : 0;
+
+    const uint32_t i03 = rowx / (p.ne01 * p.ne02);
+    const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01;
+    const uint32_t i01 = rowx % p.ne01;
+
+    uint rowy_start = 0;
+    if (p.KY > 0) {
+        rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13;
+    }
 
     if (rowx >= p.nrows_x) {
         return;
@@ -41,7 +57,7 @@ void soft_max(uint num_iters) {
 
     // ALiBi
     if (p.max_bias > 0.0f) {
-        const uint h = rowx/p.KY; // head index
+        const uint h = (rowx / p.ne01) % p.ne02; // head index
 
         const float base = h < p.n_head_log2 ? p.m0 : p.m1;
         const uint   exp  = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1;
@@ -67,7 +83,7 @@ void soft_max(uint num_iters) {
 
         FLOAT_TYPE b = FLOAT_TYPE(0);
         if (p.KY > 0 && col < p.KX) {
-            b = data_b[rowy * p.KX + col];
+            b = data_b[rowy_start + col];
         }
 
         FLOAT_TYPE v = a * p.scale + slope * b;
@@ -111,7 +127,7 @@ void soft_max(uint num_iters) {
         if (idx < DATA_CACHE_SIZE) {
             val = exp(data_cache[idx] - max_val);
         } else {
-            val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val);
+            val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy_start + col]) : FLOAT_TYPE(0.0f)) - max_val);
         }
         sum += val;
         if (idx < DATA_CACHE_SIZE) {