GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5, mul_mv_ext_f16_f32_r1_5, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2, mul_mv_ext_q4_0_f32_r1_2, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3, mul_mv_ext_q4_0_f32_r1_3, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4, mul_mv_ext_q4_0_f32_r1_4, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5, mul_mv_ext_q4_0_f32_r1_5, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2, mul_mv_ext_q4_1_f32_r1_2, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3, mul_mv_ext_q4_1_f32_r1_3, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4, mul_mv_ext_q4_1_f32_r1_4, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5, mul_mv_ext_q4_1_f32_r1_5, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2, mul_mv_ext_q5_0_f32_r1_2, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3, mul_mv_ext_q5_0_f32_r1_3, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4, mul_mv_ext_q5_0_f32_r1_4, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5, mul_mv_ext_q5_0_f32_r1_5, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2, mul_mv_ext_q5_1_f32_r1_2, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3, mul_mv_ext_q5_1_f32_r1_3, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4, mul_mv_ext_q5_1_f32_r1_4, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5, mul_mv_ext_q5_1_f32_r1_5, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2, mul_mv_ext_q8_0_f32_r1_2, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, mul_mv_ext_q4_K_f32_r1_2, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, mul_mv_ext_q4_K_f32_r1_3, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, mul_mv_ext_q4_K_f32_r1_4, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5, mul_mv_ext_q4_K_f32_r1_5, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2, mul_mv_ext_q5_K_f32_r1_2, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3, mul_mv_ext_q5_K_f32_r1_3, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4, mul_mv_ext_q5_K_f32_r1_4, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5, mul_mv_ext_q5_K_f32_r1_5, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2, mul_mv_ext_q6_K_f32_r1_2, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3, mul_mv_ext_q6_K_f32_r1_3, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4, mul_mv_ext_q6_K_f32_r1_4, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5, mul_mv_ext_q6_K_f32_r1_5, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2, mul_mv_ext_iq4_nl_f32_r1_2, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3, mul_mv_ext_iq4_nl_f32_r1_3, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4, mul_mv_ext_iq4_nl_f32_r1_4, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5, mul_mv_ext_iq4_nl_f32_r1_5, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction);
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
// to the matrix-vector kernel
- int ne11_mm_min = 4;
+ const int ne11_mm_min = 4;
+
+ // first try to use small-batch mat-mv kernels
+ // these should be efficient for BS [2, ~8]
+ if (src1t == GGML_TYPE_F32 && (ne00%256 == 0) &&
+ (
+ (
+ (
+ src0t == GGML_TYPE_F16 || // TODO: helper function
+ src0t == GGML_TYPE_Q4_0 ||
+ src0t == GGML_TYPE_Q4_1 ||
+ src0t == GGML_TYPE_Q5_0 ||
+ src0t == GGML_TYPE_Q5_1 ||
+ src0t == GGML_TYPE_Q8_0 ||
+ src0t == GGML_TYPE_IQ4_NL ||
+ false) && (ne11 >= 2 && ne11 <= 8)
+ ) ||
+ (
+ (
+ src0t == GGML_TYPE_Q4_K ||
+ src0t == GGML_TYPE_Q5_K ||
+ src0t == GGML_TYPE_Q6_K ||
+ false) && (ne11 >= 4 && ne11 <= 8)
+ )
+ )
+ ) {
+ // TODO: determine the optimal parameters based on grid utilization
+ // I still don't know why we should not always use the maximum available threads:
+ //
+ // nsg = pipeline.maxTotalThreadsPerThreadgroup / 32
+ //
+ // my current hypothesis is that the work grid is not evenly divisible for different nsg
+ // values and there can be some tail effects when nsg is high. need to confirm this
+ //
+ const int nsg = 2; // num simdgroups per threadgroup
+ const int nxpsg = ne11 < 3 ? 16 : 8; // num threads along row per simdgroup
+ const int nypsg = 32/nxpsg; // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time)
+ const int r0ptg = nypsg*nsg; // num src0 rows per threadgroup
+ int r1ptg = 4; // num src1 rows per threadgroup
+
+ // note: not sure how optimal are those across all different hardware. there might be someting cleverer
+ switch (ne11) {
+ case 2:
+ r1ptg = 2; break;
+ case 3:
+ case 6:
+ r1ptg = 3; break;
+ case 4:
+ case 7:
+ case 8:
+ r1ptg = 4; break;
+ case 5:
+ r1ptg = 5; break;
+ };
-#if 0
- // the numbers below are measured on M2 Ultra for 7B and 13B models
- // these numbers do not translate to other devices or model sizes
- // TODO: need to find a better approach
- if ([device.name isEqualToString:@"Apple M2 Ultra"]) {
- switch (src0t) {
- case GGML_TYPE_F16: ne11_mm_min = 2; break;
- case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
- case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
- case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
- case GGML_TYPE_Q4_0:
- case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
- case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
- case GGML_TYPE_Q5_0: // not tested yet
- case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
- case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
- case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
- default: ne11_mm_min = 1; break;
- }
- }
-#endif
+ id<MTLComputePipelineState> pipeline = nil;
+
+ switch (src0->type) {
+ case GGML_TYPE_F16:
+ switch (r1ptg) {
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2].pipeline; break;
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3].pipeline; break;
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4].pipeline; break;
+ case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5].pipeline; break;
+ default: GGML_ABORT("not implemented");
+ } break;
+ case GGML_TYPE_Q4_0:
+ switch (r1ptg) {
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2].pipeline; break;
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3].pipeline; break;
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4].pipeline; break;
+ case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5].pipeline; break;
+ default: GGML_ABORT("not implemented");
+ } break;
+ case GGML_TYPE_Q4_1:
+ switch (r1ptg) {
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2].pipeline; break;
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3].pipeline; break;
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4].pipeline; break;
+ case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5].pipeline; break;
+ default: GGML_ABORT("not implemented");
+ } break;
+ case GGML_TYPE_Q5_0:
+ switch (r1ptg) {
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2].pipeline; break;
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3].pipeline; break;
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4].pipeline; break;
+ case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5].pipeline; break;
+ default: GGML_ABORT("not implemented");
+ } break;
+ case GGML_TYPE_Q5_1:
+ switch (r1ptg) {
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2].pipeline; break;
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3].pipeline; break;
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4].pipeline; break;
+ case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5].pipeline; break;
+ default: GGML_ABORT("not implemented");
+ } break;
+ case GGML_TYPE_Q8_0:
+ switch (r1ptg) {
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2].pipeline; break;
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3].pipeline; break;
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4].pipeline; break;
+ case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5].pipeline; break;
+ default: GGML_ABORT("not implemented");
+ } break;
+ case GGML_TYPE_Q4_K:
+ switch (r1ptg) {
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2].pipeline; break;
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3].pipeline; break;
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4].pipeline; break;
+ case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5].pipeline; break;
+ default: GGML_ABORT("not implemented");
+ } break;
+ case GGML_TYPE_Q5_K:
+ switch (r1ptg) {
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2].pipeline; break;
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3].pipeline; break;
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4].pipeline; break;
+ case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5].pipeline; break;
+ default: GGML_ABORT("not implemented");
+ } break;
+ case GGML_TYPE_Q6_K:
+ switch (r1ptg) {
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2].pipeline; break;
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3].pipeline; break;
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4].pipeline; break;
+ case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5].pipeline; break;
+ default: GGML_ABORT("not implemented");
+ } break;
+ case GGML_TYPE_IQ4_NL:
+ switch (r1ptg) {
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2].pipeline; break;
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3].pipeline; break;
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4].pipeline; break;
+ case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5].pipeline; break;
+ default: GGML_ABORT("not implemented");
+ } break;
+ default: GGML_ABORT("not implemented");
+ }
+
+ ggml_metal_kargs_mul_mv_ext args = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne10 =*/ ne10,
+ /*.ne11 =*/ ne11,
+ /*.ne12 =*/ ne12,
+ /*.nb10 =*/ nb10,
+ /*.nb11 =*/ nb11,
+ /*.nb12 =*/ nb12,
+ /*.nb13 =*/ nb13,
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
+ /*.r2 =*/ r2,
+ /*.r3 =*/ r3,
+ /*.nsg =*/ nsg,
+ /*.nxpsg =*/ nxpsg,
+ /*.r1ptg =*/ r1ptg,
+ };
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
+ //printf("ne01 = %lld nr0ptg = %d\n", ne01, nr0ptg);
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + r0ptg - 1)/r0ptg, (ne11 + r1ptg - 1)/r1ptg, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+ } else
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
if ([device supportsFamily:MTLGPUFamilyApple7] &&
reg = (type4x4)(*src);
}
+template <typename type4>
+void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) {
+ reg = (type4)(*(src + il));
+}
+
#if defined(GGML_METAL_USE_BF16)
template <typename type4x4>
void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
#endif
template <typename type4x4>
-void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
+void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
const float d1 = il ? (xb->d / 16.h) : xb->d;
const float d2 = d1 / 256.f;
reg = (type4x4) reg_f;
}
+template <typename type4>
+void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & reg) {
+ device const uint16_t * qs = ((device const uint16_t *)xb + 1);
+ const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;
+ const float d2 = d1 / 256.f;
+ const float md = -8.h * xb->d;
+ const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;
+ const ushort mask1 = mask0 << 8;
+
+ for (int i = 0; i < 2; i++) {
+ reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + md;
+ reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + md;
+ }
+}
+
template <typename type4x4>
-void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
+void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
const float d1 = il ? (xb->d / 16.h) : xb->d;
const float d2 = d1 / 256.f;
reg = (type4x4) reg_f;
}
+template <typename type4>
+void dequantize_q4_1_t4(device const block_q4_1 * xb, short il, thread type4 & reg) {
+ device const uint16_t * qs = ((device const uint16_t *)xb + 2);
+ const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;
+ const float d2 = d1 / 256.f;
+ const float m = xb->m;
+ const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;
+ const ushort mask1 = mask0 << 8;
+
+ for (int i = 0; i < 2; i++) {
+ reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + m;
+ reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + m;
+ }
+}
+
template <typename type4x4>
-void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
+void dequantize_q5_0(device const block_q5_0 * xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 3);
const float d = xb->d;
const float md = -16.h * xb->d;
reg = (type4x4) reg_f;
}
+template <typename type4>
+void dequantize_q5_0_t4(device const block_q5_0 * xb, short il, thread type4 & reg) {
+ device const uint16_t * qs = ((device const uint16_t *)xb + 3);
+ const float d = xb->d;
+ const float md = -16.h * xb->d;
+ const ushort mask = (il/4) ? 0x00F0 : 0x000F;
+
+ const uint32_t qh = *((device const uint32_t *)xb->qh);
+
+ const int x_mv = (il/4) ? 4 : 0;
+
+ const int gh_mv = (il/4) ? 12 : 0;
+ const int gh_bk = (il/4) ? 0 : 4;
+
+ for (int ii = 0; ii < 2; ii++) {
+ int i = 2*(il%4) + ii;
+
+ // extract the 5-th bits for x0 and x1
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
+
+ // combine the 4-bits from qs with the 5th bit
+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
+ const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
+
+ reg[2*ii + 0] = d * x0 + md;
+ reg[2*ii + 1] = d * x1 + md;
+ }
+}
+
template <typename type4x4>
-void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
+void dequantize_q5_1(device const block_q5_1 * xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 4);
const float d = xb->d;
const float m = xb->m;
reg = (type4x4) reg_f;
}
+template <typename type4>
+void dequantize_q5_1_t4(device const block_q5_1 * xb, short il, thread type4 & reg) {
+ device const uint16_t * qs = ((device const uint16_t *)xb + 4);
+ const float d = xb->d;
+ const float m = xb->m;
+ const ushort mask = (il/4) ? 0x00F0 : 0x000F;
+
+ const uint32_t qh = *((device const uint32_t *)xb->qh);
+
+ const int x_mv = (il/4) ? 4 : 0;
+
+ const int gh_mv = (il/4) ? 12 : 0;
+ const int gh_bk = (il/4) ? 0 : 4;
+
+ for (int ii = 0; ii < 2; ii++) {
+ int i = 2*(il%4) + ii;
+
+ // extract the 5-th bits for x0 and x1
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
+
+ // combine the 4-bits from qs with the 5th bit
+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
+ const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
+
+ reg[2*ii + 0] = d * x0 + m;
+ reg[2*ii + 1] = d * x1 + m;
+ }
+}
+
template <typename type4x4>
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
device const int8_t * qs = ((device const int8_t *)xb->qs);
- const half d = xb->d;
+ const float d = xb->d;
float4x4 reg_f;
reg = (type4x4) reg_f;
}
+template <typename type4>
+void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & reg) {
+ device const int8_t * qs = ((device const int8_t *)xb->qs);
+ const float d = xb->d;
+
+ for (int i = 0; i < 4; i++) {
+ reg[i] = (qs[4*(il%4) + i + 16*(il/4)] * d);
+ }
+}
+
template <typename type4x4>
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
const float d = xb->d;
}
template <typename type4x4>
-void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
+void dequantize_q4_K(device const block_q4_K * xb, short il, thread type4x4 & reg) {
device const uchar * q = xb->qs;
short is = (il/4) * 2;
const float dl = d * sc[0];
const float ml = min * sc[1];
- const ushort mask = il<2 ? 0x0F : 0xF0;
+ const ushort mask = il < 2 ? 0x0F : 0xF0;
for (int i = 0; i < 16; ++i) {
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
}
}
}
+template <typename type4>
+void dequantize_iq4_nl_t4(device const block_iq4_nl * xb, short il, thread type4 & reg) {
+ device const uint16_t * q4 = (device const uint16_t *)xb->qs;
+ const float d = xb->d;
+ uint32_t aux32;
+ thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
+ aux32 = ((q4[2*(il%4)] | (q4[2*(il%4)+1] << 16)) >> 4*(il/4)) & 0x0f0f0f0f;
+ reg[0] = d * kvalues_iq4nl_f[q8[0]];
+ reg[1] = d * kvalues_iq4nl_f[q8[1]];
+ reg[2] = d * kvalues_iq4nl_f[q8[2]];
+ reg[3] = d * kvalues_iq4nl_f[q8[3]];
+}
+
template <typename type4x4>
void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
kernel_mul_mv_q8_0_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
+// mat-vec kernel processing in chunks of float4
+// chpb - chunks per quantization block
+template<short nxpsg, short r1ptg, typename q_t, short chpb, void (*deq_t4)(device const q_t *, short, thread float4 &) >
+void kernel_mul_mv_ext_q4_f32_impl(
+ constant ggml_metal_kargs_mul_mv_ext & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+ const short chpt = 4; // chunks per thread
+
+ //const short nxpsg = (32);
+ const short nypsg = (32/nxpsg);
+
+ const short tx = tiisg%nxpsg;
+ const short ty = tiisg/nxpsg;
+
+ const int i01 = tgpig.x*(nypsg*args.nsg) + nypsg*sgitg + ty;
+ const int i11 = tgpig.y*r1ptg;
+ const int i1m = tgpig.z;
+
+ const int i12 = i1m%args.ne12;
+ const int i13 = i1m/args.ne12;
+
+ const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
+
+ device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0;
+
+ device const float4 * y4[r1ptg];
+
+ for (int ir1 = 0; ir1 < r1ptg; ++ir1) {
+ y4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4 *) src1;
+ }
+
+ float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f };
+
+ short cch = tx%chpb; // current chunk index
+
+ for (int ich = tx; 4*ich < args.ne00; ich += chpt*nxpsg) {
+ float4 lx[chpt];
+
+#pragma unroll(chpt)
+ for (short ch = 0; ch < chpt; ++ch) {
+ deq_t4(xq, cch, lx[ch]);
+
+ cch += nxpsg;
+ if (cch >= chpb) {
+ xq += cch/chpb;
+ cch %= chpb;
+ }
+ }
+
+#pragma unroll(chpt)
+ for (short ch = 0; ch < chpt; ++ch) {
+#pragma unroll(r1ptg)
+ for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
+ sumf[ir1] += dot(lx[ch], y4[ir1][ch*nxpsg]);
+
+ }
+ }
+
+#pragma unroll(r1ptg)
+ for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
+ y4[ir1] += chpt*nxpsg;
+ }
+ }
+
+ // reduce only the threads in each row
+ for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
+ if (nxpsg >= 32) {
+ sumf[ir1] += simd_shuffle_down(sumf[ir1], 16);
+ }
+ if (nxpsg >= 16) {
+ sumf[ir1] += simd_shuffle_down(sumf[ir1], 8);
+ }
+ if (nxpsg >= 8) {
+ sumf[ir1] += simd_shuffle_down(sumf[ir1], 4);
+ }
+ if (nxpsg >= 4) {
+ sumf[ir1] += simd_shuffle_down(sumf[ir1], 2);
+ }
+ if (nxpsg >= 2) {
+ sumf[ir1] += simd_shuffle_down(sumf[ir1], 1);
+ }
+
+ //sumf[ir1] = simd_sum(sumf[ir1]);
+ }
+
+ if (tx == 0) {
+ for (short ir1 = 0; ir1 < r1ptg && i11 + ir1 < args.ne11; ++ir1) {
+ device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)(i11 + ir1)*args.ne0;
+
+ if (i01 < args.ne01) {
+ dst_f32[i01] = sumf[ir1];
+ }
+ }
+ }
+}
+
+// mat-vec kernel processing in chunks of float4x4
+template<short nxpsg, short r1ptg, typename q_t, short chpb, void (*deq_t4x4)(device const q_t *, short, thread float4x4 &) >
+void kernel_mul_mv_ext_q4x4_f32_impl(
+ constant ggml_metal_kargs_mul_mv_ext & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+ const short chpt = 1;
+
+ //const short nxpsg = (32);
+ const short nypsg = (32/nxpsg);
+
+ const short tx = tiisg%nxpsg;
+ const short ty = tiisg/nxpsg;
+
+ const int i01 = tgpig.x*(nypsg*args.nsg) + nypsg*sgitg + ty;
+ const int i11 = tgpig.y*r1ptg;
+ const int i1m = tgpig.z;
+
+ const int i12 = i1m%args.ne12;
+ const int i13 = i1m/args.ne12;
+
+ const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
+
+ device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0;
+
+ device const float4x4 * y4x4[r1ptg];
+
+ for (int ir1 = 0; ir1 < r1ptg; ++ir1) {
+ y4x4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4x4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4x4 *) src1;
+ }
+
+ float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f };
+
+ short cch = tx%chpb;
+
+ for (int ich = tx; 16*ich < args.ne00; ich += chpt*nxpsg) {
+ float4x4 lx[chpt];
+
+#pragma unroll(chpt)
+ for (short ch = 0; ch < chpt; ++ch) {
+ deq_t4x4(xq, cch, lx[ch]);
+
+ cch += nxpsg;
+ if (cch >= chpb) {
+ xq += cch/chpb;
+ cch %= chpb;
+ }
+ }
+
+#pragma unroll(chpt)
+ for (short ch = 0; ch < chpt; ++ch) {
+#pragma unroll(r1ptg)
+ for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
+ sumf[ir1] +=
+ dot(lx[ch][0], y4x4[ir1][ch*nxpsg][0]) +
+ dot(lx[ch][1], y4x4[ir1][ch*nxpsg][1]) +
+ dot(lx[ch][2], y4x4[ir1][ch*nxpsg][2]) +
+ dot(lx[ch][3], y4x4[ir1][ch*nxpsg][3]);
+
+ }
+ }
+
+#pragma unroll(r1ptg)
+ for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
+ y4x4[ir1] += chpt*nxpsg;
+ }
+ }
+
+ for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
+ if (nxpsg >= 32) {
+ sumf[ir1] += simd_shuffle_down(sumf[ir1], 16);
+ }
+ if (nxpsg >= 16) {
+ sumf[ir1] += simd_shuffle_down(sumf[ir1], 8);
+ }
+ if (nxpsg >= 8) {
+ sumf[ir1] += simd_shuffle_down(sumf[ir1], 4);
+ }
+ if (nxpsg >= 4) {
+ sumf[ir1] += simd_shuffle_down(sumf[ir1], 2);
+ }
+ if (nxpsg >= 2) {
+ sumf[ir1] += simd_shuffle_down(sumf[ir1], 1);
+ }
+
+ //sumf[ir1] = simd_sum(sumf[ir1]);
+ }
+
+ if (tx == 0) {
+ for (short ir1 = 0; ir1 < r1ptg && i11 + ir1 < args.ne11; ++ir1) {
+ device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)(i11 + ir1)*args.ne0;
+
+ if (i01 < args.ne01) {
+ dst_f32[i01] = sumf[ir1];
+ }
+ }
+ }
+}
+
+// dispatchers needed for compile-time nxpsg
+// epb - elements per quantization block
+template<short r1ptg, typename q_t, short epb, void (*deq_t4)(device const q_t *, short, thread float4 &)>
+kernel void kernel_mul_mv_ext_q4_f32_disp(
+ constant ggml_metal_kargs_mul_mv_ext & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+ switch (args.nxpsg) {
+ case 4: kernel_mul_mv_ext_q4_f32_impl<4, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
+ case 8: kernel_mul_mv_ext_q4_f32_impl<8, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
+ case 16: kernel_mul_mv_ext_q4_f32_impl<16, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
+ case 32: kernel_mul_mv_ext_q4_f32_impl<32, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
+ }
+}
+
+template<short r1ptg, typename q_t, short epb, void (*deq_t4x4)(device const q_t *, short, thread float4x4 &)>
+kernel void kernel_mul_mv_ext_q4x4_f32_disp(
+ constant ggml_metal_kargs_mul_mv_ext & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+ switch (args.nxpsg) {
+ case 4: kernel_mul_mv_ext_q4x4_f32_impl<4, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
+ case 8: kernel_mul_mv_ext_q4x4_f32_impl<8, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
+ case 16: kernel_mul_mv_ext_q4x4_f32_impl<16, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
+ case 32: kernel_mul_mv_ext_q4x4_f32_impl<32, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
+ }
+}
+
+typedef decltype(kernel_mul_mv_ext_q4_f32_disp <2, block_q8_0, 32, dequantize_q8_0_t4>) mul_mv_ext_q4_f32_t;
+typedef decltype(kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>) mul_mv_ext_q4x4_f32_t;
+
+template [[host_name("kernel_mul_mv_ext_f16_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, half4, 4, dequantize_f16_t4>;
+template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, half4, 4, dequantize_f16_t4>;
+template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, half4, 4, dequantize_f16_t4>;
+template [[host_name("kernel_mul_mv_ext_f16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, half4, 4, dequantize_f16_t4>;
+
+template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0, 32, dequantize_q4_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0, 32, dequantize_q4_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0, 32, dequantize_q4_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q4_0, 32, dequantize_q4_0_t4>;
+
+template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_1, 32, dequantize_q4_1_t4>;
+template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_1, 32, dequantize_q4_1_t4>;
+template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_1, 32, dequantize_q4_1_t4>;
+template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q4_1, 32, dequantize_q4_1_t4>;
+
+template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q5_0, 32, dequantize_q5_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q5_0, 32, dequantize_q5_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q5_0, 32, dequantize_q5_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q5_0, 32, dequantize_q5_0_t4>;
+
+template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q5_1, 32, dequantize_q5_1_t4>;
+template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q5_1, 32, dequantize_q5_1_t4>;
+template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q5_1, 32, dequantize_q5_1_t4>;
+template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q5_1, 32, dequantize_q5_1_t4>;
+
+template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q8_0, 32, dequantize_q8_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q8_0, 32, dequantize_q8_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q8_0, 32, dequantize_q8_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q8_0, 32, dequantize_q8_0_t4>;
+
+template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
+template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
+template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
+template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
+
+template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>;
+template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q4_K, 256, dequantize_q4_K>;
+template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q4_K, 256, dequantize_q4_K>;
+template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q4_K, 256, dequantize_q4_K>;
+
+template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q5_K, 256, dequantize_q5_K>;
+template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q5_K, 256, dequantize_q5_K>;
+template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q5_K, 256, dequantize_q5_K>;
+template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q5_K, 256, dequantize_q5_K>;
+
+template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q6_K, 256, dequantize_q6_K>;
+template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q6_K, 256, dequantize_q6_K>;
+template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>;
+template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>;
+
#define N_MV_T_T 4
template<typename T0, typename T04, typename T1, typename T14, typename args_t>