]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : matrix-matrix multiplication kernel (#2615)
authorShouzheng Liu <redacted>
Wed, 16 Aug 2023 20:07:04 +0000 (16:07 -0400)
committerGitHub <redacted>
Wed, 16 Aug 2023 20:07:04 +0000 (23:07 +0300)
* metal: matrix-matrix multiplication kernel

This commit removes MPS and uses custom matrix-matrix multiplication
kernels for all quantization types. This commit also adds grouped-query
attention to support llama2 70B.

* metal: fix performance degradation from gqa

Integers are slow on the GPU, and 64-bit divides are extremely slow.
In the context of GQA, we introduce a 64-bit divide that cannot be
optimized out by the compiler, which results in a decrease of ~8% in
inference performance. This commit fixes that issue by calculating a
part of the offset with a 32-bit divide. Naturally, this limits the
size of a single matrix to ~4GB. However, this limitation should
suffice for the near future.

* metal: fix bugs for GQA and perplexity test.

I mixed up ne02 and nb02 in previous commit.

CMakeLists.txt
Makefile
flake.nix
ggml-metal.m
ggml-metal.metal
llama.cpp

index dff4942cdb9b2703cc5d7e145f538806c6244dc1..01b40c2e881fb53b16c4d32010fc93d510989855 100644 (file)
@@ -296,7 +296,6 @@ if (LLAMA_METAL)
     find_library(FOUNDATION_LIBRARY         Foundation              REQUIRED)
     find_library(METAL_FRAMEWORK            Metal                   REQUIRED)
     find_library(METALKIT_FRAMEWORK         MetalKit                REQUIRED)
-    find_library(METALPERFORMANCE_FRAMEWORK MetalPerformanceShaders REQUIRED)
 
     set(GGML_SOURCES_METAL ggml-metal.m ggml-metal.h)
 
@@ -313,7 +312,6 @@ if (LLAMA_METAL)
         ${FOUNDATION_LIBRARY}
         ${METAL_FRAMEWORK}
         ${METALKIT_FRAMEWORK}
-        ${METALPERFORMANCE_FRAMEWORK}
         )
 endif()
 
index 070ae124282065e356b81f9145d91c904b21049c..5b801d16f5ecde9a222373451052bc66e2c88873 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -283,7 +283,7 @@ endif # LLAMA_CLBLAST
 ifdef LLAMA_METAL
        CFLAGS   += -DGGML_USE_METAL -DGGML_METAL_NDEBUG
        CXXFLAGS += -DGGML_USE_METAL
-       LDFLAGS  += -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders
+       LDFLAGS  += -framework Foundation -framework Metal -framework MetalKit
        OBJS     += ggml-metal.o
 endif # LLAMA_METAL
 
index 4178e97ffc5d0273674f7b104a078d18754647eb..616b902529d46c5a8c104083af8747c7ff8c1c47 100644 (file)
--- a/flake.nix
+++ b/flake.nix
@@ -14,8 +14,6 @@
             with pkgs.darwin.apple_sdk_11_0.frameworks; [
               Accelerate
               MetalKit
-              MetalPerformanceShaders
-              MetalPerformanceShadersGraph
             ]
           else if isAarch32 && isDarwin then
             with pkgs.darwin.apple_sdk.frameworks; [
index fbac21e3a9c7f0324f7b6b932f9360bef97955cf..e13cb4b3cd9ba136e37a421e4b98ae23b211186a 100644 (file)
@@ -5,7 +5,6 @@
 #import <Foundation/Foundation.h>
 
 #import <Metal/Metal.h>
-#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
 
 #undef MIN
 #undef MAX
@@ -79,6 +78,14 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
     GGML_METAL_DECL_KERNEL(rope);
     GGML_METAL_DECL_KERNEL(alibi_f32);
     GGML_METAL_DECL_KERNEL(cpy_f32_f16);
@@ -110,13 +117,6 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
     ctx->n_buffers = 0;
     ctx->concur_list_len = 0;
 
-    // determine if we can use MPS
-    if (MPSSupportsMTLDevice(ctx->device)) {
-        fprintf(stderr, "%s: using MPS\n", __func__);
-    } else {
-        fprintf(stderr, "%s: not using MPS\n", __func__);
-        GGML_ASSERT(false && "MPS not supported");
-    }
 
 #if 0
     // compile from source string and show compile log
@@ -196,6 +196,14 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
+        GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
+        GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
+        GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
         GGML_METAL_ADD_KERNEL(rope);
         GGML_METAL_ADD_KERNEL(alibi_f32);
         GGML_METAL_ADD_KERNEL(cpy_f32_f16);
@@ -506,7 +514,7 @@ void ggml_metal_graph_compute(
 
             id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
 
-            id<MTLComputeCommandEncoder> encoder = nil;
+            id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
 
             const int node_start =                                  (cb_idx + 0) * n_nodes_per_cb;
             const int node_end   = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
@@ -515,10 +523,6 @@ void ggml_metal_graph_compute(
                 const int i = has_concur ? ctx->concur_list[ind] : ind;
 
                 if (i == -1) {
-                    if (encoder == nil) {
-                        encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
-                        continue;
-                    }
                     [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
                     continue;
                 }
@@ -592,10 +596,6 @@ void ggml_metal_graph_compute(
                         } break;
                     case GGML_OP_ADD:
                         {
-                            if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
-                            }
-
                             if (ggml_nelements(src1) == ne10) {
                                 // src1 is a row
                                 [encoder setComputePipelineState:ctx->pipeline_add_row];
@@ -613,10 +613,6 @@ void ggml_metal_graph_compute(
                         } break;
                     case GGML_OP_MUL:
                         {
-                            if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
-                            }
-
                             if (ggml_nelements(src1) == ne10) {
                                 // src1 is a row
                                 [encoder setComputePipelineState:ctx->pipeline_mul_row];
@@ -634,10 +630,6 @@ void ggml_metal_graph_compute(
                         } break;
                     case GGML_OP_SCALE:
                         {
-                            if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
-                            }
-
                             const float scale = *(const float *) src1->data;
 
                             [encoder setComputePipelineState:ctx->pipeline_scale];
@@ -653,10 +645,6 @@ void ggml_metal_graph_compute(
                         switch (ggml_get_unary_op(gf->nodes[i])) {
                             case GGML_UNARY_OP_SILU:
                                 {
-                                    if (encoder == nil) {
-                                        encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
-                                    }
-
                                     [encoder setComputePipelineState:ctx->pipeline_silu];
                                     [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                                     [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
@@ -667,10 +655,6 @@ void ggml_metal_graph_compute(
                                 } break;
                             case GGML_UNARY_OP_RELU:
                                 {
-                                    if (encoder == nil) {
-                                        encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
-                                    }
-
                                     [encoder setComputePipelineState:ctx->pipeline_relu];
                                     [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                                     [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
@@ -681,10 +665,6 @@ void ggml_metal_graph_compute(
                                 } break;
                             case GGML_UNARY_OP_GELU:
                                 {
-                                    if (encoder == nil) {
-                                        encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
-                                    }
-
                                     [encoder setComputePipelineState:ctx->pipeline_gelu];
                                     [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                                     [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
@@ -701,10 +681,6 @@ void ggml_metal_graph_compute(
                         } break;
                     case GGML_OP_SOFT_MAX:
                         {
-                            if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
-                            }
-
                             const int nth = 32;
 
                             [encoder setComputePipelineState:ctx->pipeline_soft_max];
@@ -719,10 +695,6 @@ void ggml_metal_graph_compute(
                         } break;
                     case GGML_OP_DIAG_MASK_INF:
                         {
-                            if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
-                            }
-
                             const int n_past = ((int32_t *)(dst->op_params))[0];
 
                             [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
@@ -740,53 +712,43 @@ void ggml_metal_graph_compute(
 
                             GGML_ASSERT(ne00 == ne10);
                             // GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
+                            uint gqa = ne12/ne02;
                             GGML_ASSERT(ne03 == ne13);
 
+                            // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
+                            // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
                             if (ggml_is_contiguous(src0) &&
                                 ggml_is_contiguous(src1) &&
-                                (src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) {
-
-                                if (encoder != nil) {
-                                    [encoder endEncoding];
-                                    encoder = nil;
-                                }
-
-                                MPSDataType src0dt = src0t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
-                                MPSDataType src1dt = src1t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
-
-                                // for F32 x F32 we use MPS
-                                MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor
-                                    matrixDescriptorWithRows:ne01 columns:ne00 rowBytes:src0->nb[1] dataType:src0dt];
-
-                                MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor
-                                    matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:src1->nb[1] dataType:src1dt];
-
-                                MPSMatrixDescriptor * desc  = [MPSMatrixDescriptor
-                                    matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:dst->nb[1] dataType:MPSDataTypeFloat32];
-
-                                MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc]
-                                    initWithDevice:ctx->device transposeLeft:false transposeRight:true
-                                        resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
-
-                                // we need to do ne12 multiplications
-                                // TODO: is there a way to do this in parallel - currently very slow ..
-                                // TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS
-                                for (int64_t i02 = 0; i02 < ne12; ++i02) {
-                                    size_t offs_src0_cur = offs_src0 + i02/(ne12/ne02)*nb02; // gqa not used for now
-                                    size_t offs_src1_cur = offs_src1 + i02*nb12;
-                                    size_t offs_dst_cur  = offs_dst  + i02*nb2;
-
-                                    MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0_cur descriptor:desc0];
-                                    MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1_cur descriptor:desc1];
-                                    MPSMatrix * mat_dst  = [[MPSMatrix alloc] initWithBuffer:id_dst  offset:offs_dst_cur  descriptor:desc ];
-
-                                    [mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
-                                }
-                            } else {
-                                if (encoder == nil) {
-                                    encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
+                                src1t == GGML_TYPE_F32 &&
+                                [ctx->device supportsFamily:MTLGPUFamilyApple7] &&
+                                ne00%32 == 0 &&
+                                ne11 > 1) {
+                                    switch (src0->type) {
+                                        case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
+                                        case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
+                                        case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
+                                        case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
+                                        case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
+                                        case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
+                                        case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
+                                        case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
+                                        default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
+                                    }
+                                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                                    [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+                                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+                                    [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+                                    [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+                                    [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
+                                    [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
+                                    [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
+                                    [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
+                                    [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
+                                    [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
+                                    [encoder setThreadgroupMemoryLength:8192 atIndex:0];
+                                    [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
                                 }
-
+                            else {
                                 int nth0 = 32;
                                 int nth1 = 1;
 
@@ -885,23 +847,24 @@ void ggml_metal_graph_compute(
                                 [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
                                 [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:15];
                                 [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:16];
+                                [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
 
                                 if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
                                     src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
-                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 }
                                 else if (src0t == GGML_TYPE_Q3_K) {
 #ifdef GGML_QKK_64
-                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
 #else
-                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
 #endif
                                 }
                                 else if (src0t == GGML_TYPE_Q5_K) {
-                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 }
                                 else if (src0t == GGML_TYPE_Q6_K) {
-                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 } else {
                                     [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
                                     [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -910,10 +873,6 @@ void ggml_metal_graph_compute(
                         } break;
                     case GGML_OP_GET_ROWS:
                         {
-                            if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
-                            }
-
                             switch (src0->type) {
                                 case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
                                 case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
@@ -939,10 +898,6 @@ void ggml_metal_graph_compute(
                         } break;
                     case GGML_OP_RMS_NORM:
                         {
-                            if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
-                            }
-
                             float eps;
                             memcpy(&eps, dst->op_params, sizeof(float));
 
@@ -962,10 +917,6 @@ void ggml_metal_graph_compute(
                         } break;
                     case GGML_OP_NORM:
                         {
-                            if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
-                            }
-
                             const float eps = 1e-5f;
 
                             const int nth = 256;
@@ -984,10 +935,6 @@ void ggml_metal_graph_compute(
                         } break;
                     case GGML_OP_ALIBI:
                         {
-                            if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
-                            }
-
                             GGML_ASSERT((src0t == GGML_TYPE_F32));
 
                             const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
@@ -1027,10 +974,6 @@ void ggml_metal_graph_compute(
                         } break;
                     case GGML_OP_ROPE:
                         {
-                            if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
-                            }
-
                             const int n_past = ((int32_t *) dst->op_params)[0];
                             const int n_dims = ((int32_t *) dst->op_params)[1];
                             const int mode   = ((int32_t *) dst->op_params)[2];
@@ -1071,10 +1014,6 @@ void ggml_metal_graph_compute(
                     case GGML_OP_CPY:
                     case GGML_OP_CONT:
                         {
-                            if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
-                            }
-
                             const int nth = 32;
 
                             switch (src0t) {
index 8d26b5ec2dfa4f649b7a6d478f73238a0c700da6..3f3125236f197f36f873b47daaf73f7be9834195 100644 (file)
@@ -18,47 +18,6 @@ typedef struct {
     uint8_t qs[QK4_1 / 2];  // nibbles / quants
 } block_q4_1;
 
-static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, int k) {
-    const int qk = QK4_0;
-
-    assert(k % qk == 0);
-
-    const int nb = k / qk;
-
-    for (int i = 0; i < nb; i++) {
-        const half d = x[i].d;
-
-        for (int j = 0; j < qk/2; ++j) {
-            const int x0 = (x[i].qs[j] & 0x0F) - 8;
-            const int x1 = (x[i].qs[j] >>   4) - 8;
-
-            y[i*qk + j + 0   ] = x0*d;
-            y[i*qk + j + qk/2] = x1*d;
-        }
-    }
-}
-
-static void dequantize_row_q4_1(device const block_q4_1 * x, device float * y, int k) {
-    const int qk = QK4_1;
-
-    assert(k % qk == 0);
-
-    const int nb = k / qk;
-
-    for (int i = 0; i < nb; i++) {
-        const half d = x[i].d;
-        const half m = x[i].m;
-
-        for (int j = 0; j < qk/2; ++j) {
-            const int x0 = (x[i].qs[j] & 0x0F);
-            const int x1 = (x[i].qs[j] >>   4);
-
-            y[i*qk + j + 0   ] = x0*d + m;
-            y[i*qk + j + qk/2] = x1*d + m;
-        }
-    }
-}
-
 kernel void kernel_add(
         device const float * src0,
         device const float * src1,
@@ -219,54 +178,6 @@ kernel void kernel_diag_mask_inf(
     }
 }
 
-kernel void kernel_get_rows_f16(
-        device const  void * src0,
-        device const   int * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb1,
-        uint tpig[[thread_position_in_grid]]) {
-    const int i = tpig;
-    const int r = ((device int32_t *) src1)[i];
-
-    for (int j = 0; j < ne00; j++) {
-        dst[i*nb1 + j] = ((device half *) ((device char *) src0 + r*nb01))[j];
-    }
-}
-
-kernel void kernel_get_rows_q4_0(
-        device const  void * src0,
-        device const   int * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb1,
-        uint tpig[[thread_position_in_grid]]) {
-    const int i = tpig;
-    const int r = ((device int32_t *) src1)[i];
-
-    dequantize_row_q4_0(
-            (device const block_q4_0 *) ((device char *) src0 + r*nb01),
-                       (device float *) ((device char *)  dst + i*nb1), ne00);
-}
-
-kernel void kernel_get_rows_q4_1(
-        device const  void * src0,
-        device const   int * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb1,
-        uint tpig[[thread_position_in_grid]]) {
-    const int i = tpig;
-    const int r = ((device int32_t *) src1)[i];
-
-    dequantize_row_q4_1(
-            (device const block_q4_1 *) ((device char *) src0 + r*nb01),
-                       (device float *) ((device char *)  dst + i*nb1), ne00);
-}
-
 kernel void kernel_norm(
         device const  void * src0,
         device       float * dst,
@@ -432,14 +343,16 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
 //      N_DST, so this is another explicit assumption of the implementation.
 template<typename block_q_type, int nr, int nsg, int nw>
 void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
-                    int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01,
-                    uint2 tgpig, uint tiisg, uint sgitg) {
+                    int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
+                    uint3 tgpig, uint tiisg, uint sgitg) {
     const int nb = ne00/QK4_0;
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
+    const int im = tgpig.z;
     const int first_row = (r0 * nsg + sgitg) * nr;
-    device const block_q_type * x = (device const block_q_type *) src0 + first_row * nb;
-    device const float      * y = (device const float      *) src1 + r1*ne10;
+    const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
+    device const block_q_type * x = (device const block_q_type *) src0 + offset0;
+    device const float      * y = (device const float      *) src1 + r1*ne10 + im*ne00*ne1;
     float yl[16];       // src1 vector cache
     float sumf[nr]={0.f};
 
@@ -470,7 +383,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
     for (int row = 0; row < nr; ++row) {
         const float tot = simd_sum(sumf[row]);
         if (tiisg == 0 && first_row + row < ne01) {
-            dst[r1*ne0 + first_row + row] = tot;
+            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
         }
     }
 }
@@ -480,13 +393,17 @@ kernel void kernel_mul_mat_q4_0_f32(
         device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
-        constant   int64_t & ne10,
-        constant   int64_t & ne0,
         constant   int64_t & ne01[[buffer(4)]],
-        uint2 tgpig[[threadgroup_position_in_grid]],
+        constant   int64_t & ne02[[buffer(5)]],
+        constant   int64_t & ne10[[buffer(9)]],
+        constant   int64_t & ne12[[buffer(11)]],
+        constant   int64_t & ne0[[buffer(15)]],
+        constant   int64_t & ne1[[buffer(16)]],
+        constant   uint    & gqa[[buffer(17)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
+    mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
 }
 
 kernel void kernel_mul_mat_q4_1_f32(
@@ -494,13 +411,17 @@ kernel void kernel_mul_mat_q4_1_f32(
         device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
-        constant   int64_t & ne10,
-        constant   int64_t & ne0,
         constant   int64_t & ne01[[buffer(4)]],
-        uint2 tgpig[[threadgroup_position_in_grid]],
+        constant   int64_t & ne02[[buffer(5)]],
+        constant   int64_t & ne10[[buffer(9)]],
+        constant   int64_t & ne12[[buffer(11)]],
+        constant   int64_t & ne0[[buffer(15)]],
+        constant   int64_t & ne1[[buffer(16)]],
+        constant   uint    & gqa[[buffer(17)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
-     mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
+     mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
 }
 
 kernel void kernel_mul_mat_f16_f32(
@@ -869,354 +790,6 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
     return r;
 }
 
-//========================================== dequantization =============================
-
-static void dequantize_row_q2_K(device const block_q2_K * x, device float * y, int k) {
-    assert(k % QK_K == 0);
-    const int nb = k / QK_K;
-
-    for (int i = 0; i < nb; i++) {
-
-        const float d = x[i].d;
-        const float min = x[i].dmin;
-
-        device const uint8_t * q = x[i].qs;
-
-#if QK_K == 256
-        int is = 0;
-        float dl, ml;
-        for (int n = 0; n < QK_K; n += 128) {
-            int shift = 0;
-            for (int j = 0; j < 4; ++j) {
-
-                uint8_t sc = x[i].scales[is++];
-                dl = d * (sc & 0xF); ml = min * (sc >> 4);
-                for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml;
-
-                sc = x[i].scales[is++];
-                dl = d * (sc & 0xF); ml = min * (sc >> 4);
-                for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml;
-
-                shift += 2;
-            }
-            q += 32;
-        }
-#else
-        float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4);
-        float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4);
-        float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4);
-        float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4);
-        for (int l = 0; l < 16; ++l) {
-            y[l+ 0] = dl1 * ((q[l] >> 0) & 3) - ml1;
-            y[l+16] = dl2 * ((q[l] >> 2) & 3) - ml2;
-            y[l+32] = dl3 * ((q[l] >> 4) & 3) - ml3;
-            y[l+48] = dl4 * ((q[l] >> 6) & 3) - ml4;
-        }
-        y += QK_K;
-#endif
-
-    }
-}
-
-static void dequantize_row_q3_K(device const block_q3_K * x, device float * y, int k) {
-    assert(k % QK_K == 0);
-    const int nb = k / QK_K;
-
-#if QK_K == 256
-
-    const uint16_t kmask1 = 0x0303;
-    const uint16_t kmask2 = 0x0f0f;
-
-    uint16_t aux[8];
-    thread const int8_t * scales = (thread const int8_t*)aux;
-
-    for (int i = 0; i < nb; i++) {
-
-        const float d_all = (float)(x[i].d);
-
-        device const uint8_t * q = x[i].qs;
-        device const uint8_t * h = x[i].hmask;
-        uint8_t m = 1;
-
-        device const uint16_t * a = (device const uint16_t *)x[i].scales;
-        aux[0] = (a[0] & kmask2) | (((a[4] >> 0) & kmask1) << 4);
-        aux[1] = (a[1] & kmask2) | (((a[5] >> 0) & kmask1) << 4);
-        aux[2] = (a[2] & kmask2) | (((a[4] >> 2) & kmask1) << 4);
-        aux[3] = (a[3] & kmask2) | (((a[5] >> 2) & kmask1) << 4);
-        aux[4] = ((a[0] >> 4) & kmask2) | (((a[4] >> 4) & kmask1) << 4);
-        aux[5] = ((a[1] >> 4) & kmask2) | (((a[5] >> 4) & kmask1) << 4);
-        aux[6] = ((a[2] >> 4) & kmask2) | (((a[4] >> 6) & kmask1) << 4);
-        aux[7] = ((a[3] >> 4) & kmask2) | (((a[5] >> 6) & kmask1) << 4);
-
-        int is = 0;
-        float dl;
-        for (int n = 0; n < QK_K; n += 128) {
-            int shift = 0;
-            for (int j = 0; j < 4; ++j) {
-
-                dl = d_all * (scales[is++] - 32);
-                for (int l = 0; l < 16; ++l) {
-                    *y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((h[l+ 0] & m) ? 0 : 4));
-                }
-
-                dl = d_all * (scales[is++] - 32);
-                for (int l = 0; l < 16; ++l) {
-                    *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((h[l+16] & m) ? 0 : 4));
-                }
-
-                shift += 2;
-                m <<= 1;
-            }
-            q += 32;
-        }
-    }
-#else
-    for (int i = 0; i < nb; i++) {
-
-        const float d_all = (float)(x[i].d);
-
-        device const uint8_t * q = x[i].qs;
-        device const uint8_t * hm = x[i].hmask;
-
-        const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
-        const float d2 = d_all * ((x[i].scales[0] >>  4) - 8);
-        const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
-        const float d4 = d_all * ((x[i].scales[1] >>  4) - 8);
-
-        for (int l = 0; l < 8; ++l) {
-            uint8_t h = hm[l];
-            y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4));
-            y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4));
-            y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4));
-            y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4));
-            y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4));
-            y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4));
-            y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4));
-            y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4));
-        }
-        y += QK_K;
-    }
-#endif
-
-}
-
-static void dequantize_row_q4_K(device const block_q4_K * x, device float * y, int k) {
-    assert(k % QK_K == 0);
-    const int nb = k / QK_K;
-
-    for (int i = 0; i < nb; i++) {
-
-        device const uint8_t * q = x[i].qs;
-
-#if QK_K == 256
-        const float d = x[i].d;
-        const float min = x[i].dmin;
-
-        device const uint8_t * scales = x[i].scales;
-
-        int is = 0;
-        for (int j = 0; j < QK_K; j += 64) {
-            const uchar4 sc = get_scale_min_k4(is, scales);
-            const float d1 = d * sc[0]; const float m1 = min * sc[1];
-            const float d2 = d * sc[2]; const float m2 = min * sc[3];
-            for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
-            for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l]  >> 4) - m2;
-            q += 32; is += 2;
-        }
-#else
-        device const uint8_t * s = x[i].scales;
-        device const half2 * dh = (device const half2 *)x[i].d;
-        const float2 d = (float2)dh[0];
-        const float d1 = d[0] * (s[0] & 0xF);
-        const float d2 = d[0] * (s[1] & 0xF);
-        const float m1 = d[1] * (s[0] >>  4);
-        const float m2 = d[1] * (s[1] >>  4);
-        for (int l = 0; l < 32; ++l) {
-            y[l+ 0] = d1 * (q[l] & 0xF) - m1;
-            y[l+32] = d2 * (q[l] >>  4) - m2;
-        }
-        y += QK_K;
-#endif
-
-    }
-}
-
-static void dequantize_row_q5_K(device const block_q5_K * x, device float * y, int k) {
-    assert(k % QK_K == 0);
-    const int nb = k / QK_K;
-
-#if QK_K == 256
-   for (int i = 0; i < nb; i++) {
-
-        const float d = (float)(x[i].d);
-        const float min = (float)(x[i].dmin);
-
-        device const uint8_t * ql = x[i].qs;
-        device const uint8_t * qh = x[i].qh;
-
-        int is = 0;
-        uint8_t u1 = 1, u2 = 2;
-        for (int j = 0; j < QK_K; j += 64) {
-            const uchar4 sc = get_scale_min_k4(is, x[i].scales);
-            const float d1 = d * sc[0]; const float m1 = min * sc[1];
-            const float d2 = d * sc[2]; const float m2 = min * sc[3];
-            for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1;
-            for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l]  >> 4) + (qh[l] & u2 ? 16 : 0)) - m2;
-            ql += 32; is += 2;
-            u1 <<= 2; u2 <<= 2;
-        }
-    }
-#else
-    for (int i = 0; i < nb; i++) {
-
-        const float d = (float)x[i].d;
-
-        device const uint8_t * ql = x[i].qs;
-        device const uint8_t * qh = x[i].qh;
-        device const int8_t  * sc = x[i].scales;
-
-        for (int l = 0; l < 8; ++l) {
-            y[l+ 0] = d * sc[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16));
-            y[l+ 8] = d * sc[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16));
-            y[l+16] = d * sc[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16));
-            y[l+24] = d * sc[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16));
-            y[l+32] = d * sc[2] * ((ql[l+ 0] >>  4) - (qh[l] & 0x10 ? 0 : 16));
-            y[l+40] = d * sc[2] * ((ql[l+ 8] >>  4) - (qh[l] & 0x20 ? 0 : 16));
-            y[l+48] = d * sc[3] * ((ql[l+16] >>  4) - (qh[l] & 0x40 ? 0 : 16));
-            y[l+56] = d * sc[3] * ((ql[l+24] >>  4) - (qh[l] & 0x80 ? 0 : 16));
-        }
-        y += QK_K;
-    }
-#endif
-
-}
-
-static void dequantize_row_q6_K(device const block_q6_K * x, device float * y, int k) {
-    assert(k % QK_K == 0);
-    const int nb = k / QK_K;
-
-    for (int i = 0; i < nb; i++) {
-
-        device const uint8_t * ql = x[i].ql;
-        device const uint8_t * qh = x[i].qh;
-        device const int8_t  * sc = x[i].scales;
-
-        const float d = x[i].d;
-
-#if QK_K == 256
-        for (int n = 0; n < QK_K; n += 128) {
-            for (int l = 0; l < 32; ++l) {
-                int is = l/16;
-                const int8_t q1 = (int8_t)((ql[l +  0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
-                const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
-                const int8_t q3 = (int8_t)((ql[l +  0]  >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
-                const int8_t q4 = (int8_t)((ql[l + 32]  >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
-                y[l +  0] = d * sc[is + 0] * q1;
-                y[l + 32] = d * sc[is + 2] * q2;
-                y[l + 64] = d * sc[is + 4] * q3;
-                y[l + 96] = d * sc[is + 6] * q4;
-            }
-            y  += 128;
-            ql += 64;
-            qh += 32;
-            sc += 8;
-        }
-#else
-        for (int l = 0; l < 16; ++l) {
-            const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
-            const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
-            const int8_t q3 = (int8_t)((ql[l+ 0]  >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
-            const int8_t q4 = (int8_t)((ql[l+16]  >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
-            y[l+ 0] = d * sc[0] * q1;
-            y[l+16] = d * sc[1] * q2;
-            y[l+32] = d * sc[2] * q3;
-            y[l+48] = d * sc[3] * q4;
-        }
-        y  += 64;
-#endif
-    }
-}
-
-kernel void kernel_get_rows_q2_K(
-        device const  void * src0,
-        device const   int * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb1,
-        uint tpig[[thread_position_in_grid]]) {
-    const int i = tpig;
-    const int r = ((device int32_t *) src1)[i];
-
-    dequantize_row_q2_K(
-            (device const block_q2_K *) ((device char *) src0 + r*nb01),
-                       (device float *) ((device char *)  dst + i*nb1), ne00);
-}
-
-kernel void kernel_get_rows_q3_K(
-        device const  void * src0,
-        device const   int * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb1,
-        uint tpig[[thread_position_in_grid]]) {
-    const int i = tpig;
-    const int r = ((device int32_t *) src1)[i];
-
-    dequantize_row_q3_K(
-            (device const block_q3_K *) ((device char *) src0 + r*nb01),
-                       (device float *) ((device char *)  dst + i*nb1), ne00);
-}
-
-kernel void kernel_get_rows_q4_K(
-        device const  void * src0,
-        device const   int * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb1,
-        uint tpig[[thread_position_in_grid]]) {
-    const int i = tpig;
-    const int r = ((device int32_t *) src1)[i];
-
-    dequantize_row_q4_K(
-            (device const block_q4_K *) ((device char *) src0 + r*nb01),
-                       (device float *) ((device char *)  dst + i*nb1), ne00);
-}
-
-kernel void kernel_get_rows_q5_K(
-        device const  void * src0,
-        device const   int * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb1,
-        uint tpig[[thread_position_in_grid]]) {
-    const int i = tpig;
-    const int r = ((device int32_t *) src1)[i];
-
-    dequantize_row_q5_K(
-            (device const block_q5_K *) ((device char *) src0 + r*nb01),
-                       (device float *) ((device char *)  dst + i*nb1), ne00);
-}
-
-kernel void kernel_get_rows_q6_K(
-        device const  void * src0,
-        device const   int * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb1,
-        uint tpig[[thread_position_in_grid]]) {
-    const int i = tpig;
-    const int r = ((device int32_t *) src1)[i];
-
-    dequantize_row_q6_K(
-            (device const block_q6_K *) ((device char *) src0 + r*nb01),
-                       (device float *) ((device char *)  dst + i*nb1), ne00);
-}
-
 //====================================== dot products =========================
 
 kernel void kernel_mul_mat_q2_K_f32(
@@ -1224,21 +797,27 @@ kernel void kernel_mul_mat_q2_K_f32(
         device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
-        constant   int64_t & ne10,
-        constant   int64_t & ne0,
         constant   int64_t & ne01[[buffer(4)]],
-        uint2 tgpig[[threadgroup_position_in_grid]],
+        constant   int64_t & ne02[[buffer(5)]],
+        constant   int64_t & ne10[[buffer(9)]],
+        constant   int64_t & ne12[[buffer(11)]],
+        constant   int64_t & ne0[[buffer(15)]],
+        constant   int64_t & ne1[[buffer(16)]],
+        constant   uint    & gqa[[buffer(17)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
     const int nb = ne00/QK_K;
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
+    const int r2 = tgpig.z;
 
     const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
     const int ib_row = first_row * nb;
-    device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row;
-    device const float      * y = (device const float      *) src1 + r1*ne10;
+    const uint offset0 = r2/gqa*(nb*ne0);
+    device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
+    device const float      * y = (device const float      *) src1 + r1*ne10 + r2*ne00*ne1;
     float yl[32];
     float sumf[N_DST]={0.f}, all_sum;
 
@@ -1351,7 +930,7 @@ kernel void kernel_mul_mat_q2_K_f32(
     for (int row = 0; row < N_DST; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst[r1*ne0 + first_row + row] = all_sum;
+            dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
         }
     }
 }
@@ -1362,10 +941,14 @@ kernel void kernel_mul_mat_q3_K_f32(
         device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
-        constant   int64_t & ne10,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        uint2 tgpig[[threadgroup_position_in_grid]],
+        constant   int64_t & ne01[[buffer(4)]],
+        constant   int64_t & ne02[[buffer(5)]],
+        constant   int64_t & ne10[[buffer(9)]],
+        constant   int64_t & ne12[[buffer(11)]],
+        constant   int64_t & ne0[[buffer(15)]],
+        constant   int64_t & ne1[[buffer(16)]],
+        constant   uint    & gqa[[buffer(17)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
@@ -1373,11 +956,12 @@ kernel void kernel_mul_mat_q3_K_f32(
 
     const int64_t r0 = tgpig.x;
     const int64_t r1 = tgpig.y;
+    const int64_t r2 = tgpig.z;
 
     const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
-
-    device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb;
-    device const float     * yy = (device const float      *) src1 + r1*ne10;
+    const uint offset0 = r2/gqa*(nb*ne0);
+    device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
+    device const float     * yy = (device const float      *) src1 + r1*ne10 + r2*ne00*ne1;
 
     float yl[16];
 
@@ -1465,7 +1049,7 @@ kernel void kernel_mul_mat_q3_K_f32(
         const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
         const float tot = simd_sum(sumf);
         if (tiisg == 0) {
-            dst[r1*ne0 + first_row + row] = tot;
+            dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
         }
     }
 }
@@ -1475,10 +1059,14 @@ kernel void kernel_mul_mat_q3_K_f32(
         device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
-        constant   int64_t & ne10,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        uint2 tgpig[[threadgroup_position_in_grid]],
+        constant   int64_t & ne01[[buffer(4)]],
+        constant   int64_t & ne02[[buffer(5)]],
+        constant   int64_t & ne10[[buffer(9)]],
+        constant   int64_t & ne12[[buffer(11)]],
+        constant   int64_t & ne0[[buffer(15)]],
+        constant   int64_t & ne1[[buffer(16)]],
+        constant   uint    & gqa[[buffer(17)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
@@ -1486,11 +1074,12 @@ kernel void kernel_mul_mat_q3_K_f32(
 
     const int64_t r0 = tgpig.x;
     const int64_t r1 = tgpig.y;
+    const int64_t r2 = tgpig.z;
 
     const int row = 2 * r0 + sgitg;
-
-    device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb;
-    device const float     * yy = (device const float      *) src1 + r1*ne10;
+    const uint offset0 = r2/gqa*(nb*ne0);
+    device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
+    device const float     * yy = (device const float      *) src1 + r1*ne10 + r2*ne00*ne1;
     const int ix = tiisg/4;
     const int il = 4 * (tiisg%4);// 0, 4, 8, 12
     const int im = il/8;         // 0, 0, 1, 1
@@ -1529,7 +1118,7 @@ kernel void kernel_mul_mat_q3_K_f32(
 
     const float tot = simd_sum(sumf);
     if (tiisg == 0) {
-        dst[r1*ne0 + row] = tot;
+        dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
     }
 
 }
@@ -1541,10 +1130,14 @@ kernel void kernel_mul_mat_q4_K_f32(
         device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
-        constant   int64_t & ne10,
-        constant   int64_t & ne0,
         constant   int64_t & ne01[[buffer(4)]],
-        uint2 tgpig[[threadgroup_position_in_grid]],
+        constant   int64_t & ne02[[buffer(5)]],
+        constant   int64_t & ne10[[buffer(9)]],
+        constant   int64_t & ne12[[buffer(11)]],
+        constant   int64_t & ne0[[buffer(15)]],
+        constant   int64_t & ne1[[buffer(16)]],
+        constant   uint    & gqa[[buffer(17)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
@@ -1560,10 +1153,12 @@ kernel void kernel_mul_mat_q4_K_f32(
     const int nb = ne00/QK_K;
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
+    const int r2 = tgpig.z;
     const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
     const int ib_row = first_row * nb;
-    device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row;
-    device const float      * y = (device const float      *) src1 + r1*ne10;
+    const uint offset0 = r2/gqa*(nb*ne0);
+    device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
+    device const float      * y = (device const float      *) src1 + r1*ne10 + r2*ne00*ne1;
     float yl[16];
     float yh[16];
     float sumf[N_DST]={0.f}, all_sum;
@@ -1630,7 +1225,7 @@ kernel void kernel_mul_mat_q4_K_f32(
     for (int row = 0; row < N_DST; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst[r1*ne0 + first_row + row] = all_sum;
+            dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
         }
     }
 }
@@ -1640,10 +1235,14 @@ kernel void kernel_mul_mat_q4_K_f32(
         device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
-        constant   int64_t & ne10,
-        constant   int64_t & ne0,
         constant   int64_t & ne01[[buffer(4)]],
-        uint2 tgpig[[threadgroup_position_in_grid]],
+        constant   int64_t & ne02[[buffer(5)]],
+        constant   int64_t & ne10[[buffer(9)]],
+        constant   int64_t & ne12[[buffer(11)]],
+        constant   int64_t & ne0[[buffer(15)]],
+        constant   int64_t & ne1[[buffer(16)]],
+        constant   uint    & gqa[[buffer(17)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
@@ -1653,10 +1252,12 @@ kernel void kernel_mul_mat_q4_K_f32(
     const int nb = ne00/QK_K;
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
+    const int r2 = tgpig.z;
     const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
     const int ib_row = first_row * nb;
-    device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row;
-    device const float      * y = (device const float      *) src1 + r1*ne10;
+    const uint offset0 = r2/gqa*(nb*ne0);
+    device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
+    device const float      * y = (device const float      *) src1 + r1*ne10 + r2*ne00*ne1;
     float yl[8];
     float yh[8];
     float sumf[N_DST]={0.f}, all_sum;
@@ -1712,7 +1313,7 @@ kernel void kernel_mul_mat_q4_K_f32(
     for (int row = 0; row < N_DST; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst[r1*ne0 + first_row + row] = all_sum;
+            dst[r1*ne0+ r2*ne0*ne1 + first_row + row] = all_sum;
         }
     }
 }
@@ -1723,9 +1324,14 @@ kernel void kernel_mul_mat_q5_K_f32(
         device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
-        constant   int64_t & ne10,
-        constant   int64_t & ne0,
-        uint2 tgpig[[threadgroup_position_in_grid]],
+        constant   int64_t & ne01[[buffer(4)]],
+        constant   int64_t & ne02[[buffer(5)]],
+        constant   int64_t & ne10[[buffer(9)]],
+        constant   int64_t & ne12[[buffer(11)]],
+        constant   int64_t & ne0[[buffer(15)]],
+        constant   int64_t & ne1[[buffer(16)]],
+        constant   uint    & gqa[[buffer(17)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
@@ -1733,11 +1339,12 @@ kernel void kernel_mul_mat_q5_K_f32(
 
     const int64_t r0 = tgpig.x;
     const int64_t r1 = tgpig.y;
+    const int r2 = tgpig.z;
 
     const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
-
-    device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb;
-    device const float     * yy = (device const float      *) src1 + r1*ne10;
+    const uint offset0 = r2/gqa*(nb*ne0);
+    device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
+    device const float     * yy = (device const float      *) src1 + r1*ne10 + r2*ne00*ne1;
 
     float sumf[2]={0.f};
 
@@ -1871,7 +1478,7 @@ kernel void kernel_mul_mat_q5_K_f32(
     for (int row = 0; row < 2; ++row) {
         const float tot = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst[r1*ne0 + first_row + row] = tot;
+            dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
         }
     }
 
@@ -1882,9 +1489,14 @@ kernel void kernel_mul_mat_q6_K_f32(
         device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
-        constant   int64_t & ne10,
-        constant   int64_t & ne0,
-        uint2 tgpig[[threadgroup_position_in_grid]],
+        constant   int64_t & ne01[[buffer(4)]],
+        constant   int64_t & ne02[[buffer(5)]],
+        constant   int64_t & ne10[[buffer(9)]],
+        constant   int64_t & ne12[[buffer(11)]],
+        constant   int64_t & ne0[[buffer(15)]],
+        constant   int64_t & ne1[[buffer(16)]],
+        constant   uint    & gqa[[buffer(17)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
@@ -1897,11 +1509,12 @@ kernel void kernel_mul_mat_q6_K_f32(
 
     const int64_t r0 = tgpig.x;
     const int64_t r1 = tgpig.y;
+    const int r2 = tgpig.z;
 
     const int row = 2 * r0 + sgitg;
-
-    device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb; //r0*nb;
-    device const float     * yy = (device const float      *) src1 + r1*ne10;
+    const uint offset0 = r2/gqa*(nb*ne0);
+    device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
+    device const float     * yy = (device const float      *) src1 + r1*ne10 + r2*ne00*ne1;
 
     float sumf = 0;
 
@@ -1967,6 +1580,366 @@ kernel void kernel_mul_mat_q6_K_f32(
 
     const float tot = simd_sum(sumf);
     if (tiisg == 0) {
-        dst[r1*ne0 + row] = tot;
+        dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
+    }
+}
+
+//============================= templates and their specializations =============================
+
+template <typename type4x4>
+void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
+    half4x4 temp = *(((device half4x4 *)src));
+    for (int i = 0; i < 16; i++){
+        reg[i/4][i%4] = temp[i/4][i%4];
+    }
+}
+
+template <typename type4x4>
+void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
+    device const uint16_t * qs = ((device const uint16_t *)xb + 1);
+    const half d = il ? (xb->d / 16.h) : xb->d;
+    const half m = il ? (-8.h * 16.h) : -8.h;
+    const ushort mask0 = il ? 0x00F0 : 0x000F;
+    const ushort mask1 = il ? 0xF000 : 0x0F00;
+
+    for (int i=0;i<8;i++) {
+        reg[i/2][2*(i%2)] = (((qs[i] & mask0)) + m) * d;
+        reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
+    }
+}
+
+template <typename type4x4>
+void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
+    device const uint16_t * qs = ((device const uint16_t *)xb + 2);
+    const half d = il ? (xb->d / 16.h) : xb->d;
+    const half m = xb->m;
+    const ushort mask0 = il ? 0x00F0 : 0x000F;
+    const ushort mask1 = il ? 0xF000 : 0x0F00;
+
+    for (int i=0;i<8;i++) {
+        reg[i/2][2*(i%2)] = (((qs[i] & mask0)) * d) + m;
+        reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m;
+    }
+}
+
+template <typename type4x4>
+void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
+    const half d = xb->d;
+    const half min = xb->dmin;
+    device const uint8_t * q = (device const uint8_t *)xb->qs;
+    half dl, ml;
+    uint8_t sc = xb->scales[il];
+
+#if QK_K == 256
+    q = q + 32*(il/8) + 16*(il&1);
+    il = (il/2)%4;
+#endif
+    half  coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
+    uchar mask = il>1 ? (il>2 ? 192    : 48)     : (il>0 ? 12    : 3);
+    dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
+    for (int i = 0; i < 16; ++i) {
+        reg[i/4][i%4] = dl * (q[i] & mask) - ml;
+    }
+}
+
+template <typename type4x4>
+void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
+    const float d_all = (float)(xb->d);
+    device const uint8_t * q = (device const uint8_t *)xb->qs;
+    device const uint8_t * h = (device const uint8_t *)xb->hmask;
+    device const int8_t * scales = (device const int8_t *)xb->scales;
+
+#if QK_K == 256
+    q = q + 32 * (il/8) + 16 * (il&1);
+    h = h + 16 * (il&1);
+    uint8_t m = 1 << (il/2);
+    uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
+                                 ((il/4)>0 ? 12  : 3);
+    uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
+    uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
+    int16_t  dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) : \
+                                 (scale_2&kmask2) | ((scale_1&kmask1) << 4);
+    float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
+
+    il = (il/2)%4;
+    float   coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
+    uint8_t mask = il>1 ? (il>2 ? 192    : 48)     : (il>0 ? 12    : 3);
+
+    for (int i = 0; i < 16; ++i) {
+        reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i] & m) ? 0 : 4.f/coef));
+    }
+#else
+    float    kcoef = il&1 ? 1.f/16.f : 1.f;
+    uint16_t kmask = il&1 ? 0xF0     : 0x0F;
+    float    dl = d_all * ((scales[il/2] & kmask) * kcoef - 8);
+    float    coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
+    uint8_t  mask = il>1 ? (il>2 ? 192    : 48)     : (il>0 ? 12    : 3);
+    uint8_t  m = 1<<(il*2);
+    for (int i = 0; i < 16; ++i) {
+        reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i%8] & (m * (1 + i/8))) ? 0 : 4.f/coef));
+    }
+#endif
+}
+
+template <typename type4x4>
+void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
+    device const uint8_t * q = xb->qs;
+
+#if QK_K == 256
+    const float d = (float)(xb->d);
+    const float min = (float)(xb->dmin);
+    short is = (il/4) * 2;
+    q = q + (il/4) * 32 + 16 * (il&1);
+    il = il%4;
+    const uchar4 sc = get_scale_min_k4(is, xb->scales);
+    const float dl = il<2 ? d * sc[0]   : d * sc[2]/16.h;
+    const float ml = il<2 ? min * sc[1] : min * sc[3];
+#else
+    q = q + 16 * (il&1);
+    device const uint8_t * s = xb->scales;
+    device const half2 * dh = (device const half2 *)xb->d;
+    const float2 d = (float2)dh[0];
+    const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
+    const float ml = il<2 ? d[1] * (s[0]>>4)  : d[1 ]* (s[1]>>4);
+#endif
+    const ushort mask = il<2 ? 0x0F : 0xF0;
+    for (int i = 0; i < 16; ++i) {
+        reg[i/4][i%4] = dl * (q[i] & mask) - ml;
+    }
+}
+
+template <typename type4x4>
+void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) {
+    device const uint8_t * q  = xb->qs;
+    device const uint8_t * qh = xb->qh;
+
+#if QK_K == 256
+    const float d = (float)(xb->d);
+    const float min = (float)(xb->dmin);
+    short is = (il/4) * 2;
+    q  = q + 32 * (il/4) + 16 * (il&1);
+    qh = qh + 16 * (il&1);
+    uint8_t ul = 1 << (il/2);
+    il = il%4;
+    const uchar4 sc = get_scale_min_k4(is, xb->scales);
+    const float dl = il<2 ? d * sc[0]   : d * sc[2]/16.h;
+    const float ml = il<2 ? min * sc[1] : min * sc[3];
+
+    const ushort mask   = il<2 ? 0x0F : 0xF0;
+    const float  qh_val = il<2 ? 16.f : 256.f;
+    for (int i = 0; i < 16; ++i) {
+        reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
+    }
+#else
+    q = q + 16 * (il&1);
+    device const int8_t * s = xb->scales;
+    const float dl = xb->d * s[il];
+    uint8_t m = 1<<(il*2);
+    const float  coef = il<2 ? 1.f  : 1.f/16.f;
+    const ushort mask = il<2 ? 0x0F : 0xF0;
+    for (int i = 0; i < 16; ++i) {
+        reg[i/4][i%4] = coef * dl * ((q[i] & mask) - (qh[i%8] & (m*(1+i/8)) ? 0.f : 16.f/coef));
+    }
+#endif
+}
+
+template <typename type4x4>
+void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
+    const float d_all = (float)(xb->d);
+    device const uint8_t * ql = (device const uint8_t *)xb->ql;
+    device const uint8_t * qh = (device const uint8_t *)xb->qh;
+    device const int8_t * scales = (device const int8_t *)xb->scales;
+
+#if QK_K == 256
+    ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
+    qh = qh + 32*(il/8) + 16*(il&1);
+    float sc = scales[(il%2) + 2 * ((il/2))];
+    il = (il/2)%4;
+#else
+    ql = ql + 16 * (il&1);
+    float sc = scales[il];
+#endif
+    for (int i = 0; i < 16; ++i) {
+        uint16_t  kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
+        uint16_t  kmask2 = il>1 ? 0xF0              : 0x0F;
+        const float coef = il>1 ? 1.f/16.f          : 1.f;
+        float q = il&1 ? ((ql[i]&kmask2)|((qh[i]&kmask1)<<2)) - 32.f/coef : \
+                         ((ql[i]&kmask2)|((qh[i]&kmask1)<<4)) - 32.f/coef;
+        reg[i/4][i%4] = d_all * sc * q * coef;
+    }
+}
+
+template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
+kernel void kernel_get_rows(
+        device const  void * src0,
+        device const   int * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb1,
+        uint                 tgpig[[threadgroup_position_in_grid]],
+        uint                 tiitg[[thread_index_in_threadgroup]],
+        uint                 tptg[[threads_per_threadgroup]]) {
+    const int i = tgpig;
+    const int r = ((device int32_t *) src1)[i];
+
+    for (int ind = tiitg; ind < ne00/16; ind += tptg) {
+        float4x4 temp;
+        dequantize_func(
+            ((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
+        *(((device float4x4 *) ((device char *) dst + i*nb1)) + ind) = temp;
+    }
+}
+
+#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
+#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix A
+#define BLOCK_SIZE_K 32
+#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
+#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
+#define THREAD_PER_BLOCK 128
+#define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers
+#define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers
+#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8
+#define SG_MAT_ROW 8
+
+// each block_q contains 16*nl weights
+template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
+kernel void kernel_mul_mm(device const  uchar * src0,
+                           device const  float * src1,
+                           device        float * dst,
+                           constant    int64_t & ne00,
+                           constant    int64_t & ne02,
+                           constant    int64_t & nb01,
+                           constant    int64_t & nb02,
+                           constant    int64_t & ne12,
+                           constant    int64_t & ne0,
+                           constant    int64_t & ne1,
+                           constant    uint & gqa,
+                           threadgroup   uchar * shared_memory [[threadgroup(0)]],
+                           uint3                 tgpig[[threadgroup_position_in_grid]],
+                           uint                  tiitg[[thread_index_in_threadgroup]],
+                           uint                  sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    threadgroup half * sa = ((threadgroup half *)shared_memory);
+    threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
+
+    const uint r0 = tgpig.y;
+    const uint r1 = tgpig.x;
+    const uint im = tgpig.z;
+    // if this block is of 64x32 shape or smaller
+    short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
+    short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
+    // a thread shouldn't load data outside of the matrix
+    short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
+    short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
+
+    simdgroup_half8x8 ma[4];
+    simdgroup_float8x8 mb[2];
+    simdgroup_float8x8 c_res[8];
+    for (int i = 0; i < 8; i++){
+        c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
+    }
+
+    short il = (tiitg % THREAD_PER_ROW);
+    uint offset0 = im/gqa*nb02; ushort offset1 = il/nl;
+    device const block_q  * x = (device const block_q  *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
+    device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \
+                             + BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne00 * ne1;
+
+    for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
+        //load data and store to threadgroup memory
+        half4x4 temp_a;
+        dequantize_func(x, il, temp_a);
+        #pragma unroll(16)
+        for (int i = 0; i < 16; i++) {
+            *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
+            + 16 * (tiitg % THREAD_PER_ROW) + 8 * (i / 8)) \
+            + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
+        }
+        *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) \
+                = *((device float2x4 *)y);
+        il = (il + 2 < nl) ? il + 2 : il % 2;
+        x  = (il < 2) ? x + (2+nl-1)/nl : x;
+        y += BLOCK_SIZE_K;
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+        //load matrices from threadgroup memory and conduct outer products
+        threadgroup half  * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
+        threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
+        #pragma unroll(4)
+        for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
+            #pragma unroll(4)
+            for (int i = 0; i < 4; i++) {
+                simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
+            }
+            simdgroup_barrier(mem_flags::mem_none);
+            #pragma unroll(2)
+            for (int i = 0; i < 2; i++) {
+                simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
+            }
+
+            lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
+            lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
+            #pragma unroll(8)
+            for (int i = 0; i < 8; i++){
+                simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
+            }
+        }
+    }
+
+    if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
+        device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \
+                          + (BLOCK_SIZE_N * r1 + 16 * (sgitg>>1)) * ne0 + im*ne1*ne0;
+        for (int i = 0; i < 8; i++) {
+            simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
+        }
+    } else {
+        // block is smaller than 64x32, we should avoid writing data outside of the matrix
+        threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
+                                      + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
+        for (int i = 0; i < 8; i++) {
+            simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+        device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
+        if (sgitg==0) {
+            for (int i = 0; i < n_rows; i++) {
+                for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) {
+                    *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
+                }
+            }
+        }
     }
 }
+
+#if QK_K == 256
+#define QK_NL 16
+#else
+#define QK_NL 4
+#endif
+
+typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
+                          constant uint64_t &, constant uint64_t &, uint, uint, uint);
+
+template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
+template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
+template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
+template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
+template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
+template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
+template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
+template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
+
+typedef void (mat_mm_t)(device const uchar *, device const float *, device float *, constant int64_t &,\
+                             constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
+                             constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint);
+
+template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
+template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
+template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
+template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
+template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
+template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
+template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
+template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
index c8ab313d9c11639c5739a49b6d35e47341531ba2..a161f1566db9ce4f87172c0354a5d354b532e875 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -1845,7 +1845,7 @@ static bool llama_eval_internal(
 #endif
 
 #ifdef GGML_USE_METAL
-    if (lctx.ctx_metal && N == 1) {
+    if (lctx.ctx_metal) {
         // TODO: disabled until #2413 is resolved
         //if (!ggml_metal_if_optimized(lctx.ctx_metal)) {
         //    ggml_metal_graph_find_concurrency(lctx.ctx_metal, gf);
@@ -1857,22 +1857,6 @@ static bool llama_eval_internal(
             ggml_metal_get_tensor(lctx.ctx_metal, embeddings);
         }
     } else {
-        // IMPORTANT:
-        // Since we don't have efficient Matrix x Matrix Metal multiplication yet, we fallback to vanilla
-        // ggml_graph_compute(). It uses Apple's Accelerate CBLAS API which takes advantage of the ANE or the AMX
-        // coprocessor.
-        //
-        // When we implement Matrix x Matrix Metal multiplication, we can avoid this branch.
-        // But for now, we have focused only on Matrix x Vector Metal multiplication.
-        //
-        // TODO: avoid these syncs via shared memory (ref #1696)
-        //
-        if (lctx.ctx_metal) {
-            // We need to sync the GPU KV cache with the CPU KV cache
-            ggml_metal_get_tensor(lctx.ctx_metal, kv_self.k);
-            ggml_metal_get_tensor(lctx.ctx_metal, kv_self.v);
-        }
-
         ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads);
     }
 #else