]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml-cuda : support stablelm rope (#4156)
authorslaren <redacted>
Fri, 24 Nov 2023 17:04:31 +0000 (18:04 +0100)
committerGitHub <redacted>
Fri, 24 Nov 2023 17:04:31 +0000 (18:04 +0100)
* ggml-cuda : support stablelm rope

* remove unused freq_base kernel parameter

* add n_dims parameter to llm_build_k_shift, default to n_rot via overload

* llama : fix llm_build_k_shift args

---------

Co-authored-by: Georgi Gerganov <redacted>
ggml-cuda.cu
llama.cpp

index f0db7ae357a2fefb22c73f3c893eb4f94085b847..5b80e4ae313293bdc630a45e57567a2efea2dcc3 100644 (file)
@@ -4610,8 +4610,8 @@ static __global__ void rope(
 
 template<typename T, bool has_pos>
 static __global__ void rope_neox(
-    const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
-    float ext_factor, float attn_factor, rope_corr_dims corr_dims
+    const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims
 ) {
     const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
@@ -4620,23 +4620,25 @@ static __global__ void rope_neox(
     }
 
     const int row = blockDim.x*blockIdx.x + threadIdx.x;
-    const int i = row*ncols + col/2;
+    const int ib = col / n_dims;
+    const int ic = col % n_dims;
+
+    const int i = row*ncols + ib*n_dims + ic/2;
     const int i2 = row/p_delta_rows;
 
-    // simplified from `(ib * ncols + col) * (-1 / ncols)`, where ib is assumed to be zero
-    const float cur_rot = -float(col)/ncols;
+    float cur_rot = inv_ndims * ic - ib;
 
     const int p = has_pos ? pos[i2] : 0;
-    const float theta_base = p*powf(freq_base, cur_rot);
+    const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f);
 
     float cos_theta, sin_theta;
     rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
 
     const float x0 = x[i + 0];
-    const float x1 = x[i + ncols/2];
+    const float x1 = x[i + n_dims/2];
 
-    dst[i + 0]       = x0*cos_theta - x1*sin_theta;
-    dst[i + ncols/2] = x0*sin_theta + x1*cos_theta;
+    dst[i + 0]        = x0*cos_theta - x1*sin_theta;
+    dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
 }
 
 static __global__ void rope_glm_f32(
@@ -5739,20 +5741,26 @@ static void rope_cuda(
 
 template<typename T>
 static void rope_neox_cuda(
-    const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
+    const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
     float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
 ) {
     GGML_ASSERT(ncols % 2 == 0);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
     const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
     const dim3 block_nums(nrows, num_blocks_x, 1);
+
+    const float theta_scale = powf(freq_base, -2.0f/n_dims);
+    const float inv_ndims = -1.0f / n_dims;
+
     if (pos == nullptr) {
         rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
-            x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
+            x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+            theta_scale, inv_ndims
         );
     } else {
         rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
-            x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
+            x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+            theta_scale, inv_ndims
         );
     }
 }
@@ -6707,15 +6715,14 @@ inline void ggml_cuda_op_rope(
         GGML_ASSERT(false);
         rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, main_stream);
     } else if (is_neox) {
-        GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet");
         if (src0->type == GGML_TYPE_F32) {
             rope_neox_cuda(
-                (const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
+                (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
                 attn_factor, corr_dims, main_stream
             );
         } else if (src0->type == GGML_TYPE_F16) {
             rope_neox_cuda(
-                (const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
+                (const half *)src0_dd, (half *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
                 attn_factor, corr_dims, main_stream
             );
         } else {
index 9fb7244b41cf52b319e6c635c65e738fdfb207b2..5b31f20164470895550083369fe402278ec60ba0 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -3469,7 +3469,7 @@ static void llm_build_k_shift(
        struct ggml_cgraph * graph,
             llm_rope_type   type,
                   int64_t   n_ctx,
-                  int64_t   n_rot,
+                  int       n_rot,
                   float     freq_base,
                   float     freq_scale,
        const llm_build_cb & cb) {
@@ -3501,7 +3501,7 @@ static void llm_build_k_shift(
             // we rotate only the first n_rot dimensions
             ggml_rope_custom_inplace(ctx,
                     ggml_view_3d(ctx, kv.k,
-                        n_rot, n_head_kv, n_ctx,
+                        n_embd_head, n_head_kv, n_ctx,
                         ggml_element_size(kv.k)*n_embd_head,
                         ggml_element_size(kv.k)*n_embd_gqa,
                         ggml_element_size(kv.k)*n_embd_gqa*n_ctx*il),