]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : parallel command buffer encoding (#1860)
authorGeorgi Gerganov <redacted>
Thu, 15 Jun 2023 17:29:48 +0000 (20:29 +0300)
committerGitHub <redacted>
Thu, 15 Jun 2023 17:29:48 +0000 (20:29 +0300)
* metal : parallel command buffer encoding

* metal : determine number of command buffers based on gf->n_threads

ggml-metal.h
ggml-metal.m

index a9441a9d46eaca46af849d453762a8d3453e38e3..033c4d86ab6c57589db232b38ede8c857a4e6c9e 100644 (file)
@@ -55,6 +55,7 @@ void ggml_metal_set_tensor(struct ggml_metal_context * ctx, struct ggml_tensor *
 void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);
 
 // same as ggml_graph_compute but uses Metal
+// creates gf->n_threads command buffers in parallel
 void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
 
 #ifdef __cplusplus
index 658c392e0d1bbb5826f6733cb50b2967b4849fd2..0e9b56aa33efa5445e66b0433316f71a2ac2ac84 100644 (file)
@@ -284,528 +284,551 @@ void ggml_metal_get_tensor(
 
 void ggml_metal_graph_compute(
         struct ggml_metal_context * ctx,
-             struct ggml_cgraph * gf) {
+               struct ggml_cgraph * gf) {
     metal_printf("%s: evaluating graph\n", __func__);
 
-    size_t offs_src0 = 0;
-    size_t offs_src1 = 0;
-    size_t offs_dst  = 0;
-
-    id<MTLCommandBuffer> command_buffer  = [ctx->queue commandBuffer];
-    id<MTLComputeCommandEncoder> encoder = nil;
-
-    for (int i = 0; i < gf->n_nodes; ++i) {
-        //metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
-
-        struct ggml_tensor * src0 = gf->nodes[i]->src0;
-        struct ggml_tensor * src1 = gf->nodes[i]->src1;
-        struct ggml_tensor * dst  = gf->nodes[i];
-
-        const int64_t  ne00 = src0 ? src0->ne[0] : 0;
-        const int64_t  ne01 = src0 ? src0->ne[1] : 0;
-        const int64_t  ne02 = src0 ? src0->ne[2] : 0;
-        const int64_t  ne03 = src0 ? src0->ne[3] : 0;
-
-        const uint64_t nb00 = src0 ? src0->nb[0] : 0;
-        const uint64_t nb01 = src0 ? src0->nb[1] : 0;
-        const uint64_t nb02 = src0 ? src0->nb[2] : 0;
-        const uint64_t nb03 = src0 ? src0->nb[3] : 0;
-
-        const int64_t  ne10 = src1 ? src1->ne[0] : 0;
-        const int64_t  ne11 = src1 ? src1->ne[1] : 0;
-        const int64_t  ne12 = src1 ? src1->ne[2] : 0;
-        const int64_t  ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
-
-        const uint64_t nb10 = src1 ? src1->nb[0] : 0;
-        const uint64_t nb11 = src1 ? src1->nb[1] : 0;
-        const uint64_t nb12 = src1 ? src1->nb[2] : 0;
-        const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
-
-        const int64_t  ne0  = dst ? dst->ne[0] : 0;
-        const int64_t  ne1  = dst ? dst->ne[1] : 0;
-        const int64_t  ne2  = dst ? dst->ne[2] : 0;
-        const int64_t  ne3  = dst ? dst->ne[3] : 0;
-
-        const uint64_t nb0  = dst ? dst->nb[0] : 0;
-        const uint64_t nb1  = dst ? dst->nb[1] : 0;
-        const uint64_t nb2  = dst ? dst->nb[2] : 0;
-        const uint64_t nb3  = dst ? dst->nb[3] : 0;
-
-        const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
-        const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
-        const enum ggml_type dstt  = dst  ? dst->type  : GGML_TYPE_COUNT;
-
-        id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(ctx, src0, &offs_src0) : nil;
-        id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
-        id<MTLBuffer> id_dst  = dst  ? ggml_metal_get_buffer(ctx, dst,  &offs_dst)  : nil;
-
-        //metal_printf("%s: op - %s\n", __func__, ggml_op_name(dst->op));
-        //if (src0) {
-        //    metal_printf("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
-        //            ggml_is_contiguous(src0), src0->name);
-        //}
-        //if (src1) {
-        //    metal_printf("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
-        //            ggml_is_contiguous(src1), src1->name);
-        //}
-        //if (dst) {
-        //    metal_printf("%s: dst  - %4s [%5lld, %5lld, %5lld], 1, %s\n",  __func__, ggml_type_name(dstt),  ne0,  ne1,  ne2,
-        //            dst->name);
-        //}
-
-        switch (dst->op) {
-            case GGML_OP_RESHAPE:
-            case GGML_OP_VIEW:
-            case GGML_OP_TRANSPOSE:
-            case GGML_OP_PERMUTE:
-                {
-                    // noop
-                } break;
-            case GGML_OP_ADD:
-                {
-                    if (encoder == nil) {
-                        encoder = [command_buffer computeCommandEncoder];
-                    }
-
-                    [encoder setComputePipelineState:ctx->pipeline_add];
-                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                    [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
-                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
-
-                    const int64_t n = ggml_nelements(dst);
-
-                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                } break;
-            case GGML_OP_MUL:
-                {
-                    if (encoder == nil) {
-                        encoder = [command_buffer computeCommandEncoder];
-                    }
-
-                    if (ggml_nelements(src1) == ne10) {
-                        // src1 is a row
-                        [encoder setComputePipelineState:ctx->pipeline_mul_row];
-                    } else {
-                        [encoder setComputePipelineState:ctx->pipeline_mul];
-                    }
-                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                    [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
-                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
-                    [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
-
-                    const int64_t n = ggml_nelements(dst);
-
-                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                } break;
-            case GGML_OP_SCALE:
-                {
-                    if (encoder == nil) {
-                        encoder = [command_buffer computeCommandEncoder];
-                    }
-
-                    const float scale = *(const float *) src1->data;
-
-                    [encoder setComputePipelineState:ctx->pipeline_scale];
-                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                    [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
-
-                    const int64_t n = ggml_nelements(dst);
-
-                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                } break;
-            case GGML_OP_SILU:
-                {
-                    if (encoder == nil) {
-                        encoder = [command_buffer computeCommandEncoder];
-                    }
-
-                    [encoder setComputePipelineState:ctx->pipeline_silu];
-                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-
-                    const int64_t n = ggml_nelements(dst);
-
-                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                } break;
-            case GGML_OP_RELU:
-                {
-                    if (encoder == nil) {
-                        encoder = [command_buffer computeCommandEncoder];
-                    }
-
-                    [encoder setComputePipelineState:ctx->pipeline_relu];
-                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-
-                    const int64_t n = ggml_nelements(dst);
-
-                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                } break;
-            case GGML_OP_GELU:
-            {
-                    if (encoder == nil) {
-                        encoder = [command_buffer computeCommandEncoder];
-                    }
-
-                    [encoder setComputePipelineState:ctx->pipeline_gelu];
-                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-
-                    const int64_t n = ggml_nelements(dst);
-
-                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-            } break;
-            case GGML_OP_SOFT_MAX:
-                {
-                    if (encoder == nil) {
-                        encoder = [command_buffer computeCommandEncoder];
-                    }
-
-                    const int nth = 32;
-
-                    [encoder setComputePipelineState:ctx->pipeline_soft_max];
-                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                    [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
-                    [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
-                    [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
-                    [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
-
-                    [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
-                } break;
-            case GGML_OP_DIAG_MASK_INF:
-                {
-                    if (encoder == nil) {
-                        encoder = [command_buffer computeCommandEncoder];
-                    }
-
-                    const int n_past = ((int32_t *)(src1->data))[0];
-
-                    [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
-                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                    [encoder setBytes:&ne00   length:sizeof(ne00) atIndex:2];
-                    [encoder setBytes:&ne01   length:sizeof(ne01) atIndex:3];
-                    [encoder setBytes:&n_past length:sizeof(int)  atIndex:4];
-
-                    [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                } break;
-            case GGML_OP_MUL_MAT:
-                {
-                    // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
-
-                    GGML_ASSERT(ne00 == ne10);
-                    GGML_ASSERT(ne02 == ne12);
-
-                    if (ggml_is_contiguous(src0) &&
-                        ggml_is_contiguous(src1) &&
-                        (src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) {
-
-                        if (encoder != nil) {
-                            [encoder endEncoding];
-                            encoder = nil;
-                        }
-
-                        MPSDataType src0dt = src0t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
-                        MPSDataType src1dt = src1t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
-
-                        // for F32 x F32 we use MPS
-                        MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor
-                            matrixDescriptorWithRows:ne01 columns:ne00 rowBytes:src0->nb[1] dataType:src0dt];
-
-                        MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor
-                            matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:src1->nb[1] dataType:src1dt];
-
-                        MPSMatrixDescriptor * desc  = [MPSMatrixDescriptor
-                            matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:dst->nb[1] dataType:MPSDataTypeFloat32];
-
-                        MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc]
-                            initWithDevice:ctx->device transposeLeft:false transposeRight:true
-                                resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
-
-                        // we need to do ne02 multiplications
-                        // TODO: is there a way to do this in parallel - currently very slow ..
-                        // TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS
-                        for (int64_t i02 = 0; i02 < ne02; ++i02) {
-                            size_t offs_src0_cur = offs_src0 + i02*nb02;
-                            size_t offs_src1_cur = offs_src1 + i02*nb12;
-                            size_t offs_dst_cur  = offs_dst  + i02*nb2;
-
-                            MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0_cur descriptor:desc0];
-                            MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1_cur descriptor:desc1];
-                            MPSMatrix * mat_dst  = [[MPSMatrix alloc] initWithBuffer:id_dst  offset:offs_dst_cur  descriptor:desc ];
-
-                            [mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
-                        }
-                    } else {
-                        if (encoder == nil) {
-                            encoder = [command_buffer computeCommandEncoder];
-                        }
-
-                        int nth0 = 32;
-                        int nth1 = 1;
-
-                        // use custom matrix x vector kernel
-                        switch (src0t) {
-                            case GGML_TYPE_F16:
-                                {
-                                    GGML_ASSERT(ne02 == ne12);
-
-                                    nth0 = 64;
-                                    nth1 = 1;
-                                    [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
-                                } break;
-                            case GGML_TYPE_Q4_0:
-                                {
-                                    GGML_ASSERT(ne02 == 1);
-                                    GGML_ASSERT(ne12 == 1);
-
-                                    nth0 = 8;
-                                    nth1 = 8;
-                                    [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
-                                } break;
-                            case GGML_TYPE_Q4_1:
-                                {
-                                    GGML_ASSERT(ne02 == 1);
-                                    GGML_ASSERT(ne12 == 1);
-
-                                    nth0 = 8;
-                                    nth1 = 8;
-                                    [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
-                                } break;
-                            case GGML_TYPE_Q2_K:
-                                {
-                                    GGML_ASSERT(ne02 == 1);
-                                    GGML_ASSERT(ne12 == 1);
-
-                                    nth0 = 4;
-                                    nth1 = 16;
-                                    [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_k_f32];
-                                } break;
-                            case GGML_TYPE_Q3_K:
-                                {
-                                    GGML_ASSERT(ne02 == 1);
-                                    GGML_ASSERT(ne12 == 1);
-
-                                    nth0 = 4;
-                                    nth1 = 16;
-                                    [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_k_f32];
-                                } break;
-                            case GGML_TYPE_Q4_K:
-                                {
-                                    GGML_ASSERT(ne02 == 1);
-                                    GGML_ASSERT(ne12 == 1);
-
-                                    nth0 = 4;
-                                    nth1 = 16;
-                                    [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32];
-                                } break;
-                            case GGML_TYPE_Q5_K:
-                                {
-                                    GGML_ASSERT(ne02 == 1);
-                                    GGML_ASSERT(ne12 == 1);
-
-                                    nth0 = 4;
-                                    nth1 = 16;
-                                    [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_k_f32];
-                                } break;
-                            case GGML_TYPE_Q6_K:
-                                {
-                                    GGML_ASSERT(ne02 == 1);
-                                    GGML_ASSERT(ne12 == 1);
-
-                                    nth0 = 4;
-                                    nth1 = 16;
-                                    [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_k_f32];
-                                } break;
-                            default:
-                                {
-                                    fprintf(stderr, "Asserting on type %d\n",(int)src0t);
-                                    GGML_ASSERT(false && "not implemented");
-                                }
-                        };
-
-
-                        [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                        [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
-                        [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
-                        [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
-                        [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
-                        [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:5];
-                        [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:6];
-                        [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:7];
-                        [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:8];
-                        [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:9];
-                        [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
-                        [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
-                        [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
-                        [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:13];
-                        [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:14];
-
-                        if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) {
-                            [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
-                            [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                        }
-                        else if (src0t == GGML_TYPE_Q2_K ||
-                                 src0t == GGML_TYPE_Q3_K ||
-                                 src0t == GGML_TYPE_Q4_K ||
-                                 src0t == GGML_TYPE_Q5_K ||
-                                 src0t == GGML_TYPE_Q6_K) {
-                            [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
-                            [encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                        } else {
-                            [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
-                            [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                        }
-                    }
-                } break;
-            case GGML_OP_GET_ROWS:
-                {
-                    if (encoder == nil) {
-                        encoder = [command_buffer computeCommandEncoder];
-                    }
-
-                    switch (src0->type) {
-                        case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
-                        case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
-                        case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
-                        case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break;
-                        case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_k]; break;
-                        case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break;
-                        case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_k]; break;
-                        case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break;
-                        default: GGML_ASSERT(false && "not implemented");
-                    }
-
-                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                    [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
-                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
-                    [encoder setBytes:&(src0->ne[0]) length:sizeof( int64_t) atIndex:3];
-                    [encoder setBytes:&(src0->nb[1]) length:sizeof(uint64_t) atIndex:4];
-                    [encoder setBytes:&(dst->nb[1])  length:sizeof(uint64_t) atIndex:5];
-
-                    const int64_t n = ggml_nelements(src1);
-
-                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                } break;
-            case GGML_OP_RMS_NORM:
-                {
-                    if (encoder == nil) {
-                        encoder = [command_buffer computeCommandEncoder];
-                    }
-
-                    const float eps = 1e-6f;
-
-                    const int nth = 256;
-
-                    [encoder setComputePipelineState:ctx->pipeline_rms_norm];
-                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                    [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
-                    [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
-                    [encoder setBytes:&eps  length:sizeof(   float) atIndex:4];
-                    [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
-
-                    const int64_t nrows = ggml_nrows(src0);
-
-                    [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
-                } break;
-            case GGML_OP_ROPE:
-                {
-                    if (encoder == nil) {
-                        encoder = [command_buffer computeCommandEncoder];
-                    }
-
-                    const int n_dims = ((int32_t *) src1->data)[1];
-                    const int mode   = ((int32_t *) src1->data)[2];
-
-                    const int n_past = ((int32_t *)(src1->data))[0];
-
-                    [encoder setComputePipelineState:ctx->pipeline_rope];
-                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                    [encoder setBytes:&ne00   length:sizeof( int64_t) atIndex:2];
-                    [encoder setBytes:&ne01   length:sizeof( int64_t) atIndex:3];
-                    [encoder setBytes:&ne02   length:sizeof( int64_t) atIndex:4];
-                    [encoder setBytes:&ne03   length:sizeof( int64_t) atIndex:5];
-                    [encoder setBytes:&nb00   length:sizeof(uint64_t) atIndex:6];
-                    [encoder setBytes:&nb01   length:sizeof(uint64_t) atIndex:7];
-                    [encoder setBytes:&nb02   length:sizeof(uint64_t) atIndex:8];
-                    [encoder setBytes:&nb03   length:sizeof(uint64_t) atIndex:9];
-                    [encoder setBytes:&ne0    length:sizeof( int64_t) atIndex:10];
-                    [encoder setBytes:&ne1    length:sizeof( int64_t) atIndex:11];
-                    [encoder setBytes:&ne2    length:sizeof( int64_t) atIndex:12];
-                    [encoder setBytes:&ne3    length:sizeof( int64_t) atIndex:13];
-                    [encoder setBytes:&nb0    length:sizeof(uint64_t) atIndex:14];
-                    [encoder setBytes:&nb1    length:sizeof(uint64_t) atIndex:15];
-                    [encoder setBytes:&nb2    length:sizeof(uint64_t) atIndex:16];
-                    [encoder setBytes:&nb3    length:sizeof(uint64_t) atIndex:17];
-                    [encoder setBytes:&n_past length:sizeof(     int) atIndex:18];
-                    [encoder setBytes:&n_dims length:sizeof(     int) atIndex:19];
-                    [encoder setBytes:&mode   length:sizeof(     int) atIndex:20];
-
-                    [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                } break;
-            case GGML_OP_CPY:
-                {
-                    if (encoder == nil) {
-                        encoder = [command_buffer computeCommandEncoder];
-                    }
-
-                    const int nth = 32;
-
-                    switch (src0t) {
-                        case GGML_TYPE_F32:
-                            {
-                                switch (dstt) {
-                                    case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
-                                    case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
-                                    default: GGML_ASSERT(false && "not implemented");
-                                };
-                            } break;
-                        default: GGML_ASSERT(false && "not implemented");
-                    }
-
-                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                    [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
-                    [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
-                    [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
-                    [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
-                    [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
-                    [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
-                    [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
-                    [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
-                    [encoder setBytes:&ne0  length:sizeof( int64_t) atIndex:10];
-                    [encoder setBytes:&ne1  length:sizeof( int64_t) atIndex:11];
-                    [encoder setBytes:&ne2  length:sizeof( int64_t) atIndex:12];
-                    [encoder setBytes:&ne3  length:sizeof( int64_t) atIndex:13];
-                    [encoder setBytes:&nb0  length:sizeof(uint64_t) atIndex:14];
-                    [encoder setBytes:&nb1  length:sizeof(uint64_t) atIndex:15];
-                    [encoder setBytes:&nb2  length:sizeof(uint64_t) atIndex:16];
-                    [encoder setBytes:&nb3  length:sizeof(uint64_t) atIndex:17];
-
-                    [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
-                } break;
-            default:
-                fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
-                GGML_ASSERT(false);
-        }
-    }
+    // create multiple command buffers and enqueue them
+    // then, we encode the graph into the command buffers in parallel
+
+    const int n_cb = gf->n_threads;
+
+    NSMutableArray * command_buffers = [NSMutableArray arrayWithCapacity:n_cb];
+
+    for (int i = 0; i < n_cb; ++i) {
+        command_buffers[i] = [ctx->queue commandBuffer];
 
-    if (encoder != nil) {
-        [encoder endEncoding];
-        encoder = nil;
+        // enqueue the command buffers in order to specify their execution order
+        [command_buffers[i] enqueue];
     }
 
-    [command_buffer commit];
-    [command_buffer waitUntilCompleted];
+    // TODO: is this the best way to start threads?
+    dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
+
+    for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
+        const int n_nodes_per_cb = (gf->n_nodes + n_cb - 1) / n_cb;
+
+        dispatch_async(queue, ^{
+            size_t offs_src0 = 0;
+            size_t offs_src1 = 0;
+            size_t offs_dst  = 0;
+
+            id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
+
+            id<MTLComputeCommandEncoder> encoder = nil;
+
+            const int node_start =                                      (cb_idx + 0) * n_nodes_per_cb;
+            const int node_end   = (cb_idx == n_cb - 1) ? gf->n_nodes : (cb_idx + 1) * n_nodes_per_cb;
+
+            for (int i = node_start; i < node_end; ++i) {
+                metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
+
+                struct ggml_tensor * src0 = gf->nodes[i]->src0;
+                struct ggml_tensor * src1 = gf->nodes[i]->src1;
+                struct ggml_tensor * dst  = gf->nodes[i];
+
+                const int64_t  ne00 = src0 ? src0->ne[0] : 0;
+                const int64_t  ne01 = src0 ? src0->ne[1] : 0;
+                const int64_t  ne02 = src0 ? src0->ne[2] : 0;
+                const int64_t  ne03 = src0 ? src0->ne[3] : 0;
+
+                const uint64_t nb00 = src0 ? src0->nb[0] : 0;
+                const uint64_t nb01 = src0 ? src0->nb[1] : 0;
+                const uint64_t nb02 = src0 ? src0->nb[2] : 0;
+                const uint64_t nb03 = src0 ? src0->nb[3] : 0;
+
+                const int64_t  ne10 = src1 ? src1->ne[0] : 0;
+                const int64_t  ne11 = src1 ? src1->ne[1] : 0;
+                const int64_t  ne12 = src1 ? src1->ne[2] : 0;
+                const int64_t  ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
+
+                const uint64_t nb10 = src1 ? src1->nb[0] : 0;
+                const uint64_t nb11 = src1 ? src1->nb[1] : 0;
+                const uint64_t nb12 = src1 ? src1->nb[2] : 0;
+                const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
+
+                const int64_t  ne0  = dst ? dst->ne[0] : 0;
+                const int64_t  ne1  = dst ? dst->ne[1] : 0;
+                const int64_t  ne2  = dst ? dst->ne[2] : 0;
+                const int64_t  ne3  = dst ? dst->ne[3] : 0;
+
+                const uint64_t nb0  = dst ? dst->nb[0] : 0;
+                const uint64_t nb1  = dst ? dst->nb[1] : 0;
+                const uint64_t nb2  = dst ? dst->nb[2] : 0;
+                const uint64_t nb3  = dst ? dst->nb[3] : 0;
+
+                const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
+                const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
+                const enum ggml_type dstt  = dst  ? dst->type  : GGML_TYPE_COUNT;
+
+                id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(ctx, src0, &offs_src0) : nil;
+                id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
+                id<MTLBuffer> id_dst  = dst  ? ggml_metal_get_buffer(ctx, dst,  &offs_dst)  : nil;
+
+                //metal_printf("%s: op - %s\n", __func__, ggml_op_name(dst->op));
+                //if (src0) {
+                //    metal_printf("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
+                //            ggml_is_contiguous(src0), src0->name);
+                //}
+                //if (src1) {
+                //    metal_printf("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
+                //            ggml_is_contiguous(src1), src1->name);
+                //}
+                //if (dst) {
+                //    metal_printf("%s: dst  - %4s [%5lld, %5lld, %5lld], 1, %s\n",  __func__, ggml_type_name(dstt),  ne0,  ne1,  ne2,
+                //            dst->name);
+                //}
+
+                switch (dst->op) {
+                    case GGML_OP_RESHAPE:
+                    case GGML_OP_VIEW:
+                    case GGML_OP_TRANSPOSE:
+                    case GGML_OP_PERMUTE:
+                        {
+                            // noop
+                        } break;
+                    case GGML_OP_ADD:
+                        {
+                            if (encoder == nil) {
+                                encoder = [command_buffer computeCommandEncoder];
+                            }
+
+                            [encoder setComputePipelineState:ctx->pipeline_add];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+
+                            const int64_t n = ggml_nelements(dst);
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                        } break;
+                    case GGML_OP_MUL:
+                        {
+                            if (encoder == nil) {
+                                encoder = [command_buffer computeCommandEncoder];
+                            }
+
+                            if (ggml_nelements(src1) == ne10) {
+                                // src1 is a row
+                                [encoder setComputePipelineState:ctx->pipeline_mul_row];
+                            } else {
+                                [encoder setComputePipelineState:ctx->pipeline_mul];
+                            }
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+                            [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+
+                            const int64_t n = ggml_nelements(dst);
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                        } break;
+                    case GGML_OP_SCALE:
+                        {
+                            if (encoder == nil) {
+                                encoder = [command_buffer computeCommandEncoder];
+                            }
+
+                            const float scale = *(const float *) src1->data;
+
+                            [encoder setComputePipelineState:ctx->pipeline_scale];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                            [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
+
+                            const int64_t n = ggml_nelements(dst);
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                        } break;
+                    case GGML_OP_SILU:
+                        {
+                            if (encoder == nil) {
+                                encoder = [command_buffer computeCommandEncoder];
+                            }
+
+                            [encoder setComputePipelineState:ctx->pipeline_silu];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                            const int64_t n = ggml_nelements(dst);
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                        } break;
+                    case GGML_OP_RELU:
+                        {
+                            if (encoder == nil) {
+                                encoder = [command_buffer computeCommandEncoder];
+                            }
+
+                            [encoder setComputePipelineState:ctx->pipeline_relu];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                            const int64_t n = ggml_nelements(dst);
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                        } break;
+                    case GGML_OP_GELU:
+                    {
+                            if (encoder == nil) {
+                                encoder = [command_buffer computeCommandEncoder];
+                            }
+
+                            [encoder setComputePipelineState:ctx->pipeline_gelu];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                            const int64_t n = ggml_nelements(dst);
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                    } break;
+                    case GGML_OP_SOFT_MAX:
+                        {
+                            if (encoder == nil) {
+                                encoder = [command_buffer computeCommandEncoder];
+                            }
+
+                            const int nth = 32;
+
+                            [encoder setComputePipelineState:ctx->pipeline_soft_max];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                            [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
+                            [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
+                            [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+                            [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                        } break;
+                    case GGML_OP_DIAG_MASK_INF:
+                        {
+                            if (encoder == nil) {
+                                encoder = [command_buffer computeCommandEncoder];
+                            }
+
+                            const int n_past = ((int32_t *)(src1->data))[0];
+
+                            [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                            [encoder setBytes:&ne00   length:sizeof(ne00) atIndex:2];
+                            [encoder setBytes:&ne01   length:sizeof(ne01) atIndex:3];
+                            [encoder setBytes:&n_past length:sizeof(int)  atIndex:4];
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                        } break;
+                    case GGML_OP_MUL_MAT:
+                        {
+                            // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
+
+                            GGML_ASSERT(ne00 == ne10);
+                            GGML_ASSERT(ne02 == ne12);
+
+                            if (ggml_is_contiguous(src0) &&
+                                ggml_is_contiguous(src1) &&
+                                (src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) {
+
+                                if (encoder != nil) {
+                                    [encoder endEncoding];
+                                    encoder = nil;
+                                }
 
-    {
-        const double time_elapsed = [command_buffer GPUEndTime] - [command_buffer GPUStartTime];
-        UNUSED(time_elapsed);
+                                MPSDataType src0dt = src0t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
+                                MPSDataType src1dt = src1t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
+
+                                // for F32 x F32 we use MPS
+                                MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor
+                                    matrixDescriptorWithRows:ne01 columns:ne00 rowBytes:src0->nb[1] dataType:src0dt];
+
+                                MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor
+                                    matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:src1->nb[1] dataType:src1dt];
+
+                                MPSMatrixDescriptor * desc  = [MPSMatrixDescriptor
+                                    matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:dst->nb[1] dataType:MPSDataTypeFloat32];
+
+                                MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc]
+                                    initWithDevice:ctx->device transposeLeft:false transposeRight:true
+                                        resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
+
+                                // we need to do ne02 multiplications
+                                // TODO: is there a way to do this in parallel - currently very slow ..
+                                // TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS
+                                for (int64_t i02 = 0; i02 < ne02; ++i02) {
+                                    size_t offs_src0_cur = offs_src0 + i02*nb02;
+                                    size_t offs_src1_cur = offs_src1 + i02*nb12;
+                                    size_t offs_dst_cur  = offs_dst  + i02*nb2;
 
-        metal_printf("%s: time elapsed = %f ms\n", __func__, time_elapsed * 1000.0);
+                                    MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0_cur descriptor:desc0];
+                                    MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1_cur descriptor:desc1];
+                                    MPSMatrix * mat_dst  = [[MPSMatrix alloc] initWithBuffer:id_dst  offset:offs_dst_cur  descriptor:desc ];
+
+                                    [mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
+                                }
+                            } else {
+                                if (encoder == nil) {
+                                    encoder = [command_buffer computeCommandEncoder];
+                                }
+
+                                int nth0 = 32;
+                                int nth1 = 1;
+
+                                // use custom matrix x vector kernel
+                                switch (src0t) {
+                                    case GGML_TYPE_F16:
+                                        {
+                                            GGML_ASSERT(ne02 == ne12);
+
+                                            nth0 = 64;
+                                            nth1 = 1;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
+                                        } break;
+                                    case GGML_TYPE_Q4_0:
+                                        {
+                                            GGML_ASSERT(ne02 == 1);
+                                            GGML_ASSERT(ne12 == 1);
+
+                                            nth0 = 8;
+                                            nth1 = 8;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
+                                        } break;
+                                    case GGML_TYPE_Q4_1:
+                                        {
+                                            GGML_ASSERT(ne02 == 1);
+                                            GGML_ASSERT(ne12 == 1);
+
+                                            nth0 = 8;
+                                            nth1 = 8;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
+                                        } break;
+                                    case GGML_TYPE_Q2_K:
+                                        {
+                                            GGML_ASSERT(ne02 == 1);
+                                            GGML_ASSERT(ne12 == 1);
+
+                                            nth0 = 4;
+                                            nth1 = 16;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_k_f32];
+                                        } break;
+                                    case GGML_TYPE_Q3_K:
+                                        {
+                                            GGML_ASSERT(ne02 == 1);
+                                            GGML_ASSERT(ne12 == 1);
+
+                                            nth0 = 4;
+                                            nth1 = 16;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_k_f32];
+                                        } break;
+                                    case GGML_TYPE_Q4_K:
+                                        {
+                                            GGML_ASSERT(ne02 == 1);
+                                            GGML_ASSERT(ne12 == 1);
+
+                                            nth0 = 4;
+                                            nth1 = 16;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32];
+                                        } break;
+                                    case GGML_TYPE_Q5_K:
+                                        {
+                                            GGML_ASSERT(ne02 == 1);
+                                            GGML_ASSERT(ne12 == 1);
+
+                                            nth0 = 4;
+                                            nth1 = 16;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_k_f32];
+                                        } break;
+                                    case GGML_TYPE_Q6_K:
+                                        {
+                                            GGML_ASSERT(ne02 == 1);
+                                            GGML_ASSERT(ne12 == 1);
+
+                                            nth0 = 4;
+                                            nth1 = 16;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_k_f32];
+                                        } break;
+                                    default:
+                                        {
+                                            fprintf(stderr, "Asserting on type %d\n",(int)src0t);
+                                            GGML_ASSERT(false && "not implemented");
+                                        }
+                                };
+
+                                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                                [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+                                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+                                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+                                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
+                                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:5];
+                                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:6];
+                                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:7];
+                                [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:8];
+                                [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:9];
+                                [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
+                                [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
+                                [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
+                                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:13];
+                                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:14];
+
+                                if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) {
+                                    [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
+                                    [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                }
+                                else if (src0t == GGML_TYPE_Q2_K ||
+                                         src0t == GGML_TYPE_Q3_K ||
+                                         src0t == GGML_TYPE_Q4_K ||
+                                         src0t == GGML_TYPE_Q5_K ||
+                                         src0t == GGML_TYPE_Q6_K) {
+                                    [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
+                                    [encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                } else {
+                                    [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
+                                    [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                }
+                            }
+                        } break;
+                    case GGML_OP_GET_ROWS:
+                        {
+                            if (encoder == nil) {
+                                encoder = [command_buffer computeCommandEncoder];
+                            }
+
+                            switch (src0->type) {
+                                case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
+                                case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
+                                case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
+                                case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break;
+                                case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_k]; break;
+                                case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break;
+                                case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_k]; break;
+                                case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break;
+                                default: GGML_ASSERT(false && "not implemented");
+                            }
+
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+                            [encoder setBytes:&(src0->ne[0]) length:sizeof( int64_t) atIndex:3];
+                            [encoder setBytes:&(src0->nb[1]) length:sizeof(uint64_t) atIndex:4];
+                            [encoder setBytes:&(dst->nb[1])  length:sizeof(uint64_t) atIndex:5];
+
+                            const int64_t n = ggml_nelements(src1);
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                        } break;
+                    case GGML_OP_RMS_NORM:
+                        {
+                            if (encoder == nil) {
+                                encoder = [command_buffer computeCommandEncoder];
+                            }
+
+                            const float eps = 1e-6f;
+
+                            const int nth = 256;
+
+                            [encoder setComputePipelineState:ctx->pipeline_rms_norm];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                            [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+                            [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
+                            [encoder setBytes:&eps  length:sizeof(   float) atIndex:4];
+                            [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
+
+                            const int64_t nrows = ggml_nrows(src0);
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                        } break;
+                    case GGML_OP_ROPE:
+                        {
+                            if (encoder == nil) {
+                                encoder = [command_buffer computeCommandEncoder];
+                            }
+
+                            const int n_dims = ((int32_t *) src1->data)[1];
+                            const int mode   = ((int32_t *) src1->data)[2];
+
+                            const int n_past = ((int32_t *)(src1->data))[0];
+
+                            [encoder setComputePipelineState:ctx->pipeline_rope];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                            [encoder setBytes:&ne00   length:sizeof( int64_t) atIndex:2];
+                            [encoder setBytes:&ne01   length:sizeof( int64_t) atIndex:3];
+                            [encoder setBytes:&ne02   length:sizeof( int64_t) atIndex:4];
+                            [encoder setBytes:&ne03   length:sizeof( int64_t) atIndex:5];
+                            [encoder setBytes:&nb00   length:sizeof(uint64_t) atIndex:6];
+                            [encoder setBytes:&nb01   length:sizeof(uint64_t) atIndex:7];
+                            [encoder setBytes:&nb02   length:sizeof(uint64_t) atIndex:8];
+                            [encoder setBytes:&nb03   length:sizeof(uint64_t) atIndex:9];
+                            [encoder setBytes:&ne0    length:sizeof( int64_t) atIndex:10];
+                            [encoder setBytes:&ne1    length:sizeof( int64_t) atIndex:11];
+                            [encoder setBytes:&ne2    length:sizeof( int64_t) atIndex:12];
+                            [encoder setBytes:&ne3    length:sizeof( int64_t) atIndex:13];
+                            [encoder setBytes:&nb0    length:sizeof(uint64_t) atIndex:14];
+                            [encoder setBytes:&nb1    length:sizeof(uint64_t) atIndex:15];
+                            [encoder setBytes:&nb2    length:sizeof(uint64_t) atIndex:16];
+                            [encoder setBytes:&nb3    length:sizeof(uint64_t) atIndex:17];
+                            [encoder setBytes:&n_past length:sizeof(     int) atIndex:18];
+                            [encoder setBytes:&n_dims length:sizeof(     int) atIndex:19];
+                            [encoder setBytes:&mode   length:sizeof(     int) atIndex:20];
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                        } break;
+                    case GGML_OP_CPY:
+                        {
+                            if (encoder == nil) {
+                                encoder = [command_buffer computeCommandEncoder];
+                            }
+
+                            const int nth = 32;
+
+                            switch (src0t) {
+                                case GGML_TYPE_F32:
+                                    {
+                                        switch (dstt) {
+                                            case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
+                                            case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
+                                            default: GGML_ASSERT(false && "not implemented");
+                                        };
+                                    } break;
+                                default: GGML_ASSERT(false && "not implemented");
+                            }
+
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                            [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+                            [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
+                            [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
+                            [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
+                            [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
+                            [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
+                            [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
+                            [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
+                            [encoder setBytes:&ne0  length:sizeof( int64_t) atIndex:10];
+                            [encoder setBytes:&ne1  length:sizeof( int64_t) atIndex:11];
+                            [encoder setBytes:&ne2  length:sizeof( int64_t) atIndex:12];
+                            [encoder setBytes:&ne3  length:sizeof( int64_t) atIndex:13];
+                            [encoder setBytes:&nb0  length:sizeof(uint64_t) atIndex:14];
+                            [encoder setBytes:&nb1  length:sizeof(uint64_t) atIndex:15];
+                            [encoder setBytes:&nb2  length:sizeof(uint64_t) atIndex:16];
+                            [encoder setBytes:&nb3  length:sizeof(uint64_t) atIndex:17];
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                        } break;
+                    default:
+                        fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
+                        GGML_ASSERT(false);
+                }
+            }
+
+            if (encoder != nil) {
+                [encoder endEncoding];
+                encoder = nil;
+            }
+
+            [command_buffer commit];
+        });
     }
+
+    // wait for all threads to finish
+    dispatch_barrier_sync(queue, ^{});
+
+    [command_buffers[n_cb - 1] waitUntilCompleted];
 }