]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : parallel RoPE on Metal (#3024)
authorKawrakow <redacted>
Thu, 7 Sep 2023 13:45:01 +0000 (15:45 +0200)
committerGitHub <redacted>
Thu, 7 Sep 2023 13:45:01 +0000 (16:45 +0300)
* Parallel RoPE on metal

* PR suggestion

---------

Co-authored-by: Iwan Kawrakow <redacted>
ggml-metal.m
ggml-metal.metal

index 521ca180f085be10a3972d3a2e3ea386556a1bdb..7e2355ce6bcc7eeb8abe7f5290cad19e9629cb35 100644 (file)
@@ -1141,7 +1141,7 @@ void ggml_metal_graph_compute(
                             [encoder setBytes:&freq_base  length:sizeof(float) atIndex:21];
                             [encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
 
-                            [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                            [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
                         } break;
                     case GGML_OP_DUP:
                     case GGML_OP_CPY:
index 5edf6d521c2d3ed52175b005adcef8923af4124e..5070561fba1ace47d5b224d13ec8e9a0a7771773 100644 (file)
@@ -682,25 +682,27 @@ kernel void kernel_rope(
         constant       int & mode,
         constant     float & freq_base,
         constant     float & freq_scale,
-        uint3 tpig[[thread_position_in_grid]]) {
-    const int64_t i3 = tpig[2];
-    const int64_t i2 = tpig[1];
-    const int64_t i1 = tpig[0];
+        uint  tiitg[[thread_index_in_threadgroup]],
+        uint3 tptg[[threads_per_threadgroup]],
+        uint3 tgpig[[threadgroup_position_in_grid]]) {
+    const int64_t i3 = tgpig[2];
+    const int64_t i2 = tgpig[1];
+    const int64_t i1 = tgpig[0];
 
     const bool is_neox = mode & 2;
-    const float theta_scale = pow(freq_base, -2.0f/n_dims);
 
     const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
 
-    float theta = freq_scale * (float)p;
+    const float theta_0 = freq_scale * (float)p;
+    const float inv_ndims = -1.f/n_dims;
 
     if (!is_neox) {
-        for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
+        for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
+
+            const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
             const float cos_theta = cos(theta);
             const float sin_theta = sin(theta);
 
-            theta *= theta_scale;
-
             device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
             device       float * dst_data  = (device float *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 
@@ -712,12 +714,12 @@ kernel void kernel_rope(
         }
     } else {
         for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
-            for (int64_t ic = 0; ic < n_dims; ic += 2) {
+            for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
+
+                const float theta = theta_0 * pow(freq_base, inv_ndims*ic - ib);
                 const float cos_theta = cos(theta);
                 const float sin_theta = sin(theta);
 
-                theta *= theta_scale;
-
                 const int64_t i0 = ib*n_dims + ic/2;
 
                 device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);