]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : add gqa8 kernel to allow llama-2-70B on metal (#2459)
authorMatteo Boschini <redacted>
Tue, 1 Aug 2023 07:43:12 +0000 (09:43 +0200)
committerGitHub <redacted>
Tue, 1 Aug 2023 07:43:12 +0000 (10:43 +0300)
* Added gqa8 kernel to allow llama-2-70B on metal

* Update ggml-metal.m

Co-authored-by: Cebtenzzre <redacted>
* Extend kernel_mul_mat_f16_f32 to handle gqa broadcast

* Added ne03==ne13 assertion

---------

Co-authored-by: Cebtenzzre <redacted>
ggml-metal.m
ggml-metal.metal

index 74a6bff40411784f2b13ca4c1a7bf607bfc400c4..3f098d39677a0ee49c0952243c82abdc63f5ad0f 100644 (file)
@@ -718,7 +718,8 @@ void ggml_metal_graph_compute(
                             // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
 
                             GGML_ASSERT(ne00 == ne10);
-                            GGML_ASSERT(ne02 == ne12);
+                            // GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
+                            GGML_ASSERT(ne03 == ne13);
 
                             if (ggml_is_contiguous(src0) &&
                                 ggml_is_contiguous(src1) &&
@@ -746,11 +747,11 @@ void ggml_metal_graph_compute(
                                     initWithDevice:ctx->device transposeLeft:false transposeRight:true
                                         resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
 
-                                // we need to do ne02 multiplications
+                                // we need to do ne12 multiplications
                                 // TODO: is there a way to do this in parallel - currently very slow ..
                                 // TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS
-                                for (int64_t i02 = 0; i02 < ne02; ++i02) {
-                                    size_t offs_src0_cur = offs_src0 + i02*nb02;
+                                for (int64_t i02 = 0; i02 < ne12; ++i02) {
+                                    size_t offs_src0_cur = offs_src0 + i02/(ne12/ne02)*nb02; // gqa not used for now
                                     size_t offs_src1_cur = offs_src1 + i02*nb12;
                                     size_t offs_dst_cur  = offs_dst  + i02*nb2;
 
@@ -772,8 +773,6 @@ void ggml_metal_graph_compute(
                                 switch (src0t) {
                                     case GGML_TYPE_F16:
                                         {
-                                            GGML_ASSERT(ne02 == ne12);
-
                                             nth0 = 64;
                                             nth1 = 1;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
@@ -853,16 +852,18 @@ void ggml_metal_graph_compute(
                                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
                                 [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
                                 [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
-                                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:5];
-                                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:6];
-                                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:7];
-                                [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:8];
-                                [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:9];
-                                [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
-                                [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
-                                [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
-                                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:13];
-                                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:14];
+                                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
+                                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
+                                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
+                                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
+                                [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
+                                [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
+                                [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
+                                [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
+                                [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
+                                [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
+                                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:15];
+                                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:16];
 
                                 if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
                                     src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
index 696b33ce75cf4fa8d92850ef5e762c32190d6410..8d26b5ec2dfa4f649b7a6d478f73238a0c700da6 100644 (file)
@@ -509,11 +509,13 @@ kernel void kernel_mul_mat_f16_f32(
         device       float * dst,
         constant   int64_t & ne00,
         constant   int64_t & ne01,
+        constant   int64_t & ne02,
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
+        constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
@@ -529,7 +531,7 @@ kernel void kernel_mul_mat_f16_f32(
     const int64_t r1 = tgpig.y;
     const int64_t im = tgpig.z;
 
-    device const half  * x = (device const half  *) (src0 + r0*nb01 + im*nb02);
+    device const half  * x = (device const half  *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
     device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
 
     sum[tpitg.x] = 0.0f;
@@ -552,6 +554,7 @@ kernel void kernel_mul_mat_f16_f32(
     }
 }
 
+
 kernel void kernel_alibi_f32(
         device const float * src0,
         device       float * dst,