]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vulkan: make FA mask/softcap enables spec constants (#19309)
authorJeff Bolz <redacted>
Fri, 6 Feb 2026 07:49:58 +0000 (01:49 -0600)
committerGitHub <redacted>
Fri, 6 Feb 2026 07:49:58 +0000 (08:49 +0100)
* vulkan: make FA mask/softcap enables spec constants

* don't specialize for sinks

* bump timeout a little bit

.github/workflows/build.yml
ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

index 8ce679bd9ab2341cb4e5223a7fd16708d5e292fa..51a3dc76e9e9d4250e681e53747ca2a19671d9a7 100644 (file)
@@ -468,7 +468,7 @@ jobs:
           export GGML_VK_VISIBLE_DEVICES=0
           export GGML_VK_DISABLE_F16=1
           # This is using llvmpipe and runs slower than other backends
-          ctest -L main --verbose --timeout 4200
+          ctest -L main --verbose --timeout 4800
 
   ubuntu-24-cmake-webgpu:
     runs-on: ubuntu-24.04
index 4357da24d426db53ea3d39b4800e335b4ac12e8b..72097ffd0ffc73ae6e09ff7783ed20904827e0ae 100644 (file)
@@ -402,19 +402,19 @@ enum FaCodePath {
 };
 
 struct vk_fa_pipeline_state {
-    vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc, bool use_mask_opt)
-        : HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc), use_mask_opt(use_mask_opt) {}
+    vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc, uint32_t flags)
+        : HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc), flags(flags) {}
 
     uint32_t HSK, HSV;
     bool small_rows, small_cache;
     FaCodePath path;
     bool aligned;
     bool f32acc;
-    bool use_mask_opt;
+    uint32_t flags;
 
     bool operator<(const vk_fa_pipeline_state &b) const {
-        return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, use_mask_opt) <
-               std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc, b.use_mask_opt);
+        return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, flags) <
+               std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc, b.flags);
     }
 };
 
@@ -3193,7 +3193,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
         return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache)[0], 1, 1};
     };
 
-    auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache, bool use_mask_opt) -> std::vector<uint32_t> {
+    auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache, uint32_t flags) -> std::vector<uint32_t> {
         // For large number of rows, 128 invocations seems to work best.
         // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
         // can't use 256 for D==80.
@@ -3225,7 +3225,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
         // AMD prefers loading K directly from global memory
         const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 ? 1 : 0;
 
-        return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem, use_mask_opt};
+        return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem, flags};
     };
 
 #define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
@@ -3237,19 +3237,19 @@ static void ggml_vk_load_shaders(vk_device& device) {
             FaCodePath path = fa.first.path; \
             bool aligned = fa.first.aligned; \
             bool f32acc = fa.first.f32acc; \
-            bool use_mask_opt = fa.first.use_mask_opt; \
+            uint32_t flags = fa.first.flags; \
             if (path == FAPATH) { \
                 if (aligned) { \
                     if (f32acc) { \
-                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _data,  "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,use_mask_opt), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0));     \
+                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _data,  "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0));     \
                     } else { \
-                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,use_mask_opt), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0));     \
+                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0));     \
                     } \
                 } else { \
                     if (f32acc) { \
-                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc"         #NAMELC, flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _data,  "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,use_mask_opt), 1,                                        true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0));     \
+                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc"         #NAMELC, flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _data,  "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags), 1,                                        true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0));     \
                     } else { \
-                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc"         #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,use_mask_opt), 1,                                        true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0));     \
+                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc"         #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags), 1,                                        true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0));     \
                     } \
                 } \
             } \
@@ -8595,10 +8595,26 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
 
     bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
 
+    float scale         = 1.0f;
+    float max_bias      = 0.0f;
+    float logit_softcap = 0.0f;
+
+    memcpy(&scale,         (const float *) dst->op_params + 0, sizeof(float));
+    memcpy(&max_bias,      (const float *) dst->op_params + 1, sizeof(float));
+    memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float));
+
+    if (logit_softcap != 0) {
+        scale /= logit_softcap;
+    }
+
     // Only use mask opt when the mask is fairly large. This hasn't been tuned extensively.
     bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768;
 
-    vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, use_mask_opt);
+    uint32_t flags = (use_mask_opt       ? 1 : 0) |
+                     (mask != nullptr    ? 2 : 0) |
+                     (logit_softcap != 0 ? 4 : 0);
+
+    vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, flags);
 
     vk_pipeline pipeline = nullptr;
 
@@ -8678,18 +8694,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
         }
     }
 
-    float scale         = 1.0f;
-    float max_bias      = 0.0f;
-    float logit_softcap = 0.0f;
-
-    memcpy(&scale,         (const float *) dst->op_params + 0, sizeof(float));
-    memcpy(&max_bias,      (const float *) dst->op_params + 1, sizeof(float));
-    memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float));
-
-    if (logit_softcap != 0) {
-        scale /= logit_softcap;
-    }
-
     const uint32_t n_head_kv   = neq2;
     const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
     const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
@@ -8703,7 +8707,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
     vk_subbuffer sinks_buf = sinks ? ggml_vk_tensor_subbuffer(ctx, sinks) : q_buf;
     vk_subbuffer mask_opt_buf = use_mask_opt ? ggml_vk_subbuffer(ctx, ctx->prealloc_y, 0) : q_buf;
 
-    uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | ((mask != nullptr) << 16) | n_head_log2;
+    uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | n_head_log2;
 
     if (use_mask_opt)
     {
index 49a3c530cb6f480e36b1a277c0e5a42ed408cc6d..914f131c96518adc01aa5b096bbd6ef23281c52b 100644 (file)
@@ -127,7 +127,7 @@ void main() {
             continue;
         }
         // Only load if the block is not all zeros
-        if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0 && mask_opt_bits != MASK_OPT_ALL_ZERO) {
+        if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
             bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
 
             [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
@@ -181,7 +181,7 @@ void main() {
             }
         }
 
-        if (p.logit_softcap != 0.0f) {
+        if (LOGIT_SOFTCAP) {
             [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
                 [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
                     Sf[r][c] = p.logit_softcap * tanh(Sf[r][c]);
@@ -189,7 +189,7 @@ void main() {
             }
         }
 
-        if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0 && mask_opt_bits != MASK_OPT_ALL_ZERO) {
+        if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
             [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
                 [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
                     float mvf = masksh[c * cols_per_iter + col_tid][r];
index 252451101ab484add027fe98129475f5088a6290..74005cffb3f689ed4b4e00b5bec42ca1e2d30ade 100644 (file)
@@ -10,7 +10,11 @@ layout (constant_id = 5) const uint32_t Clamp = 0;
 layout (constant_id = 6) const uint32_t D_split = 16;
 layout (constant_id = 7) const uint32_t SubGroupSize = 32;
 layout (constant_id = 8) const uint32_t K_LOAD_SHMEM = 0;
-layout (constant_id = 9) const bool     USE_MASK_OPT = false;
+layout (constant_id = 9) const uint32_t Flags = 0;
+
+const bool USE_MASK_OPT  = (Flags & 1) != 0;
+const bool MASK_ENABLE   = (Flags & 2) != 0;
+const bool LOGIT_SOFTCAP = (Flags & 4) != 0;
 
 // Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths
 const uint32_t HSK_pad = (HSK + 15) & ~15;
@@ -60,7 +64,6 @@ layout (push_constant) uniform parameter {
 } p;
 
 #define SINK_ENABLE_BIT (1<<24)
-#define MASK_ENABLE_BIT (1<<16)
 #define N_LOG2_MASK 0xFFFF
 
 layout (binding = 4) readonly buffer S {float data_s[];};
index 89af3697e1d42303cb0f330390f4096520afc909..b31777382341a5c5c79cb04f67faf60be899f9c5 100644 (file)
@@ -160,7 +160,7 @@ void main() {
             mask_cache[idx] = f16vec4(0);
         }
 
-        if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
+        if (MASK_ENABLE) {
 
             if (USE_MASK_OPT && mask_opt_idx != j / 16) {
                 mask_opt_idx = j / 16;
@@ -303,7 +303,7 @@ void main() {
         coopMatStore(SfMat, sfsh, coord, sfshstride, gl_CooperativeMatrixLayoutRowMajor);
         barrier();
 
-        if (p.logit_softcap != 0.0f) {
+        if (LOGIT_SOFTCAP) {
             [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
                 uint32_t c = (idx + tid) / (Br / 4);
                 uint32_t r = (idx + tid) % (Br / 4);
@@ -314,7 +314,7 @@ void main() {
             barrier();
         }
 
-        if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
+        if (MASK_ENABLE) {
             [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
                 uint32_t c = (idx + tid) / (Br / 4);
                 uint32_t r = (idx + tid) % (Br / 4);
index 47b110621b713060d130064d528761990da37305..b07c21f6e55e4623b9ed2b3573966b4684d2ca44 100644 (file)
@@ -155,7 +155,7 @@ void main() {
     for (uint32_t j = start_j; j < end_j; ++j) {
 
         coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv = coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
-        if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
+        if (MASK_ENABLE) {
 
             if (USE_MASK_OPT && mask_opt_idx != j / 16) {
                 mask_opt_idx = j / 16;
@@ -197,14 +197,14 @@ void main() {
         coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC);
         S = coopMatMulAdd(Qf16, K_T, S);
 
-        if (p.logit_softcap != 0.0f) {
+        if (LOGIT_SOFTCAP) {
             [[unroll]]
             for (int k = 0; k < S.length(); ++k) {
                 S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]);
             }
         }
 
-        if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
+        if (MASK_ENABLE) {
             S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
         }