]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : slight speed-up for add and mul kernels (#2917)
authorGeorgi Gerganov <redacted>
Fri, 1 Sep 2023 10:42:41 +0000 (13:42 +0300)
committerGitHub <redacted>
Fri, 1 Sep 2023 10:42:41 +0000 (13:42 +0300)
ggml-metal.m
ggml-metal.metal

index 8c3c64f53f00a7186f9cb12236c5b389a198c57f..4267db9be3e61db79f0278966d06c82bd32ce238 100644 (file)
@@ -680,6 +680,12 @@ void ggml_metal_graph_compute(
                         } break;
                     case GGML_OP_ADD:
                         {
+                            GGML_ASSERT(ggml_is_contiguous(src0));
+
+                            // utilize float4
+                            GGML_ASSERT(ne00 % 4 == 0);
+                            const int64_t nb = ne00/4;
+
                             if (ggml_nelements(src1) == ne10) {
                                 // src1 is a row
                                 [encoder setComputePipelineState:ctx->pipeline_add_row];
@@ -689,14 +695,20 @@ void ggml_metal_graph_compute(
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
-                            [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+                            [encoder setBytes:&nb     length:sizeof(nb) atIndex:3];
 
-                            const int64_t n = ggml_nelements(dst);
+                            const int64_t n = ggml_nelements(dst)/4;
 
                             [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                         } break;
                     case GGML_OP_MUL:
                         {
+                            GGML_ASSERT(ggml_is_contiguous(src0));
+
+                            // utilize float4
+                            GGML_ASSERT(ne00 % 4 == 0);
+                            const int64_t nb = ne00/4;
+
                             if (ggml_nelements(src1) == ne10) {
                                 // src1 is a row
                                 [encoder setComputePipelineState:ctx->pipeline_mul_row];
@@ -706,9 +718,9 @@ void ggml_metal_graph_compute(
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
-                            [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+                            [encoder setBytes:&nb     length:sizeof(nb) atIndex:3];
 
-                            const int64_t n = ggml_nelements(dst);
+                            const int64_t n = ggml_nelements(dst)/4;
 
                             [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                         } break;
index 02db5323ea0f27956cd6f756b273e8bc2e91df2b..8cdf0b9d2ba0a39aa92c8fbdd02e05795839e529 100644 (file)
@@ -25,9 +25,9 @@ typedef struct {
 } block_q8_0;
 
 kernel void kernel_add(
-        device const float * src0,
-        device const float * src1,
-        device       float * dst,
+        device const float4 * src0,
+        device const float4 * src1,
+        device       float4 * dst,
         uint tpig[[thread_position_in_grid]]) {
     dst[tpig] = src0[tpig] + src1[tpig];
 }
@@ -35,18 +35,18 @@ kernel void kernel_add(
 // assumption: src1 is a row
 // broadcast src1 into src0
 kernel void kernel_add_row(
-        device const float * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
+        device const float4 * src0,
+        device const float4 * src1,
+        device       float4 * dst,
+        constant   int64_t & nb,
         uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] + src1[tpig % ne00];
+    dst[tpig] = src0[tpig] + src1[tpig % nb];
 }
 
 kernel void kernel_mul(
-        device const float * src0,
-        device const float * src1,
-        device       float * dst,
+        device const float4 * src0,
+        device const float4 * src1,
+        device       float4 * dst,
         uint tpig[[thread_position_in_grid]]) {
     dst[tpig] = src0[tpig] * src1[tpig];
 }
@@ -54,12 +54,12 @@ kernel void kernel_mul(
 // assumption: src1 is a row
 // broadcast src1 into src0
 kernel void kernel_mul_row(
-        device const float * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
+        device const float4 * src0,
+        device const float4 * src1,
+        device       float4 * dst,
+        constant    int64_t & nb,
         uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] * src1[tpig % ne00];
+    dst[tpig] = src0[tpig] * src1[tpig % nb];
 }
 
 kernel void kernel_scale(