]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vulkan: support mixed/deepseekR1 FA head sizes (#14509)
authorJeff Bolz <redacted>
Thu, 3 Jul 2025 18:21:14 +0000 (13:21 -0500)
committerGitHub <redacted>
Thu, 3 Jul 2025 18:21:14 +0000 (20:21 +0200)
* vulkan: better parameterize FA by head sizes

* vulkan: support mixed/deepseekR1 FA head sizes

ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

index 8114cb42046e47bf201d24e3154b978eb7e7010e..c0032cba218cfbc414db0740cc03d52195ed6012 100644 (file)
@@ -224,6 +224,21 @@ enum vk_device_architecture {
     INTEL_XE2,
 };
 
+// HSK x HSV
+enum FaHeadSizes {
+    FA_HEAD_SIZE_64,
+    FA_HEAD_SIZE_80,
+    FA_HEAD_SIZE_96,
+    FA_HEAD_SIZE_112,
+    FA_HEAD_SIZE_128,
+    FA_HEAD_SIZE_192,
+    FA_HEAD_SIZE_192_128,
+    FA_HEAD_SIZE_256,
+    FA_HEAD_SIZE_576_512,
+    FA_HEAD_SIZE_UNSUPPORTED,
+    FA_HEAD_SIZE_COUNT = FA_HEAD_SIZE_UNSUPPORTED,
+};
+
 static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) {
     vk::PhysicalDeviceProperties props = device.getProperties();
 
@@ -467,26 +482,11 @@ struct vk_device_struct {
     vk_pipeline pipeline_conv2d_dw_cwhn_f32;
 
     // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
-    vk_pipeline pipeline_flash_attn_f32_f16_D64_cm2[GGML_TYPE_COUNT][2][2][2];
-    vk_pipeline pipeline_flash_attn_f32_f16_D80_cm2[GGML_TYPE_COUNT][2][2][2];
-    vk_pipeline pipeline_flash_attn_f32_f16_D96_cm2[GGML_TYPE_COUNT][2][2][2];
-    vk_pipeline pipeline_flash_attn_f32_f16_D112_cm2[GGML_TYPE_COUNT][2][2][2];
-    vk_pipeline pipeline_flash_attn_f32_f16_D128_cm2[GGML_TYPE_COUNT][2][2][2];
-    vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2];
-
-    vk_pipeline pipeline_flash_attn_f32_f16_D64_cm1[GGML_TYPE_COUNT][2][2][2];
-    vk_pipeline pipeline_flash_attn_f32_f16_D80_cm1[GGML_TYPE_COUNT][2][2][2];
-    vk_pipeline pipeline_flash_attn_f32_f16_D96_cm1[GGML_TYPE_COUNT][2][2][2];
-    vk_pipeline pipeline_flash_attn_f32_f16_D112_cm1[GGML_TYPE_COUNT][2][2][2];
-    vk_pipeline pipeline_flash_attn_f32_f16_D128_cm1[GGML_TYPE_COUNT][2][2][2];
-    vk_pipeline pipeline_flash_attn_f32_f16_D256_cm1[GGML_TYPE_COUNT][2][2][2];
-
-    vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
-    vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2];
-    vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2];
-    vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2];
-    vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2];
-    vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2];
+    vk_pipeline pipeline_flash_attn_f32_f16_cm2[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
+
+    vk_pipeline pipeline_flash_attn_f32_f16_cm1[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
+
+    vk_pipeline pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
 
     vk_pipeline pipeline_flash_attn_split_k_reduce;
 
@@ -1003,7 +1003,7 @@ struct ggml_backend_vk_context {
 
     // number of additional consecutive nodes that are being fused with the
     // node currently being processed
-    uint32_t num_additional_fused_ops {};
+    int num_additional_fused_ops {};
 };
 
 static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000;  // NOLINT
@@ -1699,6 +1699,35 @@ enum FaCodePath {
     FA_COOPMAT2,
 };
 
+static FaHeadSizes fa_get_head_sizes(uint32_t hsk, uint32_t hsv) {
+    if (hsk != 192 && hsk != 576 && hsk != hsv) {
+        return FA_HEAD_SIZE_UNSUPPORTED;
+    }
+    switch (hsk) {
+    case 64: return FA_HEAD_SIZE_64;
+    case 80: return FA_HEAD_SIZE_80;
+    case 96: return FA_HEAD_SIZE_96;
+    case 112: return FA_HEAD_SIZE_112;
+    case 128: return FA_HEAD_SIZE_128;
+    case 192:
+        if (hsv == 192) {
+            return FA_HEAD_SIZE_192;
+        } else if (hsv == 128) {
+            return FA_HEAD_SIZE_192_128;
+        } else {
+            return FA_HEAD_SIZE_UNSUPPORTED;
+        }
+    case 256: return FA_HEAD_SIZE_256;
+    case 576:
+        if (hsv == 512) {
+            return FA_HEAD_SIZE_576_512;
+        } else {
+            return FA_HEAD_SIZE_UNSUPPORTED;
+        }
+    default: return FA_HEAD_SIZE_UNSUPPORTED;
+    }
+}
+
 // number of rows/cols for flash attention shader
 static constexpr uint32_t flash_attention_num_small_rows = 32;
 static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
@@ -1719,8 +1748,9 @@ static uint32_t get_fa_num_small_rows(FaCodePath path) {
     }
 }
 
-static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
+static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) {
     GGML_UNUSED(clamp);
+    GGML_UNUSED(hsv);
 
     if (path == FA_SCALAR) {
         if (small_rows) {
@@ -1744,7 +1774,7 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t D, uint32_
     }
 
     // small cols to reduce register count
-    if (ggml_is_quantized(type) || D == 256) {
+    if (ggml_is_quantized(type) || hsk >= 256) {
         return {64, 32};
     }
     return {64, 64};
@@ -2037,19 +2067,21 @@ static void ggml_vk_load_shaders(vk_device& device) {
                                       parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
     };
 
-    auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
-        return {fa_rows_cols(path, D, clamp, type, small_rows)[0], 1, 1};
+    auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
+        return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows)[0], 1, 1};
     };
 
-    auto const &fa_spec_constants = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
+    auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
         // For large number of rows, 128 invocations seems to work best.
         // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
         // can't use 256 for D==80.
         // For scalar, use 128 (arbitrary)
+        // The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs.
+        const uint32_t D = (hsk|hsv);
         uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
                             ? scalar_flash_attention_workgroup_size
                             : ((small_rows && (D % 32) == 0) ? 256 : 128);
-        auto rows_cols = fa_rows_cols(path, D, clamp, type, small_rows);
+        auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows);
 
         // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
         // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
@@ -2058,26 +2090,29 @@ static void ggml_vk_load_shaders(vk_device& device) {
 
         // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
         GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
-        return {wg_size, rows_cols[0], rows_cols[1], (D), clamp, D_split};
+        return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split};
     };
 
-#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, D) \
-        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc"         #NAMELC #SUFFIX,           flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1,                                      true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
-        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX,           flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
-        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc"         #NAMELC #SUFFIX,           flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _len,         flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _data,         "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1,                                      true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
-        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX,           flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _len,         flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _data,         "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
-        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows"         #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true),   1,                                      true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
-        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true),   fa_rows_cols(FAPATH,D,0,TYPE,true)[1],  true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
-        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows"         #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _len,         flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _data,         "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true),   1,                                      true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
-        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _len,         flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _data,         "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true),   fa_rows_cols(FAPATH,D,0,TYPE,true)[1],  true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
+#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, HSK, HSV, HEAD_SIZES) \
+        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc"         #NAMELC #SUFFIX,           flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1,                                            true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
+        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc" #NAMELC #SUFFIX,           flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
+        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc"         #NAMELC #SUFFIX,           flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _len,         flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _data,         "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1,                                            true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
+        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc" #NAMELC #SUFFIX,           flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _len,         flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _data,         "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
+        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc_smallrows"         #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true),   1,                                            true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
+        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true),   fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1],  true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
+        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc_smallrows"         #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _len,         flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _data,         "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true),   1,                                            true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
+        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _len,         flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _data,         "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true),   fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1],  true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
 
 #define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
-        CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64) \
-        CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80) \
-        CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96) \
-        CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112) \
-        CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128) \
-        CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256)
+        CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64, 64, 64) \
+        CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80, 80, 80) \
+        CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96, 96, 96) \
+        CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112, 112, 112) \
+        CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128, 128, 128) \
+        CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 192, 192) \
+        CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 128, 192_128) \
+        CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256, 256, 256) \
+        CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 576, 512, 576_512)
 
     CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
     CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
@@ -3688,7 +3723,6 @@ static void ggml_vk_instance_init() {
 
     }
 
-    size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size();
     vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr;
 
     // Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan
@@ -6002,24 +6036,47 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
     }
 }
 
-static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t D, bool f32acc) {
+static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv) {
     // Needs to be kept up to date on shader changes
+    GGML_UNUSED(hsv);
     const uint32_t wg_size = scalar_flash_attention_workgroup_size;
     const uint32_t Br = scalar_flash_attention_num_large_rows;
     const uint32_t Bc = scalar_flash_attention_Bc;
 
+    const uint32_t tmpsh = wg_size * sizeof(float);
+    const uint32_t tmpshv4 = wg_size * 4 * sizeof(float);
+
+    const uint32_t masksh = Bc * Br * sizeof(float);
+
+    const uint32_t Qf = Br * (hsk / 4 + 2) * 4 * sizeof(float);
+
+    const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf;
+    const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
+
+    VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
+
+    return supported;
+}
+
+static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc) {
+    // Needs to be kept up to date on shader changes
+    GGML_UNUSED(hsv);
+    const uint32_t wg_size = scalar_flash_attention_workgroup_size;
+    const uint32_t Br = coopmat1_flash_attention_num_large_rows;
+    const uint32_t Bc = scalar_flash_attention_Bc;
+
     const uint32_t acctype = f32acc ? 4 : 2;
     const uint32_t f16vec4 = 8;
 
     const uint32_t tmpsh = wg_size * sizeof(float);
     const uint32_t tmpshv4 = wg_size * 4 * acctype;
 
-    const uint32_t Qf = Br * (D / 4 + 2) * f16vec4;
+    const uint32_t Qf = Br * (hsk / 4 + 2) * f16vec4;
 
-    const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br;
+    const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br;
     const uint32_t sfsh = Bc * sfshstride * acctype;
 
-    const uint32_t kshstride = D / 4 + 2;
+    const uint32_t kshstride = hsk / 4 + 2;
     const uint32_t ksh = Bc * kshstride * f16vec4;
 
     const uint32_t slope = Br * sizeof(float);
@@ -6027,7 +6084,7 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
     const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope;
     const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
 
-    VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(D=" << D << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
+    VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
 
     return supported;
 }
@@ -6051,11 +6108,12 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
     const uint32_t nem1 = mask ? mask->ne[1] : 0;
     const uint32_t nem2 = mask ? mask->ne[2] : 0;
 
-    const uint32_t D = neq0;
+    const uint32_t HSK = nek0;
+    const uint32_t HSV = nev0;
     uint32_t N = neq1;
     const uint32_t KV = nek1;
 
-    GGML_ASSERT(ne0 == D);
+    GGML_ASSERT(ne0 == HSV);
     GGML_ASSERT(ne2 == N);
 
     // input tensor rows must be contiguous
@@ -6063,12 +6121,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
     GGML_ASSERT(nbk0 == ggml_type_size(k->type));
     GGML_ASSERT(nbv0 == ggml_type_size(v->type));
 
-    GGML_ASSERT(neq0 == D);
-    GGML_ASSERT(nek0 == D);
-    GGML_ASSERT(nev0 == D);
+    GGML_ASSERT(neq0 == HSK);
 
     GGML_ASSERT(neq1 == N);
-    GGML_ASSERT(nev0 == D);
 
     GGML_ASSERT(nev1 == nek1);
 
@@ -6089,7 +6144,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
         const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) ||
                                              (dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc);
 
-        const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, D, dst->op_params[3] == GGML_PREC_F32);
+        const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32);
 
         if (!coopmat_shape_supported || !coopmat_shmem_supported) {
             path = FA_SCALAR;
@@ -6142,47 +6197,25 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
         path = FA_SCALAR;
     }
 
+    // with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory
+    if (path == FA_SCALAR &&
+        !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV)) {
+        small_rows = true;
+    }
+
     bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
 
+    FaHeadSizes head_sizes = fa_get_head_sizes(k->ne[0], v->ne[0]);
+
     switch (path) {
     case FA_SCALAR:
-        switch (D) {
-        case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
-        case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
-        case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break;
-        case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break;
-        case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break;
-        case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break;
-        default:
-            GGML_ASSERT(!"unsupported D value");
-            return;
-        }
+        pipelines = &ctx->device->pipeline_flash_attn_f32_f16[k->type][head_sizes][f32acc][small_rows][0];
         break;
     case FA_COOPMAT1:
-        switch (D) {
-        case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm1[k->type][f32acc][small_rows][0]; break;
-        case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm1[k->type][f32acc][small_rows][0]; break;
-        case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm1[k->type][f32acc][small_rows][0]; break;
-        case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm1[k->type][f32acc][small_rows][0]; break;
-        case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm1[k->type][f32acc][small_rows][0]; break;
-        case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm1[k->type][f32acc][small_rows][0]; break;
-        default:
-            GGML_ASSERT(!"unsupported D value");
-            return;
-        }
+        pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm1[k->type][head_sizes][f32acc][small_rows][0];
         break;
     case FA_COOPMAT2:
-        switch (D) {
-        case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break;
-        case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break;
-        case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm2[k->type][f32acc][small_rows][0]; break;
-        case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm2[k->type][f32acc][small_rows][0]; break;
-        case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm2[k->type][f32acc][small_rows][0]; break;
-        case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm2[k->type][f32acc][small_rows][0]; break;
-        default:
-            GGML_ASSERT(!"unsupported D value");
-            return;
-        }
+        pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm2[k->type][head_sizes][f32acc][small_rows][0];
         break;
     default:
         GGML_ASSERT(0);
@@ -6212,7 +6245,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
     // Try to use split_k when KV is large enough to be worth the overhead
     if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
         // Try to run two workgroups per SM.
-        split_k = ctx->device->shader_core_count * 2 / (workgroups_y * workgroups_z);
+        split_k = shader_core_count * 2 / (workgroups_y * workgroups_z);
         if (split_k > 1) {
             // Try to evenly split KV into split_k chunks, but it needs to be a multiple
             // of "align", so recompute split_k based on that.
@@ -6224,7 +6257,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
 
     // Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
     // and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
-    const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0;
+    const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0;
     if (split_k_size > ctx->device->max_memory_allocation_size) {
         GGML_ABORT("Requested preallocation size is too large");
     }
@@ -6342,7 +6375,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
                                     pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
 
         ggml_vk_sync_buffers(subctx);
-        const std::array<uint32_t, 4> pc2 = { D, (uint32_t)ne1, (uint32_t)ne3, split_k };
+        const std::array<uint32_t, 4> pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k };
         ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
                                     {
                                         vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
@@ -10241,19 +10274,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                 ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
                 auto device = ggml_vk_get_device(ctx->device);
                 bool coopmat2 = device->coopmat2;
-                switch (op->src[0]->ne[0]) {
-                case 64:
-                case 80:
-                case 96:
-                case 112:
-                case 128:
-                case 256:
-                    break;
-                default:
-                    return false;
-                }
-                if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
-                    // different head sizes of K and V are not supported yet
+                FaHeadSizes head_sizes = fa_get_head_sizes(op->src[1]->ne[0], op->src[2]->ne[0]);
+                if (head_sizes == FA_HEAD_SIZE_UNSUPPORTED) {
                     return false;
                 }
                 if (op->src[0]->type != GGML_TYPE_F32) {
index 6f80101d1c4326f5f4632e738d2e3da9d3677c4c..788a5e065d1dba662cc2d7e2326803ffd720e487 100644 (file)
@@ -11,7 +11,8 @@
 #include "types.comp"
 #include "flash_attn_base.comp"
 
-const uint32_t D_per_thread = D / D_split;
+const uint32_t HSK_per_thread = HSK / D_split;
+const uint32_t HSV_per_thread = HSV / D_split;
 
 const uint32_t cols_per_iter = WorkGroupSize / D_split;
 const uint32_t cols_per_thread = Bc / cols_per_iter;
@@ -29,7 +30,7 @@ layout (binding = 3) readonly buffer M {float16_t data_m[];};
 // Rows index by Q's dimension 2, and the first N rows are valid.
 D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
 {
-    uint32_t offset = (iq2 + r) * D + c;
+    uint32_t offset = (iq2 + r) * HSV + c;
     data_o[o_offset + offset] = D_TYPE(elem);
     return elem;
 }
@@ -38,7 +39,7 @@ shared FLOAT_TYPE tmpsh[WorkGroupSize];
 shared vec4 tmpshv4[WorkGroupSize];
 
 shared float masksh[Bc][Br];
-shared vec4 Qf[Br][D / 4];
+shared vec4 Qf[Br][HSK / 4];
 
 void main() {
 #ifdef NEEDS_INIT_IQ_SHMEM
@@ -53,18 +54,18 @@ void main() {
 
     uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
 
-    [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
-        uint32_t d = (idx + tid) % (D / 4);
-        uint32_t r = (idx + tid) / (D / 4);
-        if (r < Br && d < D / 4 &&
+    [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
+        uint32_t d = (idx + tid) % (HSK / 4);
+        uint32_t r = (idx + tid) / (HSK / 4);
+        if (r < Br && d < HSK / 4 &&
             i * Br + r < N) {
             Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale;
         }
     }
     barrier();
 
-    vec4 Of[Br][D_per_thread / 4];
-    [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+    vec4 Of[Br][HSV_per_thread / 4];
+    [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
         [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
             Of[r][d] = vec4(0.0);
         }
@@ -116,7 +117,7 @@ void main() {
 
 
         [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
-            [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+            [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
 #if BLOCK_SIZE > 1
                 uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
                 uint ib = coord / BLOCK_SIZE;
@@ -195,14 +196,14 @@ void main() {
             Lf[r] = eMf[r]*Lf[r] + rowsumf[r];
         }
 
-        [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+        [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
             [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
                 Of[r][d] = eMf[r] * Of[r][d];
             }
         }
 
         [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
-            [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+            [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
 #if BLOCK_SIZE > 1
                 uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
                 uint ib = coord / BLOCK_SIZE;
@@ -259,7 +260,7 @@ void main() {
         Lf[r] = tmpsh[d_tid];
         barrier();
 
-        [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+        [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
 
             Of[r][d] = eMf * Of[r][d];
             tmpshv4[tid] = Of[r][d];
@@ -281,11 +282,11 @@ void main() {
     // If there is split_k, then the split_k resolve shader does the final
     // division by L. Store the intermediate O value and per-row m and L values.
     if (p.k_num > 1) {
-        uint32_t o_offset = D * p.ne1 * (split_k_index + iq3 * p.k_num);
+        uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
 
         [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
             if (r < N) {
-                [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+                [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
                     [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
                         perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
                     }
@@ -293,7 +294,7 @@ void main() {
             }
         }
 
-        o_offset = D * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
+        o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
         [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
             if (r < N) {
                 perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
@@ -309,18 +310,18 @@ void main() {
         Lfrcp[r] = 1.0 / Lf[r];
     }
 
-    [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+    [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
         [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
             Of[r][d] *= Lfrcp[r];
         }
     }
 
-    uint32_t o_offset = iq3*p.ne2*p.ne1*D;
+    uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
 
     if (p.gqa_ratio > 1) {
         [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
             if (r < N) {
-                [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+                [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
                     [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
                         perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
                     }
@@ -330,9 +331,9 @@ void main() {
     } else {
         [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
             if (i * Br + r < N) {
-                [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+                [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
                     [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
-                        data_o[o_offset + iq2 * D + (i * Br + r) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
+                        data_o[o_offset + iq2 * HSV + (i * Br + r) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
                     }
                 }
             }
index 67b935cf57e50d5e50288e0e16339e92002afb4b..6609f0bad3d85099c8ec9e3ae14167ea72ecfa62 100644 (file)
@@ -4,10 +4,10 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
 layout (constant_id = 1) const uint32_t Br = 1;
 layout (constant_id = 2) const uint32_t Bc = 32;
-layout (constant_id = 3) const uint32_t D = 32;
-layout (constant_id = 4) const uint32_t Clamp = 0;
-layout (constant_id = 5) const uint32_t D_split = 16;
-
+layout (constant_id = 3) const uint32_t HSK = 32;
+layout (constant_id = 4) const uint32_t HSV = 32;
+layout (constant_id = 5) const uint32_t Clamp = 0;
+layout (constant_id = 6) const uint32_t D_split = 16;
 
 layout (push_constant) uniform parameter {
     uint32_t N;
index 26fe50c7a81e58080d324c9d5e92512def9253b0..e74e2fa9346749660eed71e6412d2948b49b0345 100644 (file)
@@ -13,7 +13,9 @@
 #include "types.comp"
 #include "flash_attn_base.comp"
 
-const uint32_t D_per_thread = D / D_split;
+const uint32_t HSK_per_thread = HSK / D_split;
+const uint32_t HSV_per_thread = HSV / D_split;
+
 const uint32_t row_split = 4;
 const uint32_t rows_per_thread = Br / row_split;
 const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split;
@@ -32,7 +34,7 @@ layout (binding = 3) readonly buffer M {float16_t data_m[];};
 // Rows index by Q's dimension 2, and the first N rows are valid.
 D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
 {
-    uint32_t offset = (iq2 + r) * D + c;
+    uint32_t offset = (iq2 + r) * HSV + c;
     data_o[o_offset + offset] = D_TYPE(elem);
     return elem;
 }
@@ -44,14 +46,14 @@ const uint32_t MatBc = 16;
 shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
 shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x];
 
-const uint32_t qstride = D / 4 + 2; // in units of f16vec4
+const uint32_t qstride = HSK / 4 + 2; // in units of f16vec4
 shared f16vec4 Qf[Br * qstride];
 
-// Avoid padding for D==256 to make it fit in 48KB shmem.
-const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br;
+// Avoid padding for hsk==256 to make it fit in 48KB shmem.
+const uint32_t sfshstride = (HSK <= 128) ? (Br + 8) : Br;
 shared ACC_TYPE sfsh[Bc * sfshstride];
 
-const uint32_t kshstride = D / 4 + 2; // in units of f16vec4
+const uint32_t kshstride = HSK / 4 + 2; // in units of f16vec4
 shared f16vec4 ksh[Bc * kshstride];
 
 shared float slope[Br];
@@ -74,18 +76,18 @@ void main() {
 
     uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
 
-    [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
-        uint32_t d = (idx + tid) % (D / 4);
-        uint32_t r = (idx + tid) / (D / 4);
-        if (r < Br && d < D / 4 &&
+    [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
+        uint32_t d = (idx + tid) % (HSK / 4);
+        uint32_t r = (idx + tid) / (HSK / 4);
+        if (r < Br && d < HSK / 4 &&
             i * Br + r < N) {
             Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
         }
     }
     barrier();
 
-    ACC_TYPEV4 Of[rows_per_thread][D_per_thread / 4];
-    [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+    ACC_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4];
+    [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
         [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
             Of[r][d] = ACC_TYPEV4(0.0);
         }
@@ -131,10 +133,10 @@ void main() {
     [[dont_unroll]]
     for (uint32_t j = start_j; j < end_j; ++j) {
 
-        [[unroll]] for (uint32_t idx = 0; idx < Bc * D / 4; idx += gl_WorkGroupSize.x) {
-            uint32_t d = (idx + tid) % (D / 4);
-            uint32_t c = (idx + tid) / (D / 4);
-            if (c < Bc && d < D / 4) {
+        [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
+            uint32_t d = (idx + tid) % (HSK / 4);
+            uint32_t c = (idx + tid) / (HSK / 4);
+            if (c < Bc && d < HSK / 4) {
 #if BLOCK_SIZE > 1
                 uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
                 uint ib = coord / BLOCK_SIZE;
@@ -149,14 +151,14 @@ void main() {
         }
         barrier();
 
-        // K * Q^T -> S^T: Bc x D * D x Br -> Bc x Br
-        // Bc split across workgroup (four subgroups), loop over D in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
+        // K * Q^T -> S^T: Bc x HSK * HSK x Br -> Bc x Br
+        // Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
         // This is written transposed in order to allow for N being 8 if implementations need it
         coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
         coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
         coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
 
-        for (uint32_t d = 0; d < D / 16; ++d) {
+        for (uint32_t d = 0; d < HSK / 16; ++d) {
             coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
 
             uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
@@ -206,7 +208,7 @@ void main() {
             eMf[r] = exp(Moldf - Mf[r]);
         }
 
-        [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+        [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
             [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
                 Of[r][d] = float16_t(eMf[r]) * Of[r][d];
             }
@@ -221,7 +223,7 @@ void main() {
                 Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]);
                 Lf[r] += Pf[r];
             }
-            [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+            [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
 #if BLOCK_SIZE > 1
                 uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
                 uint ib = coord / BLOCK_SIZE;
@@ -284,7 +286,7 @@ void main() {
     }
 
     [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
-        [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+        [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
 
             Of[r][d] = float16_t(eMf[r]) * Of[r][d];
             tmpshv4[tid] = Of[r][d];
@@ -304,11 +306,11 @@ void main() {
     // If there is split_k, then the split_k resolve shader does the final
     // division by L. Store the intermediate O value and per-row m and L values.
     if (p.k_num > 1) {
-        uint32_t o_offset = D * p.ne1 * (split_k_index + iq3 * p.k_num);
+        uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
 
         [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
             if (tile_row(r) < N) {
-                [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+                [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
                     [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
                         perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
                     }
@@ -316,7 +318,7 @@ void main() {
             }
         }
 
-        o_offset = D * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
+        o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
         [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
             if (tile_row(r) < N) {
                 perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
@@ -332,18 +334,18 @@ void main() {
         Lfrcp[r] = 1.0 / Lf[r];
     }
 
-    [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+    [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
         [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
             Of[r][d] *= float16_t(Lfrcp[r]);
         }
     }
 
-    uint32_t o_offset = iq3*p.ne2*p.ne1*D;
+    uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
 
     if (p.gqa_ratio > 1) {
         [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
             if (tile_row(r) < N) {
-                [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+                [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
                     [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
                         perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
                     }
@@ -353,9 +355,9 @@ void main() {
     } else {
         [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
             if (i * Br + tile_row(r) < N) {
-                [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+                [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
                     [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
-                        data_o[o_offset + iq2 * D + (i * Br + tile_row(r)) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
+                        data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
                     }
                 }
             }
index cf47bd53e28bb68fe0a4e12afa1ae35d24e378fe..8792d5195e45731ff2629d5ddd173066b49051ff 100644 (file)
@@ -61,8 +61,8 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele
 // Rows index by Q's dimension 2, and the first N rows are valid.
 D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
 {
-    if (r < N && c < D) {
-        uint32_t offset = (iq2 + r) * D + c;
+    if (r < N && c < HSV) {
+        uint32_t offset = (iq2 + r) * HSV + c;
         data_o[o_offset + offset] = D_TYPE(elem);
     }
     return elem;
@@ -86,9 +86,9 @@ void main() {
     tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE);
 #endif
 
-    tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, D);
-    tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
-    tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
+    tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, HSK);
+    tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, HSK);
+    tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, HSV);
 
     // hint to the compiler that strides are aligned for the aligned variant of the shader
     if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
@@ -104,16 +104,16 @@ void main() {
     tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
     tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
 
-    coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Q;
-    coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Qf16;
+    coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseAccumulator> Q;
+    coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseA> Qf16;
 
     uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
-    coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, D));
+    coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK));
 
-    Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA>(Q);
+    Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseA>(Q);
     Qf16 *= float16_t(p.scale);
 
-    coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(0);
+    coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(0);
 
     coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
 
@@ -140,10 +140,10 @@ void main() {
 
         coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
 
-        coopmat<float16_t, gl_ScopeWorkgroup, D, Bc, gl_MatrixUseB> K_T;
+        coopmat<float16_t, gl_ScopeWorkgroup, HSK, Bc, gl_MatrixUseB> K_T;
 
         uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
-        coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, D), tensorViewTranspose DECODEFUNC);
+        coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK), tensorViewTranspose DECODEFUNC);
         S = coopMatMulAdd(Qf16, K_T, S);
 
         if (p.logit_softcap != 0.0f) {
@@ -208,42 +208,42 @@ void main() {
         rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0);
         rowsum = coopMatMulAdd(P_A, One, rowsum);
 
-        coopmat<float16_t, gl_ScopeWorkgroup, Bc, D, gl_MatrixUseB> V;
+        coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV, gl_MatrixUseB> V;
         uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
-        coopMatLoadTensorNV(V,  data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, D) DECODEFUNC);
+        coopMatLoadTensorNV(V,  data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV) DECODEFUNC);
 
         L = eM*L + rowsum;
 
         // This is the "diagonal" matrix in the paper, but since we do componentwise
         // multiply rather than matrix multiply it has the diagonal element smeared
         // across the row
-        coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> eMdiag;
+        coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> eMdiag;
 
         // resize eM by using smear/reduce
         coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
 
         // multiply with fp16 accumulation, then add to O.
-        coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(0);
+        coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(0);
         PV = coopMatMulAdd(P_A, V, PV);
 
-        O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(PV);
+        O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(PV);
     }
 
     // If there is split_k, then the split_k resolve shader does the final
     // division by L. Store the intermediate O value and per-row m and L values.
     if (p.k_num > 1) {
-        coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
+        coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(O);
 
-        uint32_t o_offset = D * p.ne1 * (split_k_index + iq3 * p.k_num);
+        uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
         coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
 
-        o_offset = D * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
+        o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
         coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
         coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
         return;
     }
 
-    coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Ldiag;
+    coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> Ldiag;
 
     // resize L by using smear/reduce
     coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);
@@ -255,18 +255,18 @@ void main() {
 
     O = Ldiag*O;
 
-    uint32_t o_offset = iq3*p.ne2*p.ne1*D;
+    uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
 
-    coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
+    coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(O);
     if (p.gqa_ratio > 1) {
         coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
     } else {
         tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV);
-        tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D);
+        tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, HSV);
 
         // permute dimensions
         tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
 
-        coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, D), tensorViewPermute);
+        coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV), tensorViewPermute);
     }
 }