]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : support MTLGPUFamily < Apple7, formatting, style (#3524)
authorGeorgi Gerganov <redacted>
Sun, 8 Oct 2023 07:01:53 +0000 (10:01 +0300)
committerGitHub <redacted>
Sun, 8 Oct 2023 07:01:53 +0000 (10:01 +0300)
* metal : improve decoding speed for batches of 2-16

* metal : rename kernels mul_mat_ to mul_mv_

* metal : indentations

* minor

* metal : print more GPU info + disable mul_mm for MTLGPUFamiliy < Apple7

ggml-metal.m
ggml-metal.metal

index f8fa05dd9f4a7c5e30d296ca6060275660c0fa75..57c238ddaaf6770dd058f864e1cf5a317f9de7fe 100644 (file)
@@ -81,18 +81,18 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(get_rows_q6_K);
     GGML_METAL_DECL_KERNEL(rms_norm);
     GGML_METAL_DECL_KERNEL(norm);
-    GGML_METAL_DECL_KERNEL(mul_mat_f32_f32);
-    GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
-    GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
-    GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4);
-    GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
-    GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
-    GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
-    GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
-    GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
-    GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
-    GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
-    GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
+    GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
+    GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
@@ -262,28 +262,30 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(get_rows_q6_K);
         GGML_METAL_ADD_KERNEL(rms_norm);
         GGML_METAL_ADD_KERNEL(norm);
-        GGML_METAL_ADD_KERNEL(mul_mat_f32_f32);
-        GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
-        GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
-        GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4);
-        GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
-        GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
-        GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
-        GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
-        GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
-        GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
-        GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
-        GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
-        GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
-        GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
-        GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
-        GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
-        GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
-        GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
-        GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
-        GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
-        GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
-        GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
+        GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
+        GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
+        if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
+            GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
+        }
         GGML_METAL_ADD_KERNEL(rope_f32);
         GGML_METAL_ADD_KERNEL(rope_f16);
         GGML_METAL_ADD_KERNEL(alibi_f32);
@@ -296,8 +298,22 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
 #undef GGML_METAL_ADD_KERNEL
     }
 
-    GGML_METAL_LOG_INFO("%s: hasUnifiedMemory              = %s\n",       __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
 #if TARGET_OS_OSX
+    // print MTL GPU family:
+    GGML_METAL_LOG_INFO("%s: GPU name:   %s\n", __func__, [[ctx->device name] UTF8String]);
+    GGML_METAL_LOG_INFO("%s: GPU arch:   %s\n", __func__, [[ctx->device architecture].name UTF8String]);
+
+    // determine max supported GPU family
+    // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
+    // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
+    for (int i = MTLGPUFamilyApple9 + 10; i >= MTLGPUFamilyApple1; --i) {
+        if ([ctx->device supportsFamily:i]) {
+            GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - MTLGPUFamilyApple1 + 1, i);
+            break;
+        }
+    }
+
+    GGML_METAL_LOG_INFO("%s: hasUnifiedMemory              = %s\n",       __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
     GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize  = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
     if (ctx->device.maxTransferRate != 0) {
         GGML_METAL_LOG_INFO("%s: maxTransferRate               = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
@@ -339,28 +355,30 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(get_rows_q6_K);
     GGML_METAL_DEL_KERNEL(rms_norm);
     GGML_METAL_DEL_KERNEL(norm);
-    GGML_METAL_DEL_KERNEL(mul_mat_f32_f32);
-    GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
-    GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
-    GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4);
-    GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
-    GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
-    GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
-    GGML_METAL_DEL_KERNEL(mul_mat_q2_K_f32);
-    GGML_METAL_DEL_KERNEL(mul_mat_q3_K_f32);
-    GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
-    GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
-    GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
-    GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
-    GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
-    GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
-    GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
-    GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
-    GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
-    GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
-    GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
-    GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
-    GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
+    GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
+    GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
+    if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
+        GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
+    }
     GGML_METAL_DEL_KERNEL(rope_f32);
     GGML_METAL_DEL_KERNEL(rope_f16);
     GGML_METAL_DEL_KERNEL(alibi_f32);
@@ -986,21 +1004,46 @@ void ggml_metal_graph_compute(
                         } break;
                     case GGML_OP_MUL_MAT:
                         {
-                            // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
-
                             GGML_ASSERT(ne00 == ne10);
-                            // GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
-                            uint gqa = ne12/ne02;
                             GGML_ASSERT(ne03 == ne13);
 
+                            const uint gqa = ne12/ne02;
+
+                            // find the break-even point where the matrix-matrix kernel becomes more efficient compared
+                            // to the matrix-vector kernel
+                            int ne11_mm_min = 1;
+
+#if 0
+                            // the numbers below are measured on M2 Ultra for 7B and 13B models
+                            // these numbers do not translate to other devices or model sizes
+                            // TODO: need to find a better approach
+                            if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
+                                switch (src0t) {
+                                    case GGML_TYPE_F16:  ne11_mm_min = 2;  break;
+                                    case GGML_TYPE_Q8_0: ne11_mm_min = 7;  break;
+                                    case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
+                                    case GGML_TYPE_Q3_K: ne11_mm_min = 7;  break;
+                                    case GGML_TYPE_Q4_0:
+                                    case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
+                                    case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
+                                    case GGML_TYPE_Q5_0:                          // not tested yet
+                                    case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
+                                    case GGML_TYPE_Q5_K: ne11_mm_min = 7;  break;
+                                    case GGML_TYPE_Q6_K: ne11_mm_min = 7;  break;
+                                    default:             ne11_mm_min = 1;  break;
+                                }
+                            }
+#endif
+
                             // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
                             // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
-                            if (!ggml_is_transposed(src0) &&
+                            if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
+                                !ggml_is_transposed(src0) &&
                                 !ggml_is_transposed(src1) &&
                                 src1t == GGML_TYPE_F32 &&
-                                [ctx->device supportsFamily:MTLGPUFamilyApple7] &&
-                                ne00%32 == 0 &&
-                                ne11 > 2) {
+                                ne00 % 32 == 0 &&
+                                ne11 > ne11_mm_min) {
+                                //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
                                 switch (src0->type) {
                                     case GGML_TYPE_F32:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32];  break;
                                     case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32];  break;
@@ -1029,17 +1072,18 @@ void ggml_metal_graph_compute(
                                 [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:12];
                                 [encoder setBytes:&gqa     length:sizeof(gqa)  atIndex:13];
                                 [encoder setThreadgroupMemoryLength:8192 atIndex:0];
-                                [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
                             } else {
                                 int nth0 = 32;
                                 int nth1 = 1;
                                 int nrows = 1;
+                                //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
 
                                 // use custom matrix x vector kernel
                                 switch (src0t) {
                                     case GGML_TYPE_F32:
                                         {
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_f32_f32];
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
                                             nrows = 4;
                                         } break;
                                     case GGML_TYPE_F16:
@@ -1047,12 +1091,12 @@ void ggml_metal_graph_compute(
                                             nth0 = 32;
                                             nth1 = 1;
                                             if (ne11 * ne12 < 4) {
-                                                [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
+                                                [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
                                             } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
-                                                [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4];
+                                                [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
                                                 nrows = ne11;
                                             } else {
-                                                [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
+                                                [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
                                                 nrows = 4;
                                             }
                                         } break;
@@ -1063,7 +1107,7 @@ void ggml_metal_graph_compute(
 
                                             nth0 = 8;
                                             nth1 = 8;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
                                         } break;
                                     case GGML_TYPE_Q4_1:
                                         {
@@ -1072,7 +1116,7 @@ void ggml_metal_graph_compute(
 
                                             nth0 = 8;
                                             nth1 = 8;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
                                         } break;
                                     case GGML_TYPE_Q8_0:
                                         {
@@ -1081,7 +1125,7 @@ void ggml_metal_graph_compute(
 
                                             nth0 = 8;
                                             nth1 = 8;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q8_0_f32];
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
                                         } break;
                                     case GGML_TYPE_Q2_K:
                                         {
@@ -1090,7 +1134,7 @@ void ggml_metal_graph_compute(
 
                                             nth0 = 2;
                                             nth1 = 32;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
                                         } break;
                                     case GGML_TYPE_Q3_K:
                                         {
@@ -1099,7 +1143,7 @@ void ggml_metal_graph_compute(
 
                                             nth0 = 2;
                                             nth1 = 32;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
                                         } break;
                                     case GGML_TYPE_Q4_K:
                                         {
@@ -1108,7 +1152,7 @@ void ggml_metal_graph_compute(
 
                                             nth0 = 4; //1;
                                             nth1 = 8; //32;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
                                         } break;
                                     case GGML_TYPE_Q5_K:
                                         {
@@ -1117,7 +1161,7 @@ void ggml_metal_graph_compute(
 
                                             nth0 = 2;
                                             nth1 = 32;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
                                         } break;
                                     case GGML_TYPE_Q6_K:
                                         {
@@ -1126,7 +1170,7 @@ void ggml_metal_graph_compute(
 
                                             nth0 = 2;
                                             nth1 = 32;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
                                         } break;
                                     default:
                                         {
@@ -1155,7 +1199,7 @@ void ggml_metal_graph_compute(
                                 [encoder setBytes:&gqa  length:sizeof(gqa)  atIndex:17];
 
                                 if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
-                                    src0t == GGML_TYPE_Q2_K) {// || src0t == GGML_TYPE_Q4_K) {
+                                    src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
                                     [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 }
                                 else if (src0t == GGML_TYPE_Q4_K) {
index 9bd94e82bfe1e7a260da7c698d237156e383026d..b6288db28660dc9e8902cb7a27369e9af30db695 100644 (file)
@@ -13,8 +13,8 @@ typedef struct {
 
 #define QK4_1 32
 typedef struct {
-    half d;          // delta
-    half m;          // min
+    half d;                 // delta
+    half m;                 // min
     uint8_t qs[QK4_1 / 2];  // nibbles / quants
 } block_q4_1;
 
@@ -423,8 +423,8 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
 }
 
 // putting them in the kernel cause a significant performance penalty
-#define N_DST 4 // each SIMD group works on 4 rows
-#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
+#define N_DST 4        // each SIMD group works on 4 rows
+#define N_SIMDGROUP 2  // number of SIMD groups in a thread group
 #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
 //Note: This is a template, but strictly speaking it only applies to
 //      quantizations where the block size is 32. It also does not
@@ -435,18 +435,23 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
                     int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
                     uint3 tgpig, uint tiisg, uint sgitg) {
     const int nb = ne00/QK4_0;
+
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
+
     const int first_row = (r0 * nsg + sgitg) * nr;
+
     const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
+
     device const block_q_type * x = (device const block_q_type *) src0 + offset0;
     device const float        * y = (device const float        *) src1 + r1*ne10 + im*ne00*ne1;
-    float yl[16];       // src1 vector cache
-    float sumf[nr]={0.f};
 
-    const int ix = tiisg/2;
-    const int il = 8*(tiisg%2);
+    float yl[16]; // src1 vector cache
+    float sumf[nr] = {0.f};
+
+    const int ix = (tiisg/2);
+    const int il = (tiisg%2)*8;
 
     device const float * yb = y + ix * QK4_0 + il;
 
@@ -457,6 +462,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
             sumy += yb[i] + yb[i+1];
             yl[i+0] = yb[i+ 0];
             yl[i+1] = yb[i+ 1]/256.f;
+
             sumy += yb[i+16] + yb[i+17];
             yl[i+8] = yb[i+16]/16.f;
             yl[i+9] = yb[i+17]/4096.f;
@@ -472,12 +478,12 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
     for (int row = 0; row < nr; ++row) {
         const float tot = simd_sum(sumf[row]);
         if (tiisg == 0 && first_row + row < ne01) {
-            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
+            dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot;
         }
     }
 }
 
-kernel void kernel_mul_mat_q4_0_f32(
+kernel void kernel_mul_mv_q4_0_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -490,12 +496,12 @@ kernel void kernel_mul_mat_q4_0_f32(
         constant   int64_t & ne1[[buffer(16)]],
         constant   uint    & gqa[[buffer(17)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
-        uint tiisg[[thread_index_in_simdgroup]],
-        uint sgitg[[simdgroup_index_in_threadgroup]]) {
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
     mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
 }
 
-kernel void kernel_mul_mat_q4_1_f32(
+kernel void kernel_mul_mv_q4_1_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -515,7 +521,7 @@ kernel void kernel_mul_mat_q4_1_f32(
 
 #define NB_Q8_0 8
 
-kernel void kernel_mul_mat_q8_0_f32(
+kernel void kernel_mul_mv_q8_0_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -579,7 +585,7 @@ kernel void kernel_mul_mat_q8_0_f32(
 
 #define N_F32_F32 4
 
-kernel void kernel_mul_mat_f32_f32(
+kernel void kernel_mul_mv_f32_f32(
         device const  char * src0,
         device const  char * src1,
         device       float * dst,
@@ -650,7 +656,7 @@ kernel void kernel_mul_mat_f32_f32(
     }
 }
 
-kernel void kernel_mul_mat_f16_f32_1row(
+kernel void kernel_mul_mv_f16_f32_1row(
         device const  char * src0,
         device const  char * src1,
         device       float * dst,
@@ -669,7 +675,7 @@ kernel void kernel_mul_mat_f16_f32_1row(
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         uint3 tgpig[[threadgroup_position_in_grid]],
-        uint tiisg[[thread_index_in_simdgroup]]) {
+        uint  tiisg[[thread_index_in_simdgroup]]) {
 
     const int64_t r0 = tgpig.x;
     const int64_t r1 = tgpig.y;
@@ -704,7 +710,7 @@ kernel void kernel_mul_mat_f16_f32_1row(
 
 #define N_F16_F32 4
 
-kernel void kernel_mul_mat_f16_f32(
+kernel void kernel_mul_mv_f16_f32(
         device const  char * src0,
         device const  char * src1,
         device       float * dst,
@@ -776,7 +782,7 @@ kernel void kernel_mul_mat_f16_f32(
 }
 
 // Assumes row size (ne00) is a multiple of 4
-kernel void kernel_mul_mat_f16_f32_l4(
+kernel void kernel_mul_mv_f16_f32_l4(
         device const  char * src0,
         device const  char * src1,
         device       float * dst,
@@ -1253,7 +1259,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
 
 //====================================== dot products =========================
 
-kernel void kernel_mul_mat_q2_K_f32(
+kernel void kernel_mul_mv_q2_K_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -1397,7 +1403,7 @@ kernel void kernel_mul_mat_q2_K_f32(
 }
 
 #if QK_K == 256
-kernel void kernel_mul_mat_q3_K_f32(
+kernel void kernel_mul_mv_q3_K_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -1549,7 +1555,7 @@ kernel void kernel_mul_mat_q3_K_f32(
     }
 }
 #else
-kernel void kernel_mul_mat_q3_K_f32(
+kernel void kernel_mul_mv_q3_K_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -1620,7 +1626,7 @@ kernel void kernel_mul_mat_q3_K_f32(
 #endif
 
 #if QK_K == 256
-kernel void kernel_mul_mat_q4_K_f32(
+kernel void kernel_mul_mv_q4_K_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -1726,7 +1732,7 @@ kernel void kernel_mul_mat_q4_K_f32(
     }
 }
 #else
-kernel void kernel_mul_mat_q4_K_f32(
+kernel void kernel_mul_mv_q4_K_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -1815,7 +1821,7 @@ kernel void kernel_mul_mat_q4_K_f32(
 }
 #endif
 
-kernel void kernel_mul_mat_q5_K_f32(
+kernel void kernel_mul_mv_q5_K_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -1988,7 +1994,7 @@ kernel void kernel_mul_mat_q5_K_f32(
 
 }
 
-kernel void kernel_mul_mat_q6_K_f32(
+kernel void kernel_mul_mv_q6_K_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -2326,7 +2332,7 @@ kernel void kernel_get_rows(
 }
 
 #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
-#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix A
+#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
 #define BLOCK_SIZE_K 32
 #define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
 #define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
@@ -2363,9 +2369,11 @@ kernel void kernel_mul_mm(device const  uchar * src0,
     const uint r0 = tgpig.y;
     const uint r1 = tgpig.x;
     const uint im = tgpig.z;
+
     // if this block is of 64x32 shape or smaller
     short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
     short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
+
     // a thread shouldn't load data outside of the matrix
     short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
     short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
@@ -2389,26 +2397,30 @@ kernel void kernel_mul_mm(device const  uchar * src0,
         + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
 
     for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
-        //load data and store to threadgroup memory
+        // load data and store to threadgroup memory
         half4x4 temp_a;
         dequantize_func(x, il, temp_a);
         threadgroup_barrier(mem_flags::mem_threadgroup);
+
         #pragma unroll(16)
         for (int i = 0; i < 16; i++) {
             *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
-            + 16 * (tiitg % THREAD_PER_ROW) + 8 * (i / 8)) \
-            + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
+            +                     (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
+            +                     (tiitg / THREAD_PER_ROW) % 8  + (i & 7) * 8) = temp_a[i/4][i%4];
         }
-        *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) \
-                = *((device float2x4 *)y);
+
+        *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
+
         il = (il + 2 < nl) ? il + 2 : il % 2;
         x  = (il < 2) ? x + (2+nl-1)/nl : x;
         y += BLOCK_SIZE_K;
 
         threadgroup_barrier(mem_flags::mem_threadgroup);
-        //load matrices from threadgroup memory and conduct outer products
+
+        // load matrices from threadgroup memory and conduct outer products
         threadgroup half  * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
         threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
+
         #pragma unroll(4)
         for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
             #pragma unroll(4)
@@ -2423,6 +2435,7 @@ kernel void kernel_mul_mm(device const  uchar * src0,
 
             lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
             lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
+
             #pragma unroll(8)
             for (int i = 0; i < 8; i++){
                 simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
@@ -2431,25 +2444,26 @@ kernel void kernel_mul_mm(device const  uchar * src0,
     }
 
     if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
-        device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \
-                          + (BLOCK_SIZE_N * r1 + 16 * (sgitg>>1)) * ne0 + im*ne1*ne0;
+        device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg &  1)) \
+                               + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
         for (int i = 0; i < 8; i++) {
             simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
         }
     } else {
         // block is smaller than 64x32, we should avoid writing data outside of the matrix
         threadgroup_barrier(mem_flags::mem_threadgroup);
-        threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
+        threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
                                       + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
         for (int i = 0; i < 8; i++) {
             simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
         }
 
         threadgroup_barrier(mem_flags::mem_threadgroup);
-        device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
-        if (sgitg==0) {
+
+        device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
+        if (sgitg == 0) {
             for (int i = 0; i < n_rows; i++) {
-                for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) {
+                for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
                     *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
                 }
             }