GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
- //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
- //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, support_simdgroup_mm);
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, flash_attn_ext_q4_0_h112, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, flash_attn_ext_q4_0_h128, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, flash_attn_ext_q4_1_h112, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, flash_attn_ext_q4_1_h128, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, flash_attn_ext_q5_0_h112, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, flash_attn_ext_q5_0_h128, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, flash_attn_ext_q5_1_h112, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, flash_attn_ext_q5_1_h128, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, flash_attn_ext_q8_0_h112, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, support_simdgroup_reduction);
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
case GGML_OP_LEAKY_RELU:
return true;
case GGML_OP_FLASH_ATTN_EXT:
- if (op->src[1]->type != GGML_TYPE_F16) {
- return false;
- }
- if (op->src[2]->type != GGML_TYPE_F16) {
- return false;
- }
- if (op->src[0]->ne[0] == 256) {
+ if (op->src[1]->type != op->src[2]->type) {
return false;
}
return support_simdgroup_mm; // TODO: over-restricted for vec-kernels
GGML_ASSERT(ne11 % 32 == 0);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src1->type == src2->type);
GGML_ASSERT(ggml_are_same_shape (src1, src2));
bool use_vec_kernel = false;
if (ne01 >= 4 || (ne00%128 != 0)) {
- switch (ne00) {
- case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
- case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
- case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
- case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
- case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
- //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
+ switch (src1->type) {
+ case GGML_TYPE_F16:
+ {
+ switch (ne00) {
+ case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
+ case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
+ case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
+ case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
+ case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
+ case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
+ default:
+ {
+ GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
+ GGML_LOG_ERROR("add template specialization for this size\n");
+ GGML_ABORT("add template specialization for this size");
+ }
+ }
+ } break;
+ case GGML_TYPE_Q4_0:
+ {
+ switch (ne00) {
+ case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
+ case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break;
+ case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break;
+ case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112].pipeline; break;
+ case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128].pipeline; break;
+ case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256].pipeline; break;
+ default:
+ {
+ GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
+ GGML_LOG_ERROR("add template specialization for this size\n");
+ GGML_ABORT("add template specialization for this size");
+ }
+ }
+ } break;
+ case GGML_TYPE_Q4_1:
+ {
+ switch (ne00) {
+ case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
+ case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break;
+ case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break;
+ case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112].pipeline; break;
+ case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128].pipeline; break;
+ case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256].pipeline; break;
+ default:
+ {
+ GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
+ GGML_LOG_ERROR("add template specialization for this size\n");
+ GGML_ABORT("add template specialization for this size");
+ }
+ }
+ } break;
+ case GGML_TYPE_Q5_0:
+ {
+ switch (ne00) {
+ case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
+ case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break;
+ case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break;
+ case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112].pipeline; break;
+ case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128].pipeline; break;
+ case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256].pipeline; break;
+ default:
+ {
+ GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
+ GGML_LOG_ERROR("add template specialization for this size\n");
+ GGML_ABORT("add template specialization for this size");
+ }
+ }
+ } break;
+ case GGML_TYPE_Q5_1:
+ {
+ switch (ne00) {
+ case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
+ case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break;
+ case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break;
+ case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112].pipeline; break;
+ case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128].pipeline; break;
+ case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256].pipeline; break;
+ default:
+ {
+ GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
+ GGML_LOG_ERROR("add template specialization for this size\n");
+ GGML_ABORT("add template specialization for this size");
+ }
+ }
+ } break;
+ case GGML_TYPE_Q8_0:
+ {
+ switch (ne00) {
+ case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
+ case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break;
+ case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break;
+ case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112].pipeline; break;
+ case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline; break;
+ case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256].pipeline; break;
+ default:
+ {
+ GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
+ GGML_LOG_ERROR("add template specialization for this size\n");
+ GGML_ABORT("add template specialization for this size");
+ }
+ }
+ } break;
default:
- {
- GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
- GGML_LOG_ERROR("add template specialization for this size\n");
- GGML_ABORT("add template specialization for this size");
- }
+ {
+ GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
+ GGML_LOG_ERROR("add template specialization for this type\n");
+ GGML_ABORT("add template specialization for this type");
+ }
}
} else {
use_vec_kernel = true;
switch (ne00) {
- case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
- //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
+ case 128:
+ {
+ switch (src1->type) {
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break;
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break;
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break;
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128].pipeline; break;
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128].pipeline; break;
+ default:
+ {
+ GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
+ GGML_LOG_ERROR("add template specialization for this type\n");
+ GGML_ABORT("add template specialization for this type");
+ }
+ }
+ } break;
+ case 256:
+ {
+ switch (src1->type) {
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break;
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break;
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break;
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256].pipeline; break;
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256].pipeline; break;
+ default:
+ {
+ GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
+ GGML_LOG_ERROR("add template specialization for this type\n");
+ GGML_ABORT("add template specialization for this type");
+ }
+ }
+ } break;
default:
{
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
GGML_ASSERT(nqptg % 8 == 0);
GGML_ASSERT(ncpsg % 32 == 0);
+ // 16*32*(nsg)
+ // the shared memory needed for the simdgroups to load the KV cache
+ // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
+ //
+#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
+
int64_t nsgmax = 2;
while (true) {
- const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2);
+ const size_t smem = FATTN_SMEM(nsgmax);
if (smem > device.maxThreadgroupMemoryLength) {
break;
}
// simdgroups per threadgroup (a.k.a. warps)
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
- const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2);
+ const size_t smem = FATTN_SMEM(nsg);
- //printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength);
+ //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
-
- [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
-
+ [encoder setThreadgroupMemoryLength:smem atIndex:0];
+#undef FATTN_SMEM
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
} else {
- // half1x4 kernel
+ // half4x4 kernel
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
GGML_ASSERT(nqptg % 1 == 0);
GGML_ASSERT(ncpsg % 32 == 0);
+ // ne00 + 2*ncpsg*(nsg)
+ // for each query, we load it as f16 in shared memory (ne00)
+ // and store the attention scores (nqptg x ncpsg) as f32
+ //
+ // 2*ne00*(nsg)
+ // each simdgroup has a full f32 head vector in shared mem to accumulate results
+ //
+#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + 2*ne00*(nsg))*(sizeof(float)/2), 16))
+
+ int64_t nsgmax = 2;
+
+ while (true) {
+ const size_t smem = FATTN_SMEM(nsgmax);
+ if (smem > device.maxThreadgroupMemoryLength) {
+ break;
+ }
+ nsgmax *= 2;
+ }
+ nsgmax /= 2;
+
// simdgroups per threadgroup (a.k.a. warps)
- const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
+ const int64_t nsgt = MAX(2, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
int64_t nsg = 1;
while (nsg <= nsgt) {
}
nsg /= 2;
- const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
+ const size_t smem = FATTN_SMEM(nsg);
- //printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength);
+ //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
- [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
-
+ [encoder setThreadgroupMemoryLength:smem atIndex:0];
+#undef FATTN_SMEM
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
}
} break;
dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
}
-typedef void (flash_attn_ext_f16_t)(
- device const char * q,
- device const char * k,
- device const char * v,
- device const char * mask,
- device float * dst,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant uint64_t & nb21,
- constant uint64_t & nb22,
- constant uint64_t & nb23,
- constant uint64_t & nb31,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant float & scale,
- constant float & max_bias,
- constant float & m0,
- constant float & m1,
- constant uint32_t & n_head_log2,
- constant float & logit_softcap,
- threadgroup half * shared,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]);
-
// ref: https://arxiv.org/pdf/2307.08691.pdf
-template<int64_t D, int64_t Q = 8, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
-kernel void kernel_flash_attn_ext_f16(
+// D - head size, Q - queries per threadgroup, KV - key/value processed per each simdgroup, C - cache items per threadgroup
+template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &), short D, short Q = 8, short KV = 8, short C = 32>
+kernel void kernel_flash_attn_ext(
device const char * q,
device const char * k,
device const char * v,
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
const short nsg = ntg.y; // number of simdgroups
- const short iq3 = tgpig[2];
- const short iq2 = tgpig[1];
- const short iq1 = tgpig[0]*Q;
+ const int iq3 = tgpig[2];
+ const int iq2 = tgpig[1];
+ const int iq1 = tgpig[0]*Q;
- const short D4 = D/4;
- const short D8 = D/8;
- //const short Q8 = Q/8;
- const short NW = N_SIMDWIDTH;
- const short SH = (C + Q); // shared memory per simdgroup in (half)
+ const short D4 = D/4;
+ const short D8 = D/8;
+ const short D16 = D/16;
+ const short NW = N_SIMDWIDTH;
+ const short SH = (C + Q); // shared memory per simdgroup in (half)
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
const short TF = T/2; // shared memory size per query in (float)
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
+ threadgroup half * skv = (threadgroup half *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K and V in shared memory
+ threadgroup half4x4 * skv4 = (threadgroup half4x4 *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in half4x4
+
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
simdgroup_half8x8 lo[D8];
threadgroup_barrier(mem_flags::mem_threadgroup);
{
- float S[Q] = { [0 ... Q-1] = 0.0h };
+ float S[Q] = { [0 ... Q-1] = 0.0f };
float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 };
+ // thread indices inside the simdgroup
+ const short tx = tiisg%4;
+ const short ty = tiisg/4;
+
// assume K and V are same shape
const short ne22 = ne12;
const short ne23 = ne13;
- // broadcast
+ // broadcast k
const short rk2 = ne02/ne12;
const short rk3 = ne03/ne13;
- const short rv2 = ne02/ne22;
- const short rv3 = ne03/ne23;
-
- // k indices
const short ik2 = iq2/rk2;
const short ik3 = iq3/rk3;
- // v indices
+ // broadcast v
+ const short rv2 = ne02/ne22;
+ const short rv3 = ne03/ne23;
+
const short iv2 = iq2/rv2;
const short iv3 = iq3/rv3;
for (short cc = 0; cc < C/8; ++cc) {
simdgroup_float8x8 mqk = make_filled_simdgroup_matrix<float, 8>(0.h);
- device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
+ // this is compile-time check, so it does not have runtime overhead
+ if (is_same<block_q, half4x4>::value) {
+ // we can read directly from global memory
+ device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
+
+ for (short i = 0; i < D8; ++i) {
+ simdgroup_half8x8 mk;
+ simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
+
+ simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
+ }
+ } else {
+ for (short ii = 0; ii < D16; ii += 4) {
+ device const block_q * pk4 = (device const block_q *) ((device const char *) k + ((ic + 8*cc + ty)*nb11 + ik2*nb12 + ik3*nb13));
+
+ if (D16%4 == 0) {
+ // the head is evenly divisible by 4*16 = 64, so no need for bound checks
+ half4x4 tmp;
+ dequantize_func(pk4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
+ skv4[4*ty + tx] = tmp;
- for (short i = 0; i < D8; ++i) {
- simdgroup_half8x8 mk;
- simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
+ simdgroup_barrier(mem_flags::mem_threadgroup);
- simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
+#pragma unroll
+ for (short k = 0; k < 4; ++k) {
+ simdgroup_half8x8 mk;
+
+ simdgroup_load(mk, skv + 16*k + 0*8, 4*16, 0, true); // transpose
+ simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
+
+ simdgroup_load(mk, skv + 16*k + 1*8, 4*16, 0, true); // transpose
+ simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
+ }
+ } else {
+ if (ii + tx < D16) {
+ half4x4 tmp;
+ dequantize_func(pk4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
+ skv4[4*ty + tx] = tmp;
+ }
+
+ simdgroup_barrier(mem_flags::mem_threadgroup);
+
+ for (short k = 0; k < 4 && ii + k < D16; ++k) {
+ simdgroup_half8x8 mk;
+
+ simdgroup_load(mk, skv + 16*k + 0*8, 4*16, 0, true); // transpose
+ simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
+
+ simdgroup_load(mk, skv + 16*k + 1*8, 4*16, 0, true); // transpose
+ simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
+ }
+ }
+ }
}
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
// O = O + (Q*K^T)*V
{
for (short cc = 0; cc < C/8; ++cc) {
- device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
+ simdgroup_float8x8 ms;
+ simdgroup_load(ms, ss + 8*cc, TF, 0, false);
+
+ if (is_same<block_q, half4x4>::value) {
+ // we can read directly from global memory
+ device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
+#pragma unroll
+ for (short i = 0; i < D8; ++i) {
+ simdgroup_half8x8 mv;
+ simdgroup_load(mv, pv + i*8, nb21/sizeof(half), 0, false);
- for (short i = 0; i < D8; ++i) {
- simdgroup_half8x8 mk;
- simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false);
+ simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
+ }
+ } else {
+ for (short ii = 0; ii < D16; ii += 4) {
+ device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 8*cc + ty)*nb21 + iv2*nb22 + iv3*nb23));
+
+ if (D16%4 == 0) {
+ // no need for bound checks
+ half4x4 tmp;
+ dequantize_func(pv4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
+ skv4[4*ty + tx] = tmp;
- simdgroup_float8x8 mv;
- simdgroup_load(mv, ss + 8*cc, TF, 0, false);
+ simdgroup_barrier(mem_flags::mem_threadgroup);
- simdgroup_multiply_accumulate(lo[i], mv, mk, lo[i]);
+#pragma unroll
+ for (short k = 0; k < 4; ++k) {
+ simdgroup_half8x8 mv;
+
+ simdgroup_load(mv, skv + 16*k + 0*8, 4*16, 0, false);
+ simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
+
+ simdgroup_load(mv, skv + 16*k + 1*8, 4*16, 0, false);
+ simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
+ }
+ } else {
+ if (ii + tx < D16) {
+ half4x4 tmp;
+ dequantize_func(pv4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
+ skv4[4*ty + tx] = tmp;
+ }
+
+ simdgroup_barrier(mem_flags::mem_threadgroup);
+
+ for (short k = 0; k < 4 && ii + k < D16; ++k) {
+ simdgroup_half8x8 mv;
+
+ simdgroup_load(mv, skv + 16*k + 0*8, 4*16, 0, false);
+ simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
+
+ simdgroup_load(mv, skv + 16*k + 1*8, 4*16, 0, false);
+ simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
+ }
+ }
+ }
}
}
}
// reduce the warps sequentially
for (short sg = 1; sg < nsg; ++sg) {
- float S = { 0.0h };
+ float S = { 0.0f };
float M = { -FLT_MAX/2 };
threadgroup_barrier(mem_flags::mem_threadgroup);
}
}
-template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>;
-template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>;
-template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>;
-template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>;
-template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>;
-//template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
-
-template<int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
-kernel void kernel_flash_attn_ext_vec_f16(
+typedef decltype(kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 64>) flash_attn_ext_t;
+
+template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 64>;
+template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 80>;
+template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 96>;
+template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 112>;
+template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 128>;
+template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 256>;
+
+template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 64>;
+template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 80>;
+template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 96>;
+template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 112>;
+template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 128>;
+template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 256>;
+
+template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 64>;
+template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 80>;
+template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 96>;
+template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 112>;
+template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 128>;
+template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 256>;
+
+template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 64>;
+template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 80>;
+template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 96>;
+template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 112>;
+template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 128>;
+template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 256>;
+
+template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 64>;
+template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 80>;
+template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 96>;
+template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 112>;
+template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 128>;
+template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 256>;
+
+template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 64>;
+template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 80>;
+template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 96>;
+template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 112>;
+template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 128>;
+template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 256>;
+
+// NOTE: can use half instead of float precision for some extra perf
+// D - head size, Q - queries per threadgroup, C - cache items per threadgroup
+template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &), short D, short Q = 1, short C = 32>
+kernel void kernel_flash_attn_ext_vec(
device const char * q,
device const char * k,
device const char * v,
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
const short nsg = ntg.y; // number of simdgroups
- const short iq3 = tgpig[2];
- const short iq2 = tgpig[1];
- const short iq1 = tgpig[0];
+ const int iq3 = tgpig[2];
+ const int iq2 = tgpig[1];
+ const int iq1 = tgpig[0];
- const short D4 = D/4;
- const short NW = N_SIMDWIDTH;
- const short SH = (C + Q); // shared memory per simdgroup in (half)
+ const short D4 = D/4;
+ const short D16 = D/16;
+ const short NW = N_SIMDWIDTH;
+ const short NW4 = NW/4;
+ const short SH = C; // shared memory per simdgroup in (half)
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
- float slope = 1.0f;
-
- // ALiBi
- if (max_bias > 0.0f) {
- const uint32_t h = iq2;
-
- const float base = h < n_head_log2 ? m0 : m1;
- const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
-
- slope = pow(base, exp);
- }
-
- //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
- threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
- threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
- threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
- threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results
+ //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
+ threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
+ threadgroup half4x4 * sq44 = (threadgroup half4x4 *) (shared + 0*D); // same as above but in half4x4
+ threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention
+ threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
+ threadgroup float4x4 * sr44 = (threadgroup float4x4 *) (shared + 2*sgitg*D + Q*T); // scratch buffer for the results
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
- half4 lo[D4/NW];
+ float4x4 lo[D16/NW4];
// load heads from Q to shared memory
device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
}
// zero out lo
- for (short i = tiisg; i < D4; i += NW) {
- lo[i/NW] = 0.0h;
+ for (short i = 0; i < D16/NW4; i += NW4) {
+ lo[i] = float4x4(0.0f);
}
// zero out shared memory SH
threadgroup_barrier(mem_flags::mem_threadgroup);
{
- float S = { 0.0h };
- float M = { -FLT_MAX/2 };
+ float S = 0.0f;
+ float M = -FLT_MAX/2;
+
+ // thread indices inside the simdgroup
+ const short tx = tiisg%8;
+ const short ty = tiisg/8;
// assume K and V are same shape
const short ne22 = ne12;
const short ne23 = ne13;
- // broadcast
+ // broadcast k
const short rk2 = ne02/ne12;
const short rk3 = ne03/ne13;
+ const short ik2 = iq2/rk2;
+ const short ik3 = iq3/rk3;
+
+ // broadcast v
const short rv2 = ne02/ne22;
const short rv3 = ne03/ne23;
- // k indices
- const short ik2 = iq2 / rk2;
- const short ik3 = iq3 / rk3;
-
- // v indices
- const short iv2 = iq2 / rv2;
- const short iv3 = iq3 / rv3;
+ const short iv2 = iq2/rv2;
+ const short iv3 = iq3/rv3;
// load the queries from shared memory into local memory
- float4 mq[D4/NW];
+ float4x4 mq[D16/NW4];
- for (short ii = 0; ii < D4; ii += NW) {
- short i = ii + tiisg;
- mq[ii/NW] = (float4) sq4[i];
+ for (short ii = 0; ii < D16; ii += NW4) {
+ mq[ii/NW4] = (float4x4) sq44[ii + tx];
}
// pointer to the mask
- device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31);
+ device const half * mp = (device const half *) (mask + iq1*nb31);
+
+ float slope = 1.0f;
+
+ // ALiBi
+ if (max_bias > 0.0f) {
+ const uint32_t h = iq2;
+
+ const float base = h < n_head_log2 ? m0 : m1;
+ const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+ slope = pow(base, exp);
+ }
// loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns
// Q*K^T
{
-#pragma unroll
+ // each simdgroup processes 1 query and 4 keys
for (short cc = 0; cc < C/4; ++cc) {
- float4 mqk = { 0.0h };
+ float mqk = 0.0;
- device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13));
+ device const block_q * pk = (device const block_q *) ((device const char *) k + ((ic + 4*cc + ty)*nb11 + ik2*nb12 + ik3*nb13));
#pragma unroll
- for (short ii = 0; ii < D4; ii += NW) {
- const short i = ii + tiisg;
+ for (short ii = 0; ii < D16; ii += NW4) {
+ const short i = ii + tx;
float4x4 mk;
- mk[0] = (float4) pk4[i + 0*(nb11/8)];
- mk[1] = (float4) pk4[i + 1*(nb11/8)];
- mk[2] = (float4) pk4[i + 2*(nb11/8)];
- mk[3] = (float4) pk4[i + 3*(nb11/8)];
+ dequantize_func(pk + i/nl, i%nl, mk);
- mqk += (float4) (mq[ii/NW] * mk);
+ mqk +=
+ dot(mq[ii/NW4][0], mk[0]) +
+ dot(mq[ii/NW4][1], mk[1]) +
+ dot(mq[ii/NW4][2], mk[2]) +
+ dot(mq[ii/NW4][3], mk[3]);
}
- // reduce the results from the threads in the simdgroup
- mqk += simd_shuffle_down(mqk, 16);
- mqk += simd_shuffle_down(mqk, 8);
+ // simdgroup reduce
+ // [ 0 .. 7] -> [ 0]
+ // [ 8 .. 15] -> [ 8]
+ // [16 .. 23] -> [16]
+ // [24 .. 31] -> [24]
+ //mqk += simd_shuffle_down(mqk, 16);
+ //mqk += simd_shuffle_down(mqk, 8);
mqk += simd_shuffle_down(mqk, 4);
mqk += simd_shuffle_down(mqk, 2);
mqk += simd_shuffle_down(mqk, 1);
// mqk = mqk*scale + mask*slope
- if (tiisg == 0) {
+ if (tx == 0) {
mqk *= scale;
if (logit_softcap != 0.0f) {
mqk = logit_softcap*precise::tanh(mqk);
}
- mqk += (mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f;
+ mqk += (mask != q) ? ((float) mp[ic + 4*cc + ty])*slope : (float) 0.0f;
- ss4[cc] = mqk;
+ ss[4*cc + ty] = mqk;
}
}
}
+ simdgroup_barrier(mem_flags::mem_threadgroup);
+
// online softmax
{
const short p = tiisg;
// O = diag(ms)*O
#pragma unroll
- for (short ii = 0; ii < D4; ii += NW) {
- lo[ii/NW] *= ms;
+ for (short ii = 0; ii < D16; ii += NW4) {
+ lo[ii/NW4] *= ms;
}
}
+ simdgroup_barrier(mem_flags::mem_threadgroup);
+
// O = O + (Q*K^T)*V
{
#pragma unroll
for (short cc = 0; cc < C/4; ++cc) {
- device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4*cc)*nb21 + iv2*nb22 + iv3*nb23));
+ device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 4*cc + ty)*nb21 + iv2*nb22 + iv3*nb23));
+
+ const float4x4 lss(ss[4*cc + ty]);
#pragma unroll
- for (short ii = 0; ii < D4; ii += NW) {
- const short i = ii + tiisg;
+ for (short ii = 0; ii < D16; ii += NW4) {
+ const short i = ii + tx;
+
+ float4x4 mv;
+ dequantize_func(pv4 + i/nl, i%nl, mv);
- lo[ii/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0];
- lo[ii/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1];
- lo[ii/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2];
- lo[ii/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3];
+ lo[ii/NW4] += mv*lss;
}
}
}
-
}
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
}
}
+ // simdgroup reduce
+ // [ 0, 8, 16, 24] -> [ 0]
+ // [ 1, 9, 17, 25] -> [ 1]
+ // [ 2, 10, 18, 26] -> [ 2]
+ // [ 3, 11, 19, 27] -> [ 3]
+ // [ 4, 12, 20, 28] -> [ 4]
+ // [ 5, 13, 21, 29] -> [ 5]
+ // [ 6, 14, 22, 30] -> [ 6]
+ // [ 7, 15, 23, 31] -> [ 7]
+ for (short ii = 0; ii < D16; ii += NW4) {
+ lo[ii/NW4][0] += simd_shuffle_down(lo[ii/NW4][0], 16);
+ lo[ii/NW4][0] += simd_shuffle_down(lo[ii/NW4][0], 8);
+
+ lo[ii/NW4][1] += simd_shuffle_down(lo[ii/NW4][1], 16);
+ lo[ii/NW4][1] += simd_shuffle_down(lo[ii/NW4][1], 8);
+
+ lo[ii/NW4][2] += simd_shuffle_down(lo[ii/NW4][2], 16);
+ lo[ii/NW4][2] += simd_shuffle_down(lo[ii/NW4][2], 8);
+
+ lo[ii/NW4][3] += simd_shuffle_down(lo[ii/NW4][3], 16);
+ lo[ii/NW4][3] += simd_shuffle_down(lo[ii/NW4][3], 8);
+ }
+
// store results to shared memory
- for (short ii = 0; ii < D4; ii += NW) {
- short i = ii + tiisg;
- sr4[i] = lo[ii/NW];
+ for (short i = tiisg; i < D16; i += NW4) {
+ sr44[i] = lo[i/NW4];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
- for (short ii = 0; ii < D4; ii += NW) {
- short i = ii + tiisg;
- sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1;
+ for (short i = tiisg; i < D16; i += NW) {
+ sr44[i] = sr44[i]*ms0 + sr44[i + r*D16]*ms1;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
- device float4 * dst4 = (device float4 *) dst;
+ device float4x4 * dst44 = (device float4x4 *) dst;
// final rescale with 1/S and store to global memory
if (sgitg == 0) {
const float S = ss[0];
- for (short ii = 0; ii < D4; ii += NW) {
- short i = ii + tiisg;
- dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S;
+ for (short i = tiisg; i < D16; i += NW) {
+ dst44[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D16 + i] = sr44[i]/S;
}
}
}
-template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
-//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
+typedef decltype(kernel_flash_attn_ext_vec<half4x4, 1, dequantize_f16, 64>) flash_attn_ext_vec_t;
+
+template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<half4x4, 1, dequantize_f16, 128>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q4_0, 2, dequantize_q4_0, 128>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q4_1, 2, dequantize_q4_1, 128>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q5_0, 2, dequantize_q5_0, 128>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q5_1, 2, dequantize_q5_1, 128>;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q8_0, 2, dequantize_q8_0, 128>;
+
+template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<half4x4, 1, dequantize_f16, 256>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q4_0, 2, dequantize_q4_0, 256>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q4_1, 2, dequantize_q4_1, 256>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q5_0, 2, dequantize_q5_0, 256>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q5_1, 2, dequantize_q5_1, 256>;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q8_0, 2, dequantize_q8_0, 256>;
template<typename T0, typename T1>
kernel void kernel_cpy(