]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : support bcast add & dup & cont op (#2323)
authorJiahao Li <redacted>
Sun, 23 Jul 2023 11:00:37 +0000 (19:00 +0800)
committerGitHub <redacted>
Sun, 23 Jul 2023 11:00:37 +0000 (14:00 +0300)
ggml-metal.m
ggml-metal.metal

index 2810fa2a841c5008d3d68b6b647db517680ab715..78a3b65f1959265a617a69f8a0e90f9e626fa1fa 100644 (file)
@@ -42,6 +42,7 @@ struct ggml_metal_context {
     id<MTLComputePipelineState> pipeline_##name
 
     GGML_METAL_DECL_KERNEL(add);
+    GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
     GGML_METAL_DECL_KERNEL(mul);
     GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
     GGML_METAL_DECL_KERNEL(scale);
@@ -157,6 +158,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name);
 
         GGML_METAL_ADD_KERNEL(add);
+        GGML_METAL_ADD_KERNEL(add_row);
         GGML_METAL_ADD_KERNEL(mul);
         GGML_METAL_ADD_KERNEL(mul_row);
         GGML_METAL_ADD_KERNEL(scale);
@@ -464,10 +466,16 @@ void ggml_metal_graph_compute(
                                 encoder = [command_buffer computeCommandEncoder];
                             }
 
-                            [encoder setComputePipelineState:ctx->pipeline_add];
+                            if (ggml_nelements(src1) == ne10) {
+                                // src1 is a row
+                                [encoder setComputePipelineState:ctx->pipeline_add_row];
+                            } else {
+                                [encoder setComputePipelineState:ctx->pipeline_add];
+                            }
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+                            [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
 
                             const int64_t n = ggml_nelements(dst);
 
@@ -919,7 +927,9 @@ void ggml_metal_graph_compute(
 
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                         } break;
+                    case GGML_OP_DUP:
                     case GGML_OP_CPY:
+                    case GGML_OP_CONT:
                         {
                             if (encoder == nil) {
                                 encoder = [command_buffer computeCommandEncoder];
index 5a9a6d842387d06bec1df96c3aa3cb9c2f7b04af..987376d560879b97af68f85b066b6e5e4fb5ad6d 100644 (file)
@@ -67,6 +67,17 @@ kernel void kernel_add(
     dst[tpig] = src0[tpig] + src1[tpig];
 }
 
+// assumption: src1 is a row
+// broadcast src1 into src0
+kernel void kernel_add_row(
+        device const float * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] + src1[tpig % ne00];
+}
+
 kernel void kernel_mul(
         device const float * src0,
         device const float * src1,