]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vulkan: support fattn sinks (#15126)
authorJeff Bolz <redacted>
Thu, 7 Aug 2025 20:44:20 +0000 (15:44 -0500)
committerGitHub <redacted>
Thu, 7 Aug 2025 20:44:20 +0000 (22:44 +0200)
ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp

index f1cb90e3b360eaa15dddda975a563cb70221ddae..b1cbbc9866c694ee1c6e65abebdd6f1f9c5c9488 100644 (file)
@@ -2286,14 +2286,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
     };
 
 #define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, HSK, HSV, HEAD_SIZES) \
-        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc"         #NAMELC #SUFFIX,           flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1,                                            true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
-        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc" #NAMELC #SUFFIX,           flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
-        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc"         #NAMELC #SUFFIX,           flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _len,         flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _data,         "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1,                                            true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
-        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc" #NAMELC #SUFFIX,           flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _len,         flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _data,         "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
-        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc_smallrows"         #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true),   1,                                            true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
-        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true),   fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1],  true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
-        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc_smallrows"         #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _len,         flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _data,         "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true),   1,                                            true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
-        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _len,         flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _data,         "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true),   fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1],  true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
+        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc"         #NAMELC #SUFFIX,           flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1,                                            true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
+        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc" #NAMELC #SUFFIX,           flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
+        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc"         #NAMELC #SUFFIX,           flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _len,         flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _data,         "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1,                                            true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
+        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc" #NAMELC #SUFFIX,           flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _len,         flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _data,         "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
+        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc_smallrows"         #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true),   1,                                            true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
+        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true),   fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1],  true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
+        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc_smallrows"         #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _len,         flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _data,         "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true),   1,                                            true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
+        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _len,         flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _data,         "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true),   fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1],  true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
 
 #define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
         CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64, 64, 64) \
@@ -2910,7 +2910,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4],   "get_rows_mxfp4_f32",   get_rows_mxfp4_f32_len,   get_rows_mxfp4_f32_data,   "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 4 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, 5 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
     ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
 
     for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
@@ -6507,11 +6507,14 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
     return supported;
 }
 
-static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, ggml_tensor * dst, bool dryrun = false) {
+static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, const ggml_tensor * sinks, ggml_tensor * dst, bool dryrun = false) {
     VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3];
     std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3];
     std::cerr << "), (" << v << ", name=" << v->name << ", type=" << v->type << ", ne0=" << v->ne[0] << ", ne1=" << v->ne[1] << ", ne2=" << v->ne[2] << ", ne3=" << v->ne[3] << ", nb0=" << v->nb[0] << ", nb1=" << v->nb[1] << ", nb2=" << v->nb[2] << ", nb3=" << v->nb[3];
     std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
+    if (sinks) {
+        std::cerr << "), (" << sinks << ", name=" << sinks->name << ", type=" << sinks->type << ", ne0=" << sinks->ne[0] << ", ne1=" << sinks->ne[1] << ", ne2=" << sinks->ne[2] << ", ne3=" << sinks->ne[3] << ", nb0=" << sinks->nb[0] << ", nb1=" << sinks->nb[1] << ", nb2=" << sinks->nb[2] << ", nb3=" << sinks->nb[3];
+    }
     std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
 
     GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
@@ -6710,10 +6713,10 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
     const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
     const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 
-    vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr;
-    size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0;
+    vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr, d_S = nullptr;
+    size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0, s_buf_offset = 0;
 
-    bool Q_uma = false, K_uma = false, V_uma = false, D_uma = false, M_uma = false;
+    bool Q_uma = false, K_uma = false, V_uma = false, D_uma = false, M_uma = false, S_uma = false;
 
     if (ctx->device->uma) {
         ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset);
@@ -6728,6 +6731,10 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
             ggml_vk_host_get(ctx->device, mask->data, d_M, m_buf_offset);
             M_uma = d_M != nullptr;
         }
+        if (sinks) {
+            ggml_vk_host_get(ctx->device, sinks->data, d_S, s_buf_offset);
+            S_uma = d_S != nullptr;
+        }
     }
 
 
@@ -6763,7 +6770,17 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
         }
     }
 
-    uint32_t mask_n_head_log2 = ((mask != nullptr) << 16) | n_head_log2;
+    if (!S_uma) {
+        d_S = d_Q;
+        s_buf_offset = q_buf_offset;
+        if (sinks) {
+            ggml_backend_vk_buffer_context * s_buf_ctx = (ggml_backend_vk_buffer_context*)sinks->buffer->context;
+            d_S = s_buf_ctx->dev_buffer;
+            s_buf_offset = vk_tensor_offset(sinks) + sinks->view_offs;
+        }
+    }
+
+    uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | ((mask != nullptr) << 16) | n_head_log2;
 
     const vk_flash_attn_push_constants pc = { N, KV,
                                               (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
@@ -6787,6 +6804,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
                                         vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
                                         vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
                                         vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
+                                        vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE},
                                         vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
                                     },
                                     // We only use split_k when group query attention is enabled, which means
@@ -6796,10 +6814,11 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
                                     pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
 
         ggml_vk_sync_buffers(subctx);
-        const std::array<uint32_t, 4> pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k };
+        const std::array<uint32_t, 5> pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k, (sinks != nullptr) };
         ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
                                     {
                                         vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
+                                        vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE},
                                         vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
                                     },
                                     pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 });
@@ -6810,6 +6829,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
                                         vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
                                         vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
                                         vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
+                                        vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE},
                                         vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
                                     },
                                     pc, { workgroups_x, workgroups_y, workgroups_z });
@@ -9874,7 +9894,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
         break;
 
     case GGML_OP_FLASH_ATTN_EXT:
-        ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node, dryrun);
+        ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node->src[4], node, dryrun);
 
         break;
 
@@ -10951,8 +10971,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                 if (head_sizes == FA_HEAD_SIZE_UNSUPPORTED) {
                     return false;
                 }
-                // TODO: support attention sinks [TAG_ATTN_SINKS]
-                if (op->src[4]) {
+                if (op->src[4] && op->src[4]->type != GGML_TYPE_F32) {
                     return false;
                 }
                 if (op->src[0]->type != GGML_TYPE_F32) {
@@ -11547,6 +11566,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
     if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
         const float * params = (const float *)tensor->op_params;
         tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]);
+        if (src_clone[4]) {
+            ggml_flash_attn_ext_add_sinks(tensor_clone, src_clone[4]);
+        }
     } else if (tensor->op == GGML_OP_MUL_MAT) {
         tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]);
     } else if (tensor->op == GGML_OP_MUL_MAT_ID) {
index 45c6e7736ace687e549d1fbd115e61bc54a4bb1b..d40848e15fe973ad644748d5b6902a065b086be1 100644 (file)
@@ -305,6 +305,27 @@ void main() {
         return;
     }
 
+    if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
+        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+            float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2);
+
+            float ms = 1.0f;
+            float vs = 1.0f;
+
+            if (sink > Mf[r]) {
+                ms = exp(Mf[r] - sink);
+
+                [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
+                    Of[r][d] *= ms;
+                }
+            } else {
+                vs = exp(sink - Mf[r]);
+            }
+
+            Lf[r] = Lf[r]*ms + vs;
+        }
+    }
+
     float Lfrcp[Br];
     [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
         Lfrcp[r] = 1.0 / Lf[r];
index 7defe72b403b5eb2dad19fa6dd127fb7c3de963f..b57c9dcfc4ee5db1265f857edff0cd78fbf2bc85 100644 (file)
@@ -50,10 +50,13 @@ layout (push_constant) uniform parameter {
     uint32_t k_num;
 } p;
 
+#define SINK_ENABLE_BIT (1<<24)
 #define MASK_ENABLE_BIT (1<<16)
 #define N_LOG2_MASK 0xFFFF
 
-layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
+layout (binding = 4) readonly buffer S {float data_s[];};
+
+layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
 
 #if defined(A_TYPE_PACKED16)
 #define BINDING_IDX_K 0
@@ -111,6 +114,14 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i
     return ACC_TYPE(pow(base, ACC_TYPE(exph)));
 }
 
+// Load the sink value, indexed by Q's dimension 2.
+ACC_TYPE perElemOpGetSink(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
+{
+    const uint32_t h = iq2 + (r % p.gqa_ratio);
+
+    return ACC_TYPE(data_s[h]);
+}
+
 uint32_t i, N, KV, split_k_index, Tr, start_j, end_j,
          iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3,
          q_stride, k_stride, v_stride, m_stride;
index 486735fe8b0c97fbe9b26882a0b2ef8fa6c32c29..230e815f22c45c150088f359a965149f1559c5c2 100644 (file)
@@ -329,6 +329,27 @@ void main() {
         return;
     }
 
+    if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
+        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+            float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2);
+
+            float ms = 1.0f;
+            float vs = 1.0f;
+
+            if (sink > Mf[r]) {
+                ms = exp(Mf[r] - sink);
+
+                [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
+                    Of[r][d] *= ACC_TYPE(ms);
+                }
+            } else {
+                vs = exp(sink - Mf[r]);
+            }
+
+            Lf[r] = Lf[r]*ms + vs;
+        }
+    }
+
     float Lfrcp[rows_per_thread];
     [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
         Lfrcp[r] = 1.0 / Lf[r];
index 274f48fcabdd081abc14fa4c243cd60bcb7c44c9..b0564ca0bfc8372a766f848319b19605449e24a9 100644 (file)
@@ -248,6 +248,34 @@ void main() {
     // resize L by using smear/reduce
     coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);
 
+    if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
+        coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> S;
+        coopMatPerElementNV(S, S, perElemOpGetSink, iq2);
+
+        coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> Mr;
+
+        // resize M by using smear/reduce
+        coopMatReduceNV(Mr, M, gl_CooperativeMatrixReduceRowNV, smearReduce);
+
+        // O, Ldiag, Mr all have the same type so all element locations match
+        [[unroll]] for (uint32_t i = 0; i < Ldiag.length(); ++i) {
+            ACC_TYPE sink = S[i];
+
+            ACC_TYPE ms = ACC_TYPE(1.0f);
+            ACC_TYPE vs = ACC_TYPE(1.0f);
+
+            if (sink > Mr[i]) {
+                ms = exp(Mr[i] - sink);
+
+                O[i] *= ms;
+            } else {
+                vs = exp(sink - Mr[i]);
+            }
+
+            Ldiag[i] = Ldiag[i]*ms + vs;
+        }
+    }
+
     [[unroll]]
     for (int k = 0; k < Ldiag.length(); ++k) {
         Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k];
index 0a17a9df23f9fbd09baee3f85c1f836062ed81ef..76ef4b6dfb571c55cb7d17db3730b488ccb95911 100644 (file)
@@ -7,13 +7,15 @@ layout(constant_id = 0) const uint BLOCK_SIZE = 32;
 layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 
 layout (binding = 0) readonly buffer A {float data_a[];};
-layout (binding = 1) writeonly buffer D {float data_d[];};
+layout (binding = 1) readonly buffer B {float data_s[];};
+layout (binding = 2) writeonly buffer D {float data_d[];};
 
 layout (push_constant) uniform parameter {
     uint D;
     uint N;
     uint ne3;
     uint k_num;
+    uint sinks;
 } p;
 
 shared float tmpsh[BLOCK_SIZE];
@@ -73,6 +75,22 @@ void main() {
     }
     L = tmpsh[0];
 
+    float sink;
+    if (p.sinks != 0) {
+        sink = data_s[n];
+
+        float ms = 1.0f;
+        float vs = 1.0f;
+
+        if (sink > m_max) {
+            ms = exp(m_max - sink);
+        } else {
+            vs = exp(sink - m_max);
+        }
+
+        L = L*ms + vs;
+    }
+
     L = 1.0 / L;
 
     // D dimension is split across workgroups in the y dimension
@@ -85,6 +103,13 @@ void main() {
             float m = data_a[m_offset + k * lm_stride];
             O += exp(m - m_max) * data_a[o_offset];
         }
+        if (p.sinks != 0) {
+            if (sink > m_max) {
+                float ms = 1.0f;
+                ms = exp(m_max - sink);
+                O *= ms;
+            }
+        }
         O *= L;
         data_d[iq3 * D * N + D * n + d] = O;
     }