]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
metal : sync ggml-metal (ref #1047)
authorGeorgi Gerganov <redacted>
Sun, 25 Jun 2023 12:40:30 +0000 (15:40 +0300)
committerGeorgi Gerganov <redacted>
Sun, 25 Jun 2023 12:40:39 +0000 (15:40 +0300)
extra/sync-ggml.sh
ggml-metal.h [new file with mode: 0644]
ggml-metal.m [new file with mode: 0644]
ggml-metal.metal [new file with mode: 0644]

index e242ea7ece584bb7a4a0b2a14645278fbfb1c922..3bd99e3a478f2be35ed666c645936624c23661f6 100755 (executable)
@@ -5,6 +5,9 @@ cp -rpv ../ggml/src/ggml-cuda.h          ./ggml-cuda.h
 cp -rpv ../ggml/src/ggml-cuda.cu         ./ggml-cuda.cu
 cp -rpv ../ggml/src/ggml-opencl.h        ./ggml-opencl.h
 cp -rpv ../ggml/src/ggml-opencl.cpp      ./ggml-opencl.cpp
+cp -rpv ../ggml/src/ggml-metal.h         ./ggml-metal.h
+cp -rpv ../ggml/src/ggml-metal.m         ./ggml-metal.m
+cp -rpv ../ggml/src/ggml-metal.metal     ./ggml-metal.metal
 cp -rpv ../ggml/include/ggml/ggml.h      ./ggml.h
 cp -rpv ../ggml/examples/common.h        ./examples/common.h
 cp -rpv ../ggml/examples/common.cpp      ./examples/common.cpp
diff --git a/ggml-metal.h b/ggml-metal.h
new file mode 100644 (file)
index 0000000..b9e50ac
--- /dev/null
@@ -0,0 +1,67 @@
+// An interface allowing to compute ggml_cgraph with Metal
+//
+// This is a fully functional interface that extends ggml with GPU support for Apple devices.
+// A similar interface can be created for other GPU backends (e.g. Vulkan, CUDA, OpenCL, etc.)
+//
+// How it works?
+//
+// As long as your program can create and evaluate a ggml_cgraph on the CPU, you can use this
+// interface to evaluate the same graph on the GPU. Instead of using ggml_graph_compute(), you
+// use ggml_metal_graph_compute() (or ggml_vulkan_graph_compute(), etc.)
+//
+// You only need to make sure that all memory buffers that you used during the graph creation
+// are mapped to the device memory with the ggml_metal_add_buffer() function. This mapping is
+// used during the graph evaluation to determine the arguments of the compute kernels.
+//
+// Synchronization between device and host memory (for example for input and output tensors)
+// is done with the ggml_metal_set_tensor() and ggml_metal_get_tensor() functions.
+//
+
+#pragma once
+
+#include <stddef.h>
+#include <stdbool.h>
+
+// max memory buffers that can be mapped to the device
+#define GGML_METAL_MAX_BUFFERS 16
+
+struct ggml_tensor;
+struct ggml_cgraph;
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+struct ggml_metal_context;
+
+struct ggml_metal_context * ggml_metal_init(void);
+void ggml_metal_free(struct ggml_metal_context * ctx);
+
+// creates a mapping between a host memory buffer and a device memory buffer
+// - make sure to map all buffers used in the graph before calling ggml_metal_graph_compute
+// - the mapping is used during computation to determine the arguments of the compute kernels
+// - you don't need to keep the host memory buffer allocated as it is never accessed by Metal
+// - max_size specifies the maximum size of a tensor and is used to create shared views such
+//   that it is guaranteed that the tensor will fit in at least one of the views
+//
+bool ggml_metal_add_buffer(
+        struct ggml_metal_context * ctx,
+                       const char * name,
+                             void * data,
+                           size_t   size,
+                           size_t   max_size);
+
+// set data from host memory into the device
+void ggml_metal_set_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);
+
+// get data from the device into host memory
+void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);
+
+// same as ggml_graph_compute but uses Metal
+// creates gf->n_threads command buffers in parallel
+void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
+
+#ifdef __cplusplus
+}
+#endif
+
diff --git a/ggml-metal.m b/ggml-metal.m
new file mode 100644 (file)
index 0000000..a7e104d
--- /dev/null
@@ -0,0 +1,972 @@
+#import "ggml-metal.h"
+
+#import "ggml.h"
+
+#import <Foundation/Foundation.h>
+
+#import <Metal/Metal.h>
+#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
+
+#ifdef GGML_METAL_NDEBUG
+#define metal_printf(...)
+#else
+#define metal_printf(...) fprintf(stderr, __VA_ARGS__)
+#endif
+
+#define UNUSED(x) (void)(x)
+
+struct ggml_metal_buffer {
+    const char * name;
+
+    void   * data;
+    size_t   size;
+
+    id<MTLBuffer> metal;
+};
+
+struct ggml_metal_context {
+    float * logits;
+
+    id<MTLDevice>       device;
+    id<MTLCommandQueue> queue;
+    id<MTLLibrary>      library;
+
+    int n_buffers;
+    struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
+
+    // custom kernels
+#define GGML_METAL_DECL_KERNEL(name) \
+    id<MTLFunction>             function_##name; \
+    id<MTLComputePipelineState> pipeline_##name
+
+    GGML_METAL_DECL_KERNEL(add);
+    GGML_METAL_DECL_KERNEL(mul);
+    GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
+    GGML_METAL_DECL_KERNEL(scale);
+    GGML_METAL_DECL_KERNEL(silu);
+    GGML_METAL_DECL_KERNEL(relu);
+    GGML_METAL_DECL_KERNEL(gelu);
+    GGML_METAL_DECL_KERNEL(soft_max);
+    GGML_METAL_DECL_KERNEL(diag_mask_inf);
+    GGML_METAL_DECL_KERNEL(get_rows_f16);
+    GGML_METAL_DECL_KERNEL(get_rows_q4_0);
+    GGML_METAL_DECL_KERNEL(get_rows_q4_1);
+    GGML_METAL_DECL_KERNEL(get_rows_q2_k);
+    GGML_METAL_DECL_KERNEL(get_rows_q3_k);
+    GGML_METAL_DECL_KERNEL(get_rows_q4_k);
+    GGML_METAL_DECL_KERNEL(get_rows_q5_k);
+    GGML_METAL_DECL_KERNEL(get_rows_q6_k);
+    GGML_METAL_DECL_KERNEL(rms_norm);
+    GGML_METAL_DECL_KERNEL(norm);
+    GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
+    GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
+    GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
+    GGML_METAL_DECL_KERNEL(mul_mat_q2_k_f32);
+    GGML_METAL_DECL_KERNEL(mul_mat_q3_k_f32);
+    GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32);
+    GGML_METAL_DECL_KERNEL(mul_mat_q5_k_f32);
+    GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32);
+    GGML_METAL_DECL_KERNEL(rope);
+    GGML_METAL_DECL_KERNEL(alibi_f32);
+    GGML_METAL_DECL_KERNEL(cpy_f32_f16);
+    GGML_METAL_DECL_KERNEL(cpy_f32_f32);
+    GGML_METAL_DECL_KERNEL(cpy_f16_f16);
+
+#undef GGML_METAL_DECL_KERNEL
+};
+
+// MSL code
+// TODO: move the contents here when ready
+//       for now it is easier to work in a separate file
+static NSString * const msl_library_source = @"see metal.metal";
+
+// Here to assist with NSBundle Path Hack
+@interface GGMLMetalClass : NSObject
+@end
+@implementation GGMLMetalClass
+@end
+
+struct ggml_metal_context * ggml_metal_init(void) {
+    fprintf(stderr, "%s: allocating\n", __func__);
+
+    struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
+
+    ctx->device = MTLCreateSystemDefaultDevice();
+    ctx->queue  = [ctx->device newCommandQueue];
+    ctx->n_buffers = 0;
+
+    // determine if we can use MPS
+    if (MPSSupportsMTLDevice(ctx->device)) {
+        fprintf(stderr, "%s: using MPS\n", __func__);
+    } else {
+        fprintf(stderr, "%s: not using MPS\n", __func__);
+        GGML_ASSERT(false && "MPS not supported");
+    }
+
+#if 0
+    // compile from source string and show compile log
+    {
+        NSError * error = nil;
+
+        ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error];
+        if (error) {
+            fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
+            exit(1);
+        }
+    }
+#else
+    UNUSED(msl_library_source);
+
+    // read the source from "ggml-metal.metal" into a string and use newLibraryWithSource
+    {
+        NSError * error = nil;
+
+        //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
+        NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
+        NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
+        fprintf(stderr, "%s: loading '%s'\n", __func__, [path UTF8String]);
+
+        NSString * src  = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
+        if (error) {
+            fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
+            exit(1);
+        }
+
+        ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
+        if (error) {
+            fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
+            exit(1);
+        }
+    }
+#endif
+
+    // load kernels
+    {
+#define GGML_METAL_ADD_KERNEL(name) \
+        ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
+        ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:nil]; \
+        fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name);
+
+        GGML_METAL_ADD_KERNEL(add);
+        GGML_METAL_ADD_KERNEL(mul);
+        GGML_METAL_ADD_KERNEL(mul_row);
+        GGML_METAL_ADD_KERNEL(scale);
+        GGML_METAL_ADD_KERNEL(silu);
+        GGML_METAL_ADD_KERNEL(relu);
+        GGML_METAL_ADD_KERNEL(gelu);
+        GGML_METAL_ADD_KERNEL(soft_max);
+        GGML_METAL_ADD_KERNEL(diag_mask_inf);
+        GGML_METAL_ADD_KERNEL(get_rows_f16);
+        GGML_METAL_ADD_KERNEL(get_rows_q4_0);
+        GGML_METAL_ADD_KERNEL(get_rows_q4_1);
+        GGML_METAL_ADD_KERNEL(get_rows_q2_k);
+        GGML_METAL_ADD_KERNEL(get_rows_q3_k);
+        GGML_METAL_ADD_KERNEL(get_rows_q4_k);
+        GGML_METAL_ADD_KERNEL(get_rows_q5_k);
+        GGML_METAL_ADD_KERNEL(get_rows_q6_k);
+        GGML_METAL_ADD_KERNEL(rms_norm);
+        GGML_METAL_ADD_KERNEL(norm);
+        GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
+        GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
+        GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
+        GGML_METAL_ADD_KERNEL(mul_mat_q2_k_f32);
+        GGML_METAL_ADD_KERNEL(mul_mat_q3_k_f32);
+        GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32);
+        GGML_METAL_ADD_KERNEL(mul_mat_q5_k_f32);
+        GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32);
+        GGML_METAL_ADD_KERNEL(rope);
+        GGML_METAL_ADD_KERNEL(alibi_f32);
+        GGML_METAL_ADD_KERNEL(cpy_f32_f16);
+        GGML_METAL_ADD_KERNEL(cpy_f32_f32);
+        GGML_METAL_ADD_KERNEL(cpy_f16_f16);
+
+#undef GGML_METAL_ADD_KERNEL
+    }
+
+    fprintf(stderr, "%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
+    fprintf(stderr, "%s: hasUnifiedMemory             = %s\n",       __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
+    if (ctx->device.maxTransferRate != 0) {
+        fprintf(stderr, "%s: maxTransferRate              = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
+    } else {
+        fprintf(stderr, "%s: maxTransferRate              = built-in GPU\n", __func__);
+    }
+
+    return ctx;
+}
+
+void ggml_metal_free(struct ggml_metal_context * ctx) {
+    fprintf(stderr, "%s: deallocating\n", __func__);
+
+    free(ctx);
+}
+
+// finds the Metal buffer that contains the tensor data on the GPU device
+// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
+// Metal buffer based on the host memory pointer
+//
+static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) {
+    //fprintf(stderr, "%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
+
+    const int64_t tsize = ggml_nbytes(t);
+
+    // find the view that contains the tensor fully
+    for (int i = 0; i < ctx->n_buffers; ++i) {
+        const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
+
+        if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
+            *offs = (size_t) ioffs;
+
+            //fprintf(stderr, "%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
+
+            return ctx->buffers[i].metal;
+        }
+    }
+
+    fprintf(stderr, "%s: error: buffer is nil\n", __func__);
+
+    return nil;
+}
+
+bool ggml_metal_add_buffer(
+        struct ggml_metal_context * ctx,
+                     const char * name,
+                           void * data,
+                         size_t   size,
+                         size_t   max_size) {
+    if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) {
+        fprintf(stderr, "%s: too many buffers\n", __func__);
+        return false;
+    }
+
+    if (data) {
+        // verify that the buffer does not overlap with any of the existing buffers
+        for (int i = 0; i < ctx->n_buffers; ++i) {
+            const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
+
+            if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
+                fprintf(stderr, "%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
+                return false;
+            }
+        }
+
+        const size_t size_page = getpagesize();
+
+        size_t size_aligned = size;
+        if ((size_aligned % size_page) != 0) {
+            size_aligned += (size_page - (size_aligned % size_page));
+        }
+
+        // the buffer fits into the max buffer size allowed by the device
+        if (size_aligned <= ctx->device.maxBufferLength) {
+            ctx->buffers[ctx->n_buffers].name = name;
+            ctx->buffers[ctx->n_buffers].data = data;
+            ctx->buffers[ctx->n_buffers].size = size;
+
+            ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
+
+            if (ctx->buffers[ctx->n_buffers].metal == nil) {
+                fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
+                return false;
+            }
+
+            fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
+
+            ++ctx->n_buffers;
+        } else {
+            // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into
+            // one of the views
+            const size_t size_ovlp = ((max_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case
+            const size_t size_step = ctx->device.maxBufferLength - size_ovlp;
+            const size_t size_view = ctx->device.maxBufferLength;
+
+            for (size_t i = 0; i < size; i += size_step) {
+                const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
+
+                ctx->buffers[ctx->n_buffers].name = name;
+                ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i);
+                ctx->buffers[ctx->n_buffers].size = size_step_aligned;
+
+                ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
+
+                if (ctx->buffers[ctx->n_buffers].metal == nil) {
+                    fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
+                    return false;
+                }
+
+                fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
+                if (i + size_step < size) {
+                    fprintf(stderr, "\n");
+                }
+
+                ++ctx->n_buffers;
+            }
+        }
+
+        fprintf(stderr, ", (%8.2f / %8.2f)",
+                ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
+                ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
+
+        if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
+            fprintf(stderr, ", warning: current allocated size is greater than the recommended max working set size\n");
+        } else {
+            fprintf(stderr, "\n");
+        }
+    }
+
+    return true;
+}
+
+void ggml_metal_set_tensor(
+        struct ggml_metal_context * ctx,
+        struct ggml_tensor * t) {
+    metal_printf("%s: set input for tensor '%s'\n", __func__, t->name);
+
+    size_t offs;
+    id<MTLBuffer> id_dst = ggml_metal_get_buffer(ctx, t, &offs);
+
+    memcpy((void *) ((uint8_t *) id_dst.contents + offs), t->data, ggml_nbytes(t));
+}
+
+void ggml_metal_get_tensor(
+        struct ggml_metal_context * ctx,
+        struct ggml_tensor * t) {
+    metal_printf("%s: extract results for tensor '%s'\n", __func__, t->name);
+
+    size_t offs;
+    id<MTLBuffer> id_src = ggml_metal_get_buffer(ctx, t, &offs);
+
+    memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), ggml_nbytes(t));
+}
+
+void ggml_metal_graph_compute(
+        struct ggml_metal_context * ctx,
+               struct ggml_cgraph * gf) {
+    metal_printf("%s: evaluating graph\n", __func__);
+
+    // create multiple command buffers and enqueue them
+    // then, we encode the graph into the command buffers in parallel
+
+    const int n_cb = gf->n_threads;
+
+    NSMutableArray * command_buffers = [NSMutableArray arrayWithCapacity:n_cb];
+
+    for (int i = 0; i < n_cb; ++i) {
+        command_buffers[i] = [ctx->queue commandBuffer];
+
+        // enqueue the command buffers in order to specify their execution order
+        [command_buffers[i] enqueue];
+    }
+
+    // TODO: is this the best way to start threads?
+    dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
+
+    for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
+        const int n_nodes_per_cb = (gf->n_nodes + n_cb - 1) / n_cb;
+
+        dispatch_async(queue, ^{
+            size_t offs_src0 = 0;
+            size_t offs_src1 = 0;
+            size_t offs_dst  = 0;
+
+            id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
+
+            id<MTLComputeCommandEncoder> encoder = nil;
+
+            const int node_start =                                      (cb_idx + 0) * n_nodes_per_cb;
+            const int node_end   = (cb_idx == n_cb - 1) ? gf->n_nodes : (cb_idx + 1) * n_nodes_per_cb;
+
+            for (int i = node_start; i < node_end; ++i) {
+                metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
+
+                struct ggml_tensor * src0 = gf->nodes[i]->src0;
+                struct ggml_tensor * src1 = gf->nodes[i]->src1;
+                struct ggml_tensor * dst  = gf->nodes[i];
+
+                const int64_t  ne00 = src0 ? src0->ne[0] : 0;
+                const int64_t  ne01 = src0 ? src0->ne[1] : 0;
+                const int64_t  ne02 = src0 ? src0->ne[2] : 0;
+                const int64_t  ne03 = src0 ? src0->ne[3] : 0;
+
+                const uint64_t nb00 = src0 ? src0->nb[0] : 0;
+                const uint64_t nb01 = src0 ? src0->nb[1] : 0;
+                const uint64_t nb02 = src0 ? src0->nb[2] : 0;
+                const uint64_t nb03 = src0 ? src0->nb[3] : 0;
+
+                const int64_t  ne10 = src1 ? src1->ne[0] : 0;
+                const int64_t  ne11 = src1 ? src1->ne[1] : 0;
+                const int64_t  ne12 = src1 ? src1->ne[2] : 0;
+                const int64_t  ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
+
+                const uint64_t nb10 = src1 ? src1->nb[0] : 0;
+                const uint64_t nb11 = src1 ? src1->nb[1] : 0;
+                const uint64_t nb12 = src1 ? src1->nb[2] : 0;
+                const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
+
+                const int64_t  ne0  = dst ? dst->ne[0] : 0;
+                const int64_t  ne1  = dst ? dst->ne[1] : 0;
+                const int64_t  ne2  = dst ? dst->ne[2] : 0;
+                const int64_t  ne3  = dst ? dst->ne[3] : 0;
+
+                const uint64_t nb0  = dst ? dst->nb[0] : 0;
+                const uint64_t nb1  = dst ? dst->nb[1] : 0;
+                const uint64_t nb2  = dst ? dst->nb[2] : 0;
+                const uint64_t nb3  = dst ? dst->nb[3] : 0;
+
+                const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
+                const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
+                const enum ggml_type dstt  = dst  ? dst->type  : GGML_TYPE_COUNT;
+
+                id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(ctx, src0, &offs_src0) : nil;
+                id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
+                id<MTLBuffer> id_dst  = dst  ? ggml_metal_get_buffer(ctx, dst,  &offs_dst)  : nil;
+
+                //metal_printf("%s: op - %s\n", __func__, ggml_op_name(dst->op));
+                //if (src0) {
+                //    metal_printf("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
+                //            ggml_is_contiguous(src0), src0->name);
+                //}
+                //if (src1) {
+                //    metal_printf("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
+                //            ggml_is_contiguous(src1), src1->name);
+                //}
+                //if (dst) {
+                //    metal_printf("%s: dst  - %4s [%5lld, %5lld, %5lld], 1, %s\n",  __func__, ggml_type_name(dstt),  ne0,  ne1,  ne2,
+                //            dst->name);
+                //}
+
+                switch (dst->op) {
+                    case GGML_OP_RESHAPE:
+                    case GGML_OP_VIEW:
+                    case GGML_OP_TRANSPOSE:
+                    case GGML_OP_PERMUTE:
+                        {
+                            // noop
+                        } break;
+                    case GGML_OP_ADD:
+                        {
+                            if (encoder == nil) {
+                                encoder = [command_buffer computeCommandEncoder];
+                            }
+
+                            [encoder setComputePipelineState:ctx->pipeline_add];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+
+                            const int64_t n = ggml_nelements(dst);
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                        } break;
+                    case GGML_OP_MUL:
+                        {
+                            if (encoder == nil) {
+                                encoder = [command_buffer computeCommandEncoder];
+                            }
+
+                            if (ggml_nelements(src1) == ne10) {
+                                // src1 is a row
+                                [encoder setComputePipelineState:ctx->pipeline_mul_row];
+                            } else {
+                                [encoder setComputePipelineState:ctx->pipeline_mul];
+                            }
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+                            [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+
+                            const int64_t n = ggml_nelements(dst);
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                        } break;
+                    case GGML_OP_SCALE:
+                        {
+                            if (encoder == nil) {
+                                encoder = [command_buffer computeCommandEncoder];
+                            }
+
+                            const float scale = *(const float *) src1->data;
+
+                            [encoder setComputePipelineState:ctx->pipeline_scale];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                            [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
+
+                            const int64_t n = ggml_nelements(dst);
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                        } break;
+                    case GGML_OP_SILU:
+                        {
+                            if (encoder == nil) {
+                                encoder = [command_buffer computeCommandEncoder];
+                            }
+
+                            [encoder setComputePipelineState:ctx->pipeline_silu];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                            const int64_t n = ggml_nelements(dst);
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                        } break;
+                    case GGML_OP_RELU:
+                        {
+                            if (encoder == nil) {
+                                encoder = [command_buffer computeCommandEncoder];
+                            }
+
+                            [encoder setComputePipelineState:ctx->pipeline_relu];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                            const int64_t n = ggml_nelements(dst);
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                        } break;
+                    case GGML_OP_GELU:
+                    {
+                            if (encoder == nil) {
+                                encoder = [command_buffer computeCommandEncoder];
+                            }
+
+                            [encoder setComputePipelineState:ctx->pipeline_gelu];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                            const int64_t n = ggml_nelements(dst);
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                    } break;
+                    case GGML_OP_SOFT_MAX:
+                        {
+                            if (encoder == nil) {
+                                encoder = [command_buffer computeCommandEncoder];
+                            }
+
+                            const int nth = 32;
+
+                            [encoder setComputePipelineState:ctx->pipeline_soft_max];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                            [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
+                            [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
+                            [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+                            [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                        } break;
+                    case GGML_OP_DIAG_MASK_INF:
+                        {
+                            if (encoder == nil) {
+                                encoder = [command_buffer computeCommandEncoder];
+                            }
+
+                            const int n_past = ((int32_t *)(src1->data))[0];
+
+                            [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                            [encoder setBytes:&ne00   length:sizeof(ne00) atIndex:2];
+                            [encoder setBytes:&ne01   length:sizeof(ne01) atIndex:3];
+                            [encoder setBytes:&n_past length:sizeof(int)  atIndex:4];
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                        } break;
+                    case GGML_OP_MUL_MAT:
+                        {
+                            // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
+
+                            GGML_ASSERT(ne00 == ne10);
+                            GGML_ASSERT(ne02 == ne12);
+
+                            if (ggml_is_contiguous(src0) &&
+                                ggml_is_contiguous(src1) &&
+                                (src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) {
+
+                                if (encoder != nil) {
+                                    [encoder endEncoding];
+                                    encoder = nil;
+                                }
+
+                                MPSDataType src0dt = src0t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
+                                MPSDataType src1dt = src1t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
+
+                                // for F32 x F32 we use MPS
+                                MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor
+                                    matrixDescriptorWithRows:ne01 columns:ne00 rowBytes:src0->nb[1] dataType:src0dt];
+
+                                MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor
+                                    matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:src1->nb[1] dataType:src1dt];
+
+                                MPSMatrixDescriptor * desc  = [MPSMatrixDescriptor
+                                    matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:dst->nb[1] dataType:MPSDataTypeFloat32];
+
+                                MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc]
+                                    initWithDevice:ctx->device transposeLeft:false transposeRight:true
+                                        resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
+
+                                // we need to do ne02 multiplications
+                                // TODO: is there a way to do this in parallel - currently very slow ..
+                                // TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS
+                                for (int64_t i02 = 0; i02 < ne02; ++i02) {
+                                    size_t offs_src0_cur = offs_src0 + i02*nb02;
+                                    size_t offs_src1_cur = offs_src1 + i02*nb12;
+                                    size_t offs_dst_cur  = offs_dst  + i02*nb2;
+
+                                    MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0_cur descriptor:desc0];
+                                    MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1_cur descriptor:desc1];
+                                    MPSMatrix * mat_dst  = [[MPSMatrix alloc] initWithBuffer:id_dst  offset:offs_dst_cur  descriptor:desc ];
+
+                                    [mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
+                                }
+                            } else {
+                                if (encoder == nil) {
+                                    encoder = [command_buffer computeCommandEncoder];
+                                }
+
+                                int nth0 = 32;
+                                int nth1 = 1;
+
+                                // use custom matrix x vector kernel
+                                switch (src0t) {
+                                    case GGML_TYPE_F16:
+                                        {
+                                            GGML_ASSERT(ne02 == ne12);
+
+                                            nth0 = 64;
+                                            nth1 = 1;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
+                                        } break;
+                                    case GGML_TYPE_Q4_0:
+                                        {
+                                            GGML_ASSERT(ne02 == 1);
+                                            GGML_ASSERT(ne12 == 1);
+
+                                            nth0 = 8;
+                                            nth1 = 8;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
+                                        } break;
+                                    case GGML_TYPE_Q4_1:
+                                        {
+                                            GGML_ASSERT(ne02 == 1);
+                                            GGML_ASSERT(ne12 == 1);
+
+                                            nth0 = 8;
+                                            nth1 = 8;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
+                                        } break;
+                                    case GGML_TYPE_Q2_K:
+                                        {
+                                            GGML_ASSERT(ne02 == 1);
+                                            GGML_ASSERT(ne12 == 1);
+
+                                            nth0 = 4;
+                                            nth1 = 16;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_k_f32];
+                                        } break;
+                                    case GGML_TYPE_Q3_K:
+                                        {
+                                            GGML_ASSERT(ne02 == 1);
+                                            GGML_ASSERT(ne12 == 1);
+
+                                            nth0 = 4;
+                                            nth1 = 16;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_k_f32];
+                                        } break;
+                                    case GGML_TYPE_Q4_K:
+                                        {
+                                            GGML_ASSERT(ne02 == 1);
+                                            GGML_ASSERT(ne12 == 1);
+
+                                            nth0 = 4;
+                                            nth1 = 16;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32];
+                                        } break;
+                                    case GGML_TYPE_Q5_K:
+                                        {
+                                            GGML_ASSERT(ne02 == 1);
+                                            GGML_ASSERT(ne12 == 1);
+
+                                            nth0 = 4;
+                                            nth1 = 16;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_k_f32];
+                                        } break;
+                                    case GGML_TYPE_Q6_K:
+                                        {
+                                            GGML_ASSERT(ne02 == 1);
+                                            GGML_ASSERT(ne12 == 1);
+
+                                            nth0 = 4;
+                                            nth1 = 16;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_k_f32];
+                                        } break;
+                                    default:
+                                        {
+                                            fprintf(stderr, "Asserting on type %d\n",(int)src0t);
+                                            GGML_ASSERT(false && "not implemented");
+                                        }
+                                };
+
+                                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                                [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+                                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+                                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+                                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
+                                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:5];
+                                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:6];
+                                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:7];
+                                [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:8];
+                                [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:9];
+                                [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
+                                [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
+                                [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
+                                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:13];
+                                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:14];
+
+                                if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) {
+                                    [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
+                                    [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                }
+                                else if (src0t == GGML_TYPE_Q2_K ||
+                                         src0t == GGML_TYPE_Q3_K ||
+                                         src0t == GGML_TYPE_Q4_K ||
+                                         src0t == GGML_TYPE_Q5_K ||
+                                         src0t == GGML_TYPE_Q6_K) {
+                                    [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
+                                    [encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                } else {
+                                    [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
+                                    [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                }
+                            }
+                        } break;
+                    case GGML_OP_GET_ROWS:
+                        {
+                            if (encoder == nil) {
+                                encoder = [command_buffer computeCommandEncoder];
+                            }
+
+                            switch (src0->type) {
+                                case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
+                                case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
+                                case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
+                                case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break;
+                                case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_k]; break;
+                                case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break;
+                                case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_k]; break;
+                                case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break;
+                                default: GGML_ASSERT(false && "not implemented");
+                            }
+
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+                            [encoder setBytes:&(src0->ne[0]) length:sizeof( int64_t) atIndex:3];
+                            [encoder setBytes:&(src0->nb[1]) length:sizeof(uint64_t) atIndex:4];
+                            [encoder setBytes:&(dst->nb[1])  length:sizeof(uint64_t) atIndex:5];
+
+                            const int64_t n = ggml_nelements(src1);
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                        } break;
+                    case GGML_OP_RMS_NORM:
+                        {
+                            if (encoder == nil) {
+                                encoder = [command_buffer computeCommandEncoder];
+                            }
+
+                            const float eps = 1e-6f;
+
+                            const int nth = 256;
+
+                            [encoder setComputePipelineState:ctx->pipeline_rms_norm];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                            [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+                            [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
+                            [encoder setBytes:&eps  length:sizeof(   float) atIndex:4];
+                            [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
+
+                            const int64_t nrows = ggml_nrows(src0);
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                        } break;
+                    case GGML_OP_NORM:
+                        {
+                            if (encoder == nil) {
+                                encoder = [command_buffer computeCommandEncoder];
+                            }
+
+                            const float eps = 1e-5f;
+
+                            const int nth = 256;
+
+                            [encoder setComputePipelineState:ctx->pipeline_norm];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                            [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+                            [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
+                            [encoder setBytes:&eps  length:sizeof(   float) atIndex:4];
+                            [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
+
+                            const int64_t nrows = ggml_nrows(src0);
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                        } break;
+                    case GGML_OP_ALIBI:
+                        {
+                            if (encoder == nil) {
+                                encoder = [command_buffer computeCommandEncoder];
+                            }
+
+                            GGML_ASSERT((src0t == GGML_TYPE_F32));
+
+                            const int   n_past   = ((int32_t *) src1->data)[0]; UNUSED(n_past);
+                            const int   n_head   = ((int32_t *) src1->data)[1];
+                            const float max_bias = ((float *)   src1->data)[2];
+
+                            if (__builtin_popcount(n_head) != 1) {
+                                GGML_ASSERT(false && "only power-of-two n_head implemented");
+                            }
+
+                            const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
+                            const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
+
+                            [encoder setComputePipelineState:ctx->pipeline_alibi_f32];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                            [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+                            [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
+                            [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
+                            [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
+                            [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
+                            [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
+                            [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
+                            [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
+                            [encoder setBytes:&ne0  length:sizeof( int64_t) atIndex:10];
+                            [encoder setBytes:&ne1  length:sizeof( int64_t) atIndex:11];
+                            [encoder setBytes:&ne2  length:sizeof( int64_t) atIndex:12];
+                            [encoder setBytes:&ne3  length:sizeof( int64_t) atIndex:13];
+                            [encoder setBytes:&nb0  length:sizeof(uint64_t) atIndex:14];
+                            [encoder setBytes:&nb1  length:sizeof(uint64_t) atIndex:15];
+                            [encoder setBytes:&nb2  length:sizeof(uint64_t) atIndex:16];
+                            [encoder setBytes:&nb3  length:sizeof(uint64_t) atIndex:17];
+                            [encoder setBytes:&m0  length:sizeof(    float) atIndex:18];
+                            const int nth = 32;
+                            [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                        } break;
+                    case GGML_OP_ROPE:
+                        {
+                            if (encoder == nil) {
+                                encoder = [command_buffer computeCommandEncoder];
+                            }
+
+                            const int n_dims = ((int32_t *) src1->data)[1];
+                            const int mode   = ((int32_t *) src1->data)[2];
+
+                            const int n_past = ((int32_t *)(src1->data))[0];
+
+                            [encoder setComputePipelineState:ctx->pipeline_rope];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                            [encoder setBytes:&ne00   length:sizeof( int64_t) atIndex:2];
+                            [encoder setBytes:&ne01   length:sizeof( int64_t) atIndex:3];
+                            [encoder setBytes:&ne02   length:sizeof( int64_t) atIndex:4];
+                            [encoder setBytes:&ne03   length:sizeof( int64_t) atIndex:5];
+                            [encoder setBytes:&nb00   length:sizeof(uint64_t) atIndex:6];
+                            [encoder setBytes:&nb01   length:sizeof(uint64_t) atIndex:7];
+                            [encoder setBytes:&nb02   length:sizeof(uint64_t) atIndex:8];
+                            [encoder setBytes:&nb03   length:sizeof(uint64_t) atIndex:9];
+                            [encoder setBytes:&ne0    length:sizeof( int64_t) atIndex:10];
+                            [encoder setBytes:&ne1    length:sizeof( int64_t) atIndex:11];
+                            [encoder setBytes:&ne2    length:sizeof( int64_t) atIndex:12];
+                            [encoder setBytes:&ne3    length:sizeof( int64_t) atIndex:13];
+                            [encoder setBytes:&nb0    length:sizeof(uint64_t) atIndex:14];
+                            [encoder setBytes:&nb1    length:sizeof(uint64_t) atIndex:15];
+                            [encoder setBytes:&nb2    length:sizeof(uint64_t) atIndex:16];
+                            [encoder setBytes:&nb3    length:sizeof(uint64_t) atIndex:17];
+                            [encoder setBytes:&n_past length:sizeof(     int) atIndex:18];
+                            [encoder setBytes:&n_dims length:sizeof(     int) atIndex:19];
+                            [encoder setBytes:&mode   length:sizeof(     int) atIndex:20];
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                        } break;
+                    case GGML_OP_CPY:
+                        {
+                            if (encoder == nil) {
+                                encoder = [command_buffer computeCommandEncoder];
+                            }
+
+                            const int nth = 32;
+
+                            switch (src0t) {
+                                case GGML_TYPE_F32:
+                                    {
+                                        switch (dstt) {
+                                            case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
+                                            case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
+                                            default: GGML_ASSERT(false && "not implemented");
+                                        };
+                                    } break;
+                                case GGML_TYPE_F16:
+                                    {
+                                        switch (dstt) {
+                                            case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
+                                            case GGML_TYPE_F32: GGML_ASSERT(false && "cpy_f16_f32 not implemented"); break;
+                                            default: GGML_ASSERT(false && "not implemented");
+                                        };
+                                    } break;
+                                default: GGML_ASSERT(false && "not implemented");
+                            }
+
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                            [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+                            [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
+                            [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
+                            [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
+                            [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
+                            [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
+                            [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
+                            [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
+                            [encoder setBytes:&ne0  length:sizeof( int64_t) atIndex:10];
+                            [encoder setBytes:&ne1  length:sizeof( int64_t) atIndex:11];
+                            [encoder setBytes:&ne2  length:sizeof( int64_t) atIndex:12];
+                            [encoder setBytes:&ne3  length:sizeof( int64_t) atIndex:13];
+                            [encoder setBytes:&nb0  length:sizeof(uint64_t) atIndex:14];
+                            [encoder setBytes:&nb1  length:sizeof(uint64_t) atIndex:15];
+                            [encoder setBytes:&nb2  length:sizeof(uint64_t) atIndex:16];
+                            [encoder setBytes:&nb3  length:sizeof(uint64_t) atIndex:17];
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                        } break;
+                    default:
+                        fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
+                        GGML_ASSERT(false);
+                }
+            }
+
+            if (encoder != nil) {
+                [encoder endEncoding];
+                encoder = nil;
+            }
+
+            [command_buffer commit];
+        });
+    }
+
+    // wait for all threads to finish
+    dispatch_barrier_sync(queue, ^{});
+
+    [command_buffers[n_cb - 1] waitUntilCompleted];
+
+    // check status of command buffers
+    // needed to detect if the device ran out-of-memory for example (#1881)
+    for (int i = 0; i < n_cb; i++) {
+        MTLCommandBufferStatus status = (MTLCommandBufferStatus) [command_buffers[i] status];
+        if (status != MTLCommandBufferStatusCompleted) {
+            fprintf(stderr, "%s: command buffer %d failed with status %lu\n", __func__, i, status);
+            GGML_ASSERT(false);
+        }
+    }
+}
diff --git a/ggml-metal.metal b/ggml-metal.metal
new file mode 100644 (file)
index 0000000..d1e4922
--- /dev/null
@@ -0,0 +1,1585 @@
+#include <metal_stdlib>
+
+using namespace metal;
+
+#define MAX(x, y) ((x) > (y) ? (x) : (y))
+
+#define QK4_0 32
+#define QR4_0 2
+typedef struct {
+    half    d;             // delta
+    uint8_t qs[QK4_0 / 2]; // nibbles / quants
+} block_q4_0;
+
+#define QK4_1 32
+typedef struct {
+    half d;          // delta
+    half m;          // min
+    uint8_t qs[QK4_1 / 2];  // nibbles / quants
+} block_q4_1;
+
+static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, int k) {
+    const int qk = QK4_0;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        const half d = x[i].d;
+
+        for (int j = 0; j < qk/2; ++j) {
+            const int x0 = (x[i].qs[j] & 0x0F) - 8;
+            const int x1 = (x[i].qs[j] >>   4) - 8;
+
+            y[i*qk + j + 0   ] = x0*d;
+            y[i*qk + j + qk/2] = x1*d;
+        }
+    }
+}
+
+static void dequantize_row_q4_1(device const block_q4_1 * x, device float * y, int k) {
+    const int qk = QK4_1;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        const half d = x[i].d;
+        const half m = x[i].m;
+
+        for (int j = 0; j < qk/2; ++j) {
+            const int x0 = (x[i].qs[j] & 0x0F);
+            const int x1 = (x[i].qs[j] >>   4);
+
+            y[i*qk + j + 0   ] = x0*d + m;
+            y[i*qk + j + qk/2] = x1*d + m;
+        }
+    }
+}
+
+kernel void kernel_add(
+        device const float * src0,
+        device const float * src1,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] + src1[tpig];
+}
+
+kernel void kernel_mul(
+        device const float * src0,
+        device const float * src1,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] * src1[tpig];
+}
+
+// assumption: src1 is a row
+// broadcast src1 into src0
+kernel void kernel_mul_row(
+        device const float * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] * src1[tpig % ne00];
+}
+
+kernel void kernel_scale(
+        device const float * src0,
+        device       float * dst,
+        constant     float & scale,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] * scale;
+}
+
+kernel void kernel_silu(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    float x = src0[tpig];
+    dst[tpig] = x / (1.0f + exp(-x));
+}
+
+kernel void kernel_relu(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = max(0.0f, src0[tpig]);
+}
+
+constant float GELU_COEF_A    = 0.044715f;
+constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
+
+kernel void kernel_gelu(
+    device const float * src0,
+    device       float * dst,
+    uint tpig[[thread_position_in_grid]]) {
+    float x = src0[tpig];
+    dst[tpig] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+}
+
+kernel void kernel_soft_max(
+        device const float * src0,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        threadgroup float  * buf [[threadgroup(0)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig[2];
+    const int64_t i02 = tgpig[1];
+    const int64_t i01 = tgpig[0];
+
+    device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+    device       float * pdst  = dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    // parallel max
+    buf[tpitg[0]] = -INFINITY;
+    for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
+        buf[tpitg[0]] = MAX(buf[tpitg[0]], psrc0[i00]);
+    }
+
+    // reduce
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    for (uint i = ntg[0]/2; i > 0; i /= 2) {
+        if (tpitg[0] < i) {
+            buf[tpitg[0]] = MAX(buf[tpitg[0]], buf[tpitg[0] + i]);
+        }
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+    }
+
+    // broadcast
+    if (tpitg[0] == 0) {
+        buf[0] = buf[0];
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    const float max = buf[0];
+
+    // parallel sum
+    buf[tpitg[0]] = 0.0f;
+    for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
+        buf[tpitg[0]] += exp(psrc0[i00] - max);
+    }
+
+    // reduce
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    for (uint i = ntg[0]/2; i > 0; i /= 2) {
+        if (tpitg[0] < i) {
+            buf[tpitg[0]] += buf[tpitg[0] + i];
+        }
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+    }
+
+    // broadcast
+    if (tpitg[0] == 0) {
+        buf[0] = buf[0];
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    const float sum = buf[0];
+
+    for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
+        pdst[i00] = exp(psrc0[i00] - max) / sum;
+    }
+}
+
+kernel void kernel_diag_mask_inf(
+        device const float * src0,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant       int & n_past,
+        uint3 tpig[[thread_position_in_grid]]) {
+    const int64_t i02 = tpig[2];
+    const int64_t i01 = tpig[1];
+    const int64_t i00 = tpig[0];
+
+    if (i00 > n_past + i01) {
+        dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
+    } else {
+        dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
+    }
+}
+
+kernel void kernel_get_rows_f16(
+        device const  void * src0,
+        device const   int * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb1,
+        uint tpig[[thread_position_in_grid]]) {
+    const int i = tpig;
+    const int r = ((device int32_t *) src1)[i];
+
+    for (int j = 0; j < ne00; j++) {
+        dst[i*nb1 + j] = ((device half *) ((device char *) src0 + r*nb01))[j];
+    }
+}
+
+kernel void kernel_get_rows_q4_0(
+        device const  void * src0,
+        device const   int * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb1,
+        uint tpig[[thread_position_in_grid]]) {
+    const int i = tpig;
+    const int r = ((device int32_t *) src1)[i];
+
+    dequantize_row_q4_0(
+            (device const block_q4_0 *) ((device char *) src0 + r*nb01),
+                       (device float *) ((device char *)  dst + i*nb1), ne00);
+}
+
+kernel void kernel_get_rows_q4_1(
+        device const  void * src0,
+        device const   int * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb1,
+        uint tpig[[thread_position_in_grid]]) {
+    const int i = tpig;
+    const int r = ((device int32_t *) src1)[i];
+
+    dequantize_row_q4_1(
+            (device const block_q4_1 *) ((device char *) src0 + r*nb01),
+                       (device float *) ((device char *)  dst + i*nb1), ne00);
+}
+
+kernel void kernel_norm(
+        device const  void * src0,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant     float & eps,
+        threadgroup float  * sum [[threadgroup(0)]],
+        uint tgpig[[threadgroup_position_in_grid]],
+        uint tpitg[[thread_position_in_threadgroup]],
+        uint   ntg[[threads_per_threadgroup]]) {
+    device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
+    // MEAN
+    // parallel sum
+    sum[tpitg] = 0.0f;
+    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+        sum[tpitg] += x[i00];
+    }
+    // reduce
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    for (uint i = ntg/2; i > 0; i /= 2) {
+        if (tpitg < i) {
+            sum[tpitg] += sum[tpitg + i];
+        }
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+    }
+    // broadcast
+    if (tpitg == 0) {
+        sum[0] /= ne00;
+    }
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    const float mean  = sum[0];
+
+    // recenter
+    device float * y = dst + tgpig*ne00;
+    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+        y[i00] = x[i00] - mean;
+    }
+
+    // VARIANCE
+    // parallel sum
+    sum[tpitg] = 0.0f;
+    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+        sum[tpitg] += y[i00] * y[i00];
+    }
+    // reduce
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    for (uint i = ntg/2; i > 0; i /= 2) {
+        if (tpitg < i) {
+            sum[tpitg] += sum[tpitg + i];
+        }
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+    }
+    // broadcast
+    if (tpitg == 0) {
+        sum[0] /= ne00;
+    }
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    const float variance = sum[0];
+
+    const float scale = 1.0f/sqrt(variance + eps);
+    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+        y[i00] = y[i00] * scale;
+    }
+}
+
+
+kernel void kernel_rms_norm(
+        device const  void * src0,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant     float & eps,
+        threadgroup float  * sum [[threadgroup(0)]],
+        uint tgpig[[threadgroup_position_in_grid]],
+        uint tpitg[[thread_position_in_threadgroup]],
+        uint   ntg[[threads_per_threadgroup]]) {
+    device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
+
+    // parallel sum
+    sum[tpitg] = 0.0f;
+    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+        sum[tpitg] += x[i00] * x[i00];
+    }
+
+    // reduce
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    for (uint i = ntg/2; i > 0; i /= 2) {
+        if (tpitg < i) {
+            sum[tpitg] += sum[tpitg + i];
+        }
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+    }
+
+    // broadcast
+    if (tpitg == 0) {
+        sum[0] /= ne00;
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    const float mean  = sum[0];
+    const float scale = 1.0f/sqrt(mean + eps);
+
+    device float * y = dst + tgpig*ne00;
+    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+        y[i00] = x[i00] * scale;
+    }
+}
+
+kernel void kernel_mul_mat_q4_0_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne10,
+        constant   int64_t & ne0,
+        threadgroup float  * sum [[threadgroup(0)]],
+        uint2 tgpig[[threadgroup_position_in_grid]],
+        uint2 tpitg[[thread_position_in_threadgroup]],
+        uint2  tptg[[threads_per_threadgroup]]) {
+    const int nb = ne00/QK4_0;
+
+    const int64_t r0 = tgpig.x;
+    const int64_t r1 = tgpig.y;
+
+    device const block_q4_0 * x = (device const block_q4_0 *) src0 + r0*nb;
+    device const float      * y = (device const float      *) src1 + r1*ne10;
+
+    const int nth = tptg.x*tptg.y;
+    const int ith = tptg.y*tpitg.x + tpitg.y;
+
+    const int ix = tpitg.y/4;           // 0 or 1
+    const int iy = tpitg.y - 4*ix;      // 0...3
+
+    const int first = 4 * iy;
+
+    float sumf = 0;
+
+    for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) {
+
+        const float d = (float)x[i].d;
+
+        device const uint8_t * xl = x[i].qs + first;
+        device const float   * yl = y + i * QK4_0 + first;
+
+        float2 acc = {0.0f, 0.0f};
+
+        for (int j = 0; j < 4; ++j) {
+
+            acc[0] += yl[j] * (xl[j] & 0xF) + yl[j+16] * (xl[j] >> 4);
+            acc[1] += yl[j] + yl[j+16];
+
+        }
+
+        sumf += d * (acc[0] - 8.f*acc[1]);
+    }
+
+    sum[ith] = sumf;
+
+    //
+    // Accumulate the sum from all threads in the threadgroup
+    //
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith%4 == 0) {
+        sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
+    }
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith%16 == 0) {
+        sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
+    }
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith == 0) {
+        for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
+        dst[r1*ne0 + r0] = sum[0];
+    }
+}
+
+kernel void kernel_mul_mat_q4_1_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne10,
+        constant   int64_t & ne0,
+        threadgroup float  * sum [[threadgroup(0)]],
+        uint2 tgpig[[threadgroup_position_in_grid]],
+        uint2 tpitg[[thread_position_in_threadgroup]],
+        uint2  tptg[[threads_per_threadgroup]]) {
+    const int nb = ne00/QK4_1;
+
+    const int64_t r0 = tgpig.x;
+    const int64_t r1 = tgpig.y;
+
+    device const block_q4_1 * x = (device const block_q4_1 *) src0 + r0*nb;
+    device const float      * y = (device const float      *) src1 + r1*ne10;
+
+    const uint nth = tptg.x*tptg.y;
+    const uint ith = tptg.y*tpitg.x + tpitg.y;
+
+    const int ix = tpitg.y/4;           // 0 or 1
+    const int iy = tpitg.y - 4*ix;      // 0...3
+
+    const int first = 4 * iy;
+
+    float sumf = 0;
+
+    for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) {
+
+        const float d = (float)x[i].d;
+        const float m = (float)x[i].m;
+
+        device const uint8_t * xl = x[i].qs + first;
+        device const float   * yl = y + i * QK4_1 + first;
+
+        float2 acc = {0.0f, 0.0f};
+
+        for (int j = 0; j < 4; ++j) {
+
+            acc[0] += yl[j+ 0] * (d * (xl[j] & 0xF) + m);
+            acc[1] += yl[j+16] * (d * (xl[j] >>  4) + m);
+
+        }
+
+        sumf += acc[0] + acc[1];
+    }
+
+    sum[ith] = sumf;
+
+    //
+    // Accumulate the sum from all threads in the threadgroup
+    //
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith%4 == 0) {
+        sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
+    }
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith%16 == 0) {
+        sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
+    }
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith == 0) {
+        for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
+        dst[r1*ne0 + r0] = sum[0];
+    }
+}
+
+kernel void kernel_mul_mat_f16_f32(
+        device const  char * src0,
+        device const  char * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        threadgroup float  * sum [[threadgroup(0)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3  tpig[[thread_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3  tptg[[threads_per_threadgroup]]) {
+
+    const int64_t r0 = tgpig.x;
+    const int64_t r1 = tgpig.y;
+    const int64_t im = tgpig.z;
+
+    device const half  * x = (device const half  *) (src0 + r0*nb01 + im*nb02);
+    device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
+
+    sum[tpitg.x] = 0.0f;
+
+    for (int i = tpitg.x; i < ne00; i += tptg.x) {
+        sum[tpitg.x] += (float) x[i] * (float) y[i];
+    }
+
+    // accumulate the sum from all threads in the threadgroup
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    for (uint i = tptg.x/2; i > 0; i /= 2) {
+        if (tpitg.x < i) {
+            sum[tpitg.x] += sum[tpitg.x + i];
+        }
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+    }
+
+    if (tpitg.x == 0) {
+        dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
+    }
+}
+
+kernel void kernel_alibi_f32(
+        device const float * src0,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        constant  uint64_t & nb3,
+        constant      float & m0,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig[2];
+    const int64_t i02 = tgpig[1];
+    const int64_t i01 = tgpig[0];
+
+    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    const int64_t i3 = n / (ne2*ne1*ne0);
+    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
+    device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+    float m_k = pow(m0, i2 + 1);
+    for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
+        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+        dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
+    }
+}
+
+kernel void kernel_rope(
+        device const  void * src0,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        constant  uint64_t & nb3,
+        constant       int & n_past,
+        constant       int & n_dims,
+        constant       int & mode,
+        uint3 tpig[[thread_position_in_grid]]) {
+    const int64_t i3 = tpig[2];
+    const int64_t i2 = tpig[1];
+    const int64_t i1 = tpig[0];
+
+    const bool is_neox = mode & 2;
+    const float theta_scale = pow(10000.0, -2.0f/n_dims);
+
+    const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
+
+    float theta = (float)p;
+
+    if (!is_neox) {
+        for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
+            const float cos_theta = cos(theta);
+            const float sin_theta = sin(theta);
+
+            theta *= theta_scale;
+
+            device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+            device       float * dst_data  = (device float *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+            const float x0 = src[0];
+            const float x1 = src[1];
+
+            dst_data[0] = x0*cos_theta - x1*sin_theta;
+            dst_data[1] = x0*sin_theta + x1*cos_theta;
+        }
+    } else {
+        // TODO: implement
+    }
+}
+
+kernel void kernel_cpy_f16_f16(
+        device const half * src0,
+        device       half * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        constant  uint64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig[2];
+    const int64_t i02 = tgpig[1];
+    const int64_t i01 = tgpig[0];
+
+    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    const int64_t i3 = n / (ne2*ne1*ne0);
+    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
+    device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+    for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
+        device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+        dst_data[i00] = src[0];
+    }
+}
+
+kernel void kernel_cpy_f32_f16(
+        device const float * src0,
+        device        half * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        constant  uint64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig[2];
+    const int64_t i02 = tgpig[1];
+    const int64_t i01 = tgpig[0];
+
+    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    const int64_t i3 = n / (ne2*ne1*ne0);
+    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
+    device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+    for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
+        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+        dst_data[i00] = src[0];
+    }
+}
+
+kernel void kernel_cpy_f32_f32(
+        device const float * src0,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        constant  uint64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig[2];
+    const int64_t i02 = tgpig[1];
+    const int64_t i01 = tgpig[0];
+
+    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    const int64_t i3 = n / (ne2*ne1*ne0);
+    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
+    device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+    for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
+        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+        dst_data[i00] = src[0];
+    }
+}
+
+//============================================ k-quants ======================================================
+
+#define QK_K 256
+
+typedef struct {
+    uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
+    uint8_t qs[QK_K/4];      // quants
+    half d;           // super-block scale for quantized scales
+    half dmin;        // super-block scale for quantized mins
+} block_q2_k;
+// 84 bytes / block
+
+typedef struct {
+    uint8_t hmask[QK_K/8];     // quants - high bit
+    uint8_t qs[QK_K/4];        // quants - low 2 bits
+    uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
+    half d;                    // super-block scale
+} block_q3_k;
+// 110 bytes / block
+
+typedef struct {
+    half d;             // super-block scale for quantized scales
+    half dmin;          // super-block scale for quantized mins
+    uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
+    uint8_t qs[QK_K/2];        // 4--bit quants
+} block_q4_k;
+// 144 bytes / block
+
+typedef struct {
+    half d;                      // super-block scale for quantized scales
+    half dmin;                   // super-block scale for quantized mins
+    uint8_t scales[3*QK_K/64];   // scales and mins, quantized with 6 bits
+    uint8_t qh[QK_K/8];          // quants, high bit
+    uint8_t qs[QK_K/2];          // quants, low 4 bits
+} block_q5_k;
+// 176 bytes / block
+
+typedef struct {
+    uint8_t ql[QK_K/2];      // quants, lower 4 bits
+    uint8_t qh[QK_K/4];      // quants, upper 2 bits
+    int8_t  scales[QK_K/16]; // scales, quantized with 8 bits
+    half d;                  // super-block scale
+} block_q6_k;
+// 210 bytes / block
+
+static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
+    uchar4 r;
+    if (j < 4) {
+        r[0] = q[j+0] & 63;
+        r[2] = q[j+1] & 63;
+        r[1] = q[j+4] & 63;
+        r[3] = q[j+5] & 63;
+    } else {
+        r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
+        r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4);
+        r[1] = (q[j+4] >>  4) | ((q[j-0] >> 6) << 4);
+        r[3] = (q[j+5] >>  4) | ((q[j+1] >> 6) << 4);
+    }
+    return r;
+}
+
+//========================================== dequantization =============================
+
+static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    for (int i = 0; i < nb; i++) {
+
+        const float d = x[i].d;
+        const float min = x[i].dmin;
+
+        device const uint8_t * q = x[i].qs;
+
+        int is = 0;
+        float dl, ml;
+        for (int n = 0; n < QK_K; n += 128) {
+            int shift = 0;
+            for (int j = 0; j < 4; ++j) {
+
+                uint8_t sc = x[i].scales[is++];
+                dl = d * (sc & 0xF); ml = min * (sc >> 4);
+                for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml;
+
+                sc = x[i].scales[is++];
+                dl = d * (sc & 0xF); ml = min * (sc >> 4);
+                for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml;
+
+                shift += 2;
+            }
+            q += 32;
+        }
+
+    }
+}
+
+static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    const uint16_t kmask1 = 0x0303;
+    const uint16_t kmask2 = 0x0f0f;
+
+    uint16_t aux[8];
+    thread const int8_t * scales = (thread const int8_t*)aux;
+
+    for (int i = 0; i < nb; i++) {
+
+        const float d_all = (float)(x[i].d);
+
+        device const uint8_t * q = x[i].qs;
+        device const uint8_t * h = x[i].hmask;
+        uint8_t m = 1;
+
+        device const uint16_t * a = (device const uint16_t *)x[i].scales;
+        aux[0] = (a[0] & kmask2) | (((a[4] >> 0) & kmask1) << 4);
+        aux[1] = (a[1] & kmask2) | (((a[5] >> 0) & kmask1) << 4);
+        aux[2] = (a[2] & kmask2) | (((a[4] >> 2) & kmask1) << 4);
+        aux[3] = (a[3] & kmask2) | (((a[5] >> 2) & kmask1) << 4);
+        aux[4] = ((a[0] >> 4) & kmask2) | (((a[4] >> 4) & kmask1) << 4);
+        aux[5] = ((a[1] >> 4) & kmask2) | (((a[5] >> 4) & kmask1) << 4);
+        aux[6] = ((a[2] >> 4) & kmask2) | (((a[4] >> 6) & kmask1) << 4);
+        aux[7] = ((a[3] >> 4) & kmask2) | (((a[5] >> 6) & kmask1) << 4);
+
+        int is = 0;
+        float dl;
+        for (int n = 0; n < QK_K; n += 128) {
+            int shift = 0;
+            for (int j = 0; j < 4; ++j) {
+
+                dl = d_all * (scales[is++] - 32);
+                for (int l = 0; l < 16; ++l) {
+                    *y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((h[l+ 0] & m) ? 0 : 4));
+                }
+
+                dl = d_all * (scales[is++] - 32);
+                for (int l = 0; l < 16; ++l) {
+                    *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((h[l+16] & m) ? 0 : 4));
+                }
+
+                shift += 2;
+                m <<= 1;
+            }
+            q += 32;
+        }
+
+    }
+
+}
+
+static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+
+    for (int i = 0; i < nb; i++) {
+
+        const float d = x[i].d;
+        const float min = x[i].dmin;
+
+        device const uint8_t * q = x[i].qs;
+        device const uint8_t * scales = x[i].scales;
+
+        int is = 0;
+        for (int j = 0; j < QK_K; j += 64) {
+            const uchar4 sc = get_scale_min_k4(is, scales);
+            const float d1 = d * sc[0]; const float m1 = min * sc[1];
+            const float d2 = d * sc[2]; const float m2 = min * sc[3];
+            for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
+            for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l]  >> 4) - m2;
+            q += 32; is += 2;
+        }
+
+    }
+}
+
+static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+   for (int i = 0; i < nb; i++) {
+
+        const float d = (float)(x[i].d);
+        const float min = (float)(x[i].dmin);
+
+        device const uint8_t * ql = x[i].qs;
+        device const uint8_t * qh = x[i].qh;
+
+        int is = 0;
+        uint8_t u1 = 1, u2 = 2;
+        for (int j = 0; j < QK_K; j += 64) {
+            const uchar4 sc = get_scale_min_k4(is, x[i].scales);
+            const float d1 = d * sc[0]; const float m1 = min * sc[1];
+            const float d2 = d * sc[2]; const float m2 = min * sc[3];
+            for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1;
+            for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l]  >> 4) + (qh[l] & u2 ? 16 : 0)) - m2;
+            ql += 32; is += 2;
+            u1 <<= 2; u2 <<= 2;
+        }
+    }
+
+}
+
+static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    for (int i = 0; i < nb; i++) {
+
+        device const uint8_t * ql = x[i].ql;
+        device const uint8_t * qh = x[i].qh;
+        device const int8_t  * sc = x[i].scales;
+
+        const float d = x[i].d;
+
+        for (int n = 0; n < QK_K; n += 128) {
+            for (int l = 0; l < 32; ++l) {
+                int is = l/16;
+                const int8_t q1 = (int8_t)((ql[l +  0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
+                const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
+                const int8_t q3 = (int8_t)((ql[l +  0]  >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
+                const int8_t q4 = (int8_t)((ql[l + 32]  >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
+                y[l +  0] = d * sc[is + 0] * q1;
+                y[l + 32] = d * sc[is + 2] * q2;
+                y[l + 64] = d * sc[is + 4] * q3;
+                y[l + 96] = d * sc[is + 6] * q4;
+            }
+            y  += 128;
+            ql += 64;
+            qh += 32;
+            sc += 8;
+        }
+    }
+}
+
+kernel void kernel_get_rows_q2_k(
+        device const  void * src0,
+        device const   int * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb1,
+        uint tpig[[thread_position_in_grid]]) {
+    const int i = tpig;
+    const int r = ((device int32_t *) src1)[i];
+
+    dequantize_row_q2_k(
+            (device const block_q2_k *) ((device char *) src0 + r*nb01),
+                       (device float *) ((device char *)  dst + i*nb1), ne00);
+}
+
+kernel void kernel_get_rows_q3_k(
+        device const  void * src0,
+        device const   int * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb1,
+        uint tpig[[thread_position_in_grid]]) {
+    const int i = tpig;
+    const int r = ((device int32_t *) src1)[i];
+
+    dequantize_row_q3_k(
+            (device const block_q3_k *) ((device char *) src0 + r*nb01),
+                       (device float *) ((device char *)  dst + i*nb1), ne00);
+}
+
+kernel void kernel_get_rows_q4_k(
+        device const  void * src0,
+        device const   int * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb1,
+        uint tpig[[thread_position_in_grid]]) {
+    const int i = tpig;
+    const int r = ((device int32_t *) src1)[i];
+
+    dequantize_row_q4_k(
+            (device const block_q4_k *) ((device char *) src0 + r*nb01),
+                       (device float *) ((device char *)  dst + i*nb1), ne00);
+}
+
+kernel void kernel_get_rows_q5_k(
+        device const  void * src0,
+        device const   int * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb1,
+        uint tpig[[thread_position_in_grid]]) {
+    const int i = tpig;
+    const int r = ((device int32_t *) src1)[i];
+
+    dequantize_row_q5_k(
+            (device const block_q5_k *) ((device char *) src0 + r*nb01),
+                       (device float *) ((device char *)  dst + i*nb1), ne00);
+}
+
+kernel void kernel_get_rows_q6_k(
+        device const  void * src0,
+        device const   int * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb1,
+        uint tpig[[thread_position_in_grid]]) {
+    const int i = tpig;
+    const int r = ((device int32_t *) src1)[i];
+
+    dequantize_row_q6_k(
+            (device const block_q6_k *) ((device char *) src0 + r*nb01),
+                       (device float *) ((device char *)  dst + i*nb1), ne00);
+}
+
+//====================================== dot products =========================
+
+kernel void kernel_mul_mat_q2_k_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne10,
+        constant   int64_t & ne0,
+        threadgroup float  * sum [[threadgroup(0)]],
+        uint2 tgpig[[threadgroup_position_in_grid]],
+        uint2 tpitg[[thread_position_in_threadgroup]],
+        uint2  tptg[[threads_per_threadgroup]]) {
+
+    const int nb = ne00/QK_K;
+
+    const int64_t r0 = tgpig.x;
+    const int64_t r1 = tgpig.y;
+
+    device const block_q2_k * x = (device const block_q2_k *) src0 + r0*nb;
+    device const float     * yy = (device const float      *) src1 + r1*ne10;
+
+    const int nth = tptg.x*tptg.y;
+    const int ith = tptg.y*tpitg.x + tpitg.y;
+
+    const int tid = tpitg.y;    // 0...16
+    const int il  = tid/4;      // 0...3
+    const int ir  = tid%4;      // 0...3
+    const int ip  = il/2;       // 0 or 1
+    const int shift1 = 4*(il%2);// 0 or 4
+    const int shift2 = shift1+2;// 2 or 6
+    const int n   = 8;
+    const int is  = 4*il + (n*ir)/16;
+
+    const int y_offset = 64*il + n*ir;
+    const int q_offset = 32*ip + n*ir;
+
+    sum[ith] = 0.0f;
+
+    float sumf = 0;
+    for (int i = tpitg.x; i < nb; i += tptg.x) {
+
+        device const uint8_t * q = x[i].qs + q_offset;
+        device const uint8_t * scales = x[i].scales + is;
+
+        uint8_t d1 = scales[0] & 0xF;
+        uint8_t d2 = scales[2] & 0xF;
+        uint8_t m1 = scales[0] >>  4;
+        uint8_t m2 = scales[2] >>  4;
+
+        device const float   * y = yy + i*QK_K + y_offset;
+
+        //float4 s = {0.f, 0.f, 0.f, 0.f};
+        float2 s = {0.f, 0.f};
+        float smin = 0;
+        for (int l = 0; l < n; ++l) {
+            s[0] += y[l+ 0] * ((q[l] >> shift1) & 3);
+            s[1] += y[l+32] * ((q[l] >> shift2) & 3);
+            smin += y[l+ 0] * m1 + y[l+32] * m2;
+        }
+
+        const float dall = (float)x[i].d;
+        const float dmin = (float)x[i].dmin;
+
+        sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
+
+    }
+    sum[ith] = sumf;
+
+    //int mask1 = (ith%4 == 0);
+    //int mask2 = (ith%16 == 0);
+
+    //threadgroup_barrier(mem_flags::mem_threadgroup);
+    //for (int i = 1; i < 4; ++i) sum[ith] += mask1 * sum[ith + i];
+    //threadgroup_barrier(mem_flags::mem_threadgroup);
+    //for (int i = 4; i < 16; i += 4) sum[ith] += mask2 * sum[ith + i];
+    //threadgroup_barrier(mem_flags::mem_threadgroup);
+    //if (ith == 0) {
+    //    for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
+    //    dst[r1*ne0 + r0] = sum[0];
+    //}
+
+    //
+    // Accumulate the sum from all threads in the threadgroup
+    // This version is slightly faster than the commented out one below,
+    // which I copy-pasted from ggerganov's q4_0 dot product for metal.
+    //
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith%4 == 0) {
+        for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
+    }
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith%16 == 0) {
+        for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
+    }
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith == 0) {
+        for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
+        dst[r1*ne0 + r0] = sum[0];
+    }
+}
+
+kernel void kernel_mul_mat_q3_k_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne10,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        threadgroup float  * sum [[threadgroup(0)]],
+        uint2 tgpig[[threadgroup_position_in_grid]],
+        uint2 tpitg[[thread_position_in_threadgroup]],
+        uint2  tptg[[threads_per_threadgroup]]) {
+
+    const uint16_t kmask1 = 0x0303;
+    const uint16_t kmask2 = 0x0f0f;
+
+    const uint8_t m3 = 3;
+    const int8_t  m4 = 4;
+
+    const int nb = ne00/QK_K;
+
+    const int64_t r0 = tgpig.x;
+    const int64_t r1 = tgpig.y;
+
+    device const block_q3_k * x = (device const block_q3_k *) src0 + r0*nb;
+    device const float     * yy = (device const float      *) src1 + r1*ne10;
+
+    const int nth = tptg.x*tptg.y;
+    const int ith = tptg.y*tpitg.x + tpitg.y;
+
+    const int tid = tpitg.y;        // expecting 16
+    const int ip  = tid/8;          // 0 or 1
+    const int il  = tid/2 - 4*ip;   // 0...3
+    const int ir  = tid%2;
+    const int n   = 8;
+    const int l0  = n*ir;
+
+    const uint8_t m = 1 << (4*ip + il);
+
+    const int shift = 2*il;
+
+    const uint16_t s_shift1 = 4*ip;
+    const uint16_t s_shift2 = s_shift1 + 2*(il/2);
+    const int ik = 4 + (il%2);
+
+    const int q_offset = 32*ip + l0;
+    const int y_offset = 128*ip + 32*il + l0;
+
+    //float sumf = 0;
+    float sumf1 = 0, sumf2 = 0;
+    for (int i = tpitg.x; i < nb; i += tptg.x) {
+
+        const float d_all = (float)(x[i].d);
+
+        device const uint8_t * q = x[i].qs + q_offset;
+        device const uint8_t * h = x[i].hmask + l0;
+        device const float   * y = yy + i * QK_K + y_offset;
+
+        device const uint16_t * a = (device const uint16_t *)x[i].scales;
+        const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
+
+        float s = 0;
+        for (int l = 0; l < n; ++l) {
+            s += y[l+ 0] * ((int8_t)((q[l+ 0] >> shift) & m3) - ((h[l+ 0] & m) ? 0 : m4));
+        }
+        float d = d_all * s;
+        sumf1 += d * scales[0];
+        sumf2 += d;
+        //sumf += d_all * s * (scales[0] - 32);
+
+        s = 0;
+        for (int l = 0; l < n; ++l) {
+            s += y[l+16] * ((int8_t)((q[l+16] >> shift) & m3) - ((h[l+16] & m) ? 0 : m4));
+        }
+        d = d_all * s;
+        sumf1 += d * scales[1];
+        sumf2 += d;
+        //sumf += d_all * s * (scales[1] - 32);
+
+    }
+
+    //sum[ith] = sumf;
+    sum[ith] = sumf1 - 32.f*sumf2;
+
+    //
+    // Accumulate the sum from all threads in the threadgroup
+    //
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith%4 == 0) {
+        for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
+    }
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith%16 == 0) {
+        for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
+    }
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith == 0) {
+        for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
+        dst[r1*ne0 + r0] = sum[0];
+    }
+
+}
+
+kernel void kernel_mul_mat_q4_k_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne10,
+        constant   int64_t & ne0,
+        threadgroup float  * sum [[threadgroup(0)]],
+        uint2 tgpig[[threadgroup_position_in_grid]],
+        uint2 tpitg[[thread_position_in_threadgroup]],
+        uint2  tptg[[threads_per_threadgroup]]) {
+
+    const uint16_t kmask1 = 0x3f3f;
+    const uint16_t kmask2 = 0x0f0f;
+    const uint16_t kmask3 = 0xc0c0;
+
+    const int nb = ne00/QK_K;
+
+    const int64_t r0 = tgpig.x;
+    const int64_t r1 = tgpig.y;
+
+    device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb;
+    device const float     * yy = (device const float      *) src1 + r1*ne10;
+
+    const int nth = tptg.x*tptg.y;
+    const int ith = tptg.y*tpitg.x + tpitg.y;
+
+    const int tid = tpitg.y;   // 0...16
+    const int il  = tid/4;     // 0...3
+    const int ir  = tid - 4*il;// 0...3
+    const int n   = 4;
+
+    const int im = il/2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
+    const int in = il%2;
+
+    const int l0 = n*(2*ir + in);
+    const int q_offset = 32*im + l0;
+    const int y_offset = 64*im + l0;
+
+    sum[ith] = 0.0f;
+
+    uchar2 sc1, sc2, sc3, sc4;
+
+    float sumf = 0;
+    for (int i = tpitg.x; i < nb; i += tptg.x) {
+
+        device const uint8_t * q1 = (x + i)->qs + q_offset;
+        device const uint8_t * q2 = q1 + 64;
+        device const float   * y1 = yy + i*QK_K + y_offset;
+        device const float   * y2 = y1 + 128;
+
+        const float dall = (float)((x + i)->d);
+        const float dmin = (float)((x + i)->dmin);
+
+        device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
+        sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
+        sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
+        sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
+        sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
+
+        float4 s = {0.f, 0.f, 0.f, 0.f};
+        float smin = 0;
+        for (int l = 0; l < n; ++l) {
+
+            s[0] += y1[l] * (q1[l] & 0xF); s[1] += y1[l+32] * (q1[l] >> 4);
+            s[2] += y2[l] * (q2[l] & 0xF); s[3] += y2[l+32] * (q2[l] >> 4);
+            smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
+
+        }
+        sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
+
+    }
+
+    sum[ith] = sumf;
+
+    //
+    // Accumulate the sum from all threads in the threadgroup
+    // This version is slightly faster than the commented out one below,
+    // which I copy-pasted from ggerganov's q4_0 dot product for metal.
+    //
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith%4 == 0) {
+        for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
+    }
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith%16 == 0) {
+        for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
+    }
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith == 0) {
+        for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
+        dst[r1*ne0 + r0] = sum[0];
+    }
+
+    //// accumulate the sum from all threads in the threadgroup
+    //threadgroup_barrier(mem_flags::mem_threadgroup);
+    //for (uint i = nth/2; i > 0; i /= 2) {
+    //    if (ith < i) {
+    //        sum[ith] += sum[ith + i];
+    //    }
+    //    threadgroup_barrier(mem_flags::mem_threadgroup);
+    //}
+
+    //if (ith == 0) {
+    //    dst[r1*ne0 + r0] = sum[0];
+    //}
+}
+
+kernel void kernel_mul_mat_q5_k_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne10,
+        constant   int64_t & ne0,
+        threadgroup float  * sum [[threadgroup(0)]],
+        uint2 tgpig[[threadgroup_position_in_grid]],
+        uint2 tpitg[[thread_position_in_threadgroup]],
+        uint2  tptg[[threads_per_threadgroup]]) {
+
+    const uint16_t kmask1 = 0x3f3f;
+    const uint16_t kmask2 = 0x0f0f;
+    const uint16_t kmask3 = 0xc0c0;
+
+    const int nb = ne00/QK_K;
+
+    const int64_t r0 = tgpig.x;
+    const int64_t r1 = tgpig.y;
+
+    device const block_q5_k * x = (device const block_q5_k *) src0 + r0*nb;
+    device const float     * yy = (device const float      *) src1 + r1*ne10;
+
+    const int nth = tptg.x*tptg.y;
+    const int ith = tptg.y*tpitg.x + tpitg.y;
+
+    const int tid = tpitg.y;   // 0...16
+    const int il  = tid/4;     // 0...3
+    const int ir  = tid - 4*il;// 0...3
+    const int n   = 4;
+
+    const int im = il/2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
+    const int in = il%2;
+
+    const int l0 = n*(2*ir + in);
+    const int q_offset = 32*im + l0;
+    const int y_offset = 64*im + l0;
+
+    const uint8_t hm1 = 1u << (2*im);
+    const uint8_t hm2 = hm1 << 1;
+    const uint8_t hm3 = hm1 << 4;
+    const uint8_t hm4 = hm2 << 4;
+
+    uchar2 sc1, sc2, sc3, sc4;
+
+    float sumf = 0;
+    for (int i = tpitg.x; i < nb; i += tptg.x) {
+
+        device const uint8_t * q1 = (x + i)->qs + q_offset;
+        device const uint8_t * q2 = q1 + 64;
+        device const uint8_t * qh = (x + i)->qh + l0;
+        device const float   * y1 = yy + i*QK_K + y_offset;
+        device const float   * y2 = y1 + 128;
+
+        const float dall = (float)((x + i)->d);
+        const float dmin = (float)((x + i)->dmin);
+
+        device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
+        sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
+        sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
+        sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
+        sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
+
+        float4 s = {0.f, 0.f, 0.f, 0.f};
+        float smin = 0;
+        for (int l = 0; l < n; ++l) {
+
+            s[0] += y1[l+ 0] * ((q1[l] & 0xF) + (qh[l] & hm1 ? 16 : 0));
+            s[1] += y1[l+32] * ((q1[l] >>  4) + (qh[l] & hm2 ? 16 : 0));
+            s[2] += y2[l+ 0] * ((q2[l] & 0xF) + (qh[l] & hm3 ? 16 : 0));
+            s[3] += y2[l+32] * ((q2[l] >>  4) + (qh[l] & hm4 ? 16 : 0));
+            smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
+
+        }
+        sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
+
+    }
+    sum[ith] = sumf;
+
+    //
+    // Accumulate the sum from all threads in the threadgroup
+    //
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith%4 == 0) {
+        sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
+    }
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith%16 == 0) {
+        sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
+    }
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith == 0) {
+        for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
+        dst[r1*ne0 + r0] = sum[0];
+    }
+
+}
+
+kernel void kernel_mul_mat_q6_k_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne10,
+        constant   int64_t & ne0,
+        threadgroup float  * sum [[threadgroup(0)]],
+        uint2 tgpig[[threadgroup_position_in_grid]],
+        uint2 tpitg[[thread_position_in_threadgroup]],
+        uint2  tptg[[threads_per_threadgroup]]) {
+
+    const uint8_t kmask1 = 0x03;
+    const uint8_t kmask2 = 0x0C;
+    const uint8_t kmask3 = 0x30;
+    const uint8_t kmask4 = 0xC0;
+
+    const int nb = ne00/QK_K;
+
+    const int64_t r0 = tgpig.x;
+    const int64_t r1 = tgpig.y;
+
+    device const block_q6_k * x = (device const block_q6_k *) src0 + r0*nb;
+    device const float     * yy = (device const float      *) src1 + r1*ne10;
+
+    const int nth = tptg.x*tptg.y;
+    const int ith = tptg.y*tpitg.x + tpitg.y;
+
+    // Note: we absolutely assume that tptg.y = 16 and QK_K = 256!
+    const int iqs  = 16 * tpitg.y;
+    const int ip   = iqs / 128;         // 0 or 1
+    const int il   = (iqs - 128*ip)/16; // 0...7
+    const int n    = 4;
+    const int l0   = n*il;
+    const int is   = 8*ip + l0/16;
+
+    const int y_offset = 128*ip + l0;
+    const int q_offset_l = 64*ip + l0;
+    const int q_offset_h = 32*ip + l0;
+
+    float sumf = 0;
+    for (int i = tpitg.x; i < nb; i += tptg.x) {
+
+        device const uint8_t * ql = x[i].ql + q_offset_l;
+        device const uint8_t * qh = x[i].qh + q_offset_h;
+        device const int8_t  * sc = x[i].scales + is;
+
+        device const float * y = yy + i * QK_K + y_offset;
+
+        const float dall = x[i].d;
+
+        float4 sums = {0.f, 0.f, 0.f, 0.f};
+        for (int l = 0; l < n; ++l) {
+            sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
+            sums[1] += y[l+32] * ((int8_t)((ql[l+32] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
+            sums[2] += y[l+64] * ((int8_t)((ql[l+ 0]  >> 4) | ((qh[l] & kmask3) << 0)) - 32);
+            sums[3] += y[l+96] * ((int8_t)((ql[l+32]  >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
+        }
+
+        sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
+
+    }
+
+    sum[ith] = sumf;
+
+    //
+    // Accumulate the sum from all threads in the threadgroup
+    //
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith%4 == 0) {
+        for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
+    }
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith%16 == 0) {
+        for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
+    }
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith == 0) {
+        for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
+        dst[r1*ne0 + r0] = sum[0];
+    }
+
+}