]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vulkan: support arbitrary KV dimension in flash attention (#16160)
authorJeff Bolz <redacted>
Sat, 27 Sep 2025 20:43:39 +0000 (16:43 -0400)
committerGitHub <redacted>
Sat, 27 Sep 2025 20:43:39 +0000 (22:43 +0200)
The "Clamp" spec constant is already based on whether KV is a multiple of Bc,
so use that to control whether bounds checking is performed. Add bounds checking
to the scalar and coopmat1 paths. Coopmat2 didn't need any changes (the K/V
tensors are already optionally clamped, nothing else needed to be changed).

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp

index 482445c6fea2c4c86a2edc9cbbdcafc998018763..43b906e5ed96ddfadf3823c3c6454eedf24ef854 100644 (file)
@@ -117,6 +117,9 @@ void main() {
 
 
         [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+            if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
+                continue;
+            }
             [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
 #if BLOCK_SIZE > 1
                 uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
@@ -155,7 +158,11 @@ void main() {
                 uint32_t c = (idx + tid) % Bc;
                 uint32_t r = (idx + tid) / Bc;
                 if (idx + tid < Bc * Br) {
-                    masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
+                    if (!KV_bounds_check || j * Bc + c < KV) {
+                        masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
+                    } else {
+                        masksh[c][r] = float(0);
+                    }
                 }
             }
             barrier();
@@ -172,8 +179,11 @@ void main() {
 
         float rowmaxf[Br], Pf[Br][cols_per_thread], rowsumf[Br], eMf[Br], Moldf[Br];
         [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-            rowmaxf[r] = Sf[r][0];
+            rowmaxf[r] = NEG_FLT_MAX_OVER_2;
             [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+                if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
+                    continue;
+                }
                 rowmaxf[r] = max(rowmaxf[r], Sf[r][c]);
             }
             Moldf[r] = Mf[r];
@@ -190,6 +200,9 @@ void main() {
             // Compute sum across row of P
             rowsumf[r] = 0.0;
             [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+                if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
+                    continue;
+                }
                 rowsumf[r] += Pf[r][c];
             }
 
@@ -203,6 +216,9 @@ void main() {
         }
 
         [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+            if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
+                continue;
+            }
             [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
 #if BLOCK_SIZE > 1
                 uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
index f73e17e1fa8d91a93e4b8bf4e87175c562a1d855..e80eff27815ae36323e769f215872076eb8b76b7 100644 (file)
@@ -13,6 +13,8 @@ layout (constant_id = 6) const uint32_t D_split = 16;
 const uint32_t HSK_pad = (HSK + 15) & ~15;
 const uint32_t HSV_pad = (HSV + 15) & ~15;
 
+const bool KV_bounds_check = Clamp != 0;
+
 layout (push_constant) uniform parameter {
     uint32_t N;
     uint32_t KV;
index 63b32171b0c07076967faf4a272e048d52506b84..ddb1246e0ba7c77ab8047a15ee07ad769eb8d28d 100644 (file)
@@ -152,14 +152,17 @@ void main() {
             uint32_t d = (idx + tid) % (HSK / 4);
             uint32_t c = (idx + tid) / (HSK / 4);
             if (c < Bc && d < HSK / 4) {
+                f16vec4 K_Tf = f16vec4(0);
+                if (!KV_bounds_check || j * Bc + c < KV) {
 #if BLOCK_SIZE > 1
-                uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
-                uint ib = coord / BLOCK_SIZE;
-                uint iqs = (coord % BLOCK_SIZE);
-                f16vec4 K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
+                    uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
+                    uint ib = coord / BLOCK_SIZE;
+                    uint iqs = (coord % BLOCK_SIZE);
+                    K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
 #else
-                f16vec4 K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
+                    K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
 #endif
+                }
 
                 ksh[c * kshstride + d] = K_Tf;
             }
@@ -202,7 +205,9 @@ void main() {
                 uint32_t c = (idx + tid) % Bc;
                 uint32_t r = (idx + tid) / Bc;
                 if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
-                    sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
+                    if (!KV_bounds_check || j * Bc + c < KV) {
+                        sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
+                    }
                 }
             }
             barrier();
@@ -210,8 +215,11 @@ void main() {
 
         float eMf[rows_per_thread];
         [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
-            float rowmaxf = sfsh[tile_row(r) + (0 * cols_per_iter + col_tid) * sfshstride];
+            float rowmaxf = NEG_FLT_MAX_OVER_2;
             [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+                if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
+                    continue;
+                }
                 rowmaxf = max(rowmaxf, float(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride]));
             }
             float Moldf = Mf[r];
@@ -233,6 +241,9 @@ void main() {
         }
 
         [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+            if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
+                continue;
+            }
             float Pf[rows_per_thread];
             [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
                 Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]);