]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
vulkan: Preprocess FA mask to detect all-neg-inf and all-zero. (llama/19281)
authorJeff Bolz <redacted>
Thu, 5 Feb 2026 15:26:38 +0000 (09:26 -0600)
committerGeorgi Gerganov <redacted>
Sat, 7 Feb 2026 08:37:38 +0000 (10:37 +0200)
Write out a 2-bit code per block and avoid loading the mask when it
matches these two common cases.

Apply this optimization when the mask is relatively large (i.e. prompt
processing).

src/ggml-vulkan/ggml-vulkan.cpp
src/ggml-vulkan/vulkan-shaders/flash_attn.comp
src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl
src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp [new file with mode: 0644]
src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
tests/test-backend-ops.cpp

index ff9cb7355c296dfaba6a4bbd38f90a88d8f1a070..4357da24d426db53ea3d39b4800e335b4ac12e8b 100644 (file)
@@ -402,18 +402,19 @@ enum FaCodePath {
 };
 
 struct vk_fa_pipeline_state {
-    vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc)
-        : HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc) {}
+    vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc, bool use_mask_opt)
+        : HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc), use_mask_opt(use_mask_opt) {}
 
     uint32_t HSK, HSV;
     bool small_rows, small_cache;
     FaCodePath path;
     bool aligned;
     bool f32acc;
+    bool use_mask_opt;
 
     bool operator<(const vk_fa_pipeline_state &b) const {
-        return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc) <
-               std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc);
+        return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, use_mask_opt) <
+               std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc, b.use_mask_opt);
     }
 };
 
@@ -820,6 +821,8 @@ struct vk_device_struct {
 
     std::map<vk_fa_pipeline_state, vk_pipeline> pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT];
 
+    std::map<std::pair<uint32_t, uint32_t>, vk_pipeline> pipeline_fa_mask_opt;
+
     vk_pipeline pipeline_flash_attn_split_k_reduce;
     vk_pipeline pipeline_count_experts;
 
@@ -1549,6 +1552,18 @@ struct vk_op_flash_attn_split_k_reduce_push_constants {
     uint32_t sinks;
 };
 
+struct vk_op_flash_attn_mask_opt_push_constants {
+    uint32_t nem0;
+    uint32_t nem1;
+    uint32_t nem2;
+    uint32_t nbm1;
+    uint32_t nbm2;
+    uint32_t nbm3;
+    uint32_t nbd1;
+    uint32_t nbd2;
+    uint32_t nbd3;
+};
+
 // Allow pre-recording command buffers
 struct vk_staging_memcpy {
     vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@@ -1757,6 +1772,7 @@ class vk_perf_logger {
                 " k(" << k->ne[0] << "," << k->ne[1] << "," << k->ne[2] << "," << k->ne[3] << "), " <<
                 " v(" << v->ne[0] << "," << v->ne[1] << "," << v->ne[2] << "," << v->ne[3] << "), " <<
                 " m(" << (m?m->ne[0]:0) << "," << (m?m->ne[1]:0) << "," << (m?m->ne[2]:0) << "," << (m?m->ne[3]:0) << ")";
+            *n_flops = 2ull * q->ne[1] * q->ne[2] * (k->ne[0] + v->ne[0]) * k->ne[1] * q->ne[3];
             return name.str();
         }
         if (node->op == GGML_OP_TOP_K) {
@@ -3177,7 +3193,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
         return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache)[0], 1, 1};
     };
 
-    auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::vector<uint32_t> {
+    auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache, bool use_mask_opt) -> std::vector<uint32_t> {
         // For large number of rows, 128 invocations seems to work best.
         // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
         // can't use 256 for D==80.
@@ -3209,7 +3225,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
         // AMD prefers loading K directly from global memory
         const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 ? 1 : 0;
 
-        return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem};
+        return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem, use_mask_opt};
     };
 
 #define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
@@ -3221,18 +3237,19 @@ static void ggml_vk_load_shaders(vk_device& device) {
             FaCodePath path = fa.first.path; \
             bool aligned = fa.first.aligned; \
             bool f32acc = fa.first.f32acc; \
+            bool use_mask_opt = fa.first.use_mask_opt; \
             if (path == FAPATH) { \
                 if (aligned) { \
                     if (f32acc) { \
-                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _data,  "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0));     \
+                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _data,  "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,use_mask_opt), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0));     \
                     } else { \
-                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0));     \
+                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,use_mask_opt), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0));     \
                     } \
                 } else { \
                     if (f32acc) { \
-                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc"         #NAMELC, flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _data,  "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1,                                        true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0));     \
+                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc"         #NAMELC, flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _data,  "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,use_mask_opt), 1,                                        true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0));     \
                     } else { \
-                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc"         #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1,                                        true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0));     \
+                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc"         #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,use_mask_opt), 1,                                        true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0));     \
                     } \
                 } \
             } \
@@ -4028,6 +4045,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
 
+    for (auto &it : device->pipeline_fa_mask_opt) {
+        auto BrBc = it.first;
+        ggml_vk_create_pipeline(device, it.second, "fa_mask_opt", fa_mask_opt_len, fa_mask_opt_data, "main", 2, sizeof(vk_op_flash_attn_mask_opt_push_constants), {1, 1, 1}, {128, 128 / device->subgroup_size, BrBc.first, BrBc.second}, 1, true, true, device->subgroup_size);
+    }
+
     if (device->subgroup_clustered && device->subgroup_require_full_support) {
         ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, "main", 2, sizeof(vk_quantize_q8_1_push_constants), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true);
     } else {
@@ -8400,8 +8422,6 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
     const uint32_t acctype = f32acc ? 4 : 2;
     const uint32_t f16vec4 = 8;
 
-    const uint32_t tmpsh = (Bc / MatBc) * sizeof(float);
-
     const uint32_t qstride = hsk_pad / 4 + 2;
     const uint32_t Qf = Br * qstride * f16vec4;
 
@@ -8418,7 +8438,7 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
 
     const uint32_t slope = Br * acctype;
 
-    const uint32_t total_size = tmpsh + Qf + Psh + sfsh + ksh + slope;
+    const uint32_t total_size = Qf + Psh + sfsh + ksh + slope;
     const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
 
     VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", kv_type=" << kv_type << ", total_size=" << total_size << ", supported=" << supported);
@@ -8445,6 +8465,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
     GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)
     GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)
 
+    const uint32_t nem0 = mask ? mask->ne[0] : 0;
     const uint32_t nem1 = mask ? mask->ne[1] : 0;
     const uint32_t nem2 = mask ? mask->ne[2] : 0;
     const uint32_t nem3 = mask ? mask->ne[3] : 0;
@@ -8574,7 +8595,10 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
 
     bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
 
-    vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc);
+    // Only use mask opt when the mask is fairly large. This hasn't been tuned extensively.
+    bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768;
+
+    vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, use_mask_opt);
 
     vk_pipeline pipeline = nullptr;
 
@@ -8625,10 +8649,32 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
         ggml_vk_preallocate_buffers(ctx, subctx);
     }
 
-    {
-        // Request descriptor sets
-        if (split_k > 1) {
-            ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1);
+    auto rows_cols = fa_rows_cols(path, HSK, HSV, !aligned, k->type, small_rows, small_cache);
+    const uint32_t Br = rows_cols[0];
+    const uint32_t Bc = rows_cols[1];
+
+    const uint32_t mask_opt_num_dwords = CEIL_DIV(nem0, 16 * Bc);
+    const uint64_t mask_opt_size = sizeof(uint32_t) * mask_opt_num_dwords * CEIL_DIV(nem1, Br) * nem2 * nem3;
+
+    vk_pipeline pipeline_fa_mask_opt = nullptr;
+    if (use_mask_opt) {
+        std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
+        auto &pipelines = ctx->device->pipeline_fa_mask_opt;
+        auto it = pipelines.find({Br, Bc});
+        if (it != pipelines.end()) {
+            pipeline_fa_mask_opt = it->second;
+        } else {
+            pipelines[{Br, Bc}] = pipeline_fa_mask_opt = std::make_shared<vk_pipeline_struct>();
+        }
+        assert(pipeline_fa_mask_opt);
+        ggml_pipeline_request_descriptor_sets(ctx, pipeline_fa_mask_opt, 1);
+
+        if (ctx->prealloc_size_y < mask_opt_size) {
+            ctx->prealloc_size_y = mask_opt_size;
+            ggml_vk_preallocate_buffers(ctx, subctx);
+        }
+        if (ctx->prealloc_y_need_sync) {
+            ggml_vk_sync_buffers(ctx, subctx);
         }
     }
 
@@ -8655,9 +8701,30 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
     vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
     vk_subbuffer mask_buf = mask ? ggml_vk_tensor_subbuffer(ctx, mask) : q_buf;
     vk_subbuffer sinks_buf = sinks ? ggml_vk_tensor_subbuffer(ctx, sinks) : q_buf;
+    vk_subbuffer mask_opt_buf = use_mask_opt ? ggml_vk_subbuffer(ctx, ctx->prealloc_y, 0) : q_buf;
 
     uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | ((mask != nullptr) << 16) | n_head_log2;
 
+    if (use_mask_opt)
+    {
+        const vk_op_flash_attn_mask_opt_push_constants opt_pc = {
+            nem0,
+            nem1,
+            nem2,
+            (uint32_t)(mask->nb[1] / sizeof(ggml_fp16_t)),
+            (uint32_t)(mask->nb[2] / sizeof(ggml_fp16_t)),
+            (uint32_t)(mask->nb[3] / sizeof(ggml_fp16_t)),
+            mask_opt_num_dwords,
+            mask_opt_num_dwords * CEIL_DIV(nem1, Br),
+            mask_opt_num_dwords * CEIL_DIV(nem1, Br) * nem2,
+        };
+
+        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline_fa_mask_opt,
+                                  { mask_buf, mask_opt_buf }, opt_pc,
+                                  { mask_opt_num_dwords, CEIL_DIV(nem1, Br), nem2 * nem3 });
+        ggml_vk_sync_buffers(ctx, subctx);
+    }
+
     const vk_flash_attn_push_constants pc = { N, KV,
                                               (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
                                               (uint32_t)neq2, (uint32_t)neq3,
@@ -8672,13 +8739,15 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
                                               gqa_ratio, split_kv, split_k };
 
     if (split_k > 1) {
+        ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1);
+
         if (ctx->prealloc_split_k_need_sync) {
             ggml_vk_sync_buffers(ctx, subctx);
         }
         workgroups_x *= pipeline->wg_denoms[0];
         vk_subbuffer split_k_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);
         ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
-                                    {q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf},
+                                    {q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf, mask_opt_buf},
                                     // We only use split_k when group query attention is enabled, which means
                                     // there's no more than one tile of rows (i.e. workgroups_x would have been
                                     // one). We reuse workgroups_x to mean the number of splits, so we need to
@@ -8697,7 +8766,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
             workgroups_x *= pipeline->wg_denoms[0];
         }
         ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
-                                    {q_buf, k_buf, v_buf, mask_buf, sinks_buf, dst_buf},
+                                    {q_buf, k_buf, v_buf, mask_buf, sinks_buf, dst_buf, mask_opt_buf},
                                     pc, { workgroups_x, workgroups_y, workgroups_z });
     }
 }
index 3ce8d07be8013412a24d9081c5cb38fcb31b190f..49a3c530cb6f480e36b1a277c0e5a42ed408cc6d 100644 (file)
@@ -94,6 +94,10 @@ void main() {
         }
     }
 
+    const uint32_t mo_stride = CEIL_DIV(KV, 16 * Bc);
+    // mo_offset will point to the tile starting at row i*Br and col 0
+    uint32_t mo_offset = mo_stride * i;
+
 #if BLOCK_SIZE > 1
     uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
     uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
@@ -104,15 +108,28 @@ void main() {
     uint32_t m_offset = gqa_iq1*KV;
     if (p.nem2 != 1 || p.nem3 != 1) {
         m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
+        mo_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * CEIL_DIV(p.nem1, Br) * mo_stride;
     }
 
+    uint32_t mask_opt = 0;
+    uint32_t mask_opt_idx = ~0;
+
     [[dont_unroll]]
     for (uint32_t j = start_j; j < end_j; ++j) {
 
-        if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
+        if (USE_MASK_OPT && mask_opt_idx != j / 16) {
+            mask_opt_idx = j / 16;
+            mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
+        }
+        uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
+        if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
+            // skip this block
+            continue;
+        }
+        // Only load if the block is not all zeros
+        if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0 && mask_opt_bits != MASK_OPT_ALL_ZERO) {
             bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
 
-            float max_mask = NEG_FLT_MAX_OVER_2;
             [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
                 uint32_t c = (idx + tid) % Bc;
                 uint32_t r = (idx + tid) / Bc;
@@ -120,25 +137,12 @@ void main() {
                     if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
                         float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
                         masksh[c][r] = m;
-                        max_mask = max(max_mask, m);
                     } else {
                         masksh[c][r] = float(0);
                     }
                 }
             }
-            // skip the block if the mask is entirely -inf
-            bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
-            barrier();
-            if (gl_SubgroupInvocationID == 0) {
-                tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
-            }
             barrier();
-            [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
-                max_mask = max(max_mask, tmpsh[s]);
-            }
-            if (max_mask <= NEG_FLT_MAX_OVER_2) {
-                continue;
-            }
         }
 
         float Sf[Br][cols_per_thread];
@@ -185,7 +189,7 @@ void main() {
             }
         }
 
-        if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
+        if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0 && mask_opt_bits != MASK_OPT_ALL_ZERO) {
             [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
                 [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
                     float mvf = masksh[c * cols_per_iter + col_tid][r];
@@ -256,9 +260,6 @@ void main() {
         barrier();
     }
 
-    // prevent race on tmpsh
-    barrier();
-
     // reduce across threads
 
     [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
index 23a4d2c00586b3be4ef4e23fc4e39d73bc0edc18..252451101ab484add027fe98129475f5088a6290 100644 (file)
@@ -10,6 +10,7 @@ layout (constant_id = 5) const uint32_t Clamp = 0;
 layout (constant_id = 6) const uint32_t D_split = 16;
 layout (constant_id = 7) const uint32_t SubGroupSize = 32;
 layout (constant_id = 8) const uint32_t K_LOAD_SHMEM = 0;
+layout (constant_id = 9) const bool     USE_MASK_OPT = false;
 
 // Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths
 const uint32_t HSK_pad = (HSK + 15) & ~15;
@@ -66,6 +67,11 @@ layout (binding = 4) readonly buffer S {float data_s[];};
 
 layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
 
+layout (binding = 6) readonly buffer MO {uint32_t data_mask_opt[];};
+
+#define MASK_OPT_ALL_NEG_INF 1
+#define MASK_OPT_ALL_ZERO 2
+
 #define BINDING_IDX_K 0
 #define BINDING_IDX_V 1
 #if defined(DATA_A_F32)
index 83d52d19d672b330e762f8ccb5a02c82cb6d350e..89af3697e1d42303cb0f330390f4096520afc909 100644 (file)
@@ -42,8 +42,6 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
     return elem;
 }
 
-shared float tmpsh[row_split];
-
 const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4
 shared f16vec4 Qf[Br * qstride];
 
@@ -134,6 +132,10 @@ void main() {
         }
     }
 
+    const uint32_t mo_stride = CEIL_DIV(KV, 16 * Bc);
+    // mo_offset will point to the tile starting at row i*Br and col 0
+    uint32_t mo_offset = mo_stride * i;
+
 #if BLOCK_SIZE > 1
     uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
     uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
@@ -144,66 +146,74 @@ void main() {
     uint32_t m_offset = gqa_iq1*KV;
     if (p.nem2 != 1 || p.nem3 != 1) {
         m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
+        mo_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * CEIL_DIV(p.nem1, Br) * mo_stride;
     }
 
+    uint32_t mask_opt = 0;
+    uint32_t mask_opt_idx = ~0;
+
     [[dont_unroll]]
     for (uint32_t j = start_j; j < end_j; ++j) {
 
         f16vec4 mask_cache[Bc * Br / 4 / WorkGroupSize];
+        [[unroll]] for (uint32_t idx = 0; idx < mask_cache.length(); ++idx) {
+            mask_cache[idx] = f16vec4(0);
+        }
+
         if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
-            bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
 
-            float max_mask = NEG_FLT_MAX_OVER_2;
-            [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
-                uint32_t c = (idx + tid) / (Br / 4);
-                uint32_t r = (idx + tid) % (Br / 4);
-                if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) {
-                    if ((!KV_bounds_check || j * Bc + c < KV)) {
-                        f16vec4 m;
-                        if (!nem1_bounds_check || i * Br + r * 4 + 3 < p.nem1) {
-                            m = f16vec4(data_m[m_offset + (i * Br + r * 4    ) * m_stride + (j * Bc + c)],
-                                        data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
-                                        data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)],
-                                        data_m[m_offset + (i * Br + r * 4 + 3) * m_stride + (j * Bc + c)]);
-                            max_mask = max(max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2])), float(m[3]));
-                        } else if (i * Br + r * 4 + 2 < p.nem1) {
-                            m = f16vec4(data_m[m_offset + (i * Br + r * 4    ) * m_stride + (j * Bc + c)],
-                                        data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
-                                        data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)],
-                                        0.0);
-                            max_mask = max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2]));
-                        } else if (i * Br + r * 4 + 1 < p.nem1) {
-                            m = f16vec4(data_m[m_offset + (i * Br + r * 4    ) * m_stride + (j * Bc + c)],
-                                        data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
-                                        0.0,
-                                        0.0);
-                            max_mask = max(max(max_mask, float(m[0])), float(m[1]));
-                        } else if (i * Br + r * 4 < p.nem1) {
-                            m = f16vec4(data_m[m_offset + (i * Br + r * 4    ) * m_stride + (j * Bc + c)],
-                                        0.0,
-                                        0.0,
-                                        0.0);
-                            max_mask = max(max_mask, float(m[0]));
-                        } else {
-                            m = f16vec4(0.0);
+            if (USE_MASK_OPT && mask_opt_idx != j / 16) {
+                mask_opt_idx = j / 16;
+                mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
+            }
+            uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
+            if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
+                // skip this block
+                continue;
+            }
+            // Only load if the block is not all zeros
+            if (mask_opt_bits != MASK_OPT_ALL_ZERO) {
+                bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
+
+                float max_mask = NEG_FLT_MAX_OVER_2;
+                [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
+                    uint32_t c = (idx + tid) / (Br / 4);
+                    uint32_t r = (idx + tid) % (Br / 4);
+                    if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) {
+                        if ((!KV_bounds_check || j * Bc + c < KV)) {
+                            f16vec4 m;
+                            if (!nem1_bounds_check || i * Br + r * 4 + 3 < p.nem1) {
+                                m = f16vec4(data_m[m_offset + (i * Br + r * 4    ) * m_stride + (j * Bc + c)],
+                                            data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
+                                            data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)],
+                                            data_m[m_offset + (i * Br + r * 4 + 3) * m_stride + (j * Bc + c)]);
+                                max_mask = max(max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2])), float(m[3]));
+                            } else if (i * Br + r * 4 + 2 < p.nem1) {
+                                m = f16vec4(data_m[m_offset + (i * Br + r * 4    ) * m_stride + (j * Bc + c)],
+                                            data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
+                                            data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)],
+                                            0.0);
+                                max_mask = max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2]));
+                            } else if (i * Br + r * 4 + 1 < p.nem1) {
+                                m = f16vec4(data_m[m_offset + (i * Br + r * 4    ) * m_stride + (j * Bc + c)],
+                                            data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
+                                            0.0,
+                                            0.0);
+                                max_mask = max(max(max_mask, float(m[0])), float(m[1]));
+                            } else if (i * Br + r * 4 < p.nem1) {
+                                m = f16vec4(data_m[m_offset + (i * Br + r * 4    ) * m_stride + (j * Bc + c)],
+                                            0.0,
+                                            0.0,
+                                            0.0);
+                                max_mask = max(max_mask, float(m[0]));
+                            } else {
+                                m = f16vec4(0.0);
+                            }
+                            mask_cache[idx / WorkGroupSize] = m;
                         }
-                        mask_cache[idx / WorkGroupSize] = m;
                     }
                 }
             }
-            // skip the block if the mask is entirely -inf
-            bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
-            barrier();
-            if (gl_SubgroupInvocationID == 0) {
-                tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
-            }
-            barrier();
-            [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
-                max_mask = max(max_mask, tmpsh[s]);
-            }
-            if (max_mask <= NEG_FLT_MAX_OVER_2) {
-                continue;
-            }
         }
 
         if (K_LOAD_SHMEM != 0) {
index 54f1b0b62267e4a6e061320f41f974a99025c2f7..47b110621b713060d130064d528761990da37305 100644 (file)
@@ -138,48 +138,53 @@ void main() {
         coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
     }
 
+    const uint32_t mo_stride = CEIL_DIV(KV, 16 * Bc);
+    // mo_offset will point to the tile starting at row i*Br and col 0
+    uint32_t mo_offset = mo_stride * i;
+
     uint32_t m_offset = gqa_iq1*KV * 2 /*sizeof(float16_t)*/;
     if (p.nem2 != 1 || p.nem3 != 1) {
         m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
+        mo_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * CEIL_DIV(p.nem1, Br) * mo_stride;
     }
 
+    uint32_t mask_opt = 0;
+    uint32_t mask_opt_idx = ~0;
+
     [[dont_unroll]]
     for (uint32_t j = start_j; j < end_j; ++j) {
 
-        coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
+        coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv = coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
         if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
-            bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
-
-            if (nem1_bounds_check) {
-                tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
-                tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
-                tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
-                tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t
-
-                coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax;
-
-                coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
-
-                // skip the block if the mask is entirely -inf
-                coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
-                if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
-                    continue;
-                }
-            } else {
-                tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
-                // Don't clamp against nem1 when GQA is enabled
-                uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1;
-                tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
-                tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
-
-                coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax;
 
-                coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
-
-                // skip the block if the mask is entirely -inf
-                coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
-                if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
-                    continue;
+            if (USE_MASK_OPT && mask_opt_idx != j / 16) {
+                mask_opt_idx = j / 16;
+                mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
+            }
+            uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
+            if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
+                // skip this block
+                continue;
+            }
+            // Only load if the block is not all zeros
+            if (mask_opt_bits != MASK_OPT_ALL_ZERO) {
+                bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
+
+                if (nem1_bounds_check) {
+                    tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
+                    tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
+                    tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
+                    tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t
+
+                    coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
+                } else {
+                    tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
+                    // Don't clamp against nem1 when GQA is enabled
+                    uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1;
+                    tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
+                    tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
+
+                    coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
                 }
             }
         }
diff --git a/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp b/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp
new file mode 100644 (file)
index 0000000..8c92c1a
--- /dev/null
@@ -0,0 +1,142 @@
+#version 450
+
+#extension GL_EXT_control_flow_attributes : enable
+#extension GL_EXT_shader_16bit_storage : enable
+#extension GL_KHR_shader_subgroup_arithmetic : enable
+
+layout (constant_id = 0) const uint BLOCK_SIZE = 128;
+layout (constant_id = 1) const uint NUM_SUBGROUPS = 4;
+layout (constant_id = 2) const uint Br = 32;
+layout (constant_id = 3) const uint Bc = 32;
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {float16_t data_a[];};
+layout (binding = 0) readonly buffer Av4 {f16vec4 data_av4[];};
+layout (binding = 1) writeonly buffer D {uint data_d[];};
+
+layout (push_constant) uniform parameter {
+    uint nem0;
+    uint nem1;
+    uint nem2;
+    uint nbm1;
+    uint nbm2;
+    uint nbm3;
+    uint nbd1;
+    uint nbd2;
+    uint nbd3;
+};
+
+#define MASK_OPT_ALL_NEG_INF 1
+#define MASK_OPT_ALL_ZERO 2
+
+shared float minsh[NUM_SUBGROUPS];
+shared float maxsh[NUM_SUBGROUPS];
+
+// For each Br x Bc block of the mask (input) buffer, read all values and check
+// if it's all -inf or all zero. Write out a two-bit code indicating which it is
+// (or zero for neither). Each workgroup processes 16 tiles and writes out a
+// 32-bit result mask.
+//
+// TODO: This is a lot of work per workgroup, might make sense to split this into
+// more workgroups in the future.
+void main() {
+    // Each workgroup handles a row
+    const uint tid = gl_LocalInvocationIndex;
+    const uint i0 = gl_WorkGroupID.x;
+    const uint i1 = gl_WorkGroupID.y;
+    const uint i2 = gl_WorkGroupID.z % nem2;
+    const uint i3 = gl_WorkGroupID.z / nem2;
+
+    float FLT_MAX_OVER_2 = uintBitsToFloat(0x7EFFFFFF);
+
+    uint result = 0;
+
+    // Fast path for fully in-bounds blocks where we can do f16vec4 loads
+    if ((nem0 % Bc) == 0 && (nem1 % Br) == 0 &&
+        ((Br * Bc) % (BLOCK_SIZE * 4)) == 0) {
+        [[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) {
+            float min_v = FLT_MAX_OVER_2;
+            float max_v = -FLT_MAX_OVER_2;
+            [[unroll]] for (uint i = 0; i < Br * Bc / 4; i += BLOCK_SIZE) {
+                uint j0 = (i + tid) % (Bc / 4);
+                uint j1 = (i + tid) / (Bc / 4);
+
+                j0 *= 4;
+                j0 += (i0 * 16 + block_x) * Bc;
+                j1 += i1 * Br;
+
+                vec4 f = vec4(data_av4[(j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3) / 4]);
+                [[unroll]] for (int c = 0; c < 4; ++c) {
+                    min_v = min(min_v, f[c]);
+                    max_v = max(max_v, f[c]);
+                }
+            }
+            min_v = subgroupMin(min_v);
+            max_v = subgroupMax(max_v);
+            if (gl_SubgroupInvocationID == 0) {
+                minsh[gl_SubgroupID] = min_v;
+                maxsh[gl_SubgroupID] = max_v;
+            }
+            barrier();
+            if (tid == 0) {
+                [[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) {
+                    min_v = min(min_v, minsh[i]);
+                    max_v = max(max_v, maxsh[i]);
+                }
+                if (max_v <= -FLT_MAX_OVER_2) {
+                    result |= 1 << (2*block_x);
+                }
+                if (min_v == 0.0f && max_v == 0.0f) {
+                    result |= 2 << (2*block_x);
+                }
+            }
+            barrier();
+        }
+    } else {
+        [[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) {
+            float min_v = FLT_MAX_OVER_2;
+            float max_v = -FLT_MAX_OVER_2;
+            [[unroll]] for (uint i = 0; i < Br * Bc; i += BLOCK_SIZE) {
+                if ((Br * Bc % BLOCK_SIZE) != 0 && i + tid >= Br * Bc) {
+                    continue;
+                }
+                uint j0 = (i + tid) % Bc;
+                uint j1 = (i + tid) / Bc;
+
+                j0 += (i0 * 16 + block_x) * Bc;
+                j1 += i1 * Br;
+
+                if (j0 < nem0 && j1 < nem1) {
+                    float f = float(data_a[j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3]);
+                    min_v = min(min_v, f);
+                    max_v = max(max_v, f);
+                }
+            }
+            min_v = subgroupMin(min_v);
+            max_v = subgroupMax(max_v);
+            if (gl_SubgroupInvocationID == 0) {
+                minsh[gl_SubgroupID] = min_v;
+                maxsh[gl_SubgroupID] = max_v;
+            }
+            barrier();
+            if (tid == 0) {
+                [[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) {
+                    min_v = min(min_v, minsh[i]);
+                    max_v = max(max_v, maxsh[i]);
+                }
+                if (max_v <= -FLT_MAX_OVER_2) {
+                    result |= 1 << (2*block_x);
+                }
+                if (min_v == 0.0f && max_v == 0.0f) {
+                    result |= 2 << (2*block_x);
+                }
+            }
+            barrier();
+        }
+    }
+
+    if (tid == 0) {
+        data_d[i0 + i1 * nbd1 + i2 * nbd2 + i3 * nbd3] = result;
+    }
+}
index ca486a288a1907b71e2ecc143f4b2311dffb4832..42ebc21e2a6eb6df303d885bdccfd3c061509df0 100644 (file)
@@ -790,6 +790,8 @@ void process_shaders() {
     string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
     string_to_spv("fa_split_k_reduce", "flash_attn_split_k_reduce.comp", {});
 
+    string_to_spv("fa_mask_opt", "flash_attn_mask_opt.comp", {});
+
     string_to_spv("quantize_q8_1", "quantize_q8_1.comp", {});
     string_to_spv("quantize_q8_1_subgroup", "quantize_q8_1.comp", {{"USE_SUBGROUPS", "1"}});
 
index cecdf470380f360d22d928865c09eba4752d3d75..fbe23037cc9418d39c6e20c0536bc1982c2fa741 100644 (file)
@@ -169,20 +169,22 @@ static void init_tensor_kq_mask(ggml_tensor * tensor, float min = -1.0f, float m
     const int blck0 = 128;
     const int blck1 = 64;
 
-    // number of INF blocks
-    const int n_inf_blocks = 0.1*(ne0*ne1*ne2*ne3)/(blck0*blck1);
+    // number of INF/zero blocks
+    const int n_inf_zero_blocks = 0.2*(ne0*ne1*ne2*ne3)/(blck0*blck1);
 
-    for (int b = 0; b < n_inf_blocks; b++) {
+    for (int b = 0; b < n_inf_zero_blocks; b++) {
         const int p3 = (rd() % ne3);
         const int p2 = (rd() % ne2);
         const int p1 = (rd() % ne1);
         const int p0 = (rd() % ne0);
 
+        bool inf = rd() & 1;
+
         for (int i1 = 0; i1 < blck1 && p1 + i1 < ne1; i1++) {
             const int idx = p3*ne2*ne1*ne0 + p2*ne1*ne0 + (p1 + i1)*ne0 + p0;
 
             for (int i0 = 0; i0 < blck0 && p0 + i0 < ne0; i0++) {
-                data_f32[idx + i0] = -INFINITY;
+                data_f32[idx + i0] = inf ? -INFINITY : 0.0f;
             }
         }
     }