]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : handle ggml_scale for n%4 != 0 (close #3754)
authorGeorgi Gerganov <redacted>
Tue, 24 Oct 2023 06:46:50 +0000 (09:46 +0300)
committerGeorgi Gerganov <redacted>
Tue, 24 Oct 2023 06:47:22 +0000 (09:47 +0300)
ggml-ci

ggml-metal.m
ggml-metal.metal

index c908106be5e680da9cad6c7d770a892fc08a69e7..c1901dca75269dbc86971b5dccccaaae8e2c871c 100644 (file)
@@ -62,6 +62,7 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(mul);
     GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
     GGML_METAL_DECL_KERNEL(scale);
+    GGML_METAL_DECL_KERNEL(scale_4);
     GGML_METAL_DECL_KERNEL(silu);
     GGML_METAL_DECL_KERNEL(relu);
     GGML_METAL_DECL_KERNEL(gelu);
@@ -249,6 +250,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(mul);
         GGML_METAL_ADD_KERNEL(mul_row);
         GGML_METAL_ADD_KERNEL(scale);
+        GGML_METAL_ADD_KERNEL(scale_4);
         GGML_METAL_ADD_KERNEL(silu);
         GGML_METAL_ADD_KERNEL(relu);
         GGML_METAL_ADD_KERNEL(gelu);
@@ -347,6 +349,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(mul);
     GGML_METAL_DEL_KERNEL(mul_row);
     GGML_METAL_DEL_KERNEL(scale);
+    GGML_METAL_DEL_KERNEL(scale_4);
     GGML_METAL_DEL_KERNEL(silu);
     GGML_METAL_DEL_KERNEL(relu);
     GGML_METAL_DEL_KERNEL(gelu);
@@ -923,15 +926,20 @@ void ggml_metal_graph_compute(
 
                             const float scale = *(const float *) src1->data;
 
-                            [encoder setComputePipelineState:ctx->pipeline_scale];
+                            int64_t n = ggml_nelements(dst);
+
+                            if (n % 4 == 0) {
+                                n /= 4;
+                                [encoder setComputePipelineState:ctx->pipeline_scale_4];
+                            } else {
+                                [encoder setComputePipelineState:ctx->pipeline_scale];
+                            }
+
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
                             [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
 
-                            const int64_t n = ggml_nelements(dst);
-                            GGML_ASSERT(n % 4 == 0);
-
-                            [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                         } break;
                     case GGML_OP_UNARY:
                         switch (ggml_get_unary_op(gf->nodes[i])) {
index 69fc7136279ad4c8fbc934bd1689376ad6734d90..f4b460564453c5b3cfb1e0b545fea0d5e2c6182e 100644 (file)
@@ -125,9 +125,17 @@ kernel void kernel_mul_row(
 }
 
 kernel void kernel_scale(
+        device const float * src0,
+        device       float * dst,
+        constant     float & scale,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] * scale;
+}
+
+kernel void kernel_scale_4(
         device const float4 * src0,
         device       float4 * dst,
-        constant     float & scale,
+        constant     float  & scale,
         uint tpig[[thread_position_in_grid]]) {
     dst[tpig] = src0[tpig] * scale;
 }