]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
metal : reduce command encoding overhead (llama/9698)
authorGeorgi Gerganov <redacted>
Wed, 2 Oct 2024 12:12:16 +0000 (15:12 +0300)
committerGeorgi Gerganov <redacted>
Thu, 3 Oct 2024 09:22:17 +0000 (12:22 +0300)
ggml/include/ggml-metal.h
ggml/src/ggml-metal.m

index d483cf1ac40c6e91590599cb1196e68ad56a2615..4d416532449e08cc8b980f560b3c55483643775c 100644 (file)
@@ -25,9 +25,6 @@
 #include <stddef.h>
 #include <stdbool.h>
 
-// max memory buffers that can be mapped to the device
-#define GGML_METAL_MAX_BUFFERS 64
-
 struct ggml_tensor;
 struct ggml_cgraph;
 
@@ -48,8 +45,6 @@ GGML_API bool ggml_backend_is_metal(ggml_backend_t backend);
 
 GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size);
 
-GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb);
-
 GGML_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data);
 
 GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
index ef3b7f0e824a913e3e934d93f31886155bc80357..c1e3a66d931c2f1fde0b0dce84f2884371b14f15 100644 (file)
 #define MIN(a, b) ((a) < (b) ? (a) : (b))
 #define MAX(a, b) ((a) > (b) ? (a) : (b))
 
+// max memory buffers that can be mapped to the device
+#define GGML_METAL_MAX_BUFFERS 64
+
+// max number of MTLCommandBuffer used to submit a graph for processing
+#define GGML_METAL_MAX_COMMAND_BUFFERS 8
+
 #ifdef GGML_METAL_NDEBUG
 #define GGML_METAL_LOG(...)
 #define GGML_METAL_LOG_INFO(...)
@@ -221,11 +227,11 @@ enum ggml_metal_kernel_type {
 };
 
 struct ggml_backend_metal_context {
-    int n_cb;
-
     id<MTLDevice>       device;
     id<MTLCommandQueue> queue;
 
+    MTLComputePassDescriptor * edesc;
+
     dispatch_queue_t d_queue;
 
     struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT];
@@ -233,7 +239,27 @@ struct ggml_backend_metal_context {
     bool support_simdgroup_reduction;
     bool support_simdgroup_mm;
 
-    bool should_capture_next_compute;
+    // capture state
+    bool capture_next_compute;
+    bool capture_started;
+
+    id<MTLCaptureScope> capture_scope;
+
+    // command buffer state
+    int n_cb;           // number of extra threads used to submit the command buffers
+    int n_nodes_0;      // number of nodes submitted by the main thread
+    int n_nodes_1;      // remaining number of nodes submitted by the n_cb threads
+    int n_nodes_per_cb;
+
+    struct ggml_cgraph * gf;
+
+    // the callback given to the thread pool
+    // TODO: ideally, this should be created once, utilizing the command buffer state above
+    //       for some reason, doing it like this leads to a crash
+    void (^encode_async)(size_t ith);
+
+    // n_cb command buffers + 1 used by the main thread
+    id<MTLCommandBuffer> command_buffers[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
 
     // abort ggml_metal_graph_compute if callback returns true
     ggml_abort_callback abort_callback;
@@ -303,7 +329,7 @@ static void * ggml_metal_host_malloc(size_t n) {
     return data;
 }
 
-static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
+static struct ggml_backend_metal_context * ggml_metal_init(void) {
     GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
 
 #if TARGET_OS_OSX && !GGML_METAL_NDEBUG
@@ -322,8 +348,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
     // Configure context
     struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
     ctx->device = device;
-    ctx->n_cb   = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
     ctx->queue  = [ctx->device newCommandQueue];
+    ctx->edesc  = MTLComputePassDescriptor.computePassDescriptor;
+    ctx->edesc.dispatchType = MTLDispatchTypeSerial;
     ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
 
     id<MTLLibrary> metal_library;
@@ -455,7 +482,15 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
     GGML_METAL_LOG_INFO("%s: simdgroup matrix mul. support = %s\n",       __func__, ctx->support_simdgroup_mm ? "true" : "false");
     GGML_METAL_LOG_INFO("%s: hasUnifiedMemory              = %s\n",       __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
 
-    ctx->should_capture_next_compute = false;
+    ctx->capture_next_compute = false;
+    ctx->capture_started = false;
+    ctx->capture_scope = nil;
+
+    ctx->gf = nil;
+    ctx->encode_async = nil;
+    for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
+        ctx->command_buffers[i] = nil;
+    }
 
 #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
     if (@available(macOS 10.12, iOS 16.0, *)) {
@@ -686,6 +721,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
     }
 
     [metal_library release];
+
     return ctx;
 }
 
@@ -874,874 +910,820 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
     }
 }
 
-static enum ggml_status ggml_metal_graph_compute(
-        struct ggml_backend_metal_context * ctx,
-               struct ggml_cgraph * gf) {
+static void ggml_metal_encode_node(
+     struct ggml_backend_metal_context * ctx,
+                                   int   idx,
+          id<MTLComputeCommandEncoder>   encoder) {
+    struct ggml_cgraph * gf = ctx->gf;
 
-    @autoreleasepool {
-    MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
-    edesc.dispatchType = MTLDispatchTypeSerial;
+    struct ggml_tensor * node = ggml_graph_node(gf, idx);
 
-    // create multiple command buffers and enqueue them
-    // then, we encode the graph into the command buffers in parallel
+    //GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op));
 
-    const int n_nodes = gf->n_nodes;
-    const int n_cb = ctx->n_cb;
-    const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
+    struct ggml_tensor * src0 = node->src[0];
+    struct ggml_tensor * src1 = node->src[1];
+    struct ggml_tensor * src2 = node->src[2];
+    struct ggml_tensor * dst  = node;
 
-    const bool should_capture = ctx->should_capture_next_compute;
-    if (should_capture) {
-        ctx->should_capture_next_compute = false;
+    if (ggml_is_empty(dst)) {
+        return;
+    }
 
-        MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
-        descriptor.captureObject = ctx->queue;
+    switch (dst->op) {
+        case GGML_OP_NONE:
+        case GGML_OP_RESHAPE:
+        case GGML_OP_VIEW:
+        case GGML_OP_TRANSPOSE:
+        case GGML_OP_PERMUTE:
+            {
+                // noop -> next node
+            } return;
+        default:
+            {
+            } break;
+    }
 
-        NSError * error = nil;
-        if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) {
-            GGML_METAL_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]);
-            GGML_ABORT("capture failed");
-        }
+    if (!ggml_metal_supports_op(ctx, dst)) {
+        GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
+        GGML_ABORT("unsupported op");
     }
 
-    id<MTLCommandBuffer> command_buffer_builder[n_cb];
-    for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
-        id<MTLCommandBuffer> command_buffer  = [ctx->queue commandBufferWithUnretainedReferences];
-        command_buffer_builder[cb_idx] = command_buffer;
+    const int64_t  ne00 = src0 ? src0->ne[0] : 0;
+    const int64_t  ne01 = src0 ? src0->ne[1] : 0;
+    const int64_t  ne02 = src0 ? src0->ne[2] : 0;
+    const int64_t  ne03 = src0 ? src0->ne[3] : 0;
+
+    const uint64_t nb00 = src0 ? src0->nb[0] : 0;
+    const uint64_t nb01 = src0 ? src0->nb[1] : 0;
+    const uint64_t nb02 = src0 ? src0->nb[2] : 0;
+    const uint64_t nb03 = src0 ? src0->nb[3] : 0;
+
+    const int64_t  ne10 = src1 ? src1->ne[0] : 0;
+    const int64_t  ne11 = src1 ? src1->ne[1] : 0;
+    const int64_t  ne12 = src1 ? src1->ne[2] : 0;
+    const int64_t  ne13 = src1 ? src1->ne[3] : 0;
+
+    const uint64_t nb10 = src1 ? src1->nb[0] : 0;
+    const uint64_t nb11 = src1 ? src1->nb[1] : 0;
+    const uint64_t nb12 = src1 ? src1->nb[2] : 0;
+    const uint64_t nb13 = src1 ? src1->nb[3] : 0;
+
+    const int64_t  ne20 = src2 ? src2->ne[0] : 0;
+    const int64_t  ne21 = src2 ? src2->ne[1] : 0;
+    const int64_t  ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22);
+    const int64_t  ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
+
+    const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
+    const uint64_t nb21 = src2 ? src2->nb[1] : 0;
+    const uint64_t nb22 = src2 ? src2->nb[2] : 0;
+    const uint64_t nb23 = src2 ? src2->nb[3] : 0;
+
+    const int64_t  ne0  =  dst ?  dst->ne[0] : 0;
+    const int64_t  ne1  =  dst ?  dst->ne[1] : 0;
+    const int64_t  ne2  =  dst ?  dst->ne[2] : 0;
+    const int64_t  ne3  =  dst ?  dst->ne[3] : 0;
+
+    const uint64_t nb0  =  dst ?  dst->nb[0] : 0;
+    const uint64_t nb1  =  dst ?  dst->nb[1] : 0;
+    const uint64_t nb2  =  dst ?  dst->nb[2] : 0;
+    const uint64_t nb3  =  dst ?  dst->nb[3] : 0;
+
+    const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
+    const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
+    const enum ggml_type dstt  = dst  ? dst->type  : GGML_TYPE_COUNT;
+
+    size_t offs_src0 = 0;
+    size_t offs_src1 = 0;
+    size_t offs_src2 = 0;
+    size_t offs_dst  = 0;
+
+    id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil;
+    id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil;
+    id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
+    id<MTLBuffer> id_dst  = dst  ? ggml_metal_get_buffer(dst,  &offs_dst)  : nil;
+
+    //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
+    //if (src0) {
+    //    GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
+    //            ggml_is_contiguous(src0), src0->name);
+    //}
+    //if (src1) {
+    //    GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
+    //            ggml_is_contiguous(src1), src1->name);
+    //}
+    //if (dst) {
+    //    GGML_METAL_LOG_INFO("%s: dst  - %4s [%5lld, %5lld, %5lld], 1, %s\n",  __func__, ggml_type_name(dstt),  ne0,  ne1,  ne2,
+    //            dst->name);
+    //}
+
+    switch (dst->op) {
+        case GGML_OP_CONCAT:
+            {
+                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
+
+                const int32_t dim = ((const int32_t *) dst->op_params)[0];
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
+                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
+                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
+                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
+                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
+                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
+                [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
+                [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
+                [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
+                [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
+                [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
+                [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
+                [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
+                [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
+                [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
+                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:19];
+                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:20];
+                [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:21];
+                [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:22];
+                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:23];
+                [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:24];
+                [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:25];
+                [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:26];
+                [encoder setBytes:&dim  length:sizeof(dim)  atIndex:27];
+
+                const int nth = MIN(1024, ne0);
+
+                [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+            } break;
+        case GGML_OP_ADD:
+        case GGML_OP_SUB:
+        case GGML_OP_MUL:
+        case GGML_OP_DIV:
+            {
+                GGML_ASSERT(src0t == GGML_TYPE_F32);
+                GGML_ASSERT(src1t == GGML_TYPE_F32);
 
-        // always enqueue the first two command buffers
-        // enqueue all of the command buffers if we don't need to abort
-        if (cb_idx < 2 || ctx->abort_callback == NULL) {
-            [command_buffer enqueue];
-        }
-    }
+                const size_t offs = 0;
 
-    const id<MTLCommandBuffer> *command_buffers = command_buffer_builder;
+                bool bcast_row = false;
 
-    dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) {
-        const int cb_idx = iter;
+                int64_t nb = ne00; // used by the "row" kernels
 
-        size_t offs_src0 = 0;
-        size_t offs_src1 = 0;
-        size_t offs_src2 = 0;
-        size_t offs_dst  = 0;
+                id<MTLComputePipelineState> pipeline = nil;
 
-        id<MTLCommandBuffer> command_buffer  = command_buffers[cb_idx];
-        id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
+                if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
+                    GGML_ASSERT(ggml_is_contiguous(src0));
 
-        const int node_start =                                      (cb_idx + 0) * n_nodes_per_cb;
-        const int node_end   = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
+                    // src1 is a row
+                    GGML_ASSERT(ne11 == 1);
 
-        for (int i = node_start; i < node_end; ++i) {
-            if (i == -1) {
-                [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
-                continue;
-            }
+                    nb = ne00 / 4;
+                    switch (dst->op) {
+                        case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
+                        case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break;
+                        case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
+                        case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
+                        default: GGML_ABORT("fatal error");
+                    }
 
-            //GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
+                    bcast_row = true;
+                } else {
+                    switch (dst->op) {
+                        case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
+                        case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
+                        case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
+                        case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
+                        default: GGML_ABORT("fatal error");
+                    }
+                }
 
-            struct ggml_tensor * src0 = gf->nodes[i]->src[0];
-            struct ggml_tensor * src1 = gf->nodes[i]->src[1];
-            struct ggml_tensor * src2 = gf->nodes[i]->src[2];
-            struct ggml_tensor * dst  = gf->nodes[i];
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
+                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
+                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
+                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
+                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
+                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
+                [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
+                [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
+                [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
+                [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
+                [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
+                [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
+                [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
+                [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
+                [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
+                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:19];
+                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:20];
+                [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:21];
+                [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:22];
+                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:23];
+                [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:24];
+                [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:25];
+                [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:26];
+                [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
+                [encoder setBytes:&nb   length:sizeof(nb)   atIndex:28];
+
+                if (bcast_row) {
+                    const int64_t n = ggml_nelements(dst)/4;
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } else {
+                    const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                }
+            } break;
+        case GGML_OP_REPEAT:
+            {
+                id<MTLComputePipelineState> pipeline;
+
+                switch (src0t) {
+                    case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break;
+                    case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break;
+                    case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break;
+                    case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break;
+                    default: GGML_ABORT("fatal error");
+                }
 
-            if (ggml_is_empty(dst)) {
-                continue;
-            }
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
+                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
+                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
+                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
+                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
+                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
+                [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
+                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:10];
+                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:11];
+                [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:12];
+                [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:13];
+                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:14];
+                [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:15];
+                [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:16];
+                [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:17];
+
+                const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
+
+                [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+            } break;
+        case GGML_OP_ACC:
+            {
+                GGML_ASSERT(src0t == GGML_TYPE_F32);
+                GGML_ASSERT(src1t == GGML_TYPE_F32);
+                GGML_ASSERT(dstt  == GGML_TYPE_F32);
+
+                GGML_ASSERT(ggml_is_contiguous(src0));
+                GGML_ASSERT(ggml_is_contiguous(src1));
+
+                const size_t pnb1 = ((const int32_t *) dst->op_params)[0];
+                const size_t pnb2 = ((const int32_t *) dst->op_params)[1];
+                const size_t pnb3 = ((const int32_t *) dst->op_params)[2];
+                const size_t offs = ((const int32_t *) dst->op_params)[3];
+
+                const bool inplace = (bool) ((const int32_t *) dst->op_params)[4];
+
+                if (!inplace) {
+                    // run a separete kernel to cpy src->dst
+                    // not sure how to avoid this
+                    // TODO: make a simpler cpy_bytes kernel
+
+                    const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline;
+
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
+                    [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
+                    [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
+                    [encoder setBytes:&ne01    length:sizeof( int64_t) atIndex:3];
+                    [encoder setBytes:&ne02    length:sizeof( int64_t) atIndex:4];
+                    [encoder setBytes:&ne03    length:sizeof( int64_t) atIndex:5];
+                    [encoder setBytes:&nb00    length:sizeof(uint64_t) atIndex:6];
+                    [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:7];
+                    [encoder setBytes:&nb02    length:sizeof(uint64_t) atIndex:8];
+                    [encoder setBytes:&nb03    length:sizeof(uint64_t) atIndex:9];
+                    [encoder setBytes:&ne0     length:sizeof( int64_t) atIndex:10];
+                    [encoder setBytes:&ne1     length:sizeof( int64_t) atIndex:11];
+                    [encoder setBytes:&ne2     length:sizeof( int64_t) atIndex:12];
+                    [encoder setBytes:&ne3     length:sizeof( int64_t) atIndex:13];
+                    [encoder setBytes:&nb0     length:sizeof(uint64_t) atIndex:14];
+                    [encoder setBytes:&nb1     length:sizeof(uint64_t) atIndex:15];
+                    [encoder setBytes:&nb2     length:sizeof(uint64_t) atIndex:16];
+                    [encoder setBytes:&nb3     length:sizeof(uint64_t) atIndex:17];
+
+                    const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                }
 
-            switch (dst->op) {
-                case GGML_OP_NONE:
-                case GGML_OP_RESHAPE:
-                case GGML_OP_VIEW:
-                case GGML_OP_TRANSPOSE:
-                case GGML_OP_PERMUTE:
-                    {
-                        // noop -> next node
-                    } continue;
-                default:
-                    {
-                    } break;
-            }
+                const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
+                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
+                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
+                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
+                [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
+                [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
+                [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
+                [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
+                [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
+                [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
+                [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
+                [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
+                [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
+                [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
+                [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
+                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:19];
+                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:20];
+                [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:21];
+                [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:22];
+                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:23];
+                [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
+                [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
+                [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
+                [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
+
+                const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
+
+                [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+            } break;
+        case GGML_OP_SCALE:
+            {
+                GGML_ASSERT(ggml_is_contiguous(src0));
 
-            if (!ggml_metal_supports_op(ctx, dst)) {
-                GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
-                GGML_ABORT("unsupported op");
-            }
+                float scale;
+                memcpy(&scale, dst->op_params, sizeof(scale));
 
-            if (should_capture) {
-                [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst) encoding:NSUTF8StringEncoding]];
-            }
+                int64_t n = ggml_nelements(dst);
 
-            const int64_t  ne00 = src0 ? src0->ne[0] : 0;
-            const int64_t  ne01 = src0 ? src0->ne[1] : 0;
-            const int64_t  ne02 = src0 ? src0->ne[2] : 0;
-            const int64_t  ne03 = src0 ? src0->ne[3] : 0;
-
-            const uint64_t nb00 = src0 ? src0->nb[0] : 0;
-            const uint64_t nb01 = src0 ? src0->nb[1] : 0;
-            const uint64_t nb02 = src0 ? src0->nb[2] : 0;
-            const uint64_t nb03 = src0 ? src0->nb[3] : 0;
-
-            const int64_t  ne10 = src1 ? src1->ne[0] : 0;
-            const int64_t  ne11 = src1 ? src1->ne[1] : 0;
-            const int64_t  ne12 = src1 ? src1->ne[2] : 0;
-            const int64_t  ne13 = src1 ? src1->ne[3] : 0;
-
-            const uint64_t nb10 = src1 ? src1->nb[0] : 0;
-            const uint64_t nb11 = src1 ? src1->nb[1] : 0;
-            const uint64_t nb12 = src1 ? src1->nb[2] : 0;
-            const uint64_t nb13 = src1 ? src1->nb[3] : 0;
-
-            const int64_t  ne20 = src2 ? src2->ne[0] : 0;
-            const int64_t  ne21 = src2 ? src2->ne[1] : 0;
-            const int64_t  ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22);
-            const int64_t  ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
-
-            const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
-            const uint64_t nb21 = src2 ? src2->nb[1] : 0;
-            const uint64_t nb22 = src2 ? src2->nb[2] : 0;
-            const uint64_t nb23 = src2 ? src2->nb[3] : 0;
-
-            const int64_t  ne0  =  dst ?  dst->ne[0] : 0;
-            const int64_t  ne1  =  dst ?  dst->ne[1] : 0;
-            const int64_t  ne2  =  dst ?  dst->ne[2] : 0;
-            const int64_t  ne3  =  dst ?  dst->ne[3] : 0;
-
-            const uint64_t nb0  =  dst ?  dst->nb[0] : 0;
-            const uint64_t nb1  =  dst ?  dst->nb[1] : 0;
-            const uint64_t nb2  =  dst ?  dst->nb[2] : 0;
-            const uint64_t nb3  =  dst ?  dst->nb[3] : 0;
-
-            const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
-            const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
-            const enum ggml_type dstt  = dst  ? dst->type  : GGML_TYPE_COUNT;
-
-            id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil;
-            id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil;
-            id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
-            id<MTLBuffer> id_dst  = dst  ? ggml_metal_get_buffer(dst,  &offs_dst)  : nil;
-
-            //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
-            //if (src0) {
-            //    GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
-            //            ggml_is_contiguous(src0), src0->name);
-            //}
-            //if (src1) {
-            //    GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
-            //            ggml_is_contiguous(src1), src1->name);
-            //}
-            //if (dst) {
-            //    GGML_METAL_LOG_INFO("%s: dst  - %4s [%5lld, %5lld, %5lld], 1, %s\n",  __func__, ggml_type_name(dstt),  ne0,  ne1,  ne2,
-            //            dst->name);
-            //}
-
-            switch (dst->op) {
-                case GGML_OP_CONCAT:
-                    {
-                        id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
-
-                        const int32_t dim = ((int32_t *) dst->op_params)[0];
-
-                        [encoder setComputePipelineState:pipeline];
-                        [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                        [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
-                        [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
-                        [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
-                        [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
-                        [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
-                        [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
-                        [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
-                        [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
-                        [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
-                        [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
-                        [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
-                        [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
-                        [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
-                        [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
-                        [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
-                        [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
-                        [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
-                        [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
-                        [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:19];
-                        [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:20];
-                        [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:21];
-                        [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:22];
-                        [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:23];
-                        [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:24];
-                        [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:25];
-                        [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:26];
-                        [encoder setBytes:&dim  length:sizeof(dim)  atIndex:27];
-
-                        const int nth = MIN(1024, ne0);
-
-                        [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
-                    } break;
-                case GGML_OP_ADD:
-                case GGML_OP_SUB:
-                case GGML_OP_MUL:
-                case GGML_OP_DIV:
-                    {
-                        GGML_ASSERT(src0t == GGML_TYPE_F32);
-                        GGML_ASSERT(src1t == GGML_TYPE_F32);
-
-                        const size_t offs = 0;
-
-                        bool bcast_row = false;
-
-                        int64_t nb = ne00; // used by the "row" kernels
-
-                        id<MTLComputePipelineState> pipeline = nil;
-
-                        if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
-                            GGML_ASSERT(ggml_is_contiguous(src0));
-
-                            // src1 is a row
-                            GGML_ASSERT(ne11 == 1);
-
-                            nb = ne00 / 4;
-                            switch (dst->op) {
-                                case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
-                                case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break;
-                                case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
-                                case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
-                                default: GGML_ABORT("fatal error");
-                            }
+                id<MTLComputePipelineState> pipeline = nil;
 
-                            bcast_row = true;
-                        } else {
-                            switch (dst->op) {
-                                case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
-                                case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
-                                case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
-                                case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
-                                default: GGML_ABORT("fatal error");
-                            }
-                        }
+                if (n % 4 == 0) {
+                    n /= 4;
+                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline;
+                } else {
+                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline;
+                }
 
-                        [encoder setComputePipelineState:pipeline];
-                        [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                        [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
-                        [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
-                        [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
-                        [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
-                        [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
-                        [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
-                        [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
-                        [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
-                        [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
-                        [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
-                        [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
-                        [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
-                        [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
-                        [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
-                        [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
-                        [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
-                        [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
-                        [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
-                        [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:19];
-                        [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:20];
-                        [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:21];
-                        [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:22];
-                        [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:23];
-                        [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:24];
-                        [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:25];
-                        [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:26];
-                        [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
-                        [encoder setBytes:&nb   length:sizeof(nb)   atIndex:28];
-
-                        if (bcast_row) {
-                            const int64_t n = ggml_nelements(dst)/4;
-
-                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                        } else {
-                            const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0   offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_dst    offset:offs_dst  atIndex:1];
+                [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
 
-                            [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
-                        }
-                    } break;
-                case GGML_OP_REPEAT:
-                    {
-                        id<MTLComputePipelineState> pipeline;
-
-                        switch (src0t) {
-                            case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break;
-                            case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break;
-                            case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break;
-                            case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break;
-                            default: GGML_ABORT("fatal error");
-                        }
+                [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+            } break;
+        case GGML_OP_CLAMP:
+            {
+                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline;
 
-                        [encoder setComputePipelineState:pipeline];
-                        [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                        [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                        [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
-                        [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
-                        [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
-                        [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
-                        [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
-                        [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
-                        [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
-                        [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
-                        [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:10];
-                        [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:11];
-                        [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:12];
-                        [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:13];
-                        [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:14];
-                        [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:15];
-                        [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:16];
-                        [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:17];
-
-                        const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
-
-                        [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
-                    } break;
-                case GGML_OP_ACC:
-                    {
-                        GGML_ASSERT(src0t == GGML_TYPE_F32);
-                        GGML_ASSERT(src1t == GGML_TYPE_F32);
-                        GGML_ASSERT(dstt  == GGML_TYPE_F32);
-
-                        GGML_ASSERT(ggml_is_contiguous(src0));
-                        GGML_ASSERT(ggml_is_contiguous(src1));
-
-                        const size_t pnb1 = ((int32_t *) dst->op_params)[0];
-                        const size_t pnb2 = ((int32_t *) dst->op_params)[1];
-                        const size_t pnb3 = ((int32_t *) dst->op_params)[2];
-                        const size_t offs = ((int32_t *) dst->op_params)[3];
-
-                        const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
-
-                        if (!inplace) {
-                            // run a separete kernel to cpy src->dst
-                            // not sure how to avoid this
-                            // TODO: make a simpler cpy_bytes kernel
-
-                            const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline;
+                float min;
+                float max;
+                memcpy(&min, ((const int32_t *) dst->op_params) + 0, sizeof(float));
+                memcpy(&max, ((const int32_t *) dst->op_params) + 1, sizeof(float));
 
-                            [encoder setComputePipelineState:pipeline];
-                            [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
-                            [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
-                            [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
-                            [encoder setBytes:&ne01    length:sizeof( int64_t) atIndex:3];
-                            [encoder setBytes:&ne02    length:sizeof( int64_t) atIndex:4];
-                            [encoder setBytes:&ne03    length:sizeof( int64_t) atIndex:5];
-                            [encoder setBytes:&nb00    length:sizeof(uint64_t) atIndex:6];
-                            [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:7];
-                            [encoder setBytes:&nb02    length:sizeof(uint64_t) atIndex:8];
-                            [encoder setBytes:&nb03    length:sizeof(uint64_t) atIndex:9];
-                            [encoder setBytes:&ne0     length:sizeof( int64_t) atIndex:10];
-                            [encoder setBytes:&ne1     length:sizeof( int64_t) atIndex:11];
-                            [encoder setBytes:&ne2     length:sizeof( int64_t) atIndex:12];
-                            [encoder setBytes:&ne3     length:sizeof( int64_t) atIndex:13];
-                            [encoder setBytes:&nb0     length:sizeof(uint64_t) atIndex:14];
-                            [encoder setBytes:&nb1     length:sizeof(uint64_t) atIndex:15];
-                            [encoder setBytes:&nb2     length:sizeof(uint64_t) atIndex:16];
-                            [encoder setBytes:&nb3     length:sizeof(uint64_t) atIndex:17];
-
-                            const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
-
-                            [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
-                        }
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0   offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_dst    offset:offs_dst  atIndex:1];
+                [encoder setBytes:&min length:sizeof(min) atIndex:2];
+                [encoder setBytes:&max length:sizeof(max) atIndex:3];
 
-                        const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
-
-                        [encoder setComputePipelineState:pipeline];
-                        [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                        [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
-                        [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
-                        [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
-                        [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
-                        [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
-                        [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
-                        [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
-                        [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
-                        [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
-                        [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
-                        [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
-                        [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
-                        [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
-                        [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
-                        [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
-                        [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
-                        [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
-                        [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
-                        [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:19];
-                        [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:20];
-                        [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:21];
-                        [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:22];
-                        [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:23];
-                        [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
-                        [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
-                        [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
-                        [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
-
-                        const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
-
-                        [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
-                    } break;
-                case GGML_OP_SCALE:
-                    {
-                        GGML_ASSERT(ggml_is_contiguous(src0));
-
-                        float scale;
-                        memcpy(&scale, dst->op_params, sizeof(scale));
-
-                        int64_t n = ggml_nelements(dst);
-
-                        id<MTLComputePipelineState> pipeline = nil;
-
-                        if (n % 4 == 0) {
-                            n /= 4;
-                            pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline;
-                        } else {
-                            pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline;
-                        }
+                const int64_t n = ggml_nelements(dst);
 
-                        [encoder setComputePipelineState:pipeline];
-                        [encoder setBuffer:id_src0   offset:offs_src0 atIndex:0];
-                        [encoder setBuffer:id_dst    offset:offs_dst  atIndex:1];
-                        [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
-
-                        [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                    } break;
-                case GGML_OP_CLAMP:
-                    {
-                        id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline;
-
-                        float min;
-                        float max;
-                        memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float));
-                        memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float));
-
-                        [encoder setComputePipelineState:pipeline];
-                        [encoder setBuffer:id_src0   offset:offs_src0 atIndex:0];
-                        [encoder setBuffer:id_dst    offset:offs_dst  atIndex:1];
-                        [encoder setBytes:&min length:sizeof(min) atIndex:2];
-                        [encoder setBytes:&max length:sizeof(max) atIndex:3];
-
-                        const int64_t n = ggml_nelements(dst);
-
-                        [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                    } break;
-                case GGML_OP_UNARY:
-                    switch (ggml_get_unary_op(gf->nodes[i])) {
-                        // we are not taking into account the strides, so for now require contiguous tensors
-                        GGML_ASSERT(ggml_is_contiguous(src0));
-
-                        case GGML_UNARY_OP_TANH:
-                            {
-                                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline;
+                [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+            } break;
+        case GGML_OP_UNARY:
+            switch (ggml_get_unary_op(node)) {
+                // we are not taking into account the strides, so for now require contiguous tensors
+                GGML_ASSERT(ggml_is_contiguous(src0));
 
-                                [encoder setComputePipelineState:pipeline];
-                                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                case GGML_UNARY_OP_TANH:
+                {
+                    id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline;
 
-                                const int64_t n = ggml_nelements(dst);
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
 
-                                [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                            } break;
-                        case GGML_UNARY_OP_RELU:
-                            {
-                                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline;
+                    const int64_t n = ggml_nelements(dst);
 
-                                [encoder setComputePipelineState:pipeline];
-                                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } break;
+                case GGML_UNARY_OP_RELU:
+                {
+                    id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline;
 
-                                const int64_t n = ggml_nelements(dst);
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
 
-                                [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                            } break;
-                        case GGML_UNARY_OP_SIGMOID:
-                            {
-                                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline;
+                    const int64_t n = ggml_nelements(dst);
 
-                                [encoder setComputePipelineState:pipeline];
-                                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } break;
+                case GGML_UNARY_OP_SIGMOID:
+                {
+                    id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline;
 
-                                const int64_t n = ggml_nelements(dst);
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
 
-                                [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                            } break;
-                        case GGML_UNARY_OP_GELU:
-                            {
-                                int64_t n = ggml_nelements(dst);
+                    const int64_t n = ggml_nelements(dst);
 
-                                id<MTLComputePipelineState> pipeline = nil;
+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } break;
+                case GGML_UNARY_OP_GELU:
+                {
+                    int64_t n = ggml_nelements(dst);
 
-                                if (n % 4 == 0) {
-                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_4].pipeline;
-                                    n /= 4;
-                                } else {
-                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
-                                }
+                    id<MTLComputePipelineState> pipeline = nil;
 
-                                [encoder setComputePipelineState:pipeline];
-                                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                    if (n % 4 == 0) {
+                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_4].pipeline;
+                        n /= 4;
+                    } else {
+                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
+                    }
 
-                                [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                            } break;
-                        case GGML_UNARY_OP_GELU_QUICK:
-                            {
-                                int64_t n = ggml_nelements(dst);
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
 
-                                id<MTLComputePipelineState> pipeline = nil;
+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } break;
+                case GGML_UNARY_OP_GELU_QUICK:
+                {
+                    int64_t n = ggml_nelements(dst);
 
-                                if (n % 4 == 0) {
-                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline;
-                                    n /= 4;
-                                } else {
-                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
-                                }
+                    id<MTLComputePipelineState> pipeline = nil;
 
-                                [encoder setComputePipelineState:pipeline];
-                                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                    if (n % 4 == 0) {
+                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline;
+                        n /= 4;
+                    } else {
+                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
+                    }
 
-                                [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                            } break;
-                        case GGML_UNARY_OP_SILU:
-                            {
-                                int64_t n = ggml_nelements(dst);
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
 
-                                id<MTLComputePipelineState> pipeline = nil;
+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } break;
+                case GGML_UNARY_OP_SILU:
+                {
+                    int64_t n = ggml_nelements(dst);
 
-                                if (n % 4 == 0) {
-                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_4].pipeline;
-                                    n /= 4;
-                                } else {
-                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
-                                }
+                    id<MTLComputePipelineState> pipeline = nil;
 
-                                [encoder setComputePipelineState:pipeline];
-                                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                    if (n % 4 == 0) {
+                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_4].pipeline;
+                        n /= 4;
+                    } else {
+                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
+                    }
 
-                                [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                            } break;
-                        default:
-                            {
-                                GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
-                                GGML_ABORT("fatal error");
-                            }
-                    } break;
-                case GGML_OP_SQR:
-                    {
-                        GGML_ASSERT(ggml_is_contiguous(src0));
-
-                        id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQR].pipeline;
-
-                        [encoder setComputePipelineState:pipeline];
-                        [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                        [encoder setBuffer:id_dst  offset:offs_dst atIndex:1];
-
-                        const int64_t n = ggml_nelements(dst);
-
-                        [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                    } break;
-                case GGML_OP_SQRT:
-                    {
-                        GGML_ASSERT(ggml_is_contiguous(src0));
-
-                        id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQRT].pipeline;
-
-                        [encoder setComputePipelineState:pipeline];
-                        [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                        [encoder setBuffer:id_dst  offset:offs_dst atIndex:1];
-
-                        const int64_t n = ggml_nelements(dst);
-
-                        [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                    } break;
-                case GGML_OP_SIN:
-                    {
-                        GGML_ASSERT(ggml_is_contiguous(src0));
-
-                        id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIN].pipeline;
-
-                        [encoder setComputePipelineState:pipeline];
-                        [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                        [encoder setBuffer:id_dst  offset:offs_dst atIndex:1];
-
-                        const int64_t n = ggml_nelements(dst);
-
-                        [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                    } break;
-                case GGML_OP_COS:
-                    {
-                        GGML_ASSERT(ggml_is_contiguous(src0));
-
-                        id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_COS].pipeline;
-
-                        [encoder setComputePipelineState:pipeline];
-                        [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                        [encoder setBuffer:id_dst  offset:offs_dst atIndex:1];
-
-                        const int64_t n = ggml_nelements(dst);
-
-                        [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                    } break;
-                case GGML_OP_SUM_ROWS:
-                    {
-                        GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
-
-                        id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
-
-                        [encoder setComputePipelineState:pipeline];
-                        [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                        [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                        [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
-                        [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
-                        [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
-                        [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
-                        [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
-                        [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
-                        [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
-                        [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
-                        [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
-                        [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
-                        [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
-                        [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
-                        [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
-                        [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
-                        [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
-                        [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
-                        [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:18];
-                        [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:19];
-                        [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:20];
-                        [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:21];
-                        [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:22];
-                        [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:23];
-                        [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:24];
-                        [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:25];
-
-                        [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                    } break;
-                case GGML_OP_SOFT_MAX:
-                    {
-                        GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
-
-                        int nth = 32; // SIMD width
-
-                        id<MTLComputePipelineState> pipeline = nil;
-
-                        const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
-
-                        if (ne00%4 == 0) {
-                            while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) {
-                                nth *= 2;
-                            }
-                            if (use_f16) {
-                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline;
-                            } else {
-                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline;
-                            }
-                        } else {
-                            while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
-                                nth *= 2;
-                            }
-                            if (use_f16) {
-                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline;
-                            } else {
-                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline;
-                            }
-                        }
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
 
-                        float scale;
-                        float max_bias;
+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } break;
+                default:
+                {
+                    GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
+                    GGML_ABORT("fatal error");
+                }
+            } break;
+        case GGML_OP_SQR:
+            {
+                GGML_ASSERT(ggml_is_contiguous(src0));
 
-                        memcpy(&scale,    ((int32_t *) dst->op_params) + 0, sizeof(scale));
-                        memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
+                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQR].pipeline;
 
-                        const int64_t nrows_x = ggml_nrows(src0);
-                        const int64_t nrows_y = src0->ne[1];
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst atIndex:1];
 
-                        const uint32_t n_head      = nrows_x/nrows_y;
-                        const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
+                const int64_t n = ggml_nelements(dst);
 
-                        const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
-                        const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+                [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+            } break;
+        case GGML_OP_SQRT:
+            {
+                GGML_ASSERT(ggml_is_contiguous(src0));
 
-                        [encoder setComputePipelineState:pipeline];
-                        [encoder setBuffer:id_src0 offset:offs_src0   atIndex:0];
-                        if (id_src1) {
-                            [encoder setBuffer:id_src1 offset:offs_src1   atIndex:1];
-                        } else {
-                            [encoder setBuffer:id_src0 offset:offs_src0   atIndex:1];
-                        }
-                        [encoder setBuffer:id_dst      offset:offs_dst            atIndex:2];
-                        [encoder setBytes:&ne00        length:sizeof(ne00)        atIndex:3];
-                        [encoder setBytes:&ne01        length:sizeof(ne01)        atIndex:4];
-                        [encoder setBytes:&ne02        length:sizeof(ne02)        atIndex:5];
-                        [encoder setBytes:&scale       length:sizeof(scale)       atIndex:6];
-                        [encoder setBytes:&max_bias    length:sizeof(max_bias)    atIndex:7];
-                        [encoder setBytes:&m0          length:sizeof(m0)          atIndex:8];
-                        [encoder setBytes:&m1          length:sizeof(m1)          atIndex:9];
-                        [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
-                        [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
-
-                        [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
-                    } break;
-                case GGML_OP_DIAG_MASK_INF:
-                    {
-                        const int n_past = ((int32_t *)(dst->op_params))[0];
-
-                        id<MTLComputePipelineState> pipeline = nil;
-
-                        if (ne00%8 == 0) {
-                            pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline;
-                        } else {
-                            pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
-                        }
+                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQRT].pipeline;
 
-                        [encoder setComputePipelineState:pipeline];
-                        [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                        [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                        [encoder setBytes:&ne00   length:sizeof(ne00) atIndex:2];
-                        [encoder setBytes:&ne01   length:sizeof(ne01) atIndex:3];
-                        [encoder setBytes:&n_past length:sizeof(int)  atIndex:4];
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst atIndex:1];
 
-                        if (ne00%8 == 0) {
-                            [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                        }
-                        else {
-                            [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                        }
-                    } break;
-                case GGML_OP_SSM_CONV:
-                    {
-                        GGML_ASSERT(src0t == GGML_TYPE_F32);
-                        GGML_ASSERT(src1t == GGML_TYPE_F32);
-
-                        GGML_ASSERT(ggml_is_contiguous(src0));
-                        GGML_ASSERT(ggml_is_contiguous(src1));
-
-                        id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline;
-
-                        [encoder setComputePipelineState:pipeline];
-                        [encoder setBuffer:id_src0 offset:offs_src0    atIndex:0];
-                        [encoder setBuffer:id_src1 offset:offs_src1    atIndex:1];
-                        [encoder setBuffer:id_dst  offset:offs_dst     atIndex:2];
-                        [encoder setBytes:&ne00    length:sizeof(ne00) atIndex:3];
-                        [encoder setBytes:&ne01    length:sizeof(ne01) atIndex:4];
-                        [encoder setBytes:&ne02    length:sizeof(ne02) atIndex:5];
-                        [encoder setBytes:&nb00    length:sizeof(nb00) atIndex:6];
-                        [encoder setBytes:&nb01    length:sizeof(nb01) atIndex:7];
-                        [encoder setBytes:&nb02    length:sizeof(nb02) atIndex:8];
-                        [encoder setBytes:&ne10    length:sizeof(ne10) atIndex:9];
-                        [encoder setBytes:&ne11    length:sizeof(ne11) atIndex:10];
-                        [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:11];
-                        [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:12];
-                        [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:13];
-                        [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:14];
-                        [encoder setBytes:&ne2     length:sizeof(ne2)  atIndex:15];
-                        [encoder setBytes:&nb0     length:sizeof(nb0)  atIndex:16];
-                        [encoder setBytes:&nb1     length:sizeof(nb1)  atIndex:17];
-                        [encoder setBytes:&nb2     length:sizeof(nb2)  atIndex:18];
-
-                        [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                    } break;
-                case GGML_OP_SSM_SCAN:
-                    {
-                        struct ggml_tensor * src3 = gf->nodes[i]->src[3];
-                        struct ggml_tensor * src4 = gf->nodes[i]->src[4];
-                        struct ggml_tensor * src5 = gf->nodes[i]->src[5];
-
-                        GGML_ASSERT(src3);
-                        GGML_ASSERT(src4);
-                        GGML_ASSERT(src5);
-
-                        size_t offs_src3 = 0;
-                        size_t offs_src4 = 0;
-                        size_t offs_src5 = 0;
-
-                        id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
-                        id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
-                        id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil;
-
-                        const int64_t  ne30 = src3->ne[0]; GGML_UNUSED(ne30);
-                        const int64_t  ne31 = src3->ne[1]; GGML_UNUSED(ne31);
-
-                        const uint64_t nb30 = src3->nb[0];
-                        const uint64_t nb31 = src3->nb[1];
-
-                        const int64_t  ne40 = src4->ne[0]; GGML_UNUSED(ne40);
-                        const int64_t  ne41 = src4->ne[1]; GGML_UNUSED(ne41);
-                        const int64_t  ne42 = src4->ne[2]; GGML_UNUSED(ne42);
-
-                        const uint64_t nb40 = src4->nb[0];
-                        const uint64_t nb41 = src4->nb[1];
-                        const uint64_t nb42 = src4->nb[2];
-
-                        const int64_t  ne50 = src5->ne[0]; GGML_UNUSED(ne50);
-                        const int64_t  ne51 = src5->ne[1]; GGML_UNUSED(ne51);
-                        const int64_t  ne52 = src5->ne[2]; GGML_UNUSED(ne52);
-
-                        const uint64_t nb50 = src5->nb[0];
-                        const uint64_t nb51 = src5->nb[1];
-                        const uint64_t nb52 = src5->nb[2];
-
-                        const int64_t d_state      = ne00;
-                        const int64_t d_inner      = ne01;
-                        const int64_t n_seq_tokens = ne11;
-                        const int64_t n_seqs       = ne02;
-
-                        id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
-
-                        [encoder setComputePipelineState:pipeline];
-                        [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                        [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
-                        [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
-                        [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
-                        [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
-                        [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
-                        [encoder setBuffer:id_dst  offset:offs_dst  atIndex:6];
-
-                        [encoder setBytes:&d_state      length:sizeof(d_state)      atIndex:7];
-                        [encoder setBytes:&d_inner      length:sizeof(d_inner)      atIndex:8];
-                        [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9];
-                        [encoder setBytes:&n_seqs       length:sizeof(n_seqs)       atIndex:10];
-
-                        [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11];
-                        [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12];
-                        [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13];
-                        [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
-                        [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
-                        [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
-                        [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
-                        [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18];
-                        [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19];
-                        [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20];
-                        [encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21];
-                        [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22];
-                        [encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23];
-                        [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24];
-                        [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25];
-                        [encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26];
-                        [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27];
-                        [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28];
-
-                        [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                    } break;
-                case GGML_OP_MUL_MAT:
-                    {
-                        GGML_ASSERT(ne00 == ne10);
-
-                        GGML_ASSERT(ne12 % ne02 == 0);
-                        GGML_ASSERT(ne13 % ne03 == 0);
-
-                        const uint r2 = ne12/ne02;
-                        const uint r3 = ne13/ne03;
-
-                        // find the break-even point where the matrix-matrix kernel becomes more efficient compared
-                        // to the matrix-vector kernel
-                        int ne11_mm_min = 1;
+                const int64_t n = ggml_nelements(dst);
+
+                [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+            } break;
+        case GGML_OP_SIN:
+            {
+                GGML_ASSERT(ggml_is_contiguous(src0));
+
+                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIN].pipeline;
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst atIndex:1];
+
+                const int64_t n = ggml_nelements(dst);
+
+                [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+            } break;
+        case GGML_OP_COS:
+            {
+                GGML_ASSERT(ggml_is_contiguous(src0));
+
+                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_COS].pipeline;
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst atIndex:1];
+
+                const int64_t n = ggml_nelements(dst);
+
+                [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+            } break;
+        case GGML_OP_SUM_ROWS:
+            {
+                GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
+
+                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
+                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
+                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
+                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
+                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
+                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
+                [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
+                [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
+                [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
+                [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
+                [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
+                [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
+                [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
+                [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
+                [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
+                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:18];
+                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:19];
+                [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:20];
+                [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:21];
+                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:22];
+                [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:23];
+                [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:24];
+                [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:25];
+
+                [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+            } break;
+        case GGML_OP_SOFT_MAX:
+            {
+                GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
+
+                int nth = 32; // SIMD width
+
+                id<MTLComputePipelineState> pipeline = nil;
+
+                const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
+
+                if (ne00%4 == 0) {
+                    while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) {
+                        nth *= 2;
+                    }
+                    if (use_f16) {
+                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline;
+                    } else {
+                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline;
+                    }
+                } else {
+                    while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
+                        nth *= 2;
+                    }
+                    if (use_f16) {
+                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline;
+                    } else {
+                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline;
+                    }
+                }
+
+                float scale;
+                float max_bias;
+
+                memcpy(&scale,    ((const int32_t *) dst->op_params) + 0, sizeof(scale));
+                memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias));
+
+                const int64_t nrows_x = ggml_nrows(src0);
+                const int64_t nrows_y = src0->ne[1];
+
+                const uint32_t n_head      = nrows_x/nrows_y;
+                const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
+
+                const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
+                const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0   atIndex:0];
+                if (id_src1) {
+                    [encoder setBuffer:id_src1 offset:offs_src1   atIndex:1];
+                } else {
+                    [encoder setBuffer:id_src0 offset:offs_src0   atIndex:1];
+                }
+                [encoder setBuffer:id_dst      offset:offs_dst            atIndex:2];
+                [encoder setBytes:&ne00        length:sizeof(ne00)        atIndex:3];
+                [encoder setBytes:&ne01        length:sizeof(ne01)        atIndex:4];
+                [encoder setBytes:&ne02        length:sizeof(ne02)        atIndex:5];
+                [encoder setBytes:&scale       length:sizeof(scale)       atIndex:6];
+                [encoder setBytes:&max_bias    length:sizeof(max_bias)    atIndex:7];
+                [encoder setBytes:&m0          length:sizeof(m0)          atIndex:8];
+                [encoder setBytes:&m1          length:sizeof(m1)          atIndex:9];
+                [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
+                [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
+
+                [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+            } break;
+        case GGML_OP_DIAG_MASK_INF:
+            {
+                const int n_past = ((const int32_t *)(dst->op_params))[0];
+
+                id<MTLComputePipelineState> pipeline = nil;
+
+                if (ne00%8 == 0) {
+                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline;
+                } else {
+                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
+                }
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                [encoder setBytes:&ne00   length:sizeof(ne00) atIndex:2];
+                [encoder setBytes:&ne01   length:sizeof(ne01) atIndex:3];
+                [encoder setBytes:&n_past length:sizeof(int)  atIndex:4];
+
+                if (ne00%8 == 0) {
+                    [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                }
+                else {
+                    [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                }
+            } break;
+        case GGML_OP_SSM_CONV:
+            {
+                GGML_ASSERT(src0t == GGML_TYPE_F32);
+                GGML_ASSERT(src1t == GGML_TYPE_F32);
+
+                GGML_ASSERT(ggml_is_contiguous(src0));
+                GGML_ASSERT(ggml_is_contiguous(src1));
+
+                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline;
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0    atIndex:0];
+                [encoder setBuffer:id_src1 offset:offs_src1    atIndex:1];
+                [encoder setBuffer:id_dst  offset:offs_dst     atIndex:2];
+                [encoder setBytes:&ne00    length:sizeof(ne00) atIndex:3];
+                [encoder setBytes:&ne01    length:sizeof(ne01) atIndex:4];
+                [encoder setBytes:&ne02    length:sizeof(ne02) atIndex:5];
+                [encoder setBytes:&nb00    length:sizeof(nb00) atIndex:6];
+                [encoder setBytes:&nb01    length:sizeof(nb01) atIndex:7];
+                [encoder setBytes:&nb02    length:sizeof(nb02) atIndex:8];
+                [encoder setBytes:&ne10    length:sizeof(ne10) atIndex:9];
+                [encoder setBytes:&ne11    length:sizeof(ne11) atIndex:10];
+                [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:11];
+                [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:12];
+                [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:13];
+                [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:14];
+                [encoder setBytes:&ne2     length:sizeof(ne2)  atIndex:15];
+                [encoder setBytes:&nb0     length:sizeof(nb0)  atIndex:16];
+                [encoder setBytes:&nb1     length:sizeof(nb1)  atIndex:17];
+                [encoder setBytes:&nb2     length:sizeof(nb2)  atIndex:18];
+
+                [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+            } break;
+        case GGML_OP_SSM_SCAN:
+            {
+                struct ggml_tensor * src3 = node->src[3];
+                struct ggml_tensor * src4 = node->src[4];
+                struct ggml_tensor * src5 = node->src[5];
+
+                GGML_ASSERT(src3);
+                GGML_ASSERT(src4);
+                GGML_ASSERT(src5);
+
+                size_t offs_src3 = 0;
+                size_t offs_src4 = 0;
+                size_t offs_src5 = 0;
+
+                id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
+                id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
+                id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil;
+
+                const int64_t  ne30 = src3->ne[0]; GGML_UNUSED(ne30);
+                const int64_t  ne31 = src3->ne[1]; GGML_UNUSED(ne31);
+
+                const uint64_t nb30 = src3->nb[0];
+                const uint64_t nb31 = src3->nb[1];
+
+                const int64_t  ne40 = src4->ne[0]; GGML_UNUSED(ne40);
+                const int64_t  ne41 = src4->ne[1]; GGML_UNUSED(ne41);
+                const int64_t  ne42 = src4->ne[2]; GGML_UNUSED(ne42);
+
+                const uint64_t nb40 = src4->nb[0];
+                const uint64_t nb41 = src4->nb[1];
+                const uint64_t nb42 = src4->nb[2];
+
+                const int64_t  ne50 = src5->ne[0]; GGML_UNUSED(ne50);
+                const int64_t  ne51 = src5->ne[1]; GGML_UNUSED(ne51);
+                const int64_t  ne52 = src5->ne[2]; GGML_UNUSED(ne52);
+
+                const uint64_t nb50 = src5->nb[0];
+                const uint64_t nb51 = src5->nb[1];
+                const uint64_t nb52 = src5->nb[2];
+
+                const int64_t d_state      = ne00;
+                const int64_t d_inner      = ne01;
+                const int64_t n_seq_tokens = ne11;
+                const int64_t n_seqs       = ne02;
+
+                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+                [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
+                [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
+                [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
+                [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:6];
+
+                [encoder setBytes:&d_state      length:sizeof(d_state)      atIndex:7];
+                [encoder setBytes:&d_inner      length:sizeof(d_inner)      atIndex:8];
+                [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9];
+                [encoder setBytes:&n_seqs       length:sizeof(n_seqs)       atIndex:10];
+
+                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11];
+                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12];
+                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13];
+                [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
+                [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
+                [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
+                [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
+                [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18];
+                [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19];
+                [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20];
+                [encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21];
+                [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22];
+                [encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23];
+                [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24];
+                [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25];
+                [encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26];
+                [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27];
+                [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28];
+
+                [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+            } break;
+        case GGML_OP_MUL_MAT:
+            {
+                GGML_ASSERT(ne00 == ne10);
+
+                GGML_ASSERT(ne12 % ne02 == 0);
+                GGML_ASSERT(ne13 % ne03 == 0);
+
+                const uint r2 = ne12/ne02;
+                const uint r3 = ne13/ne03;
+
+                // find the break-even point where the matrix-matrix kernel becomes more efficient compared
+                // to the matrix-vector kernel
+                int ne11_mm_min = 1;
 
 #if 0
-                        // the numbers below are measured on M2 Ultra for 7B and 13B models
-                        // these numbers do not translate to other devices or model sizes
-                        // TODO: need to find a better approach
+                // the numbers below are measured on M2 Ultra for 7B and 13B models
+                // these numbers do not translate to other devices or model sizes
+                // TODO: need to find a better approach
                         if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
                             switch (src0t) {
                                 case GGML_TYPE_F16:  ne11_mm_min = 2;  break;
@@ -1763,11 +1745,11 @@ static enum ggml_status ggml_metal_graph_compute(
                         // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
                         // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
                         if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
-                            !ggml_is_transposed(src0) &&
-                            !ggml_is_transposed(src1) &&
-                            src1t == GGML_TYPE_F32 &&
-                            ne00 % 32 == 0 && ne00 >= 64 &&
-                            (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
+                                !ggml_is_transposed(src0) &&
+                                !ggml_is_transposed(src1) &&
+                                src1t == GGML_TYPE_F32 &&
+                                ne00 % 32 == 0 && ne00 >= 64 &&
+                                (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
                             //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
 
                             // some Metal matrix data types require aligned pointers
@@ -2001,8 +1983,8 @@ static enum ggml_status ggml_metal_graph_compute(
                             [encoder setBytes:&r3   length:sizeof(r3)   atIndex:18];
 
                             if (src0t == GGML_TYPE_Q4_0  || src0t == GGML_TYPE_Q4_1  || src0t == GGML_TYPE_Q5_0 ||
-                                src0t == GGML_TYPE_Q5_1  || src0t == GGML_TYPE_Q8_0  || src0t == GGML_TYPE_Q2_K ||
-                                src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
+                                    src0t == GGML_TYPE_Q5_1  || src0t == GGML_TYPE_Q8_0  || src0t == GGML_TYPE_Q2_K ||
+                                    src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
                                 [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
                             else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
@@ -2036,1041 +2018,1157 @@ static enum ggml_status ggml_metal_graph_compute(
                                 [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
                         }
-                    } break;
-                case GGML_OP_MUL_MAT_ID:
-                    {
-                        const int n_as = src0->ne[2];
+            } break;
+        case GGML_OP_MUL_MAT_ID:
+            {
+                const int n_as = src0->ne[2];
+
+                // src2 = ids
+                const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
+
+                GGML_ASSERT(src2t == GGML_TYPE_I32);
+
+                GGML_ASSERT(!ggml_is_transposed(src0));
+                GGML_ASSERT(!ggml_is_transposed(src1));
+
+                GGML_ASSERT(src1t == GGML_TYPE_F32);
+
+                // find the break-even point where the matrix-matrix kernel becomes more efficient compared
+                // to the matrix-vector kernel
+                // ne20 = n_used_experts
+                // ne21 = n_rows
+                const int dst_rows = ne20*ne21;
+                const int dst_rows_min = n_as;
+                const int dst_rows_max = (ctx->device.maxThreadgroupMemoryLength - 32 - 8192)/4;
+
+                // max size of the rowids array in the kernel shared buffer
+                GGML_ASSERT(dst_rows <= dst_rows_max);
+
+                // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
+                // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
+                // !!!
+                // TODO: for now, always use mat-vec kernels until we figure out how to improve the
+                //       indirect matrix multiplication
+                // !!!
+                if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
+                        ne00 % 32 == 0 && ne00 >= 64 &&
+                        dst_rows > dst_rows_min) {
+
+                    // some Metal matrix data types require aligned pointers
+                    // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
+                    switch (src0->type) {
+                        case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
+                        case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8  == 0); break;
+                        default: break;
+                    }
 
-                        // src2 = ids
-                        const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
+                    id<MTLComputePipelineState> pipeline = nil;
+
+                    switch (src0->type) {
+                        case GGML_TYPE_F32:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32    ].pipeline; break;
+                        case GGML_TYPE_F16:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32    ].pipeline; break;
+                        case GGML_TYPE_Q4_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32   ].pipeline; break;
+                        case GGML_TYPE_Q4_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32   ].pipeline; break;
+                        case GGML_TYPE_Q5_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32   ].pipeline; break;
+                        case GGML_TYPE_Q5_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32   ].pipeline; break;
+                        case GGML_TYPE_Q8_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32   ].pipeline; break;
+                        case GGML_TYPE_Q2_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32   ].pipeline; break;
+                        case GGML_TYPE_Q3_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32   ].pipeline; break;
+                        case GGML_TYPE_Q4_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32   ].pipeline; break;
+                        case GGML_TYPE_Q5_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32   ].pipeline; break;
+                        case GGML_TYPE_Q6_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32   ].pipeline; break;
+                        case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
+                        case GGML_TYPE_IQ2_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
+                        case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
+                        case GGML_TYPE_IQ3_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32  ].pipeline; break;
+                        case GGML_TYPE_IQ2_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32  ].pipeline; break;
+                        case GGML_TYPE_IQ1_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32  ].pipeline; break;
+                        case GGML_TYPE_IQ1_M:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32  ].pipeline; break;
+                        case GGML_TYPE_IQ4_NL:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
+                        case GGML_TYPE_IQ4_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
+                        default: GGML_ABORT("MUL_MAT_ID not implemented");
+                    }
 
-                        GGML_ASSERT(src2t == GGML_TYPE_I32);
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBuffer:id_src0 offset:offs_src0    atIndex:0];
+                    [encoder setBuffer:id_src1 offset:offs_src1    atIndex:1];
+                    [encoder setBuffer:id_dst  offset:offs_dst     atIndex:2];
+                    [encoder setBuffer:id_src2 offset:offs_src2    atIndex:3];
+                    [encoder setBytes:&ne20    length:sizeof(ne20) atIndex:4];
+                    [encoder setBytes:&ne21    length:sizeof(ne21) atIndex:5];
+                    [encoder setBytes:&nb21    length:sizeof(nb21) atIndex:6];
+                    [encoder setBytes:&ne00    length:sizeof(ne00) atIndex:7];
+                    [encoder setBytes:&ne02    length:sizeof(ne02) atIndex:8];
+                    [encoder setBytes:&nb01    length:sizeof(nb01) atIndex:9];
+                    [encoder setBytes:&nb02    length:sizeof(nb02) atIndex:10];
+                    [encoder setBytes:&ne11    length:sizeof(ne11) atIndex:11];
+                    [encoder setBytes:&ne12    length:sizeof(ne12) atIndex:12];
+                    [encoder setBytes:&ne13    length:sizeof(ne13) atIndex:13];
+                    [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:14];
+                    [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:15];
+                    [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:16];
+                    [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:17];
+                    [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:18];
+                    [encoder setBytes:&nb1     length:sizeof(nb1)  atIndex:19];
+
+                    [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];
+
+                    [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+                } else {
+                    int nth0 = 32;
+                    int nth1 = 1;
+                    int nrows = 1;
+                    //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
+
+                    id<MTLComputePipelineState> pipeline = nil;
+
+                    // use custom matrix x vector kernel
+                    switch (src0t) {
+                        case GGML_TYPE_F32:
+                            {
+                                GGML_ASSERT(src1t == GGML_TYPE_F32);
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_F16:
+                            {
+                                GGML_ASSERT(src1t == GGML_TYPE_F32);
+                                nth0 = 32;
+                                nth1 = 1;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_Q4_0:
+                            {
+                                nth0 = 8;
+                                nth1 = 8;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_Q4_1:
+                            {
+                                nth0 = 8;
+                                nth1 = 8;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_Q5_0:
+                            {
+                                nth0 = 8;
+                                nth1 = 8;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_Q5_1:
+                            {
+                                nth0 = 8;
+                                nth1 = 8;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_Q8_0:
+                            {
+                                nth0 = 8;
+                                nth1 = 8;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_Q2_K:
+                            {
+                                nth0 = 2;
+                                nth1 = 32;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_Q3_K:
+                            {
+                                nth0 = 2;
+                                nth1 = 32;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_Q4_K:
+                            {
+                                nth0 = 4; //1;
+                                nth1 = 8; //32;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_Q5_K:
+                            {
+                                nth0 = 2;
+                                nth1 = 32;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_Q6_K:
+                            {
+                                nth0 = 2;
+                                nth1 = 32;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_IQ2_XXS:
+                            {
+                                nth0 = 4;
+                                nth1 = 16;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_IQ2_XS:
+                            {
+                                nth0 = 4;
+                                nth1 = 16;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_IQ3_XXS:
+                            {
+                                nth0 = 4;
+                                nth1 = 16;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_IQ3_S:
+                            {
+                                nth0 = 4;
+                                nth1 = 16;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_IQ2_S:
+                            {
+                                nth0 = 4;
+                                nth1 = 16;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_IQ1_S:
+                            {
+                                nth0 = 4;
+                                nth1 = 16;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_IQ1_M:
+                            {
+                                nth0 = 4;
+                                nth1 = 16;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_IQ4_NL:
+                            {
+                                nth0 = 4;
+                                nth1 = 16;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_IQ4_XS:
+                            {
+                                nth0 = 4;
+                                nth1 = 16;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
+                            } break;
+                        default:
+                            {
+                                GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
+                                GGML_ABORT("not implemented");
+                            }
+                    };
 
-                        GGML_ASSERT(!ggml_is_transposed(src0));
-                        GGML_ASSERT(!ggml_is_transposed(src1));
+                    if (ggml_is_quantized(src0t)) {
+                        GGML_ASSERT(ne00 >= nth0*nth1);
+                    }
 
-                        GGML_ASSERT(src1t == GGML_TYPE_F32);
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+                    [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
+                    [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
+                    [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
+                    [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
+                    [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
+                    [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8];
+                    [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9];
+                    [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10];
+                    [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11];
+                    [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12];
+                    [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13];
+                    [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14];
+                    [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15];
+                    [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16];
+                    [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17];
+                    [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18];
+                    [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19];
+                    [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:20];
+                    [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:21];
+                    [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:22];
+
+                    const int64_t _ne1 = 1;
+                    const int tgz = dst_rows;
+
+                    if (src0t == GGML_TYPE_Q4_0  || src0t == GGML_TYPE_Q4_1  || src0t == GGML_TYPE_Q5_0 ||
+                            src0t == GGML_TYPE_Q5_1  || src0t == GGML_TYPE_Q8_0  || src0t == GGML_TYPE_Q2_K ||
+                            src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
+                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                    }
+                    else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
+                        const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
+                        [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
+                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                    }
+                    else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
+                        const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
+                        [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
+                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                    }
+                    else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
+                        const int mem_size = 32*sizeof(float);
+                        [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
+                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                    }
+                    else if (src0t == GGML_TYPE_Q4_K) {
+                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                    }
+                    else if (src0t == GGML_TYPE_Q3_K) {
+                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                    }
+                    else if (src0t == GGML_TYPE_Q5_K) {
+                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                    }
+                    else if (src0t == GGML_TYPE_Q6_K) {
+                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                    } else {
+                        const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1
+                        [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                    }
+                }
+            } break;
+        case GGML_OP_GET_ROWS:
+            {
+                id<MTLComputePipelineState> pipeline = nil;
+
+                switch (src0->type) {
+                    case GGML_TYPE_F32:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32    ].pipeline; break;
+                    case GGML_TYPE_F16:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16    ].pipeline; break;
+                    case GGML_TYPE_Q4_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0   ].pipeline; break;
+                    case GGML_TYPE_Q4_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1   ].pipeline; break;
+                    case GGML_TYPE_Q5_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0   ].pipeline; break;
+                    case GGML_TYPE_Q5_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1   ].pipeline; break;
+                    case GGML_TYPE_Q8_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0   ].pipeline; break;
+                    case GGML_TYPE_Q2_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K   ].pipeline; break;
+                    case GGML_TYPE_Q3_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K   ].pipeline; break;
+                    case GGML_TYPE_Q4_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K   ].pipeline; break;
+                    case GGML_TYPE_Q5_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K   ].pipeline; break;
+                    case GGML_TYPE_Q6_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K   ].pipeline; break;
+                    case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
+                    case GGML_TYPE_IQ2_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
+                    case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
+                    case GGML_TYPE_IQ3_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S  ].pipeline; break;
+                    case GGML_TYPE_IQ2_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S  ].pipeline; break;
+                    case GGML_TYPE_IQ1_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S  ].pipeline; break;
+                    case GGML_TYPE_IQ1_M:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M  ].pipeline; break;
+                    case GGML_TYPE_IQ4_NL:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
+                    case GGML_TYPE_IQ4_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break;
+                    case GGML_TYPE_I32:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32    ].pipeline; break;
+                    default: GGML_ABORT("not implemented");
+                }
 
-                        // find the break-even point where the matrix-matrix kernel becomes more efficient compared
-                        // to the matrix-vector kernel
-                        // ne20 = n_used_experts
-                        // ne21 = n_rows
-                        const int dst_rows = ne20*ne21;
-                        const int dst_rows_min = n_as;
-                        const int dst_rows_max = (ctx->device.maxThreadgroupMemoryLength - 32 - 8192)/4;
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0     offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_src1     offset:offs_src1 atIndex:1];
+                [encoder setBuffer:id_dst      offset:offs_dst  atIndex:2];
+                [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
+                [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
+                [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
+                [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
+                [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
+                [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
+                [encoder setBytes:&nb1  length:sizeof(uint64_t) atIndex:9];
+                [encoder setBytes:&nb2  length:sizeof(uint64_t) atIndex:10];
+
+                [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
+            } break;
+        case GGML_OP_RMS_NORM:
+            {
+                GGML_ASSERT(ne00 % 4 == 0);
+                GGML_ASSERT(ggml_is_contiguous_1(src0));
 
-                        // max size of the rowids array in the kernel shared buffer
-                        GGML_ASSERT(dst_rows <= dst_rows_max);
+                float eps;
+                memcpy(&eps, dst->op_params, sizeof(float));
 
-                        // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
-                        // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
-                        // !!!
-                        // TODO: for now, always use mat-vec kernels until we figure out how to improve the
-                        //       indirect matrix multiplication
-                        // !!!
-                        if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
-                            ne00 % 32 == 0 && ne00 >= 64 &&
-                            dst_rows > dst_rows_min) {
+                int nth = 32; // SIMD width
 
-                            // some Metal matrix data types require aligned pointers
-                            // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
-                            switch (src0->type) {
-                                case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
-                                case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8  == 0); break;
-                                default: break;
-                            }
+                while (nth < ne00/4 && nth < 1024) {
+                    nth *= 2;
+                }
 
-                            id<MTLComputePipelineState> pipeline = nil;
+                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
 
-                            switch (src0->type) {
-                                case GGML_TYPE_F32:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32    ].pipeline; break;
-                                case GGML_TYPE_F16:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32    ].pipeline; break;
-                                case GGML_TYPE_Q4_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32   ].pipeline; break;
-                                case GGML_TYPE_Q4_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32   ].pipeline; break;
-                                case GGML_TYPE_Q5_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32   ].pipeline; break;
-                                case GGML_TYPE_Q5_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32   ].pipeline; break;
-                                case GGML_TYPE_Q8_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32   ].pipeline; break;
-                                case GGML_TYPE_Q2_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32   ].pipeline; break;
-                                case GGML_TYPE_Q3_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32   ].pipeline; break;
-                                case GGML_TYPE_Q4_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32   ].pipeline; break;
-                                case GGML_TYPE_Q5_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32   ].pipeline; break;
-                                case GGML_TYPE_Q6_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32   ].pipeline; break;
-                                case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
-                                case GGML_TYPE_IQ2_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
-                                case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
-                                case GGML_TYPE_IQ3_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32  ].pipeline; break;
-                                case GGML_TYPE_IQ2_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32  ].pipeline; break;
-                                case GGML_TYPE_IQ1_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32  ].pipeline; break;
-                                case GGML_TYPE_IQ1_M:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32  ].pipeline; break;
-                                case GGML_TYPE_IQ4_NL:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
-                                case GGML_TYPE_IQ4_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
-                                default: GGML_ABORT("MUL_MAT_ID not implemented");
-                            }
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
+                [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
+                [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:3];
+                [encoder setBytes:&eps     length:sizeof(   float) atIndex:4];
+                [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
 
-                            [encoder setComputePipelineState:pipeline];
-                            [encoder setBuffer:id_src0 offset:offs_src0    atIndex:0];
-                            [encoder setBuffer:id_src1 offset:offs_src1    atIndex:1];
-                            [encoder setBuffer:id_dst  offset:offs_dst     atIndex:2];
-                            [encoder setBuffer:id_src2 offset:offs_src2    atIndex:3];
-                            [encoder setBytes:&ne20    length:sizeof(ne20) atIndex:4];
-                            [encoder setBytes:&ne21    length:sizeof(ne21) atIndex:5];
-                            [encoder setBytes:&nb21    length:sizeof(nb21) atIndex:6];
-                            [encoder setBytes:&ne00    length:sizeof(ne00) atIndex:7];
-                            [encoder setBytes:&ne02    length:sizeof(ne02) atIndex:8];
-                            [encoder setBytes:&nb01    length:sizeof(nb01) atIndex:9];
-                            [encoder setBytes:&nb02    length:sizeof(nb02) atIndex:10];
-                            [encoder setBytes:&ne11    length:sizeof(ne11) atIndex:11];
-                            [encoder setBytes:&ne12    length:sizeof(ne12) atIndex:12];
-                            [encoder setBytes:&ne13    length:sizeof(ne13) atIndex:13];
-                            [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:14];
-                            [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:15];
-                            [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:16];
-                            [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:17];
-                            [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:18];
-                            [encoder setBytes:&nb1     length:sizeof(nb1)  atIndex:19];
-
-                            [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];
-
-                            [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
-                        } else {
-                            int nth0 = 32;
-                            int nth1 = 1;
-                            int nrows = 1;
-                            //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
+                const int64_t nrows = ggml_nrows(src0);
 
-                            id<MTLComputePipelineState> pipeline = nil;
+                [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+            } break;
+        case GGML_OP_GROUP_NORM:
+            {
+                GGML_ASSERT(ne00 % 4 == 0);
+                GGML_ASSERT(ggml_is_contiguous(src0));
 
-                            // use custom matrix x vector kernel
-                            switch (src0t) {
-                                case GGML_TYPE_F32:
-                                    {
-                                        GGML_ASSERT(src1t == GGML_TYPE_F32);
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_F16:
-                                    {
-                                        GGML_ASSERT(src1t == GGML_TYPE_F32);
-                                        nth0 = 32;
-                                        nth1 = 1;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_Q4_0:
-                                    {
-                                        nth0 = 8;
-                                        nth1 = 8;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_Q4_1:
-                                    {
-                                        nth0 = 8;
-                                        nth1 = 8;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_Q5_0:
-                                    {
-                                        nth0 = 8;
-                                        nth1 = 8;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_Q5_1:
-                                    {
-                                        nth0 = 8;
-                                        nth1 = 8;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_Q8_0:
-                                    {
-                                        nth0 = 8;
-                                        nth1 = 8;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_Q2_K:
-                                    {
-                                        nth0 = 2;
-                                        nth1 = 32;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_Q3_K:
-                                    {
-                                        nth0 = 2;
-                                        nth1 = 32;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_Q4_K:
-                                    {
-                                        nth0 = 4; //1;
-                                        nth1 = 8; //32;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_Q5_K:
-                                    {
-                                        nth0 = 2;
-                                        nth1 = 32;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_Q6_K:
-                                    {
-                                        nth0 = 2;
-                                        nth1 = 32;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_IQ2_XXS:
-                                    {
-                                        nth0 = 4;
-                                        nth1 = 16;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_IQ2_XS:
-                                    {
-                                        nth0 = 4;
-                                        nth1 = 16;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_IQ3_XXS:
-                                    {
-                                        nth0 = 4;
-                                        nth1 = 16;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_IQ3_S:
-                                    {
-                                        nth0 = 4;
-                                        nth1 = 16;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_IQ2_S:
-                                    {
-                                        nth0 = 4;
-                                        nth1 = 16;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_IQ1_S:
-                                    {
-                                        nth0 = 4;
-                                        nth1 = 16;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_IQ1_M:
-                                    {
-                                        nth0 = 4;
-                                        nth1 = 16;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_IQ4_NL:
-                                    {
-                                        nth0 = 4;
-                                        nth1 = 16;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_IQ4_XS:
-                                    {
-                                        nth0 = 4;
-                                        nth1 = 16;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
-                                    } break;
-                                default:
-                                    {
-                                        GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
-                                        GGML_ABORT("not implemented");
-                                    }
-                            };
+                float eps;
+                memcpy(&eps, dst->op_params + 1, sizeof(float));
 
-                            if (ggml_is_quantized(src0t)) {
-                                GGML_ASSERT(ne00 >= nth0*nth1);
-                            }
+                const int32_t n_groups = ((const int32_t *) dst->op_params)[0];
 
-                            [encoder setComputePipelineState:pipeline];
-                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                            [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
-                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
-                            [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
-                            [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
-                            [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
-                            [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
-                            [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
-                            [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8];
-                            [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9];
-                            [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10];
-                            [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11];
-                            [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12];
-                            [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13];
-                            [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14];
-                            [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15];
-                            [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16];
-                            [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17];
-                            [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18];
-                            [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19];
-                            [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:20];
-                            [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:21];
-                            [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:22];
-
-                            const int64_t _ne1 = 1;
-                            const int tgz = dst_rows;
+                int nth = 32; // SIMD width
 
-                            if (src0t == GGML_TYPE_Q4_0  || src0t == GGML_TYPE_Q4_1  || src0t == GGML_TYPE_Q5_0 ||
-                                src0t == GGML_TYPE_Q5_1  || src0t == GGML_TYPE_Q8_0  || src0t == GGML_TYPE_Q2_K ||
-                                src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                            }
-                            else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
-                                const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
-                                [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                            }
-                            else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
-                                const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
-                                [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                            }
-                            else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
-                                const int mem_size = 32*sizeof(float);
-                                [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                            }
-                            else if (src0t == GGML_TYPE_Q4_K) {
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                            }
-                            else if (src0t == GGML_TYPE_Q3_K) {
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                            }
-                            else if (src0t == GGML_TYPE_Q5_K) {
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                            }
-                            else if (src0t == GGML_TYPE_Q6_K) {
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                            } else {
-                                const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1
-                                [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                            }
-                        }
-                    } break;
-                case GGML_OP_GET_ROWS:
-                    {
-                        id<MTLComputePipelineState> pipeline = nil;
-
-                        switch (src0->type) {
-                            case GGML_TYPE_F32:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32    ].pipeline; break;
-                            case GGML_TYPE_F16:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16    ].pipeline; break;
-                            case GGML_TYPE_Q4_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0   ].pipeline; break;
-                            case GGML_TYPE_Q4_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1   ].pipeline; break;
-                            case GGML_TYPE_Q5_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0   ].pipeline; break;
-                            case GGML_TYPE_Q5_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1   ].pipeline; break;
-                            case GGML_TYPE_Q8_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0   ].pipeline; break;
-                            case GGML_TYPE_Q2_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K   ].pipeline; break;
-                            case GGML_TYPE_Q3_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K   ].pipeline; break;
-                            case GGML_TYPE_Q4_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K   ].pipeline; break;
-                            case GGML_TYPE_Q5_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K   ].pipeline; break;
-                            case GGML_TYPE_Q6_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K   ].pipeline; break;
-                            case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
-                            case GGML_TYPE_IQ2_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
-                            case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
-                            case GGML_TYPE_IQ3_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S  ].pipeline; break;
-                            case GGML_TYPE_IQ2_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S  ].pipeline; break;
-                            case GGML_TYPE_IQ1_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S  ].pipeline; break;
-                            case GGML_TYPE_IQ1_M:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M  ].pipeline; break;
-                            case GGML_TYPE_IQ4_NL:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
-                            case GGML_TYPE_IQ4_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break;
-                            case GGML_TYPE_I32:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32    ].pipeline; break;
-                            default: GGML_ABORT("not implemented");
-                        }
+                //while (nth < ne00/4 && nth < 1024) {
+                //    nth *= 2;
+                //}
 
-                        [encoder setComputePipelineState:pipeline];
-                        [encoder setBuffer:id_src0     offset:offs_src0 atIndex:0];
-                        [encoder setBuffer:id_src1     offset:offs_src1 atIndex:1];
-                        [encoder setBuffer:id_dst      offset:offs_dst  atIndex:2];
-                        [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
-                        [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
-                        [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
-                        [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
-                        [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
-                        [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
-                        [encoder setBytes:&nb1  length:sizeof(uint64_t) atIndex:9];
-                        [encoder setBytes:&nb2  length:sizeof(uint64_t) atIndex:10];
-
-                        [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
-                    } break;
-                case GGML_OP_RMS_NORM:
-                    {
-                        GGML_ASSERT(ne00 % 4 == 0);
-                        GGML_ASSERT(ggml_is_contiguous_1(src0));
-
-                        float eps;
-                        memcpy(&eps, dst->op_params, sizeof(float));
-
-                        int nth = 32; // SIMD width
-
-                        while (nth < ne00/4 && nth < 1024) {
-                            nth *= 2;
-                        }
+                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
 
-                        id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
-
-                        [encoder setComputePipelineState:pipeline];
-                        [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
-                        [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
-                        [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
-                        [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:3];
-                        [encoder setBytes:&eps     length:sizeof(   float) atIndex:4];
-                        [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
-
-                        const int64_t nrows = ggml_nrows(src0);
-
-                        [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
-                    } break;
-                case GGML_OP_GROUP_NORM:
-                    {
-                        GGML_ASSERT(ne00 % 4 == 0);
-                        GGML_ASSERT(ggml_is_contiguous(src0));
-
-                        float eps;
-                        memcpy(&eps, dst->op_params + 1, sizeof(float));
-
-                        const int32_t n_groups = ((int32_t *) dst->op_params)[0];
-
-                        int nth = 32; // SIMD width
-
-                        //while (nth < ne00/4 && nth < 1024) {
-                        //    nth *= 2;
-                        //}
-
-                        id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
-
-                        [encoder setComputePipelineState:pipeline];
-                        [encoder setBuffer:id_src0  offset:offs_src0        atIndex:0];
-                        [encoder setBuffer:id_dst   offset:offs_dst         atIndex:1];
-                        [encoder setBytes:&ne00     length:sizeof( int64_t) atIndex:2];
-                        [encoder setBytes:&ne01     length:sizeof( int64_t) atIndex:3];
-                        [encoder setBytes:&ne02     length:sizeof( int64_t) atIndex:4];
-                        [encoder setBytes:&nb00     length:sizeof(uint64_t) atIndex:5];
-                        [encoder setBytes:&nb01     length:sizeof(uint64_t) atIndex:6];
-                        [encoder setBytes:&nb02     length:sizeof(uint64_t) atIndex:7];
-                        [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
-                        [encoder setBytes:&eps      length:sizeof(   float) atIndex:9];
-                        [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
-
-                        [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
-                    } break;
-                case GGML_OP_NORM:
-                    {
-                        GGML_ASSERT(ggml_is_contiguous_1(src0));
-
-                        float eps;
-                        memcpy(&eps, dst->op_params, sizeof(float));
-
-                        const int nth = MIN(256, ne00);
-
-                        id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline;
-
-                        [encoder setComputePipelineState:pipeline];
-                        [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
-                        [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
-                        [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
-                        [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:3];
-                        [encoder setBytes:&eps     length:sizeof(   float) atIndex:4];
-                        [encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0];
-
-                        const int64_t nrows = ggml_nrows(src0);
-
-                        [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
-                    } break;
-                case GGML_OP_ROPE:
-                    {
-                        GGML_ASSERT(ne10 == ne02);
-
-                        const int nth = MIN(1024, ne00);
-
-                        const int n_past     = ((int32_t *) dst->op_params)[0];
-                        const int n_dims     = ((int32_t *) dst->op_params)[1];
-                        const int mode       = ((int32_t *) dst->op_params)[2];
-                        // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
-                        const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
-
-                        float freq_base;
-                        float freq_scale;
-                        float ext_factor;
-                        float attn_factor;
-                        float beta_fast;
-                        float beta_slow;
-
-                        memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
-                        memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
-                        memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
-                        memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
-                        memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
-                        memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
-
-                        const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
-
-                        id<MTLComputePipelineState> pipeline = nil;
-
-                        if (!is_neox) {
-                            switch (src0->type) {
-                                case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
-                                case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
-                                default: GGML_ABORT("fatal error");
-                            };
-                        } else {
-                            switch (src0->type) {
-                                case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
-                                case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
-                                default: GGML_ABORT("fatal error");
-                            };
-                        }
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0  offset:offs_src0        atIndex:0];
+                [encoder setBuffer:id_dst   offset:offs_dst         atIndex:1];
+                [encoder setBytes:&ne00     length:sizeof( int64_t) atIndex:2];
+                [encoder setBytes:&ne01     length:sizeof( int64_t) atIndex:3];
+                [encoder setBytes:&ne02     length:sizeof( int64_t) atIndex:4];
+                [encoder setBytes:&nb00     length:sizeof(uint64_t) atIndex:5];
+                [encoder setBytes:&nb01     length:sizeof(uint64_t) atIndex:6];
+                [encoder setBytes:&nb02     length:sizeof(uint64_t) atIndex:7];
+                [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
+                [encoder setBytes:&eps      length:sizeof(   float) atIndex:9];
+                [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
 
-                        [encoder setComputePipelineState:pipeline];
-                        [encoder setBuffer:id_src0     offset:offs_src0        atIndex:0];
-                        [encoder setBuffer:id_src1     offset:offs_src1        atIndex:1];
-                        if (id_src2 != nil) {
-                            [encoder setBuffer:id_src2 offset:offs_src2        atIndex:2];
-                        } else {
-                            [encoder setBuffer:id_src0 offset:offs_src0        atIndex:2];
-                        }
-                        [encoder setBuffer:id_dst      offset:offs_dst         atIndex:3];
-                        [encoder setBytes:&ne00        length:sizeof( int64_t) atIndex:4];
-                        [encoder setBytes:&ne01        length:sizeof( int64_t) atIndex:5];
-                        [encoder setBytes:&ne02        length:sizeof( int64_t) atIndex:6];
-                        [encoder setBytes:&ne03        length:sizeof( int64_t) atIndex:7];
-                        [encoder setBytes:&nb00        length:sizeof(uint64_t) atIndex:8];
-                        [encoder setBytes:&nb01        length:sizeof(uint64_t) atIndex:9];
-                        [encoder setBytes:&nb02        length:sizeof(uint64_t) atIndex:10];
-                        [encoder setBytes:&nb03        length:sizeof(uint64_t) atIndex:11];
-                        [encoder setBytes:&ne0         length:sizeof( int64_t) atIndex:12];
-                        [encoder setBytes:&ne1         length:sizeof( int64_t) atIndex:13];
-                        [encoder setBytes:&ne2         length:sizeof( int64_t) atIndex:14];
-                        [encoder setBytes:&ne3         length:sizeof( int64_t) atIndex:15];
-                        [encoder setBytes:&nb0         length:sizeof(uint64_t) atIndex:16];
-                        [encoder setBytes:&nb1         length:sizeof(uint64_t) atIndex:17];
-                        [encoder setBytes:&nb2         length:sizeof(uint64_t) atIndex:18];
-                        [encoder setBytes:&nb3         length:sizeof(uint64_t) atIndex:19];
-                        [encoder setBytes:&n_past      length:sizeof(     int) atIndex:20];
-                        [encoder setBytes:&n_dims      length:sizeof(     int) atIndex:21];
-                        [encoder setBytes:&n_ctx_orig  length:sizeof(     int) atIndex:22];
-                        [encoder setBytes:&freq_base   length:sizeof(   float) atIndex:23];
-                        [encoder setBytes:&freq_scale  length:sizeof(   float) atIndex:24];
-                        [encoder setBytes:&ext_factor  length:sizeof(   float) atIndex:25];
-                        [encoder setBytes:&attn_factor length:sizeof(   float) atIndex:26];
-                        [encoder setBytes:&beta_fast   length:sizeof(   float) atIndex:27];
-                        [encoder setBytes:&beta_slow   length:sizeof(   float) atIndex:28];
-
-                        [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
-                    } break;
-                case GGML_OP_IM2COL:
-                    {
-                        GGML_ASSERT(src0->type == GGML_TYPE_F16);
-                        GGML_ASSERT(src1->type == GGML_TYPE_F32);
-                        GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
-
-                        const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
-                        const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
-                        const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
-                        const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
-                        const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
-                        const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
-
-                        const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
-
-                        const int32_t N  = src1->ne[is_2D ? 3 : 2];
-                        const int32_t IC = src1->ne[is_2D ? 2 : 1];
-                        const int32_t IH = is_2D ? src1->ne[1] : 1;
-                        const int32_t IW =         src1->ne[0];
-
-                        const int32_t KH = is_2D ? src0->ne[1] : 1;
-                        const int32_t KW =         src0->ne[0];
-
-                        const int32_t OH = is_2D ? dst->ne[2] : 1;
-                        const int32_t OW =         dst->ne[1];
-
-                        const int32_t CHW = IC * KH * KW;
-
-                        const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
-                        const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
-
-                        id<MTLComputePipelineState> pipeline = nil;
-
-                        switch (dst->type) {
-                            case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break;
-                            case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
-                            default: GGML_ABORT("fatal error");
-                        };
-
-                        [encoder setComputePipelineState:pipeline];
-                        [encoder setBuffer:id_src1 offset:offs_src1        atIndex:0];
-                        [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
-                        [encoder setBytes:&ofs0    length:sizeof( int32_t) atIndex:2];
-                        [encoder setBytes:&ofs1    length:sizeof( int32_t) atIndex:3];
-                        [encoder setBytes:&IW      length:sizeof( int32_t) atIndex:4];
-                        [encoder setBytes:&IH      length:sizeof( int32_t) atIndex:5];
-                        [encoder setBytes:&CHW     length:sizeof( int32_t) atIndex:6];
-                        [encoder setBytes:&s0      length:sizeof( int32_t) atIndex:7];
-                        [encoder setBytes:&s1      length:sizeof( int32_t) atIndex:8];
-                        [encoder setBytes:&p0      length:sizeof( int32_t) atIndex:9];
-                        [encoder setBytes:&p1      length:sizeof( int32_t) atIndex:10];
-                        [encoder setBytes:&d0      length:sizeof( int32_t) atIndex:11];
-                        [encoder setBytes:&d1      length:sizeof( int32_t) atIndex:12];
-
-                        [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
-                    } break;
-                case GGML_OP_UPSCALE:
-                    {
-                        GGML_ASSERT(src0->type == GGML_TYPE_F32);
-
-                        const float sf0 = (float)ne0/src0->ne[0];
-                        const float sf1 = (float)ne1/src0->ne[1];
-                        const float sf2 = (float)ne2/src0->ne[2];
-                        const float sf3 = (float)ne3/src0->ne[3];
-
-                        const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
-
-                        [encoder setComputePipelineState:pipeline];
-                        [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                        [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                        [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
-                        [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
-                        [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
-                        [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
-                        [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
-                        [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
-                        [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
-                        [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
-                        [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:10];
-                        [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:11];
-                        [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:12];
-                        [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:13];
-                        [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:14];
-                        [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:15];
-                        [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:16];
-                        [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:17];
-                        [encoder setBytes:&sf0  length:sizeof(sf0)  atIndex:18];
-                        [encoder setBytes:&sf1  length:sizeof(sf1)  atIndex:19];
-                        [encoder setBytes:&sf2  length:sizeof(sf2)  atIndex:20];
-                        [encoder setBytes:&sf3  length:sizeof(sf3)  atIndex:21];
-
-                        const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
-
-                        [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
-                    } break;
-                case GGML_OP_PAD:
-                    {
-                        GGML_ASSERT(src0->type == GGML_TYPE_F32);
-
-                        id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
-
-                        [encoder setComputePipelineState:pipeline];
-                        [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                        [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                        [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
-                        [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
-                        [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
-                        [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
-                        [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
-                        [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
-                        [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
-                        [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
-                        [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:10];
-                        [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:11];
-                        [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:12];
-                        [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:13];
-                        [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:14];
-                        [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:15];
-                        [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:16];
-                        [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:17];
-
-                        const int nth = MIN(1024, ne0);
-
-                        [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
-                    } break;
-                case GGML_OP_ARANGE:
-                    {
-                        GGML_ASSERT(dst->type == GGML_TYPE_F32);
-
-                        float start;
-                        float step;
-
-                        memcpy(&start, ((int32_t *) dst->op_params) + 0, sizeof(float));
-                        memcpy(&step,  ((int32_t *) dst->op_params) + 2, sizeof(float));
-
-                        id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
-
-                        [encoder setComputePipelineState:pipeline];
-                        [encoder setBuffer:id_dst  offset:offs_dst    atIndex:0];
-                        [encoder setBytes:&ne0   length:sizeof(ne0)   atIndex:1];
-                        [encoder setBytes:&start length:sizeof(start) atIndex:2];
-                        [encoder setBytes:&step  length:sizeof(step)  atIndex:3];
-
-                        const int nth = MIN(1024, ne0);
-
-                        [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
-                    } break;
-                case GGML_OP_TIMESTEP_EMBEDDING:
-                    {
-                        GGML_ASSERT(src0->type == GGML_TYPE_F32);
-
-                        const int dim        = dst->op_params[0];
-                        const int max_period = dst->op_params[1];
-
-                        const int half = dim / 2;
-
-                        id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
-
-                        [encoder setComputePipelineState:pipeline];
-                        [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                        [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                        [encoder setBytes:&nb1   length:sizeof(nb1) atIndex:2];
-                        [encoder setBytes:&dim   length:sizeof(dim) atIndex:3];
-                        [encoder setBytes:&max_period length:sizeof(max_period) atIndex:4];
-
-                        const int nth = MIN(1024, half);
-
-                        [encoder dispatchThreadgroups:MTLSizeMake(ne00, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
-                    } break;
-                case GGML_OP_ARGSORT:
-                    {
-                        GGML_ASSERT(src0->type == GGML_TYPE_F32);
-                        GGML_ASSERT( dst->type == GGML_TYPE_I32);
-
-                        const int nrows = ggml_nrows(src0);
-
-                        enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
-
-                        // bitonic sort requires the number of elements to be power of 2
-                        int64_t ne00_padded = 1;
-                        while (ne00_padded < ne00) {
-                            ne00_padded *= 2;
-                        }
+                [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+            } break;
+        case GGML_OP_NORM:
+            {
+                GGML_ASSERT(ggml_is_contiguous_1(src0));
 
-                        // Metal kernels require the buffer size to be multiple of 16 bytes
-                        // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
-                        const int mem_size = GGML_PAD(ne00_padded*sizeof(int32_t), 16);
+                float eps;
+                memcpy(&eps, dst->op_params, sizeof(float));
 
-                        id<MTLComputePipelineState> pipeline = nil;
+                const int nth = MIN(256, ne00);
 
-                        switch (order) {
-                            case GGML_SORT_ORDER_ASC:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline;  break;
-                            case GGML_SORT_ORDER_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break;
-                            default: GGML_ABORT("fatal error");
-                        };
+                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline;
 
-                        [encoder setComputePipelineState:pipeline];
-                        [encoder setBuffer:id_src0     offset:offs_src0        atIndex:0];
-                        [encoder setBuffer:id_dst      offset:offs_dst         atIndex:1];
-                        [encoder setBytes:&ne00        length:sizeof( int64_t) atIndex:2];
-                        [encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3];
-                        [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
+                [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
+                [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:3];
+                [encoder setBytes:&eps     length:sizeof(   float) atIndex:4];
+                [encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0];
+
+                const int64_t nrows = ggml_nrows(src0);
+
+                [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+            } break;
+        case GGML_OP_ROPE:
+            {
+                GGML_ASSERT(ne10 == ne02);
+
+                const int nth = MIN(1024, ne00);
+
+                const int n_past     = ((const int32_t *) dst->op_params)[0];
+                const int n_dims     = ((const int32_t *) dst->op_params)[1];
+                const int mode       = ((const int32_t *) dst->op_params)[2];
+                // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
+                const int n_ctx_orig = ((const int32_t *) dst->op_params)[4];
+
+                float freq_base;
+                float freq_scale;
+                float ext_factor;
+                float attn_factor;
+                float beta_fast;
+                float beta_slow;
+
+                memcpy(&freq_base,   (const int32_t *) dst->op_params +  5, sizeof(float));
+                memcpy(&freq_scale,  (const int32_t *) dst->op_params +  6, sizeof(float));
+                memcpy(&ext_factor,  (const int32_t *) dst->op_params +  7, sizeof(float));
+                memcpy(&attn_factor, (const int32_t *) dst->op_params +  8, sizeof(float));
+                memcpy(&beta_fast,   (const int32_t *) dst->op_params +  9, sizeof(float));
+                memcpy(&beta_slow,   (const int32_t *) dst->op_params + 10, sizeof(float));
+
+                const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
+
+                id<MTLComputePipelineState> pipeline = nil;
+
+                if (!is_neox) {
+                    switch (src0->type) {
+                        case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
+                        case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
+                        default: GGML_ABORT("fatal error");
+                    };
+                } else {
+                    switch (src0->type) {
+                        case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
+                        case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
+                        default: GGML_ABORT("fatal error");
+                    };
+                }
 
-                        [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)];
-                    } break;
-                case GGML_OP_LEAKY_RELU:
-                    {
-                        GGML_ASSERT(src0->type == GGML_TYPE_F32);
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0     offset:offs_src0        atIndex:0];
+                [encoder setBuffer:id_src1     offset:offs_src1        atIndex:1];
+                if (id_src2 != nil) {
+                    [encoder setBuffer:id_src2 offset:offs_src2        atIndex:2];
+                } else {
+                    [encoder setBuffer:id_src0 offset:offs_src0        atIndex:2];
+                }
+                [encoder setBuffer:id_dst      offset:offs_dst         atIndex:3];
+                [encoder setBytes:&ne00        length:sizeof( int64_t) atIndex:4];
+                [encoder setBytes:&ne01        length:sizeof( int64_t) atIndex:5];
+                [encoder setBytes:&ne02        length:sizeof( int64_t) atIndex:6];
+                [encoder setBytes:&ne03        length:sizeof( int64_t) atIndex:7];
+                [encoder setBytes:&nb00        length:sizeof(uint64_t) atIndex:8];
+                [encoder setBytes:&nb01        length:sizeof(uint64_t) atIndex:9];
+                [encoder setBytes:&nb02        length:sizeof(uint64_t) atIndex:10];
+                [encoder setBytes:&nb03        length:sizeof(uint64_t) atIndex:11];
+                [encoder setBytes:&ne0         length:sizeof( int64_t) atIndex:12];
+                [encoder setBytes:&ne1         length:sizeof( int64_t) atIndex:13];
+                [encoder setBytes:&ne2         length:sizeof( int64_t) atIndex:14];
+                [encoder setBytes:&ne3         length:sizeof( int64_t) atIndex:15];
+                [encoder setBytes:&nb0         length:sizeof(uint64_t) atIndex:16];
+                [encoder setBytes:&nb1         length:sizeof(uint64_t) atIndex:17];
+                [encoder setBytes:&nb2         length:sizeof(uint64_t) atIndex:18];
+                [encoder setBytes:&nb3         length:sizeof(uint64_t) atIndex:19];
+                [encoder setBytes:&n_past      length:sizeof(     int) atIndex:20];
+                [encoder setBytes:&n_dims      length:sizeof(     int) atIndex:21];
+                [encoder setBytes:&n_ctx_orig  length:sizeof(     int) atIndex:22];
+                [encoder setBytes:&freq_base   length:sizeof(   float) atIndex:23];
+                [encoder setBytes:&freq_scale  length:sizeof(   float) atIndex:24];
+                [encoder setBytes:&ext_factor  length:sizeof(   float) atIndex:25];
+                [encoder setBytes:&attn_factor length:sizeof(   float) atIndex:26];
+                [encoder setBytes:&beta_fast   length:sizeof(   float) atIndex:27];
+                [encoder setBytes:&beta_slow   length:sizeof(   float) atIndex:28];
+
+                [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+            } break;
+        case GGML_OP_IM2COL:
+            {
+                GGML_ASSERT(src0->type == GGML_TYPE_F16);
+                GGML_ASSERT(src1->type == GGML_TYPE_F32);
+                GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
 
-                        float slope;
-                        memcpy(&slope, dst->op_params, sizeof(float));
+                const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
+                const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
+                const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
+                const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
+                const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
+                const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
 
-                        id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
+                const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
 
-                        [encoder setComputePipelineState:pipeline];
-                        [encoder setBuffer:id_src0 offset:offs_src0   atIndex:0];
-                        [encoder setBuffer:id_dst  offset:offs_dst    atIndex:1];
-                        [encoder setBytes:&slope length:sizeof(slope) atIndex:2];
+                const int32_t N  = src1->ne[is_2D ? 3 : 2];
+                const int32_t IC = src1->ne[is_2D ? 2 : 1];
+                const int32_t IH = is_2D ? src1->ne[1] : 1;
+                const int32_t IW =         src1->ne[0];
 
-                        const int64_t n = ggml_nelements(dst);
+                const int32_t KH = is_2D ? src0->ne[1] : 1;
+                const int32_t KW =         src0->ne[0];
 
-                        [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                    } break;
-                case GGML_OP_FLASH_ATTN_EXT:
-                    {
-                        GGML_ASSERT(ne00 % 4  == 0);
-                        GGML_ASSERT(ne11 % 32 == 0);
+                const int32_t OH = is_2D ? dst->ne[2] : 1;
+                const int32_t OW =         dst->ne[1];
 
-                        GGML_ASSERT(src0->type == GGML_TYPE_F32);
+                const int32_t CHW = IC * KH * KW;
 
-                        GGML_ASSERT(ggml_are_same_shape (src1, src2));
+                const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
+                const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
 
-                        struct ggml_tensor * src3 = gf->nodes[i]->src[3];
+                id<MTLComputePipelineState> pipeline = nil;
 
-                        size_t offs_src3 = 0;
+                switch (dst->type) {
+                    case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break;
+                    case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
+                    default: GGML_ABORT("fatal error");
+                };
 
-                        id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src1 offset:offs_src1        atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
+                [encoder setBytes:&ofs0    length:sizeof( int32_t) atIndex:2];
+                [encoder setBytes:&ofs1    length:sizeof( int32_t) atIndex:3];
+                [encoder setBytes:&IW      length:sizeof( int32_t) atIndex:4];
+                [encoder setBytes:&IH      length:sizeof( int32_t) atIndex:5];
+                [encoder setBytes:&CHW     length:sizeof( int32_t) atIndex:6];
+                [encoder setBytes:&s0      length:sizeof( int32_t) atIndex:7];
+                [encoder setBytes:&s1      length:sizeof( int32_t) atIndex:8];
+                [encoder setBytes:&p0      length:sizeof( int32_t) atIndex:9];
+                [encoder setBytes:&p1      length:sizeof( int32_t) atIndex:10];
+                [encoder setBytes:&d0      length:sizeof( int32_t) atIndex:11];
+                [encoder setBytes:&d1      length:sizeof( int32_t) atIndex:12];
+
+                [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
+            } break;
+        case GGML_OP_UPSCALE:
+            {
+                GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+                const float sf0 = (float)ne0/src0->ne[0];
+                const float sf1 = (float)ne1/src0->ne[1];
+                const float sf2 = (float)ne2/src0->ne[2];
+                const float sf3 = (float)ne3/src0->ne[3];
+
+                const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
+                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
+                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
+                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
+                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
+                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
+                [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
+                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:10];
+                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:11];
+                [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:12];
+                [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:13];
+                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:14];
+                [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:15];
+                [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:16];
+                [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:17];
+                [encoder setBytes:&sf0  length:sizeof(sf0)  atIndex:18];
+                [encoder setBytes:&sf1  length:sizeof(sf1)  atIndex:19];
+                [encoder setBytes:&sf2  length:sizeof(sf2)  atIndex:20];
+                [encoder setBytes:&sf3  length:sizeof(sf3)  atIndex:21];
+
+                const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
+
+                [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+            } break;
+        case GGML_OP_PAD:
+            {
+                GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
+                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
+                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
+                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
+                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
+                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
+                [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
+                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:10];
+                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:11];
+                [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:12];
+                [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:13];
+                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:14];
+                [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:15];
+                [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:16];
+                [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:17];
+
+                const int nth = MIN(1024, ne0);
+
+                [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+            } break;
+        case GGML_OP_ARANGE:
+            {
+                GGML_ASSERT(dst->type == GGML_TYPE_F32);
 
-                        GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16);
-                        GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
-                                "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
+                float start;
+                float step;
 
-                        const int64_t  ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
-                      //const int64_t  ne31 = src3 ? src3->ne[1] : 0;
-                        const int64_t  ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
-                        const int64_t  ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
+                memcpy(&start, ((const int32_t *) dst->op_params) + 0, sizeof(float));
+                memcpy(&step,  ((const int32_t *) dst->op_params) + 2, sizeof(float));
 
-                        const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30);
-                        const uint64_t nb31 = src3 ? src3->nb[1] : 0;
-                        const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32);
-                        const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33);
+                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
 
-                        const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_dst  offset:offs_dst    atIndex:0];
+                [encoder setBytes:&ne0   length:sizeof(ne0)   atIndex:1];
+                [encoder setBytes:&start length:sizeof(start) atIndex:2];
+                [encoder setBytes:&step  length:sizeof(step)  atIndex:3];
 
-                        float scale;
-                        float max_bias;
-                        float logit_softcap;
-                        memcpy(&scale,         ((int32_t *) dst->op_params) + 0, sizeof(scale));
-                        memcpy(&max_bias,      ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
-                        memcpy(&logit_softcap, ((int32_t *) dst->op_params) + 2, sizeof(logit_softcap));
+                const int nth = MIN(1024, ne0);
 
-                        if (logit_softcap != 0.0f) {
-                            scale /= logit_softcap;
-                        }
+                [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+            } break;
+        case GGML_OP_TIMESTEP_EMBEDDING:
+            {
+                GGML_ASSERT(src0->type == GGML_TYPE_F32);
 
-                        const uint32_t n_head      = src0->ne[2];
-                        const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
+                const int dim        = dst->op_params[0];
+                const int max_period = dst->op_params[1];
 
-                        const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
-                        const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+                const int half = dim / 2;
 
-                        id<MTLComputePipelineState> pipeline = nil;
+                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
 
-                        bool use_vec_kernel = false;
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                [encoder setBytes:&nb1   length:sizeof(nb1) atIndex:2];
+                [encoder setBytes:&dim   length:sizeof(dim) atIndex:3];
+                [encoder setBytes:&max_period length:sizeof(max_period) atIndex:4];
 
-                        if (ne01 >= 4 || (ne00%128 != 0)) {
-                            switch (ne00) {
-                                case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
-                                case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
-                                case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
-                                case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
-                                case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
-                              //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
-                                default:
-                                          {
-                                              GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
-                                              GGML_METAL_LOG_ERROR("add template specialization for this size\n");
-                                              GGML_ABORT("add template specialization for this size");
-                                          }
-                            }
-                        } else {
-                            use_vec_kernel = true;
+                const int nth = MIN(1024, half);
 
-                            switch (ne00) {
-                                case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
-                              //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
-                                default:
-                                          {
-                                              GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
-                                              GGML_METAL_LOG_ERROR("add template specialization for this size\n");
-                                              GGML_ABORT("add template specialization for this size");
-                                          }
-                            }
-                        }
+                [encoder dispatchThreadgroups:MTLSizeMake(ne00, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+            } break;
+        case GGML_OP_ARGSORT:
+            {
+                GGML_ASSERT(src0->type == GGML_TYPE_F32);
+                GGML_ASSERT( dst->type == GGML_TYPE_I32);
 
-                        [encoder setComputePipelineState:pipeline];
-                        [encoder setBuffer:id_src0     offset:offs_src0           atIndex:0];
-                        [encoder setBuffer:id_src1     offset:offs_src1           atIndex:1];
-                        [encoder setBuffer:id_src2     offset:offs_src2           atIndex:2];
-                        if (id_src3) {
-                            [encoder setBuffer:id_src3     offset:offs_src3           atIndex:3];
-                        } else {
-                            [encoder setBuffer:id_src0     offset:offs_src0           atIndex:3];
-                        }
-                        [encoder setBuffer:id_dst        offset:offs_dst              atIndex:4];
-                        [encoder setBytes:&ne01          length:sizeof( int64_t)      atIndex:5];
-                        [encoder setBytes:&ne02          length:sizeof( int64_t)      atIndex:6];
-                        [encoder setBytes:&ne03          length:sizeof( int64_t)      atIndex:7];
-                        [encoder setBytes:&nb01          length:sizeof(uint64_t)      atIndex:8];
-                        [encoder setBytes:&nb02          length:sizeof(uint64_t)      atIndex:9];
-                        [encoder setBytes:&nb03          length:sizeof(uint64_t)      atIndex:10];
-                        [encoder setBytes:&ne11          length:sizeof( int64_t)      atIndex:11];
-                        [encoder setBytes:&ne12          length:sizeof( int64_t)      atIndex:12];
-                        [encoder setBytes:&ne13          length:sizeof( int64_t)      atIndex:13];
-                        [encoder setBytes:&nb11          length:sizeof(uint64_t)      atIndex:14];
-                        [encoder setBytes:&nb12          length:sizeof(uint64_t)      atIndex:15];
-                        [encoder setBytes:&nb13          length:sizeof(uint64_t)      atIndex:16];
-                        [encoder setBytes:&nb21          length:sizeof(uint64_t)      atIndex:17];
-                        [encoder setBytes:&nb22          length:sizeof(uint64_t)      atIndex:18];
-                        [encoder setBytes:&nb23          length:sizeof(uint64_t)      atIndex:19];
-                        [encoder setBytes:&nb31          length:sizeof(uint64_t)      atIndex:20];
-                        [encoder setBytes:&ne1           length:sizeof( int64_t)      atIndex:21];
-                        [encoder setBytes:&ne2           length:sizeof( int64_t)      atIndex:22];
-                        [encoder setBytes:&scale         length:sizeof(   float)      atIndex:23];
-                        [encoder setBytes:&max_bias      length:sizeof(   float)      atIndex:24];
-                        [encoder setBytes:&m0            length:sizeof(m0)            atIndex:25];
-                        [encoder setBytes:&m1            length:sizeof(m1)            atIndex:26];
-                        [encoder setBytes:&n_head_log2   length:sizeof(n_head_log2)   atIndex:27];
-                        [encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:28];
-
-                        if (!use_vec_kernel) {
-                            // half8x8 kernel
-                            const int64_t nqptg = 8;  // queries per threadgroup    !! sync with kernel template arguments !!
-                            const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
-
-                            GGML_ASSERT(nqptg <= 32);
-                            GGML_ASSERT(nqptg  % 8  == 0);
-                            GGML_ASSERT(ncpsg  % 32 == 0);
-
-                            int64_t nsgmax = 2;
-
-                            while (true) {
-                                const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2);
-                                if (smem > ctx->device.maxThreadgroupMemoryLength) {
-                                    break;
-                                }
-                                nsgmax *= 2;
-                            }
-                            nsgmax /= 2;
+                const int nrows = ggml_nrows(src0);
 
-                            // simdgroups per threadgroup (a.k.a. warps)
-                            const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
+                enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
 
-                            const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2);
+                // bitonic sort requires the number of elements to be power of 2
+                int64_t ne00_padded = 1;
+                while (ne00_padded < ne00) {
+                    ne00_padded *= 2;
+                }
 
-                            //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
-                            GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
+                // Metal kernels require the buffer size to be multiple of 16 bytes
+                // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
+                const int mem_size = GGML_PAD(ne00_padded*sizeof(int32_t), 16);
 
-                            [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
+                id<MTLComputePipelineState> pipeline = nil;
 
-                            [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
-                        } else {
-                            // half1x4 kernel
-                            const int64_t nqptg = 1;  // queries per threadgroup    !! sync with kernel template arguments !!
-                            const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
+                switch (order) {
+                    case GGML_SORT_ORDER_ASC:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline;  break;
+                    case GGML_SORT_ORDER_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break;
+                    default: GGML_ABORT("fatal error");
+                };
 
-                            GGML_ASSERT(nqptg <= 32);
-                            GGML_ASSERT(nqptg  % 1  == 0);
-                            GGML_ASSERT(ncpsg  % 32 == 0);
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0     offset:offs_src0        atIndex:0];
+                [encoder setBuffer:id_dst      offset:offs_dst         atIndex:1];
+                [encoder setBytes:&ne00        length:sizeof( int64_t) atIndex:2];
+                [encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3];
+                [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
 
-                            // simdgroups per threadgroup (a.k.a. warps)
-                            const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
+                [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)];
+            } break;
+        case GGML_OP_LEAKY_RELU:
+            {
+                GGML_ASSERT(src0->type == GGML_TYPE_F32);
 
-                            int64_t nsg = 1;
-                            while (nsg <= nsgt) {
-                                nsg *= 2;
-                            }
-                            nsg /= 2;
+                float slope;
+                memcpy(&slope, dst->op_params, sizeof(float));
 
-                            const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
+                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
 
-                            //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
-                            GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
-                            [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0   atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst    atIndex:1];
+                [encoder setBytes:&slope length:sizeof(slope) atIndex:2];
 
-                            [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
-                        }
-                    } break;
-                case GGML_OP_DUP:
-                case GGML_OP_CPY:
-                case GGML_OP_CONT:
-                    {
-                        GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
-
-                        int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
-
-                        id<MTLComputePipelineState> pipeline = nil;
-
-                        switch (src0t) {
-                            case GGML_TYPE_F32:
-                                {
-                                    GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
-
-                                    switch (dstt) {
-                                        case GGML_TYPE_F32:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
-                                        case GGML_TYPE_F16:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
-                                        case GGML_TYPE_Q8_0:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
-                                        case GGML_TYPE_Q4_0:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
-                                        case GGML_TYPE_Q4_1:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
-                                        case GGML_TYPE_Q5_0:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
-                                        case GGML_TYPE_Q5_1:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
-                                        case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break;
-                                        default: GGML_ABORT("not implemented");
-                                    };
-                                } break;
-                            case GGML_TYPE_F16:
-                                {
-                                    switch (dstt) {
-                                        case GGML_TYPE_F32:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
-                                        case GGML_TYPE_F16:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
-                                        default: GGML_ABORT("not implemented");
-                                    };
-                                } break;
-                            default: GGML_ABORT("not implemented");
+                const int64_t n = ggml_nelements(dst);
+
+                [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+            } break;
+        case GGML_OP_FLASH_ATTN_EXT:
+            {
+                GGML_ASSERT(ne00 % 4  == 0);
+                GGML_ASSERT(ne11 % 32 == 0);
+
+                GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+                GGML_ASSERT(ggml_are_same_shape (src1, src2));
+
+                struct ggml_tensor * src3 = node->src[3];
+
+                size_t offs_src3 = 0;
+
+                id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
+
+                GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16);
+                GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
+                        "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
+
+                const int64_t  ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
+                //const int64_t  ne31 = src3 ? src3->ne[1] : 0;
+                const int64_t  ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
+                const int64_t  ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
+
+                const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30);
+                const uint64_t nb31 = src3 ? src3->nb[1] : 0;
+                const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32);
+                const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33);
+
+                const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
+
+                float scale;
+                float max_bias;
+                float logit_softcap;
+                memcpy(&scale,         ((const int32_t *) dst->op_params) + 0, sizeof(scale));
+                memcpy(&max_bias,      ((const int32_t *) dst->op_params) + 1, sizeof(max_bias));
+                memcpy(&logit_softcap, ((const int32_t *) dst->op_params) + 2, sizeof(logit_softcap));
+
+                if (logit_softcap != 0.0f) {
+                    scale /= logit_softcap;
+                }
+
+                const uint32_t n_head      = src0->ne[2];
+                const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
+
+                const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
+                const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+                id<MTLComputePipelineState> pipeline = nil;
+
+                bool use_vec_kernel = false;
+
+                if (ne01 >= 4 || (ne00%128 != 0)) {
+                    switch (ne00) {
+                        case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
+                        case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
+                        case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
+                        case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
+                        case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
+                                  //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
+                        default:
+                                  {
+                                      GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
+                                      GGML_METAL_LOG_ERROR("add template specialization for this size\n");
+                                      GGML_ABORT("add template specialization for this size");
+                                  }
+                    }
+                } else {
+                    use_vec_kernel = true;
+
+                    switch (ne00) {
+                        case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
+                                  //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
+                        default:
+                                  {
+                                      GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
+                                      GGML_METAL_LOG_ERROR("add template specialization for this size\n");
+                                      GGML_ABORT("add template specialization for this size");
+                                  }
+                    }
+                }
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0     offset:offs_src0           atIndex:0];
+                [encoder setBuffer:id_src1     offset:offs_src1           atIndex:1];
+                [encoder setBuffer:id_src2     offset:offs_src2           atIndex:2];
+                if (id_src3) {
+                    [encoder setBuffer:id_src3     offset:offs_src3           atIndex:3];
+                } else {
+                    [encoder setBuffer:id_src0     offset:offs_src0           atIndex:3];
+                }
+                [encoder setBuffer:id_dst        offset:offs_dst              atIndex:4];
+                [encoder setBytes:&ne01          length:sizeof( int64_t)      atIndex:5];
+                [encoder setBytes:&ne02          length:sizeof( int64_t)      atIndex:6];
+                [encoder setBytes:&ne03          length:sizeof( int64_t)      atIndex:7];
+                [encoder setBytes:&nb01          length:sizeof(uint64_t)      atIndex:8];
+                [encoder setBytes:&nb02          length:sizeof(uint64_t)      atIndex:9];
+                [encoder setBytes:&nb03          length:sizeof(uint64_t)      atIndex:10];
+                [encoder setBytes:&ne11          length:sizeof( int64_t)      atIndex:11];
+                [encoder setBytes:&ne12          length:sizeof( int64_t)      atIndex:12];
+                [encoder setBytes:&ne13          length:sizeof( int64_t)      atIndex:13];
+                [encoder setBytes:&nb11          length:sizeof(uint64_t)      atIndex:14];
+                [encoder setBytes:&nb12          length:sizeof(uint64_t)      atIndex:15];
+                [encoder setBytes:&nb13          length:sizeof(uint64_t)      atIndex:16];
+                [encoder setBytes:&nb21          length:sizeof(uint64_t)      atIndex:17];
+                [encoder setBytes:&nb22          length:sizeof(uint64_t)      atIndex:18];
+                [encoder setBytes:&nb23          length:sizeof(uint64_t)      atIndex:19];
+                [encoder setBytes:&nb31          length:sizeof(uint64_t)      atIndex:20];
+                [encoder setBytes:&ne1           length:sizeof( int64_t)      atIndex:21];
+                [encoder setBytes:&ne2           length:sizeof( int64_t)      atIndex:22];
+                [encoder setBytes:&scale         length:sizeof(   float)      atIndex:23];
+                [encoder setBytes:&max_bias      length:sizeof(   float)      atIndex:24];
+                [encoder setBytes:&m0            length:sizeof(m0)            atIndex:25];
+                [encoder setBytes:&m1            length:sizeof(m1)            atIndex:26];
+                [encoder setBytes:&n_head_log2   length:sizeof(n_head_log2)   atIndex:27];
+                [encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:28];
+
+                if (!use_vec_kernel) {
+                    // half8x8 kernel
+                    const int64_t nqptg = 8;  // queries per threadgroup    !! sync with kernel template arguments !!
+                    const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
+
+                    GGML_ASSERT(nqptg <= 32);
+                    GGML_ASSERT(nqptg  % 8  == 0);
+                    GGML_ASSERT(ncpsg  % 32 == 0);
+
+                    int64_t nsgmax = 2;
+
+                    while (true) {
+                        const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2);
+                        if (smem > ctx->device.maxThreadgroupMemoryLength) {
+                            break;
                         }
+                        nsgmax *= 2;
+                    }
+                    nsgmax /= 2;
 
-                        [encoder setComputePipelineState:pipeline];
-                        [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
-                        [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
-                        [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
-                        [encoder setBytes:&ne01    length:sizeof( int64_t) atIndex:3];
-                        [encoder setBytes:&ne02    length:sizeof( int64_t) atIndex:4];
-                        [encoder setBytes:&ne03    length:sizeof( int64_t) atIndex:5];
-                        [encoder setBytes:&nb00    length:sizeof(uint64_t) atIndex:6];
-                        [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:7];
-                        [encoder setBytes:&nb02    length:sizeof(uint64_t) atIndex:8];
-                        [encoder setBytes:&nb03    length:sizeof(uint64_t) atIndex:9];
-                        [encoder setBytes:&ne0     length:sizeof( int64_t) atIndex:10];
-                        [encoder setBytes:&ne1     length:sizeof( int64_t) atIndex:11];
-                        [encoder setBytes:&ne2     length:sizeof( int64_t) atIndex:12];
-                        [encoder setBytes:&ne3     length:sizeof( int64_t) atIndex:13];
-                        [encoder setBytes:&nb0     length:sizeof(uint64_t) atIndex:14];
-                        [encoder setBytes:&nb1     length:sizeof(uint64_t) atIndex:15];
-                        [encoder setBytes:&nb2     length:sizeof(uint64_t) atIndex:16];
-                        [encoder setBytes:&nb3     length:sizeof(uint64_t) atIndex:17];
-
-                        [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
-                    } break;
-                default:
-                    {
-                        GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
-                        GGML_ABORT("fatal error");
+                    // simdgroups per threadgroup (a.k.a. warps)
+                    const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
+
+                    const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2);
+
+                    //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
+                    GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
+
+                    [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
+
+                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+                } else {
+                    // half1x4 kernel
+                    const int64_t nqptg = 1;  // queries per threadgroup    !! sync with kernel template arguments !!
+                    const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
+
+                    GGML_ASSERT(nqptg <= 32);
+                    GGML_ASSERT(nqptg  % 1  == 0);
+                    GGML_ASSERT(ncpsg  % 32 == 0);
+
+                    // simdgroups per threadgroup (a.k.a. warps)
+                    const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
+
+                    int64_t nsg = 1;
+                    while (nsg <= nsgt) {
+                        nsg *= 2;
                     }
+                    nsg /= 2;
+
+                    const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
+
+                    //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
+                    GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
+                    [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
+
+                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+                }
+            } break;
+        case GGML_OP_DUP:
+        case GGML_OP_CPY:
+        case GGML_OP_CONT:
+            {
+                GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
+
+                int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
+
+                id<MTLComputePipelineState> pipeline = nil;
+
+                switch (src0t) {
+                    case GGML_TYPE_F32:
+                        {
+                            GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
+
+                            switch (dstt) {
+                                case GGML_TYPE_F32:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
+                                case GGML_TYPE_F16:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
+                                case GGML_TYPE_Q8_0:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
+                                case GGML_TYPE_Q4_0:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
+                                case GGML_TYPE_Q4_1:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
+                                case GGML_TYPE_Q5_0:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
+                                case GGML_TYPE_Q5_1:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
+                                case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break;
+                                default: GGML_ABORT("not implemented");
+                            };
+                        } break;
+                    case GGML_TYPE_F16:
+                        {
+                            switch (dstt) {
+                                case GGML_TYPE_F32:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
+                                case GGML_TYPE_F16:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
+                                default: GGML_ABORT("not implemented");
+                            };
+                        } break;
+                    default: GGML_ABORT("not implemented");
+                }
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
+                [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
+                [encoder setBytes:&ne01    length:sizeof( int64_t) atIndex:3];
+                [encoder setBytes:&ne02    length:sizeof( int64_t) atIndex:4];
+                [encoder setBytes:&ne03    length:sizeof( int64_t) atIndex:5];
+                [encoder setBytes:&nb00    length:sizeof(uint64_t) atIndex:6];
+                [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:7];
+                [encoder setBytes:&nb02    length:sizeof(uint64_t) atIndex:8];
+                [encoder setBytes:&nb03    length:sizeof(uint64_t) atIndex:9];
+                [encoder setBytes:&ne0     length:sizeof( int64_t) atIndex:10];
+                [encoder setBytes:&ne1     length:sizeof( int64_t) atIndex:11];
+                [encoder setBytes:&ne2     length:sizeof( int64_t) atIndex:12];
+                [encoder setBytes:&ne3     length:sizeof( int64_t) atIndex:13];
+                [encoder setBytes:&nb0     length:sizeof(uint64_t) atIndex:14];
+                [encoder setBytes:&nb1     length:sizeof(uint64_t) atIndex:15];
+                [encoder setBytes:&nb2     length:sizeof(uint64_t) atIndex:16];
+                [encoder setBytes:&nb3     length:sizeof(uint64_t) atIndex:17];
+
+                [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+            } break;
+       default:
+            {
+                GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
+                GGML_ABORT("fatal error");
             }
+    }
+}
+
+static enum ggml_status ggml_metal_graph_compute(
+        struct ggml_backend_metal_context * ctx,
+                       struct ggml_cgraph * gf) {
+    // number of nodes encoded by the main thread (empirically determined)
+    const int n_main = 128;
+
+    // number of threads in addition to the main thread
+    const int n_cb = ctx->n_cb;
+
+    // submit the ggml compute graph to the GPU by creating command buffers and encoding the ops in them
+    // the first n_nodes_0 are encoded and submitted for processing directly by the calling thread
+    // while these nodes are processing, we start n_cb threads to enqueue the rest of the nodes
+    // each thread creates it's own command buffer and enqueues the ops in parallel
+    //
+    // tests on M1 Pro and M2 Ultra using LLaMA models, show that optimal values for n_cb are 1 or 2
+
+    @autoreleasepool {
+        ctx->gf = gf;
 
-            if (should_capture) {
-                [encoder popDebugGroup];
+        ctx->n_nodes_0 = MIN(n_main, gf->n_nodes);
+        ctx->n_nodes_1 = gf->n_nodes - ctx->n_nodes_0;
+
+        ctx->n_nodes_per_cb = (ctx->n_nodes_1 + ctx->n_cb - 1) / ctx->n_cb;
+
+        const bool should_capture = ctx->capture_next_compute;
+        if (should_capture) {
+            ctx->capture_next_compute = false;
+
+            if (!ctx->capture_started) {
+                // create capture scope
+                ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx->device];
+
+                MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
+                descriptor.captureObject = ctx->capture_scope;
+                descriptor.destination = MTLCaptureDestinationGPUTraceDocument;
+                descriptor.outputURL = [NSURL fileURLWithPath:[NSString stringWithFormat:@"/tmp/perf-metal.gputrace"]];
+
+                NSError * error = nil;
+                if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) {
+                    GGML_METAL_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]);
+                } else {
+                    [ctx->capture_scope beginScope];
+                    ctx->capture_started = true;
+                }
             }
         }
 
-        [encoder endEncoding];
+        // TODO: how to avoid this allocation? I tried initializing it in ggml_backend_metal_set_n_cb but it crashes.
+        ctx->encode_async = ^(size_t iter) {
+            const int cb_idx = iter;
+            const int n_cb_l = ctx->n_cb;
 
-        if (cb_idx < 2 || ctx->abort_callback == NULL) {
-            [command_buffer commit];
-        }
-    });
+            const int n_nodes_0 = ctx->n_nodes_0;
+            const int n_nodes_1 = ctx->n_nodes_1;
+
+            const int n_nodes_per_cb = ctx->n_nodes_per_cb;
 
-    // Wait for completion and check status of each command buffer
-    // needed to detect if the device ran out-of-memory for example (#1881)
+            id<MTLCommandBuffer> command_buffer  = ctx->command_buffers[cb_idx];
+            id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: ctx->edesc];
 
-    for (int i = 0; i < n_cb; ++i) {
-        id<MTLCommandBuffer> command_buffer = command_buffers[i];
-        [command_buffer waitUntilCompleted];
+            int node_start = 0;
+            int node_end   = n_nodes_0;
 
-        MTLCommandBufferStatus status = [command_buffer status];
-        if (status != MTLCommandBufferStatusCompleted) {
-            GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
-            if (status == MTLCommandBufferStatusError) {
-                GGML_METAL_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
+            if (cb_idx < n_cb_l) {
+                node_start = n_nodes_0 + (                                         (cb_idx + 0) * n_nodes_per_cb);
+                node_end   = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1));
             }
 
-            return GGML_STATUS_FAILED;
-        }
+            for (int idx = node_start; idx < node_end; ++idx) {
+                if (should_capture) {
+                    [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(gf, idx)) encoding:NSUTF8StringEncoding]];
+                }
 
-        id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? command_buffers[i + 1] : nil);
-        if (!next_buffer) {
-            continue;
+                ggml_metal_encode_node(ctx, idx, encoder);
+
+                if (should_capture) {
+                    [encoder popDebugGroup];
+                }
+            }
+
+            [encoder endEncoding];
+
+            if (cb_idx < 2 || ctx->abort_callback == NULL) {
+                [command_buffer commit];
+            }
+        };
+
+        // the main thread commits the first few commands immediately
+        // command_buffer[n_cb]
+        {
+            id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
+            ctx->command_buffers[n_cb] = command_buffer;
+
+            [command_buffer enqueue];
+            ctx->encode_async(n_cb);
         }
 
-        bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued);
-        if (next_queued) {
-            continue;
+        // prepare the rest of the command buffers asynchronously
+        // command_buffer[0.. n_cb)
+        for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
+            id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
+            ctx->command_buffers[cb_idx] = command_buffer;
+
+            // always enqueue the first two command buffers
+            // enqueue all of the command buffers if we don't need to abort
+            if (cb_idx < 2 || ctx->abort_callback == NULL) {
+                [command_buffer enqueue];
+            }
         }
 
-        if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) {
-            GGML_METAL_LOG_INFO("%s: command buffer %d aborted", __func__, i);
-            return GGML_STATUS_ABORTED;
+        dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async);
+
+        // wait for completion and check status of each command buffer
+        // needed to detect if the device ran out-of-memory for example (#1881)
+        {
+            id<MTLCommandBuffer> command_buffer = ctx->command_buffers[n_cb];
+            [command_buffer waitUntilCompleted];
+
+            MTLCommandBufferStatus status = [command_buffer status];
+            if (status != MTLCommandBufferStatusCompleted) {
+                GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
+                if (status == MTLCommandBufferStatusError) {
+                    GGML_METAL_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
+                }
+
+                return GGML_STATUS_FAILED;
+            }
         }
 
-        [next_buffer commit];
-    }
+        for (int i = 0; i < n_cb; ++i) {
+            id<MTLCommandBuffer> command_buffer = ctx->command_buffers[i];
+            [command_buffer waitUntilCompleted];
 
-    if (should_capture) {
-        [[MTLCaptureManager sharedCaptureManager] stopCapture];
-    }
+            MTLCommandBufferStatus status = [command_buffer status];
+            if (status != MTLCommandBufferStatusCompleted) {
+                GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
+                if (status == MTLCommandBufferStatusError) {
+                    GGML_METAL_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
+                }
+
+                return GGML_STATUS_FAILED;
+            }
+
+            id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->command_buffers[i + 1] : nil);
+            if (!next_buffer) {
+                continue;
+            }
+
+            const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued);
+            if (next_queued) {
+                continue;
+            }
+
+            if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) {
+                GGML_METAL_LOG_INFO("%s: command buffer %d aborted", __func__, i);
+                return GGML_STATUS_ABORTED;
+            }
+
+            [next_buffer commit];
+        }
 
+        if (!should_capture && ctx->capture_started) {
+            [ctx->capture_scope endScope];
+            [[MTLCaptureManager sharedCaptureManager] stopCapture];
+        }
     }
+
     return GGML_STATUS_SUCCESS;
 }
 
@@ -3405,6 +3503,25 @@ GGML_CALL static bool ggml_backend_metal_supports_buft(ggml_backend_t backend, g
     UNUSED(backend);
 }
 
+static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
+    GGML_ASSERT(ggml_backend_is_metal(backend));
+
+    struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
+
+    if (ctx->n_cb != n_cb) {
+        ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_COMMAND_BUFFERS);
+
+        if (ctx->n_cb > 2) {
+            GGML_METAL_LOG_WARN("%s: n_cb = %d, using n_cb > 2 is not recommended and can degrade the performance in some cases\n", __func__, n_cb);
+        }
+    }
+
+    // TODO: setting encode_async here causes crash during the next ggml_metal_graph_compute call. why?
+    //ctx->encode_async = ^(size_t iter) {
+    //    ...
+    //};
+}
+
 static struct ggml_backend_i ggml_backend_metal_i = {
     /* .get_name                = */ ggml_backend_metal_name,
     /* .free                    = */ ggml_backend_metal_free,
@@ -3439,35 +3556,29 @@ static ggml_guid_t ggml_backend_metal_guid(void) {
 }
 
 ggml_backend_t ggml_backend_metal_init(void) {
-    struct ggml_backend_metal_context * ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
+    struct ggml_backend_metal_context * ctx = ggml_metal_init();
     if (ctx == NULL) {
         GGML_METAL_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
         return NULL;
     }
 
-    ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
+    ggml_backend_t backend = malloc(sizeof(struct ggml_backend));
 
-    *metal_backend = (struct ggml_backend) {
+    *backend = (struct ggml_backend) {
         /* .guid      = */ ggml_backend_metal_guid(),
         /* .interface = */ ggml_backend_metal_i,
         /* .context   = */ ctx,
     };
 
-    return metal_backend;
+    ggml_backend_metal_set_n_cb(backend, 1);
+
+    return backend;
 }
 
 bool ggml_backend_is_metal(ggml_backend_t backend) {
     return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_metal_guid());
 }
 
-void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
-    GGML_ASSERT(ggml_backend_is_metal(backend));
-
-    struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
-
-    ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
-}
-
 void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data) {
     GGML_ASSERT(ggml_backend_is_metal(backend));
 
@@ -3489,7 +3600,7 @@ void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
     GGML_ASSERT(ggml_backend_is_metal(backend));
 
     struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
-    ctx->should_capture_next_compute = true;
+    ctx->capture_next_compute = true;
 }
 
 GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning