]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : fix build errors and kernel sig after #2268 (#3898)
authorGeorgi Gerganov <redacted>
Thu, 2 Nov 2023 06:33:37 +0000 (08:33 +0200)
committerGitHub <redacted>
Thu, 2 Nov 2023 06:33:37 +0000 (08:33 +0200)
ggml-metal.m
ggml-metal.metal

index 611d5e173681eb5a6e8d967beb84488ee933cb4d..b33a3cb8fd0128ac6478b9228e8609f75298ad96 100644 (file)
@@ -1419,34 +1419,35 @@ void ggml_metal_graph_compute(
                                 default: GGML_ASSERT(false);
                             };
 
-                            [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
-                            [encoder setBuffer:id_src1 offset:offs_src1        atIndex:1];
-                            [encoder setBuffer:id_dst  offset:offs_dst         atIndex:2];
-                            [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:3];
-                            [encoder setBytes:&ne01    length:sizeof( int64_t) atIndex:4];
-                            [encoder setBytes:&ne02    length:sizeof( int64_t) atIndex:5];
-                            [encoder setBytes:&ne03    length:sizeof( int64_t) atIndex:6];
-                            [encoder setBytes:&nb00    length:sizeof(uint64_t) atIndex:7];
-                            [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:8];
-                            [encoder setBytes:&nb02    length:sizeof(uint64_t) atIndex:9];
-                            [encoder setBytes:&nb03    length:sizeof(uint64_t) atIndex:10];
-                            [encoder setBytes:&ne0     length:sizeof( int64_t) atIndex:11];
-                            [encoder setBytes:&ne1     length:sizeof( int64_t) atIndex:12];
-                            [encoder setBytes:&ne2     length:sizeof( int64_t) atIndex:13];
-                            [encoder setBytes:&ne3     length:sizeof( int64_t) atIndex:14];
-                            [encoder setBytes:&nb0     length:sizeof(uint64_t) atIndex:15];
-                            [encoder setBytes:&nb1     length:sizeof(uint64_t) atIndex:16];
-                            [encoder setBytes:&nb2     length:sizeof(uint64_t) atIndex:17];
-                            [encoder setBytes:&nb3     length:sizeof(uint64_t) atIndex:18];
-                            [encoder setBytes:&n_past  length:sizeof(     int) atIndex:19];
-                            [encoder setBytes:&n_dims  length:sizeof(     int) atIndex:20];
-                            [encoder setBytes:&mode    length:sizeof(     int) atIndex:21];
-                            [encoder setBytes:&freq_base  length:sizeof(float) atIndex:22];
-                            [encoder setBytes:&freq_scale length:sizeof(float) atIndex:23];
-                            [encoder setBytes:&ext_factor  length:sizeof(float) atIndex:24];
-                            [encoder setBytes:&attn_factor length:sizeof(float) atIndex:25];
-                            [encoder setBytes:&beta_fast   length:sizeof(float) atIndex:26];
-                            [encoder setBytes:&beta_slow   length:sizeof(float) atIndex:27];
+                            [encoder setBuffer:id_src0     offset:offs_src0        atIndex:0];
+                            [encoder setBuffer:id_src1     offset:offs_src1        atIndex:1];
+                            [encoder setBuffer:id_dst      offset:offs_dst         atIndex:2];
+                            [encoder setBytes:&ne00        length:sizeof( int64_t) atIndex:3];
+                            [encoder setBytes:&ne01        length:sizeof( int64_t) atIndex:4];
+                            [encoder setBytes:&ne02        length:sizeof( int64_t) atIndex:5];
+                            [encoder setBytes:&ne03        length:sizeof( int64_t) atIndex:6];
+                            [encoder setBytes:&nb00        length:sizeof(uint64_t) atIndex:7];
+                            [encoder setBytes:&nb01        length:sizeof(uint64_t) atIndex:8];
+                            [encoder setBytes:&nb02        length:sizeof(uint64_t) atIndex:9];
+                            [encoder setBytes:&nb03        length:sizeof(uint64_t) atIndex:10];
+                            [encoder setBytes:&ne0         length:sizeof( int64_t) atIndex:11];
+                            [encoder setBytes:&ne1         length:sizeof( int64_t) atIndex:12];
+                            [encoder setBytes:&ne2         length:sizeof( int64_t) atIndex:13];
+                            [encoder setBytes:&ne3         length:sizeof( int64_t) atIndex:14];
+                            [encoder setBytes:&nb0         length:sizeof(uint64_t) atIndex:15];
+                            [encoder setBytes:&nb1         length:sizeof(uint64_t) atIndex:16];
+                            [encoder setBytes:&nb2         length:sizeof(uint64_t) atIndex:17];
+                            [encoder setBytes:&nb3         length:sizeof(uint64_t) atIndex:18];
+                            [encoder setBytes:&n_past      length:sizeof(     int) atIndex:19];
+                            [encoder setBytes:&n_dims      length:sizeof(     int) atIndex:20];
+                            [encoder setBytes:&mode        length:sizeof(     int) atIndex:21];
+                            [encoder setBytes:&n_orig_ctx  length:sizeof(     int) atIndex:22];
+                            [encoder setBytes:&freq_base   length:sizeof(   float) atIndex:23];
+                            [encoder setBytes:&freq_scale  length:sizeof(   float) atIndex:24];
+                            [encoder setBytes:&ext_factor  length:sizeof(   float) atIndex:25];
+                            [encoder setBytes:&attn_factor length:sizeof(   float) atIndex:26];
+                            [encoder setBytes:&beta_fast   length:sizeof(   float) atIndex:27];
+                            [encoder setBytes:&beta_slow   length:sizeof(   float) atIndex:28];
 
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
index 471d7d390f8138bc657e9a342a28407380659de7..7c35f23a7612fd75362457b3fc9d137cd37e0bfa 100644 (file)
@@ -1070,20 +1070,20 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
 // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
 static void rope_yarn(
     float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
-    float * cos_theta, float * sin_theta
+    thread float * cos_theta, thread float * sin_theta
 ) {
     // Get n-d rotational scaling corrected for extrapolation
     float theta_interp = freq_scale * theta_extrap;
     float theta = theta_interp;
     if (ext_factor != 0.0f) {
-        ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
+        float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
         theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
 
         // Get n-d magnitude scaling corrected for interpolation
-        mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
+        mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
     }
-    *cos_theta = cosf(theta) * mscale;
-    *sin_theta = sinf(theta) * mscale;
+    *cos_theta = cos(theta) * mscale;
+    *sin_theta = sin(theta) * mscale;
 }
 
 // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
@@ -1123,8 +1123,13 @@ typedef void (rope_t)(
         constant         int & n_past,
         constant         int & n_dims,
         constant         int & mode,
+        constant         int & n_orig_ctx,
         constant       float & freq_base,
         constant       float & freq_scale,
+        constant       float & ext_factor,
+        constant       float & attn_factor,
+        constant       float & beta_fast,
+        constant       float & beta_slow,
         uint  tiitg[[thread_index_in_threadgroup]],
         uint3 tptg[[threads_per_threadgroup]],
         uint3 tgpig[[threadgroup_position_in_grid]]);
@@ -1153,6 +1158,7 @@ kernel void kernel_rope(
         constant         int & n_past,
         constant         int & n_dims,
         constant         int & mode,
+        constant         int & n_orig_ctx,
         constant       float & freq_base,
         constant       float & freq_scale,
         constant       float & ext_factor,