uint32_t subgroup_size;
uint32_t shader_core_count;
bool uma;
+ bool coopmat2;
size_t idx;
vk_matmul_pipeline2 pipeline_matmul_f16_f32;
vk_pipeline pipeline_matmul_split_k_reduce;
+ vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT];
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT];
vk_matmul_pipeline pipeline_matmul_id_f32;
vk_pipeline pipeline_timestep_embedding_f32;
vk_pipeline pipeline_pool2d_f32;
+ // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
+ vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
+ vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2];
+ vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2];
+ vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2];
+ vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2];
+ vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2];
+
std::unordered_map<std::string, vk_pipeline_ref> pipelines;
std::unordered_map<std::string, uint64_t> pipeline_descriptor_set_requirements;
uint32_t nei0; uint32_t ne11;
};
+struct vk_flash_attn_push_constants {
+ uint32_t N;
+ uint32_t KV;
+
+ uint32_t ne1;
+ uint32_t ne2;
+ uint32_t ne3;
+
+ uint32_t neq2;
+ uint32_t neq3;
+ uint32_t nek2;
+ uint32_t nek3;
+ uint32_t nev2;
+ uint32_t nev3;
+ uint32_t nem1;
+
+ uint32_t nb02;
+ uint32_t nb03;
+ uint32_t nb12;
+ uint32_t nb13;
+ uint32_t nb22;
+ uint32_t nb23;
+ uint32_t nb31;
+
+ float scale;
+ float max_bias;
+ float logit_softcap;
+
+ uint32_t mask;
+ uint32_t n_head_log2;
+ float m0;
+ float m1;
+};
+
struct vk_op_push_constants {
uint32_t KX;
uint32_t KY;
);
}
+// number of rows/cols for flash attention shader
+static constexpr uint32_t flash_attention_num_small_rows = 32;
+static std::array<uint32_t, 2> fa_rows_cols(uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
+ GGML_UNUSED(clamp);
+
+ // small rows, large cols
+ if (small_rows) {
+ return {flash_attention_num_small_rows, 128};
+ }
+ // small cols to reduce register count
+ if (ggml_is_quantized(type) || D == 256) {
+ return {64, 32};
+ }
+ return {64, 64};
+};
+
+
static void ggml_vk_load_shaders(vk_device& device) {
VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
// mulmat
std::vector<uint32_t> l_warptile, m_warptile, s_warptile,
- l_warptile_mmq, m_warptile_mmq, s_warptile_mmq;
+ l_warptile_mmq, m_warptile_mmq, s_warptile_mmq,
+ l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k,
+ l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid;
std::array<uint32_t, 3> l_wg_denoms, m_wg_denoms, s_wg_denoms,
- l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms;
- uint32_t l_align, m_align, s_align;
+ l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms,
+ l_mmq_wg_denoms_k, m_mmq_wg_denoms_k, s_mmq_wg_denoms_k,
+ l_mmqid_wg_denoms, m_mmqid_wg_denoms, s_mmqid_wg_denoms;
- l_warptile = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
- m_warptile = { 128, 64, 64, 16, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
- s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, 2, 2, device->subgroup_size };
-
- l_warptile_mmq = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
- m_warptile_mmq = { 128, 64, 64, 32, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
- s_warptile_mmq = { subgroup_size_16, 32, 32, 32, 32, 32, 2, 2, 2, device->subgroup_size };
-
- l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
- m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 };
- s_mmq_wg_denoms = s_wg_denoms = { 32, 32, 1 };
-
- l_align = 128;
- m_align = 64;
- s_align = 32;
-
- // Fallback to smaller sizes if there's not enough shared memory. Given the current shaders
- // and tile sizes, this should handle 16KB, 32KB, and 48KB+.
- // This logic doesn't explicitly account for the 12KB row_ids in the mul_mat_mat_id shaders.
- // But the numbers happen to work out for 32KB shared memory size that when using the medium
- // size there's enough room for everything, and we assert for this.
- uint32_t shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float);
- if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) {
- l_warptile = m_warptile;
- l_wg_denoms = m_wg_denoms;
- shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float);
- GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize);
- }
- if (device->properties.limits.maxComputeSharedMemorySize >= 32768) {
- // assert mul_mat_mat_id shaders will fit.
- GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize);
- }
-
- shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float);
- if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) {
- if (device->properties.limits.maxComputeSharedMemorySize == 32768) {
- l_warptile_mmq = m_warptile_mmq;
- l_mmq_wg_denoms = m_mmq_wg_denoms;
- } else {
- l_warptile_mmq = s_warptile_mmq;
- l_mmq_wg_denoms = s_mmq_wg_denoms;
+ uint32_t l_align, m_align, s_align;
+ if (device->coopmat2) {
+ // spec constants and tile sizes for non-quant matmul/matmul_id
+ l_warptile = { 256, 128, 256, 64 };
+ m_warptile = { 256, 128, 128, 64 };
+ s_warptile = { 128, 32, 16, 64 };
+ l_wg_denoms = {128, 256, 1 };
+ m_wg_denoms = {128, 128, 1 };
+ s_wg_denoms = { 32, 16, 1 };
+
+ // spec constants and tile sizes for quant matmul (non-Qi_K)
+ l_warptile_mmq = { 256, 128, 256, 64 };
+ m_warptile_mmq = { 256, 128, 128, 64 };
+ s_warptile_mmq = { 256, 128, 128, 64 };
+ l_mmq_wg_denoms = { 128, 256, 1 };
+ m_mmq_wg_denoms = { 128, 128, 1 };
+ s_mmq_wg_denoms = { 128, 128, 1 };
+
+ // spec constants and tile sizes for quant matmul (Qi_K)
+ l_warptile_mmq_k = { 256, 128, 512, 16 };
+ m_warptile_mmq_k = { 256, 128, 256, 16 };
+ s_warptile_mmq_k = { 256, 32, 128, 64 };
+ l_mmq_wg_denoms_k = { 128, 512, 1 };
+ m_mmq_wg_denoms_k = { 128, 256, 1 };
+ s_mmq_wg_denoms_k = { 32, 128, 1 };
+
+ // spec constants and tile sizes for quant matmul_id
+ l_warptile_mmqid = { 256, 128, 128, 16 };
+ m_warptile_mmqid = { 256, 128, 64, 16 };
+ s_warptile_mmqid = { 256, 64, 64, 16 };
+ l_mmqid_wg_denoms = { 128, 128, 1 };
+ m_mmqid_wg_denoms = { 128, 64, 1 };
+ s_mmqid_wg_denoms = { 64, 64, 1 };
+
+ l_align = 128;
+ m_align = 64;
+ s_align = 32;
+ } else {
+ l_warptile = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
+ m_warptile = { 128, 64, 64, 16, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
+ s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, 2, 2, device->subgroup_size };
+ l_warptile_mmq = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
+ m_warptile_mmq = { 128, 64, 64, 32, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
+ s_warptile_mmq = { subgroup_size_16, 32, 32, 32, 32, 32, 2, 2, 2, device->subgroup_size };
+ l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
+ m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 };
+ s_mmq_wg_denoms = s_wg_denoms = { 32, 32, 1 };
+ l_align = 128;
+ m_align = 64;
+ s_align = 32;
+
+ // Fallback to smaller sizes if there's not enough shared memory. Given the current shaders
+ // and tile sizes, this should handle 16KB, 32KB, and 48KB+.
+ // This logic doesn't explicitly account for the 12KB row_ids in the mul_mat_mat_id shaders.
+ // But the numbers happen to work out for 32KB shared memory size that when using the medium
+ // size there's enough room for everything, and we assert for this.
+ uint32_t shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float);
+ if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) {
+ l_warptile = m_warptile;
+ l_wg_denoms = m_wg_denoms;
+ shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float);
+ GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize);
+ }
+ if (device->properties.limits.maxComputeSharedMemorySize >= 32768) {
+ // assert mul_mat_mat_id shaders will fit.
+ GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize);
}
+
shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float);
- GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize);
- }
- if (device->properties.limits.maxComputeSharedMemorySize >= 32768) {
- // assert mul_mat_mat_id shaders will fit.
- GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize);
+ if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) {
+ if (device->properties.limits.maxComputeSharedMemorySize == 32768) {
+ l_warptile_mmq = m_warptile_mmq;
+ l_mmq_wg_denoms = m_mmq_wg_denoms;
+ } else {
+ l_warptile_mmq = s_warptile_mmq;
+ l_mmq_wg_denoms = s_mmq_wg_denoms;
+ }
+ shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float);
+ GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize);
+ }
+ if (device->properties.limits.maxComputeSharedMemorySize >= 32768) {
+ // assert mul_mat_mat_id shaders will fit.
+ GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize);
+ }
}
device->pipeline_matmul_f32 = std::make_shared<vk_matmul_pipeline_struct>();
compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint, parameter_count, push_constant_size, wg_denoms, specialization_constants, align, disable_robustness));
};
+#if defined(VK_NV_cooperative_matrix2)
+ if (device->coopmat2) {
+
+ auto const &fa_wg_denoms = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
+ return {fa_rows_cols(D, clamp, type, small_rows)[0], 1, 1};
+ };
+
+ auto const &fa_spec_constants = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
+ // For large number of rows, 128 invocations seems to work best.
+ // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
+ // can't use 256 for D==80.
+ uint32_t wg_size = (small_rows && (D % 32) == 0) ? 256 : 128;
+ auto rows_cols = fa_rows_cols(D, clamp, type, small_rows);
+ return {wg_size, rows_cols[0], rows_cols[1], (D), clamp};
+ };
+
+#define CREATE_FA2(TYPE, NAMELC, D) \
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1); \
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]); \
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1); \
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]); \
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1); \
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]); \
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1); \
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]); \
+
+#define CREATE_FA(TYPE, NAMELC) \
+ CREATE_FA2(TYPE, NAMELC, 64) \
+ CREATE_FA2(TYPE, NAMELC, 80) \
+ CREATE_FA2(TYPE, NAMELC, 96) \
+ CREATE_FA2(TYPE, NAMELC, 112) \
+ CREATE_FA2(TYPE, NAMELC, 128) \
+ CREATE_FA2(TYPE, NAMELC, 256)
+
+ CREATE_FA(GGML_TYPE_F16, f16)
+ CREATE_FA(GGML_TYPE_Q4_0, q4_0)
+ CREATE_FA(GGML_TYPE_Q4_1, q4_1)
+ CREATE_FA(GGML_TYPE_Q5_0, q5_0)
+ CREATE_FA(GGML_TYPE_Q5_1, q5_1)
+ CREATE_FA(GGML_TYPE_Q8_0, q8_0)
+ // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently
+ //CREATE_FA(GGML_TYPE_Q2_K, q2_k)
+ //CREATE_FA(GGML_TYPE_Q3_K, q3_k)
+ //CREATE_FA(GGML_TYPE_Q4_K, q4_k)
+ //CREATE_FA(GGML_TYPE_Q5_K, q5_k)
+ //CREATE_FA(GGML_TYPE_Q6_K, q6_k)
+ CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl)
+#undef CREATE_FA
+
+ // Create 6 variants, {s,m,l}x{unaligned,aligned}
+#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
+
+ // Create 2 variants, {f16,f32} accumulator
+#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
+ CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
+ CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
+
+ CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
+ CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
+
+ CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3)
+ CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3)
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
+
+ CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM(pipeline_matmul_id_f16, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
+
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+#undef CREATE_MM
+#undef CREATE_MM2
+ } else
+#endif
if (device->fp16) {
// Create 6 variants, {s,m,l}x{unaligned,aligned}
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
device->physical_device = physical_devices[dev_num];
const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties();
+ bool fp16_storage = false;
+ bool fp16_compute = false;
bool maintenance4_support = false;
bool sm_builtins = false;
+ bool pipeline_robustness = false;
+ bool coopmat2_support = false;
// Check if maintenance4 is supported
for (const auto& properties : ext_props) {
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
maintenance4_support = true;
+ } else if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
+ fp16_storage = true;
+ } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
+ fp16_compute = true;
} else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) {
sm_builtins = true;
+ } else if (strcmp("VK_EXT_pipeline_robustness", properties.extensionName) == 0) {
+ pipeline_robustness = true;
+ } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
+ !getenv("GGML_VULKAN_DISABLE_COOPMAT2")) {
+ coopmat2_support = true;
}
}
last_struct = (VkBaseOutStructure *)&sm_props;
}
+#if defined(VK_NV_cooperative_matrix2)
+ vk::PhysicalDeviceCooperativeMatrix2PropertiesNV coopmat2_props;
+ if (coopmat2_support) {
+ last_struct->pNext = (VkBaseOutStructure *)&coopmat2_props;
+ last_struct = (VkBaseOutStructure *)&coopmat2_props;
+ }
+#endif
+
device->physical_device.getProperties2(&props2);
device->properties = props2.properties;
device->shader_core_count = 0;
}
- bool fp16_storage = false;
- bool fp16_compute = false;
- bool pipeline_robustness = false;
-
- for (const auto& properties : ext_props) {
- if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
- fp16_storage = true;
- } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
- fp16_compute = true;
- } else if (strcmp("VK_EXT_pipeline_robustness", properties.extensionName) == 0) {
- pipeline_robustness = true;
- }
- }
-
const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
const bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
vk11_features.pNext = &vk12_features;
+ last_struct = (VkBaseOutStructure *)&vk12_features;
+
VkPhysicalDevicePipelineRobustnessFeaturesEXT pl_robustness_features;
pl_robustness_features.pNext = nullptr;
pl_robustness_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_ROBUSTNESS_FEATURES_EXT;
pl_robustness_features.pipelineRobustness = VK_FALSE;
if (pipeline_robustness) {
- vk12_features.pNext = &pl_robustness_features;
+ last_struct->pNext = (VkBaseOutStructure *)&pl_robustness_features;
+ last_struct = (VkBaseOutStructure *)&pl_robustness_features;
device_extensions.push_back("VK_EXT_pipeline_robustness");
}
+#if defined(VK_NV_cooperative_matrix2)
+ VkPhysicalDeviceCooperativeMatrix2FeaturesNV coopmat2_features {};
+ coopmat2_features.pNext = nullptr;
+ coopmat2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_2_FEATURES_NV;
+ if (coopmat2_support) {
+ last_struct->pNext = (VkBaseOutStructure *)&coopmat2_features;
+ last_struct = (VkBaseOutStructure *)&coopmat2_features;
+ device_extensions.push_back("VK_NV_cooperative_matrix2");
+ }
+#endif
+
vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
device->fp16 = device->fp16 && vk12_features.shaderFloat16;
device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
+ if (coopmat2_support) {
+#if defined(VK_NV_cooperative_matrix2)
+ if (coopmat2_features.cooperativeMatrixWorkgroupScope &&
+ coopmat2_features.cooperativeMatrixFlexibleDimensions &&
+ coopmat2_features.cooperativeMatrixReductions &&
+ coopmat2_features.cooperativeMatrixConversions &&
+ coopmat2_features.cooperativeMatrixPerElementOperations &&
+ coopmat2_features.cooperativeMatrixTensorAddressing &&
+ coopmat2_features.cooperativeMatrixBlockLoads &&
+ vk12_features.bufferDeviceAddress) {
+
+ std::vector<VkCooperativeMatrixFlexibleDimensionsPropertiesNV> flexible_dimensions;
+ uint32_t count = 0;
+
+ PFN_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV
+ _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV =
+ (PFN_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV)
+ vk_instance.instance.getProcAddr("vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV");
+
+ _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(device->physical_device, &count, nullptr);
+
+ VkCooperativeMatrixFlexibleDimensionsPropertiesNV empty_prop {};
+ empty_prop.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_FLEXIBLE_DIMENSIONS_PROPERTIES_NV;
+ flexible_dimensions.resize(count, empty_prop);
+
+ _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(device->physical_device, &count, flexible_dimensions.data());
+
+ bool found_fp16_128 = false,
+ found_fp16_256 = false,
+ found_fp32_128 = false,
+ found_fp32_256 = false;
+ // need to support fp16*fp16 with fp16/fp32 accumulator, for workgroupsize 128
+ // with 32x16x16 and 256 with 32x32x16.
+ for (auto &prop : flexible_dimensions) {
+ if (prop.saturatingAccumulation == VK_FALSE &&
+ prop.scope == VK_SCOPE_WORKGROUP_KHR &&
+ prop.AType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
+ prop.BType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
+
+ if (prop.workgroupInvocations == 128 &&
+ prop.MGranularity <= 32 &&
+ prop.NGranularity <= 16 &&
+ prop.KGranularity <= 16) {
+ if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
+ prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
+ found_fp16_128 = true;
+ }
+ if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
+ prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) {
+ found_fp32_128 = true;
+ }
+ }
+ if (prop.workgroupInvocations == 256 &&
+ prop.MGranularity <= 32 &&
+ prop.NGranularity <= 32 &&
+ prop.KGranularity <= 16) {
+ if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
+ prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
+ found_fp16_256 = true;
+ }
+ if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
+ prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) {
+ found_fp32_256 = true;
+ }
+ }
+ }
+ }
+ if (found_fp16_128 && found_fp16_256 &&
+ found_fp32_128 && found_fp32_256 &&
+ coopmat2_props.cooperativeMatrixFlexibleDimensionsMaxDimension >= 512) {
+ device->coopmat2 = true;
+ }
+ }
+#endif
+ }
+
if (!vk11_features.storageBuffer16BitAccess) {
std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl;
throw std::runtime_error("Unsupported device");
return ctx->device->pipeline_dequant[type];
}
-static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type) {
+static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) {
VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_pipeline(" << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
return ctx->device->pipeline_matmul_f32;
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
return ctx->device->pipeline_matmul_f32_f16;
}
- if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
- return ctx->device->pipeline_matmul_f16_f32.f32acc;
- }
- if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
- return ctx->device->pipeline_matmul_f16.f32acc;
+ if (prec == GGML_PREC_DEFAULT && ctx->device->coopmat2) {
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
+ return ctx->device->pipeline_matmul_f16_f32.f16acc;
+ }
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
+ return ctx->device->pipeline_matmul_f16.f16acc;
+ }
+ } else {
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
+ return ctx->device->pipeline_matmul_f16_f32.f32acc;
+ }
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
+ return ctx->device->pipeline_matmul_f16.f32acc;
+ }
}
- if (src1_type != GGML_TYPE_F32) {
+ if (src1_type != GGML_TYPE_F32 && !ctx->device->coopmat2) {
return nullptr;
}
return nullptr;
}
+ if (ctx->device->coopmat2) {
+ assert(src1_type == GGML_TYPE_F16);
+ return ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f16acc;
+ }
return ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc;
}
break;
}
+ if (ctx->device->coopmat2) {
+ if ((m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) {
+ return aligned ? mmp->a_l : mmp->l;
+ }
+ if ((m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) {
+ return aligned ? mmp->a_m : mmp->m;
+ }
+ return aligned ? mmp->a_s : mmp->s;
+ }
+
if (m <= 32 || n <= 32) {
return aligned ? mmp->a_s : mmp->s;
}
}
const bool x_non_contig = !ggml_vk_dim01_contiguous(src0);
- const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
+ // Reformat and convert to fp16 if src1 is non-contiguous, or for coopmat2 for better perf
+ const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
+ !ggml_vk_dim01_contiguous(src1);
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
- vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type);
+ vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]);
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
if (qx_needs_dequant) {
// Fall back to dequant + f16 mulmat
- mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16);
+ mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params[0]);
}
// Not implemented
}
}
+static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, ggml_tensor * dst, bool dryrun = false) {
+ VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3];
+ std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3];
+ std::cerr << "), (" << v << ", name=" << v->name << ", type=" << v->type << ", ne0=" << v->ne[0] << ", ne1=" << v->ne[1] << ", ne2=" << v->ne[2] << ", ne3=" << v->ne[3] << ", nb0=" << v->nb[0] << ", nb1=" << v->nb[1] << ", nb2=" << v->nb[2] << ", nb3=" << v->nb[3];
+ std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
+ std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
+
+ GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
+ GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
+ GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
+ GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
+ GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
+ GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
+
+ const uint32_t nem1 = mask ? mask->ne[1] : 0;
+ const uint32_t nbm1 = mask ? mask->nb[1] : 0;
+
+ const uint32_t D = neq0;
+ const uint32_t N = neq1;
+ const uint32_t KV = nek1;
+
+ GGML_ASSERT(ne0 == D);
+ GGML_ASSERT(ne2 == N);
+
+ // input tensor rows must be contiguous
+ GGML_ASSERT(nbq0 == ggml_type_size(q->type));
+ GGML_ASSERT(nbk0 == ggml_type_size(k->type));
+ GGML_ASSERT(nbv0 == ggml_type_size(v->type));
+
+ GGML_ASSERT(neq0 == D);
+ GGML_ASSERT(nek0 == D);
+ GGML_ASSERT(nev0 == D);
+
+ GGML_ASSERT(neq1 == N);
+ GGML_ASSERT(nev0 == D);
+
+ GGML_ASSERT(nev1 == nek1);
+
+ // dst cannot be transposed or permuted
+ GGML_ASSERT(nb0 == sizeof(float));
+ GGML_ASSERT(nb0 <= nb1);
+ GGML_ASSERT(nb1 <= nb2);
+ GGML_ASSERT(nb2 <= nb3);
+
+ assert(dst->type == GGML_TYPE_F32);
+ assert(q->type == GGML_TYPE_F32);
+ assert(k->type == v->type);
+
+ vk_pipeline *pipelines;
+ // XXX TODO other backends may be changing accumulator precision to default to f32 soon
+ bool f32acc = dst->op_params[3] == GGML_PREC_F32;
+ bool small_rows = N <= flash_attention_num_small_rows;
+ switch (D) {
+ case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
+ case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
+ case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break;
+ case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break;
+ case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break;
+ case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break;
+ default:
+ assert(!"unsupported D value");
+ return;
+ }
+ assert(pipelines);
+
+ bool aligned = (KV % pipelines[1]->align) == 0;
+ vk_pipeline pipeline = pipelines[aligned];
+ assert(pipeline);
+
+ if (dryrun) {
+ // Request descriptor sets
+ ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
+ return;
+ }
+
+ float scale = 1.0f;
+ float max_bias = 0.0f;
+ float logit_softcap = 0.0f;
+
+ memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
+ memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
+ memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float));
+
+ if (logit_softcap != 0) {
+ scale /= logit_softcap;
+ }
+
+ const uint32_t n_head_kv = neq2;
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+ ggml_vk_sync_buffers(subctx);
+
+ vk_buffer d_Q, d_K, d_V, d_D, d_M;
+ uint64_t q_buf_offset, k_buf_offset, v_buf_offset, d_buf_offset, m_buf_offset;
+
+ bool Q_uma = false, K_uma = false, V_uma = false, D_uma = false, M_uma = false;
+
+ if (ctx->device->uma) {
+ ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset);
+ ggml_vk_host_get(ctx->device, k->data, d_K, q_buf_offset);
+ ggml_vk_host_get(ctx->device, v->data, d_V, q_buf_offset);
+ ggml_vk_host_get(ctx->device, dst->data, d_D, q_buf_offset);
+ Q_uma = d_Q != nullptr;
+ K_uma = d_K != nullptr;
+ V_uma = d_V != nullptr;
+ D_uma = d_D != nullptr;
+ if (mask) {
+ ggml_vk_host_get(ctx->device, mask->data, d_M, q_buf_offset);
+ M_uma = d_M != nullptr;
+ }
+ }
+
+
+ ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
+ ggml_backend_vk_buffer_context * q_buf_ctx = (ggml_backend_vk_buffer_context *)q->buffer->context;
+ ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context;
+ ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context;
+
+ if (!Q_uma) {
+ d_Q = q_buf_ctx->dev_buffer;
+ q_buf_offset = vk_tensor_offset(q) + q->view_offs;
+ }
+ if (!K_uma) {
+ d_K = k_buf_ctx->dev_buffer;
+ k_buf_offset = vk_tensor_offset(k) + k->view_offs;
+ }
+ if (!V_uma) {
+ d_V = v_buf_ctx->dev_buffer;
+ v_buf_offset = vk_tensor_offset(v) + v->view_offs;
+ }
+ if (!D_uma) {
+ d_D = d_buf_ctx->dev_buffer;
+ d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
+ }
+
+ if (!M_uma) {
+ d_M = d_Q;
+ m_buf_offset = q_buf_offset;
+ if (mask) {
+ ggml_backend_vk_buffer_context * m_buf_ctx = (ggml_backend_vk_buffer_context*)mask->buffer->context;
+ d_M = m_buf_ctx->dev_buffer;
+ m_buf_offset = vk_tensor_offset(mask) + mask->view_offs;
+ }
+ }
+
+ const vk_flash_attn_push_constants pc = { N, KV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, (uint32_t)neq2, (uint32_t)neq3, (uint32_t)nek2, (uint32_t)nek3, (uint32_t)nev2, (uint32_t)nev3, nem1, (uint32_t)nbq2, (uint32_t)nbq3, (uint32_t)nbk2, (uint32_t)nbk3, (uint32_t)nbv2, (uint32_t)nbv3, nbm1, scale, max_bias, logit_softcap, mask != nullptr, n_head_log2, m0, m1 };
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
+ {
+ vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
+ vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
+ vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
+ vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
+ vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
+ },
+ sizeof(vk_flash_attn_push_constants), &pc, { (uint32_t)neq1, (uint32_t)neq2, (uint32_t)neq3 });
+}
+
static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
switch (op) {
case GGML_OP_GET_ROWS:
ggml_vk_buffer_write(d_Y, 0, y, sizeof(Y_TYPE) * k * n * batch);
vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
+ ggml_vk_ctx_begin(ctx->device, subctx);
for (size_t i = 0; i < num_it; i++) {
- ggml_vk_ctx_begin(ctx->device, subctx);
ggml_vk_matmul(
ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k),
m, n, k,
k, k, m, k*m, k*n, m*n,
split_k, batch, batch, batch, 1, 1
);
- ggml_vk_ctx_end(subctx);
}
+ ggml_vk_ctx_end(subctx);
auto begin = std::chrono::high_resolution_clock::now();
ggml_vk_submit(subctx, ctx->fence);
ggml_vk_buffer_write(y_buf, 0, y, y_sz);
vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
+ ggml_vk_ctx_begin(ctx->device, subctx);
for (size_t i = 0; i < num_it; i++) {
- ggml_vk_ctx_begin(ctx->device, subctx);
ggml_vk_matmul(
ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k),
m, n, k,
k, k, m, k*m, k*n, m*n,
split_k, batch, batch, batch, 1, 1
);
- ggml_vk_ctx_end(subctx);
}
+ ggml_vk_ctx_end(subctx);
auto begin = std::chrono::high_resolution_clock::now();
4096, 512, 11008,
32000, 512, 4096,
};
- const size_t num_it = 1;
+ const size_t num_it = 100;
+
for (size_t i = 0; i < vals.size(); i += 3) {
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0);
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1);
const ggml_tensor * src0 = node->src[0];
const ggml_tensor * src1 = node->src[1];
const ggml_tensor * src2 = node->src[2];
+ const ggml_tensor * src3 = node->src[3];
switch (node->op) {
// Return on empty ops to avoid generating a compute_ctx and setting exit_tensor
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_POOL_2D:
case GGML_OP_LEAKY_RELU:
+ case GGML_OP_FLASH_ATTN_EXT:
break;
default:
std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
case GGML_OP_MUL_MAT_ID:
ggml_vk_mul_mat_id(ctx, compute_ctx, src0, src1, src2, node, dryrun);
+ break;
+
+ case GGML_OP_FLASH_ATTN_EXT:
+ ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node, dryrun);
+
break;
default:
return false;
break;
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
+ case GGML_OP_FLASH_ATTN_EXT:
buf = tensor->buffer;
break;
return true;
} break;
+ case GGML_OP_FLASH_ATTN_EXT:
+ {
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
+ if (!ggml_vk_get_device(ctx->device)->coopmat2) {
+ return false;
+ }
+ switch (op->src[0]->ne[0]) {
+ case 64:
+ case 80:
+ case 96:
+ case 112:
+ case 128:
+ case 256:
+ break;
+ default:
+ return false;
+ }
+ if (op->src[0]->type != GGML_TYPE_F32) {
+ return false;
+ }
+ if (op->type != GGML_TYPE_F32) {
+ return false;
+ }
+ if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
+ return false;
+ }
+ // It's straightforward to support different K/V dequant, but would
+ // significantly increase the number of pipelines
+ if (op->src[1]->type != op->src[2]->type) {
+ return false;
+ }
+ switch (op->src[1]->type) {
+ case GGML_TYPE_F16:
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q5_0:
+ case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q8_0:
+ // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently
+ //case GGML_TYPE_Q2_K:
+ //case GGML_TYPE_Q3_K:
+ //case GGML_TYPE_Q4_K:
+ //case GGML_TYPE_Q5_K:
+ //case GGML_TYPE_Q6_K:
+ case GGML_TYPE_IQ4_NL:
+ break;
+ default:
+ return false;
+ }
+ return true;
+ }
case GGML_OP_GET_ROWS:
{
switch (op->src[0]->type) {
ggml_tensor * src0 = tensor->src[0];
ggml_tensor * src1 = tensor->src[1];
ggml_tensor * src2 = tensor->src[2];
+ ggml_tensor * src3 = tensor->src[3];
struct ggml_init_params iparams = {
/*.mem_size =*/ 2ul*1024ul*1024ul*1024ul,
struct ggml_tensor * src0_clone = nullptr;
struct ggml_tensor * src1_clone = nullptr;
struct ggml_tensor * src2_clone = nullptr;
+ struct ggml_tensor * src3_clone = nullptr;
struct ggml_tensor * tensor_clone = nullptr;
size_t src0_size;
size_t src1_size;
size_t src2_size;
+ size_t src3_size;
void * src0_buffer = nullptr;
void * src1_buffer = nullptr;
void * src2_buffer = nullptr;
+ void * src3_buffer = nullptr;
if (src0 != nullptr) {
src0_clone = ggml_dup_tensor(ggml_ctx, src0);
ggml_vk_print_tensor(src2, "src2");
}
}
+ if (src3 != nullptr) {
+ src3_clone = ggml_dup_tensor(ggml_ctx, src3);
+
+ src3_size = ggml_nbytes(src3);
+
+ src3_buffer = malloc(src3_size);
+ src3_clone->data = src3_buffer;
+ if (ggml_backend_buffer_is_host(src3->buffer)) {
+ memcpy(src3_clone->data, src3->data, src3_size);
+ memcpy(src3_clone->nb, src3->nb, sizeof(size_t) * GGML_MAX_DIMS);
+ } else if (ggml_backend_buffer_is_vk(src3->buffer)) {
+ ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src3->buffer->context;
+ vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
+ uint64_t offset = vk_tensor_offset(src3) + src3->view_offs;
+ if (!ggml_is_contiguous(src3) && ggml_vk_dim01_contiguous(src3)) {
+ for (int i3 = 0; i3 < src3->ne[3]; i3++) {
+ for (int i2 = 0; i2 < src3->ne[2]; i2++) {
+ const int idx = i3*src3->ne[2] + i2;
+ ggml_vk_buffer_read(buffer_gpu, offset + idx * src3->nb[2], ((char *)src3_clone->data + idx * src3_clone->nb[2]), src3->ne[1] * src3->nb[1]);
+ }
+ }
+
+ src3_clone->nb[0] = src3->nb[0];
+ src3_clone->nb[1] = src3->nb[1];
+ for (int i = 2; i < GGML_MAX_DIMS; i++) {
+ src3_clone->nb[i] = src3_clone->nb[i - 1]*src3_clone->ne[i - 1];
+ }
+ } else {
+ if (offset + src3_size >= buffer_gpu->size) {
+ src3_size = buffer_gpu->size - offset;
+ }
+ ggml_vk_buffer_read(buffer_gpu, offset, src3_clone->data, src3_size);
+ memcpy(src3_clone->nb, src3->nb, sizeof(size_t) * GGML_MAX_DIMS);
+ }
+ } else {
+ GGML_ABORT("fatal error");
+ }
+
+ if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
+ ggml_vk_print_tensor(src3, "src3");
+ }
+ }
- if (tensor->op == GGML_OP_MUL_MAT) {
+ if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
+ const float *params = (const float *)tensor->op_params;
+ tensor_clone = ggml_flash_attn_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, src3_clone, params[0], params[1], params[2]);
+ } else if (tensor->op == GGML_OP_MUL_MAT) {
tensor_clone = ggml_mul_mat(ggml_ctx, src0_clone, src1_clone);
} else if (tensor->op == GGML_OP_MUL_MAT_ID) {
tensor_clone = ggml_mul_mat_id(ggml_ctx, src0_clone, src1_clone, src2_clone);
--- /dev/null
+#version 450
+
+#extension GL_EXT_control_flow_attributes : enable
+#extension GL_EXT_shader_16bit_storage : require
+
+#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
+#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
+#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
+#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
+
+#extension GL_KHR_memory_scope_semantics : enable
+#extension GL_KHR_cooperative_matrix : enable
+#extension GL_NV_cooperative_matrix2 : enable
+#extension GL_EXT_buffer_reference : enable
+#extension GL_KHR_shader_subgroup_ballot : enable
+#extension GL_KHR_shader_subgroup_vote : enable
+#extension GL_EXT_null_initializer : enable
+
+#include "types.comp"
+#include "dequant_funcs_cm2.comp"
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+layout (constant_id = 1) const uint32_t Br = 32;
+layout (constant_id = 2) const uint32_t Bc = 32;
+layout (constant_id = 3) const uint32_t D = 32;
+layout (constant_id = 4) const uint32_t Clamp = gl_CooperativeMatrixClampModeConstantNV;
+
+layout (push_constant) uniform parameter {
+ uint32_t N;
+ uint32_t KV;
+
+ uint32_t ne1;
+ uint32_t ne2;
+ uint32_t ne3;
+
+ uint32_t neq2;
+ uint32_t neq3;
+ uint32_t nek2;
+ uint32_t nek3;
+ uint32_t nev2;
+ uint32_t nev3;
+ uint32_t nem1;
+
+ uint32_t nb02;
+ uint32_t nb03;
+ uint32_t nb12;
+ uint32_t nb13;
+ uint32_t nb22;
+ uint32_t nb23;
+ uint32_t nb31;
+
+ float scale;
+ float max_bias;
+ float logit_softcap;
+
+ uint32_t mask;
+ uint32_t n_head_log2;
+ float m0;
+ float m1;
+} p;
+
+layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
+layout (binding = 1) readonly buffer K {uint8_t data_k[];};
+layout (binding = 2) readonly buffer V {uint8_t data_v[];};
+layout (binding = 3) readonly buffer M {uint8_t data_m[];};
+layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
+
+#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
+
+ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
+ return max(x, y);
+}
+
+ACC_TYPE smearReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
+ return x;
+}
+
+// Replace matrix elements >= numRows or numCols with 'replace'
+ACC_TYPE replacePadding(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem, const in ACC_TYPE replace, const in uint32_t numRows, const in uint32_t numCols) {
+ if (row >= numRows || col >= numCols) {
+ return replace;
+ }
+ return elem;
+}
+
+ACC_TYPE Exp(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem)
+{
+ return exp(elem);
+}
+
+ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem0, const in ACC_TYPE elem1)
+{
+ return max(elem0, elem1);
+}
+
+#if defined(BLOCK_SIZE)
+#define DECODEFUNC , DEQUANTFUNC
+#else
+#define DECODEFUNC
+#endif
+
+void main() {
+#if defined(DATA_A_IQ4_NL)
+ init_iq4nl_shmem();
+#endif
+
+ const uint32_t N = p.N;
+ const uint32_t KV = p.KV;
+
+ const uint32_t Tr = CEIL_DIV(N, Br);
+ const uint32_t Tc = CEIL_DIV(KV, Bc);
+
+ const uint32_t i = gl_WorkGroupID.x;
+
+ const uint32_t iq2 = gl_WorkGroupID.y;
+ const uint32_t iq3 = gl_WorkGroupID.z;
+
+ // broadcast factors
+ const uint32_t rk2 = p.neq2/p.nek2;
+ const uint32_t rk3 = p.neq3/p.nek3;
+
+ const uint32_t rv2 = p.neq2/p.nev2;
+ const uint32_t rv3 = p.neq3/p.nev3;
+
+ // k indices
+ const uint32_t ik3 = iq3 / rk3;
+ const uint32_t ik2 = iq2 / rk2;
+
+ // v indices
+ const uint32_t iv3 = iq3 / rv3;
+ const uint32_t iv2 = iq2 / rv2;
+
+ tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
+ tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp);
+ tensorLayoutNV<2, Clamp> tensorLayoutV = createTensorLayoutNV(2, Clamp);
+
+ tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
+
+#if defined(BLOCK_SIZE)
+ tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE);
+ tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE);
+#endif
+
+ tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, D);
+ tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
+ tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
+
+ coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Q;
+ coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Qf16;
+
+ uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
+ coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, D));
+
+ Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA>(Q);
+ Qf16 *= float16_t(p.scale);
+
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(0);
+
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
+
+ L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
+ M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(-1.0/0.0);
+
+ ACC_TYPE slope = ACC_TYPE(1.0);
+
+ // ALiBi
+ if (p.max_bias > 0.0f) {
+ const uint32_t h = iq2;
+
+ const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
+ const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
+
+ slope = pow(base, ACC_TYPE(exph));
+ }
+
+ [[dont_unroll]]
+ for (uint32_t j = 0; j < Tc; ++j) {
+
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
+
+ coopmat<float16_t, gl_ScopeWorkgroup, D, Bc, gl_MatrixUseB> K_T;
+
+ uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
+ coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, D), tensorViewTranspose DECODEFUNC);
+ S = coopMatMulAdd(Qf16, K_T, S);
+
+ if (p.logit_softcap != 0.0f) {
+ [[unroll]]
+ for (int k = 0; k < S.length(); ++k) {
+ S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]);
+ }
+ }
+
+ if (p.mask != 0) {
+ tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
+ tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
+
+ coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
+
+ coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
+
+ S += slope*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
+ }
+
+ // Clear padding elements to -inf, so they don't contribute to rowmax
+ if (Clamp != 0 &&
+ ((j + 1) * Bc > KV ||
+ (i + 1) * Br > N)) {
+
+ uint R = ((i + 1) * Br > N) ? (N % Br) : Br;
+ uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc;
+
+ coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(-1.0/0.0), R, C);
+ }
+
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> rowmax, P, rowsum, eM;
+
+ coopMatReduceNV(rowmax, S, gl_CooperativeMatrixReduceRowNV, maxReduce);
+
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> Mold = M;
+
+ // M = max(rowmax, Mold)
+ // P = e^(S - M)
+ // eM = e^(Mold - M)
+ coopMatPerElementNV(M, rowmax, Max, Mold);
+ coopMatPerElementNV(P, S - M, Exp);
+ coopMatPerElementNV(eM, Mold - M, Exp);
+
+ // Clear padding elements to 0, so they don't contribute to rowsum
+ if (Clamp != 0 &&
+ ((j + 1) * Bc > KV ||
+ (i + 1) * Br > N)) {
+
+ uint R = ((i + 1) * Br > N) ? (N % Br) : Br;
+ uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc;
+
+ coopMatPerElementNV(P, P, replacePadding, ACC_TYPE(0.0), R, C);
+ }
+
+ coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA> P_A = coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA>(P);
+
+ // compute rowsum by multiplying by matrix of all ones.
+ coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB> One = coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB>(1.0);
+
+ rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0);
+ rowsum = coopMatMulAdd(P_A, One, rowsum);
+
+ coopmat<float16_t, gl_ScopeWorkgroup, Bc, D, gl_MatrixUseB> V;
+ uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
+ coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, D) DECODEFUNC);
+
+ L = eM*L + rowsum;
+
+ // This is the "diagonal" matrix in the paper, but since we do componentwise
+ // multiply rather than matrix multiply it has the diagonal element smeared
+ // across the row
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> eMdiag;
+
+ // resize eM by using smear/reduce
+ coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
+
+ O = eMdiag * O;
+
+ O = coopMatMulAdd(P_A, V, O);
+ }
+
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Ldiag;
+
+ // resize L by using smear/reduce
+ coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);
+
+ [[unroll]]
+ for (int k = 0; k < Ldiag.length(); ++k) {
+ Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k];
+ }
+
+ O = Ldiag*O;
+
+ tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV);
+ tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D);
+
+ // permute dimensions
+ tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
+ uint32_t o_offset = iq3*p.ne2*p.ne1;
+
+ coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
+ coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, 1, 0, D), tensorViewPermute);
+}
--- /dev/null
+#version 450
+
+#extension GL_EXT_control_flow_attributes : enable
+#extension GL_EXT_shader_16bit_storage : require
+
+#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
+#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
+#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
+#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
+
+#extension GL_KHR_memory_scope_semantics : enable
+#extension GL_KHR_cooperative_matrix : enable
+#extension GL_NV_cooperative_matrix2 : enable
+#extension GL_EXT_buffer_reference : enable
+#extension GL_KHR_shader_subgroup_ballot : enable
+#extension GL_KHR_shader_subgroup_vote : enable
+
+#include "types.comp"
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+layout (constant_id = 1) const uint BM = 64;
+layout (constant_id = 2) const uint BN = 64;
+layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
+
+layout (push_constant) uniform parameter
+{
+ uint M;
+ uint N;
+ uint K;
+ uint stride_a;
+ uint stride_b;
+ uint stride_d;
+
+ uint batch_stride_a;
+ uint batch_stride_b;
+ uint batch_stride_d;
+
+#ifdef MUL_MAT_ID
+ uint nei0;
+ uint nei1;
+ uint nbi1;
+ uint ne11;
+#else
+ uint k_split;
+ uint ne02;
+ uint ne12;
+ uint broadcast2;
+ uint broadcast3;
+#endif
+} p;
+
+
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
+layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
+
+#if QUANT_K > 1
+#define DECODEFUNCA , dequantFuncA
+#define MAT_A_TYPE float16_t
+
+#include "dequant_funcs_cm2.comp"
+
+#else
+#define DECODEFUNCA
+#define MAT_A_TYPE A_TYPE
+#endif
+
+#define MAT_B_TYPE B_TYPE
+
+#ifdef MUL_MAT_ID
+layout (binding = 3) readonly buffer IDS {int data_ids[];};
+
+shared u16vec4 row_ids[3072];
+
+layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB {
+ B_TYPE b[];
+};
+
+uint _ne1;
+shared uint _ne1_sh;
+
+B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2])
+{
+ const uint row_i = blockCoords[0];
+
+ if (row_i >= _ne1) {
+ return B_TYPE(0.0);
+ }
+
+ const u16vec4 row_idx = row_ids[row_i];
+ B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]];
+
+ return ret;
+}
+
+D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t ir, const in uint32_t ic)
+{
+ uint dr = ir * BM + r;
+ uint dc = ic * BN + c;
+
+ if (dr < p.M && dc < _ne1) {
+ uint row_i = dc;
+ const u16vec4 row_idx = row_ids[row_i];
+ data_d[row_idx.y * p.batch_stride_d + row_idx.z * p.stride_d + dr] = elem;
+ }
+ return elem;
+}
+
+#endif
+
+void main() {
+#if defined(DATA_A_IQ4_NL)
+ init_iq4nl_shmem();
+#endif
+
+#ifdef MUL_MAT_ID
+ const uint expert_idx = gl_GlobalInvocationID.z;
+#else
+ const uint batch_idx = gl_GlobalInvocationID.z;
+
+ const uint i13 = batch_idx / p.ne12;
+ const uint i12 = batch_idx % p.ne12;
+
+ const uint i03 = i13 / p.broadcast3;
+ const uint i02 = i12 / p.broadcast2;
+
+ const uint batch_idx_a = i03 * p.ne02 + i02;
+#endif
+
+ const uint blocks_m = (p.M + BM - 1) / BM;
+ const uint ir = gl_WorkGroupID.x % blocks_m;
+ const uint ik = gl_WorkGroupID.x / blocks_m;
+ const uint ic = gl_WorkGroupID.y;
+
+#ifdef MUL_MAT_ID
+ // Spread the search across all elements in the first subgroup
+ if (gl_SubgroupID == 0) {
+ _ne1 = 0;
+ uint num_elements = p.nei1 * p.nei0;
+
+ for (uint i = gl_SubgroupInvocationID; subgroupAny(i < num_elements); i += gl_SubgroupSize) {
+ bool in_range = i < num_elements;
+ uint ii0 = i % p.nei0;
+ uint ii1 = i / p.nei0;
+ uint id = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
+ uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
+ uint idx = subgroupBallotExclusiveBitCount(ballot);
+ if (in_range && id == expert_idx) {
+ row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0);
+ }
+ _ne1 += subgroupBallotBitCount(ballot);
+ }
+ _ne1_sh = _ne1;
+ }
+
+ barrier();
+
+ _ne1 = _ne1_sh;
+
+ // Workgroup has no work
+ if (ic * BN >= _ne1) return;
+#endif
+
+#ifdef MUL_MAT_ID
+ uint start_k = 0;
+ const uint end_k = p.K;
+#else
+ uint start_k = ik * p.k_split;
+ const uint end_k = min(p.K, (ik + 1) * p.k_split);
+#endif
+
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
+ sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
+
+#ifdef MUL_MAT_ID
+ uint pos_a = (expert_idx * p.batch_stride_a) / QUANT_K;
+ uint pos_b = 0;
+#else
+ uint pos_a = (batch_idx_a * p.batch_stride_a) / QUANT_K;
+ uint pos_b = batch_idx * p.batch_stride_b;
+#endif
+
+ uint stride_a = p.stride_a / QUANT_K;
+ uint stride_b = p.stride_b;
+
+ // Hint to the compiler that values are aligned (want 16B alignment).
+ // Quants are always block-aligned, no alignment needed.
+#if ALIGNED
+#if QUANT_K == 1
+ stride_a &= ~7;
+#endif
+ stride_b &= ~7;
+#endif
+
+ // Create layouts for both clamped and unclamped accesses
+ tensorLayoutNV<2> tensorLayoutA = createTensorLayoutNV(2);
+ tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutAClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
+ tensorLayoutNV<2> tensorLayoutB = createTensorLayoutNV(2);
+ tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutBClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
+ tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
+
+#if QUANT_K > 1
+ tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K);
+ tensorLayoutAClamp = setTensorLayoutBlockSizeNV(tensorLayoutAClamp, 1, QUANT_K);
+#endif
+
+ // Use end_k rather than p.K as the dimension because that's what
+ // we need to bound check against when using split_k
+ tensorLayoutA = setTensorLayoutDimensionNV(tensorLayoutA, p.M, end_k);
+ tensorLayoutB = setTensorLayoutDimensionNV(tensorLayoutB, p.N, end_k);
+ tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.N, p.M);
+ tensorLayoutAClamp = setTensorLayoutDimensionNV(tensorLayoutAClamp, p.M, end_k);
+ tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.N, end_k);
+
+ tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
+
+#if !defined(MUL_MAT_ID)
+ // Detect a fast path where all loads are entirely in bounds and no clamping is required
+ if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.N && (start_k % BK) == 0 && (end_k % BK) == 0 &&
+#if QUANT_K == 1
+ (stride_a % 8) == 0 &&
+#endif
+ (stride_b % 8) == 0 && (start_k % 8) == 0) {
+ // Hint to the compiler that values are aligned (want 16B alignment)
+ start_k &= ~7;
+ stride_b &= ~7;
+#if QUANT_K == 1
+ stride_a &= ~7;
+#endif
+
+ tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1);
+ tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1);
+
+ uint k_iters = (end_k - start_k + BK - 1) / BK;
+
+ for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
+
+ coopmat<MAT_A_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
+ coopmat<MAT_B_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
+
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
+ coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
+
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
+ coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
+
+ sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
+ }
+ } else
+#endif // !defined(MUL_MAT_ID)
+ {
+ tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1);
+
+ tensorLayoutAClamp = setTensorLayoutStrideNV(tensorLayoutAClamp, stride_a, 1);
+
+ tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1);
+
+ tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1);
+
+ [[dont_unroll]]
+ for (uint block_k = start_k; block_k < end_k; block_k += BK) {
+
+ coopmat<MAT_A_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
+ coopmat<MAT_B_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
+ coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a_ft;
+ coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b_ft;
+
+ // Clamping is expensive, so detect different code paths for each combination
+ // of A and B needing clamping.
+ bool unclampedA = (ir + 1) * BM <= p.M && block_k + BK <= end_k && (block_k % 8) == 0;
+#ifdef MUL_MAT_ID
+ bool unclampedB = true;
+#else
+ bool unclampedB = (ic + 1) * BN <= p.N && block_k + BK <= end_k && (block_k % 8) == 0;
+#endif
+ if (unclampedA && unclampedB) {
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA);
+#ifdef MUL_MAT_ID
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
+#else
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose);
+#endif
+ mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
+ mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
+ sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
+ } else if (unclampedA && !unclampedB) {
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA);
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
+
+ mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
+ mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
+ sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
+ } else if (!unclampedA && unclampedB) {
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
+#ifdef MUL_MAT_ID
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
+#else
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose);
+#endif
+ mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
+ mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
+ sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
+ } else if (!unclampedA && !unclampedB) {
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
+
+ mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
+ mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
+ sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
+ }
+ }
+ }
+
+ // Convert from ACC_TYPE to D_TYPE
+ coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d;
+ mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
+
+#ifdef MUL_MAT_ID
+ // Call callback to store each element, remapping row through shared memory
+ coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
+#else
+ tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1);
+
+ uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
+ coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
+#endif
+}
#include <fcntl.h>
#endif
+#include <vulkan/vulkan_core.h>
+
#define ASYNCIO_CONCURRENCY 64
std::mutex lock;
static std::mutex compile_count_mutex;
static std::condition_variable compile_count_cond;
-void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true) {
- std::string name = _name + (fp16 ? "" : "_fp32");
+void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat2 = false, bool f16acc = false) {
+ std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
std::string out_fname = join_paths(output_dir, name + ".spv");
std::string in_path = join_paths(input_dir, in_fname);
+ std::string target_env = (name.find("_cm2") != std::string::npos) ? "--target-env=vulkan1.3" : "--target-env=vulkan1.2";
+
#ifdef _WIN32
- std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", "--target-env=vulkan1.2", "-O", "\"" + in_path + "\"", "-o", "\"" + out_fname + "\""};
+ std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, "-O", "\"" + in_path + "\"", "-o", "\"" + out_fname + "\""};
#else
- std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", "--target-env=vulkan1.2", "-O", in_path, "-o", out_fname};
+ std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, "-O", in_path, "-o", out_fname};
#endif
#ifdef GGML_VULKAN_SHADER_DEBUG_INFO
}
static std::vector<std::future<void>> compiles;
-void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true) {
+void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat2 = false, bool f16acc = false) {
{
// wait until fewer than N compiles are in progress.
// 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors.
}
compile_count++;
}
- compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16));
+ compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat2, f16acc));
}
-void matmul_shaders(bool fp16, bool matmul_id) {
- std::string load_vec = fp16 ? "8" : "4";
- std::string aligned_b_type_f32 = fp16 ? "mat2x4" : "vec4";
- std::string aligned_b_type_f16 = fp16 ? "f16mat2x4" : "f16vec4";
+void matmul_shaders(bool fp16, bool matmul_id, bool coopmat2, bool f16acc) {
+ std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4";
+ std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4";
+ std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
- std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", fp16 ? "float16_t" : "float"}};
+ std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", (coopmat2 || fp16) ? "float16_t" : "float"}};
std::string shader_name = "matmul";
if (matmul_id) {
base_dict["FLOAT16"] = "1";
}
+ base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
+
+ std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
+
// Shaders with f16 B_TYPE
- string_to_spv(shader_name + "_f32_f16", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16);
- string_to_spv(shader_name + "_f32_f16_aligned", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}}), fp16);
+ string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat2, f16acc);
+ string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc);
- string_to_spv(shader_name + "_f16", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16);
- string_to_spv(shader_name + "_f16_aligned", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}}), fp16);
+ string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc);
+ string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat2, f16acc);
for (const auto& tname : type_names) {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
// For unaligned, load one at a time for f32/f16, or two at a time for quants
- std::string load_vec_a_unaligned = (tname == "f32" || tname == "f16") ? "1" : "2";
+ std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : "2";
// For aligned matmul loads
- std::string load_vec_a = (tname == "f32" || tname == "f16") ? load_vec : "2";
- string_to_spv(shader_name + "_" + tname + "_f32", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16);
- string_to_spv(shader_name + "_" + tname + "_f32_aligned", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}}), fp16);
+ std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2";
+
+ string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat2, f16acc);
+ string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc);
+
+ if (tname != "f16" && tname != "f32") {
+ string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat2, f16acc);
+ string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc);
+ }
}
}
std::cout << "ggml_vulkan: Generating and compiling shaders to SPIR-V" << std::endl;
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}};
+ // matmul
for (const auto& fp16 : {false, true}) {
- matmul_shaders(fp16, false);
- matmul_shaders(fp16, true);
+ for (const auto& matmul_id : {false, true}) {
+ for (const auto& coopmat2 : {false, true}) {
+ for (const auto& f16acc : {false, true}) {
+#if !defined(VK_NV_cooperative_matrix2)
+ if (coopmat2) {
+ continue;
+ }
+#endif
+ if (coopmat2 && !fp16) {
+ continue;
+ }
+ if (!coopmat2 && f16acc) {
+ continue;
+ }
+ matmul_shaders(fp16, matmul_id, coopmat2, f16acc);
+ }
+ }
+ }
}
+#if defined(VK_NV_cooperative_matrix2)
+ // flash attention
+ for (const auto& f16acc : {false, true}) {
+ std::string acctype = f16acc ? "float16_t" : "float";
+
+ for (const auto& tname : type_names) {
+ if (tname == "f32") {
+ continue;
+ }
+
+ if (tname == "f16") {
+ string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
+ merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, true, f16acc);
+ } else {
+ std::string data_a_key = "DATA_A_" + to_uppercase(tname);
+ string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
+ merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, true, f16acc);
+ }
+ }
+ }
+#endif
+
for (const auto& tname : type_names) {
// mul mat vec
std::string data_a_key = "DATA_A_" + to_uppercase(tname);