]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : refactor rope norm/neox (#7634)
authorGeorgi Gerganov <redacted>
Wed, 5 Jun 2024 08:29:20 +0000 (11:29 +0300)
committerGitHub <redacted>
Wed, 5 Jun 2024 08:29:20 +0000 (11:29 +0300)
* ggml : unify rope norm/neox (CPU)

* ggml : fix compile warning

* ggml : remove GLM rope mode

ggml-ci

* metal : better rope implementation

ggml-ci

* cuda : better rope implementation

ggml-ci

* naming : n_orig_ctx -> n_ctx_orig

ggml-ci

* dev : add reminders to update backends

ggml-ci

* vulkan : fix ggml_rope_ext() usage

* cuda : fix array size + indents

ggml-ci

19 files changed:
examples/baby-llama/baby-llama.cpp
examples/convert-legacy-llama.py
examples/finetune/finetune.cpp
examples/train-text-from-scratch/train-text-from-scratch.cpp
ggml-cuda/rope.cu
ggml-kompute.cpp
ggml-metal.m
ggml-metal.metal
ggml-sycl.cpp
ggml-vulkan.cpp
ggml.c
ggml.h
kompute-shaders/op_rope_f16.comp
kompute-shaders/op_rope_f32.comp
kompute-shaders/rope_common.comp
llama.cpp
tests/test-backend-ops.cpp
tests/test-grad0.cpp
tests/test-rope.cpp

index bf0125e75374645365552419e7ceb5d2b97349d3..4f6c3746a106c9b8a5b86890d7a96b496ad4c428 100644 (file)
@@ -522,8 +522,8 @@ static struct ggml_tensor * forward(
             // wk   shape [n_embd, n_embd, 1, 1]
             // Qcur shape [n_embd/n_head, n_head, N, 1]
             // Kcur shape [n_embd/n_head, n_head, N, 1]
-            struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N), KQ_pos, n_rot, 0, 0);
-            struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N), KQ_pos, n_rot, 0, 0);
+            struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N), KQ_pos, n_rot, 0);
+            struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N), KQ_pos, n_rot, 0);
 
             // store key and value to memory
             {
@@ -759,8 +759,8 @@ static struct ggml_tensor * forward_batch(
             // wk   shape [n_embd, n_embd, 1, 1]
             // Qcur shape [n_embd/n_head, n_head, N, n_batch]
             // Kcur shape [n_embd/n_head, n_head, N, n_batch]
-            struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), KQ_pos, n_rot, 0, 0);
-            struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), KQ_pos, n_rot, 0, 0);
+            struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), KQ_pos, n_rot, 0);
+            struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), KQ_pos, n_rot, 0);
             assert_shape_4d(Qcur, n_embd/n_head, n_head, N, n_batch);
             assert_shape_4d(Kcur, n_embd/n_head, n_head, N, n_batch);
 
@@ -1056,7 +1056,7 @@ static struct ggml_tensor * forward_lora(
                                                         model->layers[il].wqb,
                                                         cur)),
                                                 n_embd/n_head, n_head, N),
-                                            KQ_pos, n_rot, 0, 0);
+                                            KQ_pos, n_rot, 0);
             struct ggml_tensor * Kcur = ggml_rope(ctx0,
                                             ggml_reshape_3d(ctx0,
                                                 ggml_mul_mat(ctx0,
@@ -1065,7 +1065,7 @@ static struct ggml_tensor * forward_lora(
                                                         model->layers[il].wkb,
                                                         cur)),
                                                 n_embd/n_head, n_head, N),
-                                            KQ_pos, n_rot, 0, 0);
+                                            KQ_pos, n_rot, 0);
 
             // store key and value to memory
             {
index fd840101569a9da58cc9ddc36512270f6054730d..721a57c00299b4583f16c76130e4e5ef155aaeaf 100755 (executable)
@@ -176,7 +176,7 @@ class Params:
     rope_scaling_type: gguf.RopeScalingType | None = None
     f_rope_freq_base: float | None = None
     f_rope_scale: float | None = None
-    n_orig_ctx: int | None = None
+    n_ctx_orig: int | None = None
     rope_finetuned: bool | None = None
 
     ftype: GGMLFileType | None = None
@@ -226,7 +226,7 @@ class Params:
         with open(config_path) as f:
             config = json.load(f)
 
-        rope_scaling_type = f_rope_scale = n_orig_ctx = rope_finetuned = None
+        rope_scaling_type = f_rope_scale = n_ctx_orig = rope_finetuned = None
         rope_scaling = config.get("rope_scaling")
 
         if rope_scaling is not None and (typ := rope_scaling.get("type")):
@@ -236,7 +236,7 @@ class Params:
                 rope_scaling_type = gguf.RopeScalingType.LINEAR
             elif typ == "yarn":
                 rope_scaling_type = gguf.RopeScalingType.YARN
-                n_orig_ctx = rope_scaling['original_max_position_embeddings']
+                n_ctx_orig = rope_scaling['original_max_position_embeddings']
                 rope_finetuned = rope_scaling['finetuned']
             else:
                 raise NotImplementedError(f'Unknown rope scaling type: {typ}')
@@ -272,7 +272,7 @@ class Params:
             f_rope_freq_base  = config.get("rope_theta"),
             rope_scaling_type = rope_scaling_type,
             f_rope_scale      = f_rope_scale,
-            n_orig_ctx        = n_orig_ctx,
+            n_ctx_orig        = n_ctx_orig,
             rope_finetuned    = rope_finetuned,
         )
 
@@ -864,8 +864,8 @@ class OutputFile:
             self.gguf.add_rope_scaling_type(params.rope_scaling_type)
             self.gguf.add_rope_scaling_factor(params.f_rope_scale)
 
-        if params.n_orig_ctx is not None:
-            self.gguf.add_rope_scaling_orig_ctx_len(params.n_orig_ctx)
+        if params.n_ctx_orig is not None:
+            self.gguf.add_rope_scaling_orig_ctx_len(params.n_ctx_orig)
 
         if params.rope_finetuned is not None:
             self.gguf.add_rope_scaling_finetuned(params.rope_finetuned)
index 22425730f20eb070343d557d1df10e6b6caa481e..71a4333ee790844a4b3aa1678e98d9fc7edc96b6 100644 (file)
@@ -564,7 +564,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
         const int rope_mode = 0;
 
         return ggml_rope_ext(ctx,
-            t, KQ_pos, nullptr, n_rot, rope_mode, n_ctx, 0,
+            t, KQ_pos, nullptr, n_rot, rope_mode, n_ctx,
             rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
         );
     };
index e2f85c68297b80ce5cdf25f4d5a6451118e3e062..b779f6bd49876bfe4c3f9bf0d8dd1efc2dc93877 100644 (file)
@@ -302,7 +302,7 @@ static struct ggml_tensor * llama_build_train_graphs(
         const int rope_mode = 0;
 
         return ggml_rope_ext(
-            ctx, t, KQ_pos, nullptr, n_rot, rope_mode, n_ctx, 0, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
+            ctx, t, KQ_pos, nullptr, n_rot, rope_mode, n_ctx, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
         );
     };
 
index 0dd07977ebab1e263db0419c67da9e7ae0c79cd3..596fb7c135058e1c66896ffb7944f1e7f3783d82 100644 (file)
@@ -1,7 +1,7 @@
 #include "rope.cuh"
 
 struct rope_corr_dims {
-    float v[4];
+    float v[2];
 };
 
 static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
@@ -13,8 +13,7 @@ static __device__ float rope_yarn_ramp(const float low, const float high, const
 // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
 static __device__ void rope_yarn(
     float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
-    float * cos_theta, float * sin_theta
-) {
+    float * cos_theta, float * sin_theta) {
     // Get n-d rotational scaling corrected for extrapolation
     float theta_interp = freq_scale * theta_extrap;
     float theta = theta_interp;
@@ -29,27 +28,38 @@ static __device__ void rope_yarn(
     *sin_theta = sinf(theta) * mscale;
 }
 
-// rope == RoPE == rotary positional embedding
-template<typename T, bool has_pos>
-static __global__ void rope(
-    const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
-    float ext_factor, float attn_factor, rope_corr_dims corr_dims
-) {
-    const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
+template<typename T, bool has_ff>
+static __global__ void rope_norm(
+    const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
+    const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
-    if (col >= ncols) {
+    if (i0 >= ne0) {
         return;
     }
 
     const int row = blockDim.x*blockIdx.x + threadIdx.x;
-    const int i = row*ncols + col;
+
+    if (i0 >= n_dims) {
+        const int i = row*ne0 + i0;
+
+        dst[i + 0] = x[i + 0];
+        dst[i + 1] = x[i + 1];
+
+        return;
+    }
+
+    const int i  = row*ne0 + i0;
     const int i2 = row/p_delta_rows;
 
-    const int p = has_pos ? pos[i2] : 0;
-    const float theta_base = p*powf(freq_base, -float(col)/ncols);
+    const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
+
+    const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
 
-    float cos_theta, sin_theta;
-    rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta);
+    float cos_theta;
+    float sin_theta;
+
+    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
 
     const float x0 = x[i + 0];
     const float x1 = x[i + 1];
@@ -58,23 +68,20 @@ static __global__ void rope(
     dst[i + 1] = x0*sin_theta + x1*cos_theta;
 }
 
-template<typename T, bool has_pos, bool has_freq_facs>
+template<typename T, bool has_ff>
 static __global__ void rope_neox(
-    const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors
-) {
-    const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
+    const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
+    const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
-    if (col >= ncols) {
+    if (i0 >= ne0) {
         return;
     }
 
     const int row = blockDim.x*blockIdx.x + threadIdx.x;
-    const int ib = col / n_dims;
-    const int ic = col % n_dims;
 
-    if (ib > 0) {
-        const int i = row*ncols + ib*n_dims + ic;
+    if (i0 >= n_dims) {
+        const int i = row*ne0 + i0;
 
         dst[i + 0] = x[i + 0];
         dst[i + 1] = x[i + 1];
@@ -82,16 +89,17 @@ static __global__ void rope_neox(
         return;
     }
 
-    const int i  = row*ncols + ib*n_dims + ic/2;
+    const int i  = row*ne0 + i0/2;
     const int i2 = row/p_delta_rows;
 
-    const int p = has_pos ? pos[i2] : 0;
-    const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;
+    const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
 
-    const float theta_base = p*powf(theta_scale, col/2.0f)/freq_factor;
+    const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
 
-    float cos_theta, sin_theta;
-    rope_yarn(theta_base, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
+    float cos_theta;
+    float sin_theta;
+
+    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
 
     const float x0 = x[i + 0];
     const float x1 = x[i + n_dims/2];
@@ -100,144 +108,81 @@ static __global__ void rope_neox(
     dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
 }
 
-static __global__ void rope_glm_f32(
-    const float * x, float * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
-    int n_ctx
-) {
-    const int col = blockDim.x*blockIdx.x + threadIdx.x;
-    const int half_n_dims = ncols/4;
-
-    if (col >= half_n_dims) {
-        return;
-    }
-
-    const int row = blockDim.y*blockIdx.y + threadIdx.y;
-    const int i = row*ncols + col;
-    const int i2 = row/p_delta_rows;
-
-    const float col_theta_scale = powf(freq_base, -2.0f*col/ncols);
-     // FIXME: this is likely wrong
-    const int p = pos != nullptr ? pos[i2] : 0;
-
-    const float theta = min(p, n_ctx - 2)*freq_scale*col_theta_scale;
-    const float sin_theta = sinf(theta);
-    const float cos_theta = cosf(theta);
-
-    const float x0 = x[i + 0];
-    const float x1 = x[i + half_n_dims];
-
-    dst[i + 0]           = x0*cos_theta - x1*sin_theta;
-    dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;
-
-    const float block_theta = ((float)max(p - n_ctx - 2, 0))*col_theta_scale;
-    const float sin_block_theta = sinf(block_theta);
-    const float cos_block_theta = cosf(block_theta);
-
-    const float x2 = x[i + half_n_dims * 2];
-    const float x3 = x[i + half_n_dims * 3];
-
-    dst[i + half_n_dims * 2] = x2*cos_block_theta - x3*sin_block_theta;
-    dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
-}
-
-
 template<typename T>
-static void rope_cuda(
-    const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
-) {
-    GGML_ASSERT(ncols % 2 == 0);
+static void rope_norm_cuda(
+    const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
+    GGML_ASSERT(ne0 % 2 == 0);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
-    const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
-    const dim3 block_nums(nrows, num_blocks_x, 1);
-    if (pos == nullptr) {
-        rope<T, false><<<block_nums, block_dims, 0, stream>>>(
-            x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
-        );
+    const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
+    const dim3 block_nums(nr, n_blocks_x, 1);
+
+    const float theta_scale = powf(freq_base, -2.0f/n_dims);
+
+    if (freq_factors == nullptr) {
+        rope_norm<T, false><<<block_nums, block_dims, 0, stream>>>(
+                x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+                theta_scale, freq_factors
+                );
     } else {
-        rope<T, true><<<block_nums, block_dims, 0, stream>>>(
-            x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
-        );
+        rope_norm<T, true><<<block_nums, block_dims, 0, stream>>>(
+                x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+                theta_scale, freq_factors
+                );
     }
 }
 
 template<typename T>
 static void rope_neox_cuda(
-    const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
-) {
-    GGML_ASSERT(ncols % 2 == 0);
+    const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
+    GGML_ASSERT(ne0 % 2 == 0);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
-    const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
-    const dim3 block_nums(nrows, num_blocks_x, 1);
+    const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
+    const dim3 block_nums(nr, n_blocks_x, 1);
 
     const float theta_scale = powf(freq_base, -2.0f/n_dims);
 
-    if (pos == nullptr) {
-        if (freq_factors == nullptr) {
-            rope_neox<T, false, false><<<block_nums, block_dims, 0, stream>>>(
-                x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
-                theta_scale, freq_factors
-                );
-        } else {
-            rope_neox<T, false, true><<<block_nums, block_dims, 0, stream>>>(
-                x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+    if (freq_factors == nullptr) {
+        rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
+                x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
                 theta_scale, freq_factors
                 );
-        }
     } else {
-        if (freq_factors == nullptr) {
-            rope_neox<T, true, false><<<block_nums, block_dims, 0, stream>>>(
-                x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
-                theta_scale, freq_factors
-                );
-        } else {
-            rope_neox<T, true, true><<<block_nums, block_dims, 0, stream>>>(
-                x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+        rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
+                x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
                 theta_scale, freq_factors
                 );
-        }
     }
 }
 
-static void rope_glm_f32_cuda(
-    const float * x, float * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, int n_ctx, cudaStream_t stream
-) {
-    GGML_ASSERT(ncols % 4 == 0);
-    const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1);
-    const int num_blocks_x = (ncols + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE;
-    const dim3 block_nums(num_blocks_x, nrows, 1);
-    rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, n_ctx);
-}
-
-static void rope_cuda_f16(
-    const half * x, half * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) {
+static void rope_norm_cuda_f16(
+    const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
 
-    rope_cuda<half>(x, dst, ncols, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
+    rope_norm_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
 }
 
-static void rope_cuda_f32(
-    const float * x, float * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) {
+static void rope_norm_cuda_f32(
+    const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
 
-    rope_cuda<float>(x, dst, ncols, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
+    rope_norm_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
 }
 
 static void rope_neox_cuda_f16(
-    const half * x, half * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
+    const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
     float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
 
-    rope_neox_cuda<half>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
+    rope_neox_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
 }
 
 static void rope_neox_cuda_f32(
-    const float * x, float * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
+    const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
     float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
 ) {
 
-    rope_neox_cuda<float>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
+    rope_neox_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
 }
 
 void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -258,16 +203,22 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 
     const int64_t ne00 = src0->ne[0];
     const int64_t ne01 = src0->ne[1];
-    const int64_t nrows = ggml_nrows(src0);
+    const int64_t nr = ggml_nrows(src0);
 
-    //const int n_past      = ((int32_t *) dst->op_params)[0];
-    const int n_dims      = ((int32_t *) dst->op_params)[1];
-    const int mode        = ((int32_t *) dst->op_params)[2];
-    const int n_ctx       = ((int32_t *) dst->op_params)[3];
-    const int n_orig_ctx  = ((int32_t *) dst->op_params)[4];
+    //const int n_past     = ((int32_t *) dst->op_params)[0];
+    const int n_dims     = ((int32_t *) dst->op_params)[1];
+    const int mode       = ((int32_t *) dst->op_params)[2];
+    //const int n_ctx      = ((int32_t *) dst->op_params)[3];
+    const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
 
     // RoPE alteration for extended context
-    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+    float freq_base;
+    float freq_scale;
+    float ext_factor;
+    float attn_factor;
+    float beta_fast;
+    float beta_slow;
+
     memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
     memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
     memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
@@ -275,38 +226,28 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
     memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
 
-    const float * freq_factors = nullptr;
-    const int32_t * pos = nullptr;
-
     const bool is_neox = mode & 2;
-    const bool is_glm  = mode & 4;
 
-    pos = (const int32_t *) src1_d;
+    const int32_t * pos = (const int32_t *) src1_d;
 
-    if (is_neox) {
-        if (src2 != nullptr) {
-            freq_factors = (const float *) src2->data;
-        }
-    } else {
-        GGML_ASSERT(src2 == nullptr && "TODO: freq_factors not implemented for !is_neox");
+    const float * freq_factors = nullptr;
+    if (src2 != nullptr) {
+        freq_factors = (const float *) src2->data;
     }
 
     rope_corr_dims corr_dims;
-    ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
+    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
 
     // compute
-    if (is_glm) {
-        GGML_ASSERT(false);
-        rope_glm_f32_cuda(src0_d, dst_d, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, stream);
-    } else if (is_neox) {
+    if (is_neox) {
         if (src0->type == GGML_TYPE_F32) {
             rope_neox_cuda_f32(
-                (const float *)src0_d, (float *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
+                (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
                 attn_factor, corr_dims, freq_factors, stream
             );
         } else if (src0->type == GGML_TYPE_F16) {
             rope_neox_cuda_f16(
-                (const half *)src0_d, (half *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
+                (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
                 attn_factor, corr_dims, freq_factors, stream
             );
         } else {
@@ -314,14 +255,14 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
         }
     } else {
         if (src0->type == GGML_TYPE_F32) {
-            rope_cuda_f32(
-                (const float *)src0_d, (float *)dst_d, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, stream
+            rope_norm_cuda_f32(
+                (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+                attn_factor, corr_dims, freq_factors, stream
             );
         } else if (src0->type == GGML_TYPE_F16) {
-            rope_cuda_f16(
-                (const half *)src0_d, (half *)dst_d, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, stream
+            rope_norm_cuda_f16(
+                (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+                attn_factor, corr_dims, freq_factors, stream
             );
         } else {
             GGML_ASSERT(false);
index eabd70d5eeed86eb2d465bee672eabf015a54c0d..5592741be425580f351ffa151a6e1e2161c95669 100644 (file)
@@ -1192,7 +1192,7 @@ static void ggml_vk_rope(
     const std::shared_ptr<kp::Tensor>& inB,
     const std::shared_ptr<kp::Tensor>& out,
     uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
-    ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_orig_ctx,
+    ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_ctx_orig,
     float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
     int32_t ne01, int32_t ne02, int32_t ne03,
     uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
@@ -1221,14 +1221,14 @@ static void ggml_vk_rope(
 
     struct PushConstants {
         uint32_t inAOff, inBOff, outOff;
-        int32_t n_dims, mode, n_orig_ctx;
+        int32_t n_dims, mode, n_ctx_orig;
         float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
         uint32_t nb00, nb01, nb02, nb03;
         int32_t ne0;
         uint32_t nb0, nb1, nb2, nb3;
     } pushConsts {
         safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(outOff, type_size),
-        n_dims, mode, n_orig_ctx,
+        n_dims, mode, n_ctx_orig,
         freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
         nb00, nb01, nb02, nb03,
         ne0,
@@ -1692,13 +1692,16 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
 #pragma message("      https://github.com/ggerganov/llama.cpp/pull/7225")
                         GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
 
+#pragma message("TODO: update rope NORM mode to match NEOX mode")
+#pragma message("      https://github.com/ggerganov/llama.cpp/pull/7634")
+
                         GGML_ASSERT(ne10 == ne02);
                         GGML_ASSERT(src0t == dstt);
                         // const int n_past = ((int32_t *) dst->op_params)[0];
                         const int n_dims     = ((int32_t *) dst->op_params)[1];
                         const int mode       = ((int32_t *) dst->op_params)[2];
                         // skip 3, n_ctx used in GLM RoPE, unimplemented in Vulkan
-                        const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
+                        const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
 
                         float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
                         memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
@@ -1708,7 +1711,7 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
                         memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
                         memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
                         ggml_vk_rope(
-                            seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_orig_ctx,
+                            seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_ctx_orig,
                             freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
                             ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3
                         );
index fddc44f78d8afc5198e795625ff768a80e784821..946f11813dcf994c3790dd31b3e47378556949cc 100644 (file)
@@ -172,8 +172,10 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
-    GGML_METAL_KERNEL_TYPE_ROPE_F32,
-    GGML_METAL_KERNEL_TYPE_ROPE_F16,
+    GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
+    GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
+    GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
+    GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
     GGML_METAL_KERNEL_TYPE_IM2COL_F16,
     GGML_METAL_KERNEL_TYPE_IM2COL_F32,
     GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
@@ -626,8 +628,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,           mul_mm_id_iq1_m_f32,            ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,          mul_mm_id_iq4_nl_f32,           ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,          mul_mm_id_iq4_xs_f32,           ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32,                      rope_f32,                       true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16,                      rope_f16,                       true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,                 rope_norm_f32,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,                 rope_norm_f16,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,                 rope_neox_f32,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,                 rope_neox_f16,                  true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16,                    im2col_f16,                     true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32,                    im2col_f32,                     true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32,                   upscale_f32,                    true);
@@ -2285,7 +2289,7 @@ static enum ggml_status ggml_metal_graph_compute(
                         const int n_dims     = ((int32_t *) dst->op_params)[1];
                         const int mode       = ((int32_t *) dst->op_params)[2];
                         // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
-                        const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
+                        const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
 
                         float freq_base;
                         float freq_scale;
@@ -2302,22 +2306,23 @@ static enum ggml_status ggml_metal_graph_compute(
                         memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
 
                         const bool is_neox = mode & 2;
-                        const bool is_glm  = mode & 4;
 
-                        GGML_ASSERT(!is_glm && "GLM RoPE not implemented in Metal");
+                        id<MTLComputePipelineState> pipeline = nil;
 
                         if (!is_neox) {
-                            GGML_ASSERT(id_src2 == nil && "TODO: freq_factors not implemented for !is_neox");
+                            switch (src0->type) {
+                                case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
+                                case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
+                                default: GGML_ASSERT(false);
+                            };
+                        } else {
+                            switch (src0->type) {
+                                case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
+                                case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
+                                default: GGML_ASSERT(false);
+                            };
                         }
 
-                        id<MTLComputePipelineState> pipeline = nil;
-
-                        switch (src0->type) {
-                            case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline; break;
-                            case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline; break;
-                            default: GGML_ASSERT(false);
-                        };
-
                         [encoder setComputePipelineState:pipeline];
                         [encoder setBuffer:id_src0     offset:offs_src0        atIndex:0];
                         [encoder setBuffer:id_src1     offset:offs_src1        atIndex:1];
@@ -2345,14 +2350,13 @@ static enum ggml_status ggml_metal_graph_compute(
                         [encoder setBytes:&nb3         length:sizeof(uint64_t) atIndex:19];
                         [encoder setBytes:&n_past      length:sizeof(     int) atIndex:20];
                         [encoder setBytes:&n_dims      length:sizeof(     int) atIndex:21];
-                        [encoder setBytes:&mode        length:sizeof(     int) atIndex:22];
-                        [encoder setBytes:&n_orig_ctx  length:sizeof(     int) atIndex:23];
-                        [encoder setBytes:&freq_base   length:sizeof(   float) atIndex:24];
-                        [encoder setBytes:&freq_scale  length:sizeof(   float) atIndex:25];
-                        [encoder setBytes:&ext_factor  length:sizeof(   float) atIndex:26];
-                        [encoder setBytes:&attn_factor length:sizeof(   float) atIndex:27];
-                        [encoder setBytes:&beta_fast   length:sizeof(   float) atIndex:28];
-                        [encoder setBytes:&beta_slow   length:sizeof(   float) atIndex:29];
+                        [encoder setBytes:&n_ctx_orig  length:sizeof(     int) atIndex:22];
+                        [encoder setBytes:&freq_base   length:sizeof(   float) atIndex:23];
+                        [encoder setBytes:&freq_scale  length:sizeof(   float) atIndex:24];
+                        [encoder setBytes:&ext_factor  length:sizeof(   float) atIndex:25];
+                        [encoder setBytes:&attn_factor length:sizeof(   float) atIndex:26];
+                        [encoder setBytes:&beta_fast   length:sizeof(   float) atIndex:27];
+                        [encoder setBytes:&beta_slow   length:sizeof(   float) atIndex:28];
 
                         [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                     } break;
index 0cb85e1a5bad4daeabe6e1de00656402556e9f8b..e2796fd601281166213e81575fe56b54c97f2c72 100644 (file)
@@ -1654,8 +1654,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
 // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
 static void rope_yarn(
     float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
-    thread float * cos_theta, thread float * sin_theta
-) {
+    thread float * cos_theta, thread float * sin_theta) {
     // Get n-d rotational scaling corrected for extrapolation
     float theta_interp = freq_scale * theta_extrap;
     float theta = theta_interp;
@@ -1672,19 +1671,20 @@ static void rope_yarn(
 
 // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
 // `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
-static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
-    return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base));
+static float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
+    return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base));
 }
 
 static void rope_yarn_corr_dims(
-    int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
+    int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
 ) {
     // start and end correction dims
-    dims[0] = max(0.0f,         floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
-    dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
+    dims[0] = max(0.0f,         floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
+    dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
 }
 
-typedef void (rope_t)(
+template<typename T>
+kernel void kernel_rope_norm(
         device const    void * src0,
         device const int32_t * src1,
         device const   float * src2,
@@ -1707,8 +1707,7 @@ typedef void (rope_t)(
         constant    uint64_t & nb3,
         constant         int & n_past,
         constant         int & n_dims,
-        constant         int & mode,
-        constant         int & n_orig_ctx,
+        constant         int & n_ctx_orig,
         constant       float & freq_base,
         constant       float & freq_scale,
         constant       float & ext_factor,
@@ -1717,10 +1716,52 @@ typedef void (rope_t)(
         constant       float & beta_slow,
         uint  tiitg[[thread_index_in_threadgroup]],
         uint3 tptg[[threads_per_threadgroup]],
-        uint3 tgpig[[threadgroup_position_in_grid]]);
+        uint3 tgpig[[threadgroup_position_in_grid]]) {
+    const int64_t i3 = tgpig[2];
+    const int64_t i2 = tgpig[1];
+    const int64_t i1 = tgpig[0];
+
+    float corr_dims[2];
+    rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
+
+    device const int32_t * pos = src1;
+
+    const float theta_base = (float) pos[i2];
+    const float inv_ndims = -1.f/n_dims;
+
+    float cos_theta;
+    float sin_theta;
+
+    for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
+        if (i0 < n_dims) {
+            const int64_t ic = i0/2;
+
+            const float theta = theta_base * pow(freq_base, inv_ndims*i0);
+
+            const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
+
+            rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+
+            device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+            device       T * dst_data  = (device T *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+            const float x0 = src[0];
+            const float x1 = src[1];
+
+            dst_data[0] = x0*cos_theta - x1*sin_theta;
+            dst_data[1] = x0*sin_theta + x1*cos_theta;
+        } else {
+            device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+            device       T * dst_data  = (device T *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+            dst_data[0] = src[0];
+            dst_data[1] = src[1];
+        }
+    }
+}
 
 template<typename T>
-kernel void kernel_rope(
+kernel void kernel_rope_neox(
         device const    void * src0,
         device const int32_t * src1,
         device const   float * src2,
@@ -1743,8 +1784,7 @@ kernel void kernel_rope(
         constant    uint64_t & nb3,
         constant         int & n_past,
         constant         int & n_dims,
-        constant         int & mode,
-        constant         int & n_orig_ctx,
+        constant         int & n_ctx_orig,
         constant       float & freq_base,
         constant       float & freq_scale,
         constant       float & ext_factor,
@@ -1758,69 +1798,53 @@ kernel void kernel_rope(
     const int64_t i2 = tgpig[1];
     const int64_t i1 = tgpig[0];
 
-    const bool is_neox = mode & 2;
-
     float corr_dims[2];
-    rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
+    rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
 
     device const int32_t * pos = src1;
 
-    const int64_t p = pos[i2];
-
-    const float theta_base = (float)p;
+    const float theta_base = (float) pos[i2];
     const float inv_ndims = -1.f/n_dims;
 
-    if (!is_neox) {
-        for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
-            const float theta = theta_base * pow(freq_base, inv_ndims*i0);
-
-            float cos_theta, sin_theta;
-            rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
-
-            device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-            device       T * dst_data  = (device T *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
-
-            const T x0 = src[0];
-            const T x1 = src[1];
+    float cos_theta;
+    float sin_theta;
 
-            dst_data[0] = x0*cos_theta - x1*sin_theta;
-            dst_data[1] = x0*sin_theta + x1*cos_theta;
-        }
-    } else {
-        for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
-            if (ic < n_dims) {
-                const int64_t i0 = ic/2;
+    for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
+        if (i0 < n_dims) {
+            const int64_t ic = i0/2;
 
-                const float freq_factor = src2 != src0 ? src2[i0] : 1.0f;
-
-                const float theta = theta_base * pow(freq_base, inv_ndims*ic);
+            const float theta = theta_base * pow(freq_base, inv_ndims*i0);
 
-                float cos_theta, sin_theta;
-                rope_yarn(theta/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
+            const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
 
-                device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                device       T * dst_data  = (device T *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+            rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
 
-                const float x0 = src[0];
-                const float x1 = src[n_dims/2];
+            device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
+            device       T * dst_data  = (device T *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
 
-                dst_data[0]        = x0*cos_theta - x1*sin_theta;
-                dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
-            } else {
-                const int64_t i0 = ic;
+            const float x0 = src[0];
+            const float x1 = src[n_dims/2];
 
-                device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                device       T * dst_data  = (device T *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+            dst_data[0]        = x0*cos_theta - x1*sin_theta;
+            dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
+        } else {
+            device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+            device       T * dst_data  = (device T *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 
-                dst_data[0] = src[0];
-                dst_data[1] = src[1];
-            }
+            dst_data[0] = src[0];
+            dst_data[1] = src[1];
         }
     }
 }
 
-template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
-template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
+typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
+typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
+
+template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
+template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
+
+template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
+template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
 
 typedef void (im2col_t)(
         device const float * x,
index 5cd97e4ff98df55983c43be14025dea07c609ea8..3ff76474d7e72e7373e564b7de830169bd92a77e 100644 (file)
@@ -8928,49 +8928,6 @@ static void rope_neox(
     dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
 }
 
-static void rope_glm_f32(
-    const float * x, float * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
-    int n_ctx
-, const sycl::nd_item<3> &item_ct1) {
-    const int col = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                    item_ct1.get_local_id(2);
-    const int half_n_dims = ncols/4;
-
-    if (col >= half_n_dims) {
-        return;
-    }
-
-    const int row = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
-                    item_ct1.get_local_id(1);
-    const int i = row*ncols + col;
-    const int i2 = row/p_delta_rows;
-
-    const float col_theta_scale = dpct::pow(freq_base, -2.0f * col / ncols);
-     // FIXME: this is likely wrong
-    const int p = pos != nullptr ? pos[i2] : 0;
-
-    const float theta = sycl::min(p, n_ctx - 2) * freq_scale * col_theta_scale;
-    const float sin_theta = sycl::sin((float)theta);
-    const float cos_theta = sycl::cos((float)theta);
-
-    const float x0 = x[i + 0];
-    const float x1 = x[i + half_n_dims];
-
-    dst[i + 0]           = x0*cos_theta - x1*sin_theta;
-    dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;
-
-    const float block_theta =
-        ((float)sycl::max(p - n_ctx - 2, 0)) * col_theta_scale;
-    const float sin_block_theta = sycl::sin((float)block_theta);
-    const float cos_block_theta = sycl::cos((float)block_theta);
-
-    const float x2 = x[i + half_n_dims * 2];
-    const float x3 = x[i + half_n_dims * 3];
-
-    dst[i + half_n_dims * 2] = x2*cos_block_theta - x3*sin_block_theta;
-    dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
-}
-
 static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
                            const sycl::nd_item<3> &item_ct1) {
     const int row = item_ct1.get_group(1);
@@ -12520,22 +12477,6 @@ static void rope_neox_sycl(const T *x, T *dst, int ncols, int n_dims, int nrows,
     }
 }
 
-static void rope_glm_f32_sycl(const float *x, float *dst, int ncols, int nrows,
-                              const int32_t *pos, float freq_scale,
-                              int p_delta_rows, float freq_base, int n_ctx,
-                              dpct::queue_ptr stream) {
-    GGML_ASSERT(ncols % 4 == 0);
-    const sycl::range<3> block_dims(1, 1, SYCL_ROPE_BLOCK_SIZE / 4);
-    const int num_blocks_x = (ncols + SYCL_ROPE_BLOCK_SIZE - 1) / SYCL_ROPE_BLOCK_SIZE;
-    const sycl::range<3> block_nums(1, nrows, num_blocks_x);
-    stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                         [=](sycl::nd_item<3> item_ct1) {
-                             rope_glm_f32(x, dst, ncols, pos, freq_scale,
-                                          p_delta_rows, freq_base, n_ctx,
-                                          item_ct1);
-                         });
-}
-
 static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
                               const int nrows, dpct::queue_ptr stream) {
     const sycl::range<3> block_dims(1, 1, WARP_SIZE);
@@ -14066,8 +14007,8 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
     //const int n_past      = ((int32_t *) dst->op_params)[0];
     const int n_dims      = ((int32_t *) dst->op_params)[1];
     const int mode        = ((int32_t *) dst->op_params)[2];
-    const int n_ctx       = ((int32_t *) dst->op_params)[3];
-    const int n_orig_ctx  = ((int32_t *) dst->op_params)[4];
+    //const int n_ctx       = ((int32_t *) dst->op_params)[3];
+    const int n_ctx_orig  = ((int32_t *) dst->op_params)[4];
 
     // RoPE alteration for extended context
     float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
@@ -14087,7 +14028,9 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
     }
 
     const bool is_neox = mode & 2;
-    const bool is_glm  = mode & 4;
+
+#pragma message("TODO: update rope NORM mode to match NEOX mode")
+#pragma message("      https://github.com/ggerganov/llama.cpp/pull/7634")
 
     if (is_neox) {
         pos = (const int32_t *) src1_dd;
@@ -14100,13 +14043,10 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
     }
 
     rope_corr_dims corr_dims;
-    ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
+    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
 
     // compute
-    if (is_glm) {
-        GGML_ASSERT(false);
-        rope_glm_f32_sycl(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, main_stream);
-    } else if (is_neox) {
+    if (is_neox) {
         if (src0->type == GGML_TYPE_F32) {
             rope_neox_sycl(
                 (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
index 5e12ea9dde4d7b53d217451b2c3e040f686625a9..e0c512c0dab0f5bc9ab504618945b54a117cb81b 100644 (file)
@@ -3898,11 +3898,6 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
         {
             const int mode = ((const int32_t *) dst->op_params)[2];
             const bool is_neox = mode & 2;
-            const bool is_glm  = mode & 4;
-
-            if (is_glm) {
-                return nullptr;
-            }
 
             if (is_neox) {
                 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
@@ -4401,7 +4396,7 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, con
     const int n_dims        = ((int32_t *) dst->op_params)[1];
     const int mode          = ((int32_t *) dst->op_params)[2];
     // const int n_ctx         = ((int32_t *) dst->op_params)[3];
-    const int n_orig_ctx    = ((int32_t *) dst->op_params)[4];
+    const int n_ctx_orig    = ((int32_t *) dst->op_params)[4];
     const float freq_base   = ((float *)   dst->op_params)[5];
     const float freq_scale  = ((float *)   dst->op_params)[6];
     const float ext_factor  = ((float *)   dst->op_params)[7];
@@ -4410,12 +4405,12 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, con
     const float beta_slow   = ((float *)   dst->op_params)[10];
 
     const bool is_neox = mode & 2;
-    const bool is_glm  = mode & 4;
 
-    GGML_ASSERT(!is_glm);
+#pragma message("TODO: update rope NORM mode to match NEOX mode")
+#pragma message("      https://github.com/ggerganov/llama.cpp/pull/7634")
 
     float corr_dims[2];
-    ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
+    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
 
     if (is_neox) {
         const float theta_scale = powf(freq_base, -2.0f/n_dims);
@@ -6485,9 +6480,8 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
         case GGML_OP_ROPE:
             {
                 const int mode = ((const int32_t *) op->op_params)[2];
-                const bool is_glm  = mode & 4;
 
-                return !is_glm;
+                return true;
             } break;
         case GGML_OP_NONE:
         case GGML_OP_RESHAPE:
@@ -6992,15 +6986,15 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
     } else if (tensor->op == GGML_OP_ROPE) {
         const int n_dims      = ((int32_t *) tensor->op_params)[1];
         const int mode        = ((int32_t *) tensor->op_params)[2];
-        const int n_ggml_ctx       = ((int32_t *) tensor->op_params)[3];
-        const int n_orig_ggml_ctx  = ((int32_t *) tensor->op_params)[4];
+        //const int n_ctx_ggml       = ((int32_t *) tensor->op_params)[3];
+        const int n_ctx_orig_ggml  = ((int32_t *) tensor->op_params)[4];
         float freq_base       = ((float *)   tensor->op_params)[5];
         float freq_scale      = ((float *)   tensor->op_params)[6];
         float ext_factor      = ((float *)   tensor->op_params)[7];
         float attn_factor     = ((float *)   tensor->op_params)[8];
         float beta_fast       = ((float *)   tensor->op_params)[9];
         float beta_slow       = ((float *)   tensor->op_params)[10];
-        tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ggml_ctx, n_orig_ggml_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
+        tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
     } else if (tensor->op == GGML_OP_UNARY) {
         switch (ggml_get_unary_op(tensor)) {
         case GGML_UNARY_OP_SILU:
diff --git a/ggml.c b/ggml.c
index 11e5c34ab56adef2c6857ca927ea19e044fa0984..1fc77743bc7b945baa6441845735a73870949ada 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -6250,16 +6250,13 @@ static struct ggml_tensor * ggml_rope_impl(
         struct ggml_tensor  * c,
         int                   n_dims,
         int                   mode,
-        int                   n_ctx,
-        int                   n_orig_ctx,
+        int                   n_ctx_orig,
         float                 freq_base,
         float                 freq_scale,
         float                 ext_factor,
         float                 attn_factor,
         float                 beta_fast,
         float                 beta_slow,
-        float                 xpos_base,
-        bool                  xpos_down,
         bool                  inplace) {
     GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
 
@@ -6280,15 +6277,13 @@ static struct ggml_tensor * ggml_rope_impl(
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx };
+    int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
     memcpy(params +  5, &freq_base,    sizeof(float));
     memcpy(params +  6, &freq_scale,   sizeof(float));
     memcpy(params +  7, &ext_factor,   sizeof(float));
     memcpy(params +  8, &attn_factor,  sizeof(float));
     memcpy(params +  9, &beta_fast,    sizeof(float));
     memcpy(params + 10, &beta_slow,    sizeof(float));
-    memcpy(params + 11, &xpos_base,    sizeof(float));
-    memcpy(params + 12, &xpos_down,    sizeof(bool));
     ggml_set_op_params(result, params, sizeof(params));
 
     result->op   = GGML_OP_ROPE;
@@ -6305,10 +6300,9 @@ struct ggml_tensor * ggml_rope(
         struct ggml_tensor  * a,
         struct ggml_tensor  * b,
         int                   n_dims,
-        int                   mode,
-        int                   n_ctx) {
+        int                   mode) {
     return ggml_rope_impl(
-        ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
+        ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, false
     );
 }
 
@@ -6317,10 +6311,9 @@ struct ggml_tensor * ggml_rope_inplace(
         struct ggml_tensor  * a,
         struct ggml_tensor  * b,
         int                   n_dims,
-        int                   mode,
-        int                   n_ctx) {
+        int                   mode) {
     return ggml_rope_impl(
-        ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
+        ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, true
     );
 }
 
@@ -6331,8 +6324,7 @@ struct ggml_tensor * ggml_rope_ext(
         struct ggml_tensor  * c,
         int                   n_dims,
         int                   mode,
-        int                   n_ctx,
-        int                   n_orig_ctx,
+        int                   n_ctx_orig,
         float                 freq_base,
         float                 freq_scale,
         float                 ext_factor,
@@ -6340,8 +6332,8 @@ struct ggml_tensor * ggml_rope_ext(
         float                 beta_fast,
         float                 beta_slow) {
     return ggml_rope_impl(
-        ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
-        ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
+        ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
+        ext_factor, attn_factor, beta_fast, beta_slow, false
     );
 }
 
@@ -6352,8 +6344,7 @@ struct ggml_tensor * ggml_rope_ext_inplace(
         struct ggml_tensor  * c,
         int                   n_dims,
         int                   mode,
-        int                   n_ctx,
-        int                   n_orig_ctx,
+        int                   n_ctx_orig,
         float                 freq_base,
         float                 freq_scale,
         float                 ext_factor,
@@ -6361,8 +6352,8 @@ struct ggml_tensor * ggml_rope_ext_inplace(
         float                 beta_fast,
         float                 beta_slow) {
     return ggml_rope_impl(
-        ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
-        ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
+        ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
+        ext_factor, attn_factor, beta_fast, beta_slow, true
     );
 }
 
@@ -6372,8 +6363,7 @@ struct ggml_tensor * ggml_rope_custom(
         struct ggml_tensor  * b,
         int                   n_dims,
         int                   mode,
-        int                   n_ctx,
-        int                   n_orig_ctx,
+        int                   n_ctx_orig,
         float                 freq_base,
         float                 freq_scale,
         float                 ext_factor,
@@ -6381,8 +6371,8 @@ struct ggml_tensor * ggml_rope_custom(
         float                 beta_fast,
         float                 beta_slow) {
     return ggml_rope_impl(
-        ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
-        ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
+        ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
+        ext_factor, attn_factor, beta_fast, beta_slow, false
     );
 }
 
@@ -6392,8 +6382,7 @@ struct ggml_tensor * ggml_rope_custom_inplace(
         struct ggml_tensor  * b,
         int                   n_dims,
         int                   mode,
-        int                   n_ctx,
-        int                   n_orig_ctx,
+        int                   n_ctx_orig,
         float                 freq_base,
         float                 freq_scale,
         float                 ext_factor,
@@ -6401,21 +6390,11 @@ struct ggml_tensor * ggml_rope_custom_inplace(
         float                 beta_fast,
         float                 beta_slow) {
     return ggml_rope_impl(
-        ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
-        ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
+        ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
+        ext_factor, attn_factor, beta_fast, beta_slow, true
     );
 }
 
-struct ggml_tensor * ggml_rope_xpos_inplace(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        struct ggml_tensor  * b,
-        int                   n_dims,
-        float                 base,
-        bool                  down) {
-    return ggml_rope_impl(ctx, a, b, NULL, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
-}
-
 // ggml_rope_back
 
 struct ggml_tensor * ggml_rope_back(
@@ -6425,16 +6404,13 @@ struct ggml_tensor * ggml_rope_back(
         struct ggml_tensor  * c,
         int                   n_dims,
         int                   mode,
-        int                   n_ctx,
-        int                   n_orig_ctx,
+        int                   n_ctx_orig,
         float                 freq_base,
         float                 freq_scale,
         float                 ext_factor,
         float                 attn_factor,
         float                 beta_fast,
-        float                 beta_slow,
-        float                 xpos_base,
-        bool                  xpos_down) {
+        float                 beta_slow) {
     GGML_ASSERT(ggml_is_vector(b));
     GGML_ASSERT(b->type == GGML_TYPE_I32);
     GGML_ASSERT(a->ne[2] == b->ne[0]);
@@ -6450,15 +6426,13 @@ struct ggml_tensor * ggml_rope_back(
 
     struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
 
-    int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx };
+    int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
     memcpy(params +  5, &freq_base,    sizeof(float));
     memcpy(params +  6, &freq_scale,   sizeof(float));
     memcpy(params +  7, &ext_factor,   sizeof(float));
     memcpy(params +  8, &attn_factor,  sizeof(float));
     memcpy(params +  9, &beta_fast,    sizeof(float));
     memcpy(params + 10, &beta_slow,    sizeof(float));
-    memcpy(params + 11, &xpos_base,    sizeof(float));
-    memcpy(params + 12, &xpos_down,    sizeof(bool));
     ggml_set_op_params(result, params, sizeof(params));
 
     result->op   = GGML_OP_ROPE_BACK;
@@ -14227,8 +14201,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
 // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
 static void rope_yarn(
     float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
-    float * cos_theta, float * sin_theta
-) {
+    float * cos_theta, float * sin_theta) {
     // Get n-d rotational scaling corrected for extrapolation
     float theta_interp = freq_scale * theta_extrap;
     float theta = theta_interp;
@@ -14245,18 +14218,19 @@ static void rope_yarn(
 
 // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
 // `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
-static float ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, float base) {
-    return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
+static float ggml_rope_yarn_corr_dim(int n_dims, int n_ctx_orig, float n_rot, float base) {
+    return n_dims * logf(n_ctx_orig / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
 }
 
 static void ggml_rope_cache_init(
-     float theta_base, float freq_scale, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
-     float * cache, float sin_sign, float theta_scale
-) {
+     float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
+     float * cache, float sin_sign, float theta_scale) {
+    // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
     float theta = theta_base;
     for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
+        const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
         rope_yarn(
-            theta, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
+            theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
         );
         cache[i0 + 1] *= sin_sign;
 
@@ -14265,11 +14239,11 @@ static void ggml_rope_cache_init(
 }
 
 GGML_CALL void ggml_rope_yarn_corr_dims(
-    int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
+    int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
 ) {
     // start and end correction dims
-    float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_fast, freq_base));
-    float end   =  ceilf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_slow, freq_base));
+    float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_fast, freq_base));
+    float end   =  ceilf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_slow, freq_base));
     dims[0] = MAX(0, start);
     dims[1] = MIN(n_dims - 1, end);
 }
@@ -14289,15 +14263,11 @@ static void ggml_compute_forward_rope_f32(
 
     float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
 
-    // these two only relevant for xPos RoPE:
-    float xpos_base;
-    bool  xpos_down;
-
     //const int n_past     = ((int32_t *) dst->op_params)[0];
     const int n_dims     = ((int32_t *) dst->op_params)[1];
     const int mode       = ((int32_t *) dst->op_params)[2];
-    const int n_ctx      = ((int32_t *) dst->op_params)[3];
-    const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
+    //const int n_ctx      = ((int32_t *) dst->op_params)[3];
+    const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
 
     memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
     memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
@@ -14305,8 +14275,6 @@ static void ggml_compute_forward_rope_f32(
     memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
     memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
     memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
-    memcpy(&xpos_base,   (int32_t *) dst->op_params + 11, sizeof(float));
-    memcpy(&xpos_down,   (int32_t *) dst->op_params + 12, sizeof(bool));
 
     GGML_TENSOR_UNARY_OP_LOCALS
 
@@ -14336,20 +14304,15 @@ static void ggml_compute_forward_rope_f32(
     const float theta_scale = powf(freq_base, -2.0f/n_dims);
 
     float corr_dims[2];
-    ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
+    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
 
     const bool is_neox = mode & 2;
-    const bool is_glm  = mode & 4;
 
     const float * freq_factors = NULL;
-    if (is_neox) {
-        if (src2 != NULL) {
-            GGML_ASSERT(src2->type == GGML_TYPE_F32);
-            GGML_ASSERT(src2->ne[0] >= n_dims / 2);
-            freq_factors = (const float *) src2->data;
-        }
-    } else {
-        GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
+    if (src2 != NULL) {
+        GGML_ASSERT(src2->type == GGML_TYPE_F32);
+        GGML_ASSERT(src2->ne[0] >= n_dims / 2);
+        freq_factors = (const float *) src2->data;
     }
 
     // backward process uses inverse rotation by cos and sin.
@@ -14364,94 +14327,50 @@ static void ggml_compute_forward_rope_f32(
             const int64_t p = pos[i2];
 
             float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
-            if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox
-                ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
-            }
+            ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
 
             for (int64_t i1 = 0; i1 < ne1; i1++) {
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
 
-                float theta_base = (float)p;
-
-                if (is_glm) {
-                    theta_base = MIN(p, n_ctx - 2);
-                    float block_theta = MAX(p - (n_ctx - 2), 0);
-                    for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
-                        const float cos_theta = cosf(theta_base);
-                        const float sin_theta = sinf(theta_base) * sin_sign;
-                        const float cos_block_theta = cosf(block_theta);
-                        const float sin_block_theta = sinf(block_theta) * sin_sign;
-
-                        theta_base  *= theta_scale;
-                        block_theta *= theta_scale;
-
-                        const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                              float * dst_data  = (float *)((char *)  dst->data +  i3*nb3 + i2*nb2  + i1*nb1  + i0*nb0);
-
-                        const float x0 = src[0];
-                        const float x1 = src[n_dims/2];
-                        const float x2 = src[n_dims];
-                        const float x3 = src[n_dims/2*3];
-
-                        dst_data[0]          = x0*cos_theta - x1*sin_theta;
-                        dst_data[n_dims/2]   = x0*sin_theta + x1*cos_theta;
-                        dst_data[n_dims]     = x2*cos_block_theta - x3*sin_block_theta;
-                        dst_data[n_dims/2*3] = x2*sin_block_theta + x3*cos_block_theta;
-                    }
-                } else if (!is_neox) {
-                    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
+                if (!is_neox) {
+                    for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
                         const float cos_theta = cache[i0 + 0];
                         const float sin_theta = cache[i0 + 1];
 
-                        // zeta scaling for xPos only:
-                        float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
-                        if (xpos_down) zeta = 1.0f / zeta;
-
                         const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
                               float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 
                         const float x0 = src[0];
                         const float x1 = src[1];
 
-                        dst_data[0] = x0*cos_theta*zeta - x1*sin_theta*zeta;
-                        dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta;
+                        dst_data[0] = x0*cos_theta - x1*sin_theta;
+                        dst_data[1] = x0*sin_theta + x1*cos_theta;
                     }
                 } else {
-                    // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
-                    for (int64_t ic = 0; ic < ne0; ic += 2) {
-                        if (ic < n_dims) {
-                            const int64_t i0 = ic/2;
-
-                            const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f;
-
-                            float cos_theta, sin_theta;
-                            rope_yarn(
-                                theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor,
-                                &cos_theta, &sin_theta
-                            );
+                    for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
+                        const int64_t ic = i0/2;
 
-                            sin_theta  *= sin_sign;
-                            theta_base *= theta_scale;
+                        const float cos_theta = cache[i0 + 0];
+                        const float sin_theta = cache[i0 + 1];
 
-                            const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                                  float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+                        const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
+                        float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
 
-                            const float x0 = src[0];
-                            const float x1 = src[n_dims/2];
+                        const float x0 = src[0];
+                        const float x1 = src[n_dims/2];
 
-                            dst_data[0]        = x0*cos_theta - x1*sin_theta;
-                            dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
-                        } else {
-                            const int64_t i0 = ic;
+                        dst_data[0]        = x0*cos_theta - x1*sin_theta;
+                        dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
+                    }
+                }
 
-                            const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                                  float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+                for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
+                    const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                    float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 
-                            dst_data[0] = src[0];
-                            dst_data[1] = src[1];
-                        }
-                    }
+                    dst_data[0] = src[0];
+                    dst_data[1] = src[1];
                 }
             }
         }
@@ -14477,8 +14396,8 @@ static void ggml_compute_forward_rope_f16(
     //const int n_past     = ((int32_t *) dst->op_params)[0];
     const int n_dims     = ((int32_t *) dst->op_params)[1];
     const int mode       = ((int32_t *) dst->op_params)[2];
-    const int n_ctx      = ((int32_t *) dst->op_params)[3];
-    const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
+    //const int n_ctx      = ((int32_t *) dst->op_params)[3];
+    const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
     memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
     memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
     memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
@@ -14514,20 +14433,15 @@ static void ggml_compute_forward_rope_f16(
     const float theta_scale = powf(freq_base, -2.0f/n_dims);
 
     float corr_dims[2];
-    ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
+    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
 
     const bool is_neox = mode & 2;
-    const bool is_glm  = mode & 4;
 
     const float * freq_factors = NULL;
-    if (is_neox) {
-        if (src2 != NULL) {
-            GGML_ASSERT(src2->type == GGML_TYPE_F32);
-            GGML_ASSERT(src2->ne[0] >= n_dims / 2);
-            freq_factors = (const float *) src2->data;
-        }
-    } else {
-        GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
+    if (src2 != NULL) {
+        GGML_ASSERT(src2->type == GGML_TYPE_F32);
+        GGML_ASSERT(src2->ne[0] >= n_dims / 2);
+        freq_factors = (const float *) src2->data;
     }
 
     // backward process uses inverse rotation by cos and sin.
@@ -14542,43 +14456,14 @@ static void ggml_compute_forward_rope_f16(
             const int64_t p = pos[i2];
 
             float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
-            if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox
-                ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
-            }
+            ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
 
             for (int64_t i1 = 0; i1 < ne1; i1++) {
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
 
-                float theta_base = (float)p;
-
-                if (is_glm) {
-                    theta_base = MIN(p, n_ctx - 2);
-                    float block_theta = MAX(p - (n_ctx - 2), 0);
-                    for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
-                        const float cos_theta = cosf(theta_base);
-                        const float sin_theta = sinf(theta_base) * sin_sign;
-                        const float cos_block_theta = cosf(block_theta);
-                        const float sin_block_theta = sinf(block_theta) * sin_sign;
-
-                        theta_base  *= theta_scale;
-                        block_theta *= theta_scale;
-
-                        const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                              ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data +  i3*nb3 + i2*nb2  + i1*nb1  + i0*nb0);
-
-                        const float x0 = GGML_FP16_TO_FP32(src[0]);
-                        const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
-                        const float x2 = GGML_FP16_TO_FP32(src[n_dims]);
-                        const float x3 = GGML_FP16_TO_FP32(src[n_dims/2*3]);
-
-                        dst_data[0]          = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
-                        dst_data[n_dims/2]   = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
-                        dst_data[n_dims]     = GGML_FP32_TO_FP16(x2*cos_block_theta - x3*sin_block_theta);
-                        dst_data[n_dims/2*3] = GGML_FP32_TO_FP16(x2*sin_block_theta + x3*cos_block_theta);
-                    }
-                } else if (!is_neox) {
-                    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
+                if (!is_neox) {
+                    for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
                         const float cos_theta = cache[i0 + 0];
                         const float sin_theta = cache[i0 + 1];
 
@@ -14592,40 +14477,29 @@ static void ggml_compute_forward_rope_f16(
                         dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
                     }
                 } else {
-                    // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
-                    for (int64_t ic = 0; ic < ne0; ic += 2) {
-                        if (ic < n_dims) {
-                            const int64_t i0 = ic/2;
-
-                            const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f;
+                    for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
+                        const int64_t ic = i0/2;
 
-                            float cos_theta, sin_theta;
-                            rope_yarn(
-                                theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor,
-                                &cos_theta, &sin_theta
-                            );
-
-                            sin_theta  *= sin_sign;
-                            theta_base *= theta_scale;
+                        const float cos_theta = cache[i0 + 0];
+                        const float sin_theta = cache[i0 + 1];
 
-                            const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                                  ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+                        const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
+                        ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
 
-                            const float x0 = GGML_FP16_TO_FP32(src[0]);
-                            const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
+                        const float x0 = GGML_FP16_TO_FP32(src[0]);
+                        const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
 
-                            dst_data[0]        = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
-                            dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
-                        } else {
-                            const int64_t i0 = ic;
+                        dst_data[0]        = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
+                        dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
+                    }
+                }
 
-                            const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                                  ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+                for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
+                    const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                    ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 
-                            dst_data[0] = src[0];
-                            dst_data[1] = src[1];
-                        }
-                    }
+                    dst_data[0] = src[0];
+                    dst_data[1] = src[1];
                 }
             }
         }
@@ -18327,9 +18201,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     //const int n_past = ((int32_t *) tensor->op_params)[0];
                     const int n_dims     = ((int32_t *) tensor->op_params)[1];
                     const int mode       = ((int32_t *) tensor->op_params)[2];
-                    const int n_ctx      = ((int32_t *) tensor->op_params)[3];
-                    const int n_orig_ctx = ((int32_t *) tensor->op_params)[4];
-                    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down;
+                    //const int n_ctx      = ((int32_t *) tensor->op_params)[3];
+                    const int n_ctx_orig = ((int32_t *) tensor->op_params)[4];
+                    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
 
                     memcpy(&freq_base,   (int32_t *) tensor->op_params +  5, sizeof(float));
                     memcpy(&freq_scale,  (int32_t *) tensor->op_params +  6, sizeof(float));
@@ -18337,8 +18211,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     memcpy(&attn_factor, (int32_t *) tensor->op_params +  8, sizeof(float));
                     memcpy(&beta_fast,   (int32_t *) tensor->op_params +  9, sizeof(float));
                     memcpy(&beta_slow,   (int32_t *) tensor->op_params + 10, sizeof(float));
-                    memcpy(&xpos_base,   (int32_t *) tensor->op_params + 11, sizeof(float));
-                    memcpy(&xpos_down,   (int32_t *) tensor->op_params + 12, sizeof(bool));
 
                     src0->grad = ggml_add_or_set(ctx,
                             src0->grad,
@@ -18348,16 +18220,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 src2,
                                 n_dims,
                                 mode,
-                                n_ctx,
-                                n_orig_ctx,
+                                n_ctx_orig,
                                 freq_base,
                                 freq_scale,
                                 ext_factor,
                                 attn_factor,
                                 beta_fast,
-                                beta_slow,
-                                xpos_base,
-                                xpos_down),
+                                beta_slow),
                             zero_table);
                 }
             } break;
@@ -18367,9 +18236,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     //const int n_past = ((int32_t *) tensor->op_params)[0];
                     const int n_dims     = ((int32_t *) tensor->op_params)[1];
                     const int mode       = ((int32_t *) tensor->op_params)[2];
-                    const int n_ctx      = ((int32_t *) tensor->op_params)[3];
-                    const int n_orig_ctx = ((int32_t *) tensor->op_params)[4];
-                    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down;
+                    //const int n_ctx      = ((int32_t *) tensor->op_params)[3];
+                    const int n_ctx_orig = ((int32_t *) tensor->op_params)[4];
+                    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
 
                     memcpy(&freq_base,   (int32_t *) tensor->op_params +  5, sizeof(float));
                     memcpy(&freq_scale,  (int32_t *) tensor->op_params +  6, sizeof(float));
@@ -18377,8 +18246,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     memcpy(&attn_factor, (int32_t *) tensor->op_params +  8, sizeof(float));
                     memcpy(&beta_fast,   (int32_t *) tensor->op_params +  9, sizeof(float));
                     memcpy(&beta_slow,   (int32_t *) tensor->op_params + 10, sizeof(float));
-                    memcpy(&xpos_base,   (int32_t *) tensor->op_params + 11, sizeof(float));
-                    memcpy(&xpos_down,   (int32_t *) tensor->op_params + 12, sizeof(bool));
 
                     src0->grad = ggml_add_or_set(ctx,
                             src0->grad,
@@ -18388,16 +18255,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 src2,
                                 n_dims,
                                 mode,
-                                n_ctx,
-                                n_orig_ctx,
+                                n_ctx_orig,
                                 freq_base,
                                 freq_scale,
                                 ext_factor,
                                 attn_factor,
                                 beta_fast,
                                 beta_slow,
-                                xpos_base,
-                                xpos_down,
                                 false),
                             zero_table);
                 }
diff --git a/ggml.h b/ggml.h
index addcf1bfef6739d4160ce9d327a710c70261f578..13502a3622fc4f7e3725b0b88a6ccd88b88b3f04 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -1465,7 +1465,6 @@ extern "C" {
     // rotary position embedding
     // if mode & 1 == 1, skip n_past elements (NOT SUPPORTED)
     // if mode & 2 == 1, GPT-NeoX style
-    // if mode & 4 == 1, ChatGLM style
     //
     // b is an int32 vector with size a->ne[2], it contains the positions
     // c is freq factors (e.g. phi3-128k), (optional)
@@ -1474,8 +1473,7 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * b,
             int                   n_dims,
-            int                   mode,
-            int                   n_ctx);
+            int                   mode);
 
     // in-place, returns view(a)
     GGML_API struct ggml_tensor * ggml_rope_inplace(
@@ -1483,8 +1481,7 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * b,
             int                   n_dims,
-            int                   mode,
-            int                   n_ctx);
+            int                   mode);
 
     // custom RoPE
     GGML_API struct ggml_tensor * ggml_rope_ext(
@@ -1494,8 +1491,7 @@ extern "C" {
             struct ggml_tensor  * c,
             int                   n_dims,
             int                   mode,
-            int                   n_ctx,
-            int                   n_orig_ctx,
+            int                   n_ctx_orig,
             float                 freq_base,
             float                 freq_scale,
             float                 ext_factor,
@@ -1511,8 +1507,7 @@ extern "C" {
             struct ggml_tensor  * c,
             int                   n_dims,
             int                   mode,
-            int                   n_ctx,
-            int                   n_orig_ctx,
+            int                   n_ctx_orig,
             float                 freq_base,
             float                 freq_scale,
             float                 ext_factor,
@@ -1526,8 +1521,7 @@ extern "C" {
             struct ggml_tensor  * b,
             int                   n_dims,
             int                   mode,
-            int                   n_ctx,
-            int                   n_orig_ctx,
+            int                   n_ctx_orig,
             float                 freq_base,
             float                 freq_scale,
             float                 ext_factor,
@@ -1542,8 +1536,7 @@ extern "C" {
             struct ggml_tensor  * b,
             int                   n_dims,
             int                   mode,
-            int                   n_ctx,
-            int                   n_orig_ctx,
+            int                   n_ctx_orig,
             float                 freq_base,
             float                 freq_scale,
             float                 ext_factor,
@@ -1552,17 +1545,9 @@ extern "C" {
             float                 beta_slow),
         "use ggml_rope_ext_inplace instead");
 
-    struct ggml_tensor * ggml_rope_xpos_inplace(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        struct ggml_tensor  * b,
-        int                   n_dims,
-        float                 base,
-        bool                  down);
-
     // compute correction dims for YaRN RoPE scaling
     GGML_CALL void ggml_rope_yarn_corr_dims(
-        int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]);
+        int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]);
 
     // rotary position embedding backward, i.e compute dx from dy
     // a - dy
@@ -1573,16 +1558,13 @@ extern "C" {
             struct ggml_tensor  * c,
             int                   n_dims,
             int                   mode,
-            int                   n_ctx,
-            int                   n_orig_ctx,
+            int                   n_ctx_orig,
             float                 freq_base,
             float                 freq_scale,
             float                 ext_factor,
             float                 attn_factor,
             float                 beta_fast,
-            float                 beta_slow,
-            float                 xpos_base,
-            bool                  xpos_down);
+            float                 beta_slow);
 
     // clamp
     // in-place, returns view(a)
index b446225849d5f9bf55c083226e4fd9e1a8d2a8f1..1a4058b3f1f1076d3cbd56b2322213c80eabb5d2 100644 (file)
@@ -14,7 +14,7 @@ void main() {
     const bool is_neox = (pcs.mode & 2) != 0;
 
     float corr_dims[2];
-    rope_yarn_corr_dims(pcs.n_dims, pcs.n_orig_ctx, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
+    rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
 
     const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
 
index 2c0235d75b6b63dde7e08d8887c52b0dda0b4b72..65e03827a2660efc187d6ac7fa126741fa90dc44 100644 (file)
@@ -14,7 +14,7 @@ void main() {
     const bool is_neox = (pcs.mode & 2) != 0;
 
     float corr_dims[2];
-    rope_yarn_corr_dims(pcs.n_dims, pcs.n_orig_ctx, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
+    rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
 
     const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
 
index 57ba6597a7eb20883cc0faf7910b889ba7552089..7b9394cb2fffc5e464fc8e7e82f8dbe7f7eac979 100644 (file)
@@ -9,7 +9,7 @@ layout (push_constant) uniform parameter {
     uint outOff;
     int n_dims;
     int mode;
-    int n_orig_ctx;
+    int n_ctx_orig;
     float freq_base;
     float freq_scale;
     float ext_factor;
@@ -54,14 +54,14 @@ void rope_yarn(
 
 // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
 // `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
-float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
-    return n_dims * log(n_orig_ctx / (n_rot * TWOPI_F)) / (2 * log(base));
+float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
+    return n_dims * log(n_ctx_orig / (n_rot * TWOPI_F)) / (2 * log(base));
 }
 
 void rope_yarn_corr_dims(
-    int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, out float dims[2]
+    int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, out float dims[2]
 ) {
     // start and end correction dims
-    dims[0] = max(0.0f,         floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
-    dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
+    dims[0] = max(0.0f,         floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
+    dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
 }
index 06889126ecdc4b2ac0d1e2aab202c9c5849c2a39..414d390e891a1b89ae77171547fa1f8260c9444b 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -1848,7 +1848,7 @@ struct llama_hparams {
     float    rope_attn_factor = 1.0f;
     float    rope_freq_base_train;
     float    rope_freq_scale_train;
-    uint32_t n_yarn_orig_ctx;
+    uint32_t n_ctx_orig_yarn;
     float    rope_yarn_log_mul;
 
     // for State Space Models
@@ -1890,7 +1890,7 @@ struct llama_hparams {
         if (this->n_expert_shared    != other.n_expert_shared)    return true;
 
         if (this->rope_finetuned  != other.rope_finetuned)  return true;
-        if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true;
+        if (this->n_ctx_orig_yarn != other.n_ctx_orig_yarn) return true;
 
         if (this->ssm_d_conv  != other.ssm_d_conv)  return true;
         if (this->ssm_d_inner != other.ssm_d_inner) return true;
@@ -1949,7 +1949,7 @@ struct llama_cparams {
     float rope_freq_base;
     float rope_freq_scale;
 
-    uint32_t n_yarn_orig_ctx;
+    uint32_t n_ctx_orig_yarn;
     // These hyperparameters are not exposed in GGUF, because all
     // existing YaRN models use the same values for them.
     float yarn_ext_factor;
@@ -4005,8 +4005,8 @@ static void llm_load_hparams(
     ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false);
     hparams.rope_finetuned = rope_finetuned;
 
-    hparams.n_yarn_orig_ctx = hparams.n_ctx_train;
-    ml.get_key(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, hparams.n_yarn_orig_ctx, false);
+    hparams.n_ctx_orig_yarn = hparams.n_ctx_train;
+    ml.get_key(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, hparams.n_ctx_orig_yarn, false);
 
     // rope_freq_base (optional)
     hparams.rope_freq_base_train = 10000.0f;
@@ -4968,7 +4968,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
     LLAMA_LOG_INFO("%s: rope scaling     = %s\n",     __func__, rope_scaling_type);
     LLAMA_LOG_INFO("%s: freq_base_train  = %.1f\n",   __func__, hparams.rope_freq_base_train);
     LLAMA_LOG_INFO("%s: freq_scale_train = %g\n",     __func__, hparams.rope_freq_scale_train);
-    LLAMA_LOG_INFO("%s: n_yarn_orig_ctx  = %u\n",     __func__, hparams.n_yarn_orig_ctx);
+    LLAMA_LOG_INFO("%s: n_ctx_orig_yarn  = %u\n",     __func__, hparams.n_ctx_orig_yarn);
     LLAMA_LOG_INFO("%s: rope_finetuned   = %s\n",     __func__, hparams.rope_finetuned ? "yes" : "unknown");
     LLAMA_LOG_INFO("%s: ssm_d_conv       = %u\n",     __func__, hparams.ssm_d_conv);
     LLAMA_LOG_INFO("%s: ssm_d_inner      = %u\n",     __func__, hparams.ssm_d_inner);
@@ -7134,7 +7134,7 @@ struct llm_build_context {
     const int32_t n_kv;     // size of KV cache to consider (n_kv <= kv_self.size)
     const int32_t n_outputs;
     const int32_t kv_head;  // index of where we store new KV data in the cache
-    const int32_t n_orig_ctx;
+    const int32_t n_ctx_orig;
 
     const bool flash_attn;
 
@@ -7183,7 +7183,7 @@ struct llm_build_context {
         n_kv             (worst_case ? kv_self.size : kv_self.n),
         n_outputs        (worst_case ? n_tokens : lctx.n_outputs),
         kv_head          (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head),
-        n_orig_ctx       (cparams.n_yarn_orig_ctx),
+        n_ctx_orig       (cparams.n_ctx_orig_yarn),
         flash_attn       (cparams.flash_attn),
         pooling_type     (cparams.pooling_type),
         rope_type        (hparams.rope_type),
@@ -7241,7 +7241,7 @@ struct llm_build_context {
                             ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
                             ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
                             0),
-                        lctx.inp_K_shift, rope_factors, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                        lctx.inp_K_shift, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow);
 
             cb(tmp, "K_shifted", il);
@@ -7350,7 +7350,7 @@ struct llm_build_context {
         // choose long/short freq factors based on the context size
         const auto n_ctx_pre_seq = cparams.n_ctx / cparams.n_seq_max;
 
-        if (n_ctx_pre_seq > hparams.n_yarn_orig_ctx) {
+        if (n_ctx_pre_seq > hparams.n_ctx_orig_yarn) {
             return model.layers[il].rope_long;
         }
 
@@ -7466,14 +7466,14 @@ struct llm_build_context {
 
                 Qcur = ggml_rope_ext(
                     ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
                     ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
@@ -7597,12 +7597,12 @@ struct llm_build_context {
                     case MODEL_7B:
                         Qcur = ggml_rope_ext(
                             ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                            n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                            n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                             ext_factor, attn_factor, beta_fast, beta_slow
                         );
                         Kcur = ggml_rope_ext(
                             ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                            n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                            n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                             ext_factor, attn_factor, beta_fast, beta_slow
                         );
                         break;
@@ -7709,14 +7709,14 @@ struct llm_build_context {
 
                 Qcur = ggml_rope_ext(
                     ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
                     ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
@@ -7829,13 +7829,13 @@ struct llm_build_context {
 
                 // using mode = 2 for neox mode
                 Qcur = ggml_rope_ext(
-                    ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx,
+                    ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
-                    ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx,
+                    ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
@@ -7953,14 +7953,14 @@ struct llm_build_context {
 
                 Qcur = ggml_rope_ext(
                     ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
                     ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
@@ -8106,14 +8106,14 @@ struct llm_build_context {
 
                 Qcur = ggml_rope_ext(
                     ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
                     ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
@@ -8460,14 +8460,14 @@ struct llm_build_context {
 
                 Qcur = ggml_rope_ext(
                     ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
                     ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
@@ -8900,14 +8900,14 @@ struct llm_build_context {
 
                 Qcur = ggml_rope_ext(
                     ctx0, Qcur, inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
                     ctx0, Kcur, inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
@@ -9019,13 +9019,13 @@ struct llm_build_context {
 
                 // using mode = 2 for neox mode
                 Qcur = ggml_rope_ext(
-                    ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx,
+                    ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
-                    ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx,
+                    ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
@@ -9131,14 +9131,14 @@ struct llm_build_context {
 
                 Qcur = ggml_rope_ext(
                     ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
                     ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
@@ -9245,14 +9245,14 @@ struct llm_build_context {
 
                 Qcur = ggml_rope_ext(
                     ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
                     ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
@@ -9397,7 +9397,7 @@ struct llm_build_context {
                 Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
 
                 Qcur = ggml_rope_ext(
-                    ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx,
+                    ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
@@ -9408,7 +9408,7 @@ struct llm_build_context {
                 cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
-                    ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx,
+                    ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
@@ -9519,7 +9519,7 @@ struct llm_build_context {
                 Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
 
                 Qcur = ggml_rope_ext(
-                    ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, 0, n_orig_ctx,
+                    ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
@@ -9528,7 +9528,7 @@ struct llm_build_context {
                 cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
-                    ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, 0, n_orig_ctx,
+                    ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
@@ -9636,13 +9636,13 @@ struct llm_build_context {
 
                 Qcur = ggml_rope_ext(
                         ctx0, ggml_reshape_3d(ctx0, Qcur, n_rot, n_head,    n_tokens), inp_pos, nullptr,
-                        n_embd_head, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                        n_embd_head, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow);
                 cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
                         ctx0, ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens), inp_pos, nullptr,
-                        n_embd_head, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                        n_embd_head, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow);
                 cb(Kcur, "Kcur", il);
 
@@ -9844,14 +9844,14 @@ struct llm_build_context {
 
                 struct ggml_tensor * Qcur = ggml_rope_ext(
                     ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 struct ggml_tensor * Kcur = ggml_rope_ext(
                     ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
@@ -9960,14 +9960,14 @@ struct llm_build_context {
 
                 Qcur = ggml_rope_ext(
                     ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
                     ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
@@ -10077,14 +10077,14 @@ struct llm_build_context {
 
                 Qcur = ggml_rope_ext(
                     ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
                     ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
@@ -10207,14 +10207,14 @@ struct llm_build_context {
 
                 Qcur = ggml_rope_ext(
                     ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
                     ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
@@ -10327,7 +10327,7 @@ struct llm_build_context {
 
                 Qcur = ggml_rope_ext(
                         ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head,    n_tokens), inp_pos, nullptr,
-                        n_embd_head_k, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                        n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow);
                 cb(Qcur, "Qcur", il);
 
@@ -10336,7 +10336,7 @@ struct llm_build_context {
 
                 Kcur = ggml_rope_ext(
                         ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
-                        n_embd_head_k, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                        n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow);
                 cb(Kcur, "Kcur", il);
 
@@ -10447,14 +10447,14 @@ struct llm_build_context {
 
                 Qcur = ggml_rope_ext(
                     ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
                     ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
@@ -10737,14 +10737,14 @@ struct llm_build_context {
 
                 Qcur = ggml_rope_ext(
                     ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
                     ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
@@ -10868,14 +10868,14 @@ struct llm_build_context {
 
                 Qcur = ggml_rope_ext(
                     ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
                     ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
@@ -10982,14 +10982,14 @@ struct llm_build_context {
 
                 Qcur = ggml_rope_ext(
                     ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
                     ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
@@ -11117,14 +11117,14 @@ struct llm_build_context {
 
                 Qcur = ggml_rope_ext(
                     ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
                     ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
@@ -11334,7 +11334,7 @@ struct llm_build_context {
                 q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
                 q_pe = ggml_rope_ext(
                     ctx0, q_pe, inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor_scaled, beta_fast, beta_slow
                 );
                 cb(q_pe, "q_pe", il);
@@ -11343,7 +11343,7 @@ struct llm_build_context {
                 k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
                 k_pe = ggml_rope_ext(
                     ctx0, k_pe, inp_pos, nullptr,
-                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor_scaled, beta_fast, beta_slow
                 );
                 cb(k_pe, "k_pe", il);
@@ -16067,8 +16067,8 @@ struct llama_context * llama_new_context_with_model(
 
     cparams.n_ubatch         = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
 
-    cparams.n_yarn_orig_ctx  = params.yarn_orig_ctx    != 0 ? params.yarn_orig_ctx    :
-                               hparams.n_yarn_orig_ctx != 0 ? hparams.n_yarn_orig_ctx :
+    cparams.n_ctx_orig_yarn  = params.yarn_orig_ctx    != 0 ? params.yarn_orig_ctx    :
+                               hparams.n_ctx_orig_yarn != 0 ? hparams.n_ctx_orig_yarn :
                                                               hparams.n_ctx_train;
 
     cparams.cb_eval           = params.cb_eval;
index 8dc90a45dd8726b5ec672ce4ca167f2e4d4c5201..ce406a8af867a65875fce92bc5fa09d40a3f3da5 100644 (file)
@@ -1141,7 +1141,7 @@ struct test_rope : public test_case {
     const std::array<int64_t, 4> ne_a;
     int n_dims;
     int mode;
-    int n_ctx;
+    int n_ctx; // used to generate positions
     float fs; // freq_scale
     float ef; // ext_factor
     float af; // attn_factor
@@ -1168,7 +1168,7 @@ struct test_rope : public test_case {
         }
         ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2]);
         ggml_tensor * freq = ff ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_dims/2) : nullptr;
-        ggml_tensor * out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, n_ctx, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
+        ggml_tensor * out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
         return out;
     }
 
@@ -1615,7 +1615,7 @@ struct llama_hparams {
 
     // cparams
     static constexpr uint32_t n_ctx = 512; // user-specified context size
-    static constexpr uint32_t n_orig_ctx = n_ctx;
+    static constexpr uint32_t n_ctx_orig = n_ctx;
 
     // batch
     int32_t n_tokens;
@@ -1806,13 +1806,13 @@ struct test_llama : public test_llm {
 
                 Qcur = ggml_rope_ext(
                     ctx, ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head,    hp.n_tokens), inp_pos, nullptr,
-                    hp.n_rot, 0, 0, hp.n_orig_ctx, freq_base, freq_scale,
+                    hp.n_rot, 0, hp.n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
 
                 Kcur = ggml_rope_ext(
                     ctx, ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens), inp_pos, nullptr,
-                    hp.n_rot, 0, 0, hp.n_orig_ctx, freq_base, freq_scale,
+                    hp.n_rot, 0, hp.n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
 
@@ -1931,12 +1931,12 @@ struct test_falcon : public test_llm {
 
                 // using mode = 2 for neox mode
                 Qcur = ggml_rope_ext(
-                    ctx, Qcur, inp_pos, nullptr, hp.n_rot, 2, 0, hp.n_orig_ctx,
+                    ctx, Qcur, inp_pos, nullptr, hp.n_rot, 2, hp.n_ctx_orig,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );
 
                 Kcur = ggml_rope_ext(
-                    ctx, Kcur, inp_pos, nullptr, hp.n_rot, 2, 0, hp.n_orig_ctx,
+                    ctx, Kcur, inp_pos, nullptr, hp.n_rot, 2, hp.n_ctx_orig,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );
 
@@ -2236,15 +2236,15 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
                 for (float ef : { 0.0f, 0.7465f }) {
                     for (float af : { 1.0f, 1.4245f }) {
                         for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
-                            // TODO: ff not supported yet for !neox
-                            test_cases.emplace_back(new test_rope(type, {128,  32, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 7B
-                            if (all) {
-                                test_cases.emplace_back(new test_rope(type, {128,  40, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 13B
-                                test_cases.emplace_back(new test_rope(type, {128,  52, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 30B
-                                test_cases.emplace_back(new test_rope(type, {128,  64, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 65B
-                            }
-
                             for (bool ff : {false, true}) { // freq_factors
+                                test_cases.emplace_back(new test_rope(type, {128,  32, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 7B
+
+                                if (all) {
+                                    test_cases.emplace_back(new test_rope(type, {128,  40, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 13B
+                                    test_cases.emplace_back(new test_rope(type, {128,  52, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 30B
+                                    test_cases.emplace_back(new test_rope(type, {128,  64, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 65B
+                                }
+
                                 if (all) {
                                     test_cases.emplace_back(new test_rope(type, { 64,   1, 10, 1},  64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B)
                                     test_cases.emplace_back(new test_rope(type, { 64,  71, 10, 1},  64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B)
@@ -2256,6 +2256,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
                                 test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1},  64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 40B)
                             }
                         }
+
                         all = false;
                     }
                 }
index 21ca43be3a963bdc5dba3fdba4074d5597f3d96a..a353276459b2d144abd609b318d7e60558fbf4c7 100644 (file)
@@ -1465,7 +1465,7 @@ int main(int argc, const char ** argv) {
                             continue;
                         }
 
-                        struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode, 0));
+                        struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode));
 
                         GGML_PRINT_DEBUG("rope f32: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
                         check_gradient("rope f32", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY);
@@ -1505,7 +1505,7 @@ int main(int argc, const char ** argv) {
                             continue;
                         }
 
-                        struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode, 0));
+                        struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode));
 
                         GGML_PRINT_DEBUG("rope f16: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
                         check_gradient("rope f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY);
index 26c1f42dc0e956c88adf99fee795c6fb4b465e5e..f0895ffaad6a10fb6c92b9de408fc80c3e14ba36 100644 (file)
@@ -162,12 +162,12 @@ int main(int /*argc*/, const char ** /*argv*/) {
         x = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
 
         // 100, 101, 102, ..., 172
-        struct ggml_tensor * r0 = ggml_rope(ctx0, x,  p0, n_rot, mode, 1024);
+        struct ggml_tensor * r0 = ggml_rope(ctx0, x,  p0, n_rot, mode);
         // -67, -67, -67, ..., -67
-        struct ggml_tensor * r1 = ggml_rope(ctx0, r0, p1, n_rot, mode, 1024); // "context swap", i.e. forget n_past_0 - n_past_2 tokens
+        struct ggml_tensor * r1 = ggml_rope(ctx0, r0, p1, n_rot, mode); // "context swap", i.e. forget n_past_0 - n_past_2 tokens
 
         //  33,  34,  35, ..., 105
-        struct ggml_tensor * r2 = ggml_rope(ctx0, x,  p2, n_rot, mode, 1024);
+        struct ggml_tensor * r2 = ggml_rope(ctx0, x,  p2, n_rot, mode);
 
         ggml_cgraph * gf = ggml_new_graph(ctx0);