]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : refactor rope norm/neox (llama/7634)
authorGeorgi Gerganov <redacted>
Wed, 5 Jun 2024 08:29:20 +0000 (11:29 +0300)
committerGeorgi Gerganov <redacted>
Sat, 15 Jun 2024 19:05:47 +0000 (22:05 +0300)
* ggml : unify rope norm/neox (CPU)

* ggml : fix compile warning

* ggml : remove GLM rope mode

ggml-ci

* metal : better rope implementation

ggml-ci

* cuda : better rope implementation

ggml-ci

* naming : n_orig_ctx -> n_ctx_orig

ggml-ci

* dev : add reminders to update backends

ggml-ci

* vulkan : fix ggml_rope_ext() usage

* cuda : fix array size + indents

ggml-ci

include/ggml/ggml.h
src/ggml-cuda/rope.cu
src/ggml-kompute.cpp
src/ggml-metal.m
src/ggml-metal.metal
src/ggml-sycl.cpp
src/ggml-vulkan.cpp
src/ggml.c
tests/test-backend-ops.cpp
tests/test-grad0.cpp

index addcf1bfef6739d4160ce9d327a710c70261f578..13502a3622fc4f7e3725b0b88a6ccd88b88b3f04 100644 (file)
@@ -1465,7 +1465,6 @@ extern "C" {
     // rotary position embedding
     // if mode & 1 == 1, skip n_past elements (NOT SUPPORTED)
     // if mode & 2 == 1, GPT-NeoX style
-    // if mode & 4 == 1, ChatGLM style
     //
     // b is an int32 vector with size a->ne[2], it contains the positions
     // c is freq factors (e.g. phi3-128k), (optional)
@@ -1474,8 +1473,7 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * b,
             int                   n_dims,
-            int                   mode,
-            int                   n_ctx);
+            int                   mode);
 
     // in-place, returns view(a)
     GGML_API struct ggml_tensor * ggml_rope_inplace(
@@ -1483,8 +1481,7 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * b,
             int                   n_dims,
-            int                   mode,
-            int                   n_ctx);
+            int                   mode);
 
     // custom RoPE
     GGML_API struct ggml_tensor * ggml_rope_ext(
@@ -1494,8 +1491,7 @@ extern "C" {
             struct ggml_tensor  * c,
             int                   n_dims,
             int                   mode,
-            int                   n_ctx,
-            int                   n_orig_ctx,
+            int                   n_ctx_orig,
             float                 freq_base,
             float                 freq_scale,
             float                 ext_factor,
@@ -1511,8 +1507,7 @@ extern "C" {
             struct ggml_tensor  * c,
             int                   n_dims,
             int                   mode,
-            int                   n_ctx,
-            int                   n_orig_ctx,
+            int                   n_ctx_orig,
             float                 freq_base,
             float                 freq_scale,
             float                 ext_factor,
@@ -1526,8 +1521,7 @@ extern "C" {
             struct ggml_tensor  * b,
             int                   n_dims,
             int                   mode,
-            int                   n_ctx,
-            int                   n_orig_ctx,
+            int                   n_ctx_orig,
             float                 freq_base,
             float                 freq_scale,
             float                 ext_factor,
@@ -1542,8 +1536,7 @@ extern "C" {
             struct ggml_tensor  * b,
             int                   n_dims,
             int                   mode,
-            int                   n_ctx,
-            int                   n_orig_ctx,
+            int                   n_ctx_orig,
             float                 freq_base,
             float                 freq_scale,
             float                 ext_factor,
@@ -1552,17 +1545,9 @@ extern "C" {
             float                 beta_slow),
         "use ggml_rope_ext_inplace instead");
 
-    struct ggml_tensor * ggml_rope_xpos_inplace(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        struct ggml_tensor  * b,
-        int                   n_dims,
-        float                 base,
-        bool                  down);
-
     // compute correction dims for YaRN RoPE scaling
     GGML_CALL void ggml_rope_yarn_corr_dims(
-        int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]);
+        int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]);
 
     // rotary position embedding backward, i.e compute dx from dy
     // a - dy
@@ -1573,16 +1558,13 @@ extern "C" {
             struct ggml_tensor  * c,
             int                   n_dims,
             int                   mode,
-            int                   n_ctx,
-            int                   n_orig_ctx,
+            int                   n_ctx_orig,
             float                 freq_base,
             float                 freq_scale,
             float                 ext_factor,
             float                 attn_factor,
             float                 beta_fast,
-            float                 beta_slow,
-            float                 xpos_base,
-            bool                  xpos_down);
+            float                 beta_slow);
 
     // clamp
     // in-place, returns view(a)
index 0dd07977ebab1e263db0419c67da9e7ae0c79cd3..596fb7c135058e1c66896ffb7944f1e7f3783d82 100644 (file)
@@ -1,7 +1,7 @@
 #include "rope.cuh"
 
 struct rope_corr_dims {
-    float v[4];
+    float v[2];
 };
 
 static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
@@ -13,8 +13,7 @@ static __device__ float rope_yarn_ramp(const float low, const float high, const
 // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
 static __device__ void rope_yarn(
     float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
-    float * cos_theta, float * sin_theta
-) {
+    float * cos_theta, float * sin_theta) {
     // Get n-d rotational scaling corrected for extrapolation
     float theta_interp = freq_scale * theta_extrap;
     float theta = theta_interp;
@@ -29,27 +28,38 @@ static __device__ void rope_yarn(
     *sin_theta = sinf(theta) * mscale;
 }
 
-// rope == RoPE == rotary positional embedding
-template<typename T, bool has_pos>
-static __global__ void rope(
-    const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
-    float ext_factor, float attn_factor, rope_corr_dims corr_dims
-) {
-    const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
+template<typename T, bool has_ff>
+static __global__ void rope_norm(
+    const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
+    const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
-    if (col >= ncols) {
+    if (i0 >= ne0) {
         return;
     }
 
     const int row = blockDim.x*blockIdx.x + threadIdx.x;
-    const int i = row*ncols + col;
+
+    if (i0 >= n_dims) {
+        const int i = row*ne0 + i0;
+
+        dst[i + 0] = x[i + 0];
+        dst[i + 1] = x[i + 1];
+
+        return;
+    }
+
+    const int i  = row*ne0 + i0;
     const int i2 = row/p_delta_rows;
 
-    const int p = has_pos ? pos[i2] : 0;
-    const float theta_base = p*powf(freq_base, -float(col)/ncols);
+    const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
+
+    const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
 
-    float cos_theta, sin_theta;
-    rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta);
+    float cos_theta;
+    float sin_theta;
+
+    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
 
     const float x0 = x[i + 0];
     const float x1 = x[i + 1];
@@ -58,23 +68,20 @@ static __global__ void rope(
     dst[i + 1] = x0*sin_theta + x1*cos_theta;
 }
 
-template<typename T, bool has_pos, bool has_freq_facs>
+template<typename T, bool has_ff>
 static __global__ void rope_neox(
-    const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors
-) {
-    const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
+    const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
+    const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
-    if (col >= ncols) {
+    if (i0 >= ne0) {
         return;
     }
 
     const int row = blockDim.x*blockIdx.x + threadIdx.x;
-    const int ib = col / n_dims;
-    const int ic = col % n_dims;
 
-    if (ib > 0) {
-        const int i = row*ncols + ib*n_dims + ic;
+    if (i0 >= n_dims) {
+        const int i = row*ne0 + i0;
 
         dst[i + 0] = x[i + 0];
         dst[i + 1] = x[i + 1];
@@ -82,16 +89,17 @@ static __global__ void rope_neox(
         return;
     }
 
-    const int i  = row*ncols + ib*n_dims + ic/2;
+    const int i  = row*ne0 + i0/2;
     const int i2 = row/p_delta_rows;
 
-    const int p = has_pos ? pos[i2] : 0;
-    const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;
+    const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
 
-    const float theta_base = p*powf(theta_scale, col/2.0f)/freq_factor;
+    const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
 
-    float cos_theta, sin_theta;
-    rope_yarn(theta_base, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
+    float cos_theta;
+    float sin_theta;
+
+    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
 
     const float x0 = x[i + 0];
     const float x1 = x[i + n_dims/2];
@@ -100,144 +108,81 @@ static __global__ void rope_neox(
     dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
 }
 
-static __global__ void rope_glm_f32(
-    const float * x, float * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
-    int n_ctx
-) {
-    const int col = blockDim.x*blockIdx.x + threadIdx.x;
-    const int half_n_dims = ncols/4;
-
-    if (col >= half_n_dims) {
-        return;
-    }
-
-    const int row = blockDim.y*blockIdx.y + threadIdx.y;
-    const int i = row*ncols + col;
-    const int i2 = row/p_delta_rows;
-
-    const float col_theta_scale = powf(freq_base, -2.0f*col/ncols);
-     // FIXME: this is likely wrong
-    const int p = pos != nullptr ? pos[i2] : 0;
-
-    const float theta = min(p, n_ctx - 2)*freq_scale*col_theta_scale;
-    const float sin_theta = sinf(theta);
-    const float cos_theta = cosf(theta);
-
-    const float x0 = x[i + 0];
-    const float x1 = x[i + half_n_dims];
-
-    dst[i + 0]           = x0*cos_theta - x1*sin_theta;
-    dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;
-
-    const float block_theta = ((float)max(p - n_ctx - 2, 0))*col_theta_scale;
-    const float sin_block_theta = sinf(block_theta);
-    const float cos_block_theta = cosf(block_theta);
-
-    const float x2 = x[i + half_n_dims * 2];
-    const float x3 = x[i + half_n_dims * 3];
-
-    dst[i + half_n_dims * 2] = x2*cos_block_theta - x3*sin_block_theta;
-    dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
-}
-
-
 template<typename T>
-static void rope_cuda(
-    const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
-) {
-    GGML_ASSERT(ncols % 2 == 0);
+static void rope_norm_cuda(
+    const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
+    GGML_ASSERT(ne0 % 2 == 0);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
-    const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
-    const dim3 block_nums(nrows, num_blocks_x, 1);
-    if (pos == nullptr) {
-        rope<T, false><<<block_nums, block_dims, 0, stream>>>(
-            x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
-        );
+    const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
+    const dim3 block_nums(nr, n_blocks_x, 1);
+
+    const float theta_scale = powf(freq_base, -2.0f/n_dims);
+
+    if (freq_factors == nullptr) {
+        rope_norm<T, false><<<block_nums, block_dims, 0, stream>>>(
+                x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+                theta_scale, freq_factors
+                );
     } else {
-        rope<T, true><<<block_nums, block_dims, 0, stream>>>(
-            x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
-        );
+        rope_norm<T, true><<<block_nums, block_dims, 0, stream>>>(
+                x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+                theta_scale, freq_factors
+                );
     }
 }
 
 template<typename T>
 static void rope_neox_cuda(
-    const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
-) {
-    GGML_ASSERT(ncols % 2 == 0);
+    const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
+    GGML_ASSERT(ne0 % 2 == 0);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
-    const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
-    const dim3 block_nums(nrows, num_blocks_x, 1);
+    const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
+    const dim3 block_nums(nr, n_blocks_x, 1);
 
     const float theta_scale = powf(freq_base, -2.0f/n_dims);
 
-    if (pos == nullptr) {
-        if (freq_factors == nullptr) {
-            rope_neox<T, false, false><<<block_nums, block_dims, 0, stream>>>(
-                x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
-                theta_scale, freq_factors
-                );
-        } else {
-            rope_neox<T, false, true><<<block_nums, block_dims, 0, stream>>>(
-                x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+    if (freq_factors == nullptr) {
+        rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
+                x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
                 theta_scale, freq_factors
                 );
-        }
     } else {
-        if (freq_factors == nullptr) {
-            rope_neox<T, true, false><<<block_nums, block_dims, 0, stream>>>(
-                x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
-                theta_scale, freq_factors
-                );
-        } else {
-            rope_neox<T, true, true><<<block_nums, block_dims, 0, stream>>>(
-                x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+        rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
+                x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
                 theta_scale, freq_factors
                 );
-        }
     }
 }
 
-static void rope_glm_f32_cuda(
-    const float * x, float * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, int n_ctx, cudaStream_t stream
-) {
-    GGML_ASSERT(ncols % 4 == 0);
-    const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1);
-    const int num_blocks_x = (ncols + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE;
-    const dim3 block_nums(num_blocks_x, nrows, 1);
-    rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, n_ctx);
-}
-
-static void rope_cuda_f16(
-    const half * x, half * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) {
+static void rope_norm_cuda_f16(
+    const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
 
-    rope_cuda<half>(x, dst, ncols, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
+    rope_norm_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
 }
 
-static void rope_cuda_f32(
-    const float * x, float * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) {
+static void rope_norm_cuda_f32(
+    const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
 
-    rope_cuda<float>(x, dst, ncols, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
+    rope_norm_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
 }
 
 static void rope_neox_cuda_f16(
-    const half * x, half * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
+    const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
     float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
 
-    rope_neox_cuda<half>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
+    rope_neox_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
 }
 
 static void rope_neox_cuda_f32(
-    const float * x, float * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
+    const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
     float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
 ) {
 
-    rope_neox_cuda<float>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
+    rope_neox_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
 }
 
 void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -258,16 +203,22 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 
     const int64_t ne00 = src0->ne[0];
     const int64_t ne01 = src0->ne[1];
-    const int64_t nrows = ggml_nrows(src0);
+    const int64_t nr = ggml_nrows(src0);
 
-    //const int n_past      = ((int32_t *) dst->op_params)[0];
-    const int n_dims      = ((int32_t *) dst->op_params)[1];
-    const int mode        = ((int32_t *) dst->op_params)[2];
-    const int n_ctx       = ((int32_t *) dst->op_params)[3];
-    const int n_orig_ctx  = ((int32_t *) dst->op_params)[4];
+    //const int n_past     = ((int32_t *) dst->op_params)[0];
+    const int n_dims     = ((int32_t *) dst->op_params)[1];
+    const int mode       = ((int32_t *) dst->op_params)[2];
+    //const int n_ctx      = ((int32_t *) dst->op_params)[3];
+    const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
 
     // RoPE alteration for extended context
-    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+    float freq_base;
+    float freq_scale;
+    float ext_factor;
+    float attn_factor;
+    float beta_fast;
+    float beta_slow;
+
     memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
     memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
     memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
@@ -275,38 +226,28 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
     memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
 
-    const float * freq_factors = nullptr;
-    const int32_t * pos = nullptr;
-
     const bool is_neox = mode & 2;
-    const bool is_glm  = mode & 4;
 
-    pos = (const int32_t *) src1_d;
+    const int32_t * pos = (const int32_t *) src1_d;
 
-    if (is_neox) {
-        if (src2 != nullptr) {
-            freq_factors = (const float *) src2->data;
-        }
-    } else {
-        GGML_ASSERT(src2 == nullptr && "TODO: freq_factors not implemented for !is_neox");
+    const float * freq_factors = nullptr;
+    if (src2 != nullptr) {
+        freq_factors = (const float *) src2->data;
     }
 
     rope_corr_dims corr_dims;
-    ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
+    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
 
     // compute
-    if (is_glm) {
-        GGML_ASSERT(false);
-        rope_glm_f32_cuda(src0_d, dst_d, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, stream);
-    } else if (is_neox) {
+    if (is_neox) {
         if (src0->type == GGML_TYPE_F32) {
             rope_neox_cuda_f32(
-                (const float *)src0_d, (float *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
+                (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
                 attn_factor, corr_dims, freq_factors, stream
             );
         } else if (src0->type == GGML_TYPE_F16) {
             rope_neox_cuda_f16(
-                (const half *)src0_d, (half *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
+                (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
                 attn_factor, corr_dims, freq_factors, stream
             );
         } else {
@@ -314,14 +255,14 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
         }
     } else {
         if (src0->type == GGML_TYPE_F32) {
-            rope_cuda_f32(
-                (const float *)src0_d, (float *)dst_d, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, stream
+            rope_norm_cuda_f32(
+                (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+                attn_factor, corr_dims, freq_factors, stream
             );
         } else if (src0->type == GGML_TYPE_F16) {
-            rope_cuda_f16(
-                (const half *)src0_d, (half *)dst_d, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, stream
+            rope_norm_cuda_f16(
+                (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+                attn_factor, corr_dims, freq_factors, stream
             );
         } else {
             GGML_ASSERT(false);
index eabd70d5eeed86eb2d465bee672eabf015a54c0d..5592741be425580f351ffa151a6e1e2161c95669 100644 (file)
@@ -1192,7 +1192,7 @@ static void ggml_vk_rope(
     const std::shared_ptr<kp::Tensor>& inB,
     const std::shared_ptr<kp::Tensor>& out,
     uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
-    ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_orig_ctx,
+    ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_ctx_orig,
     float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
     int32_t ne01, int32_t ne02, int32_t ne03,
     uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
@@ -1221,14 +1221,14 @@ static void ggml_vk_rope(
 
     struct PushConstants {
         uint32_t inAOff, inBOff, outOff;
-        int32_t n_dims, mode, n_orig_ctx;
+        int32_t n_dims, mode, n_ctx_orig;
         float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
         uint32_t nb00, nb01, nb02, nb03;
         int32_t ne0;
         uint32_t nb0, nb1, nb2, nb3;
     } pushConsts {
         safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(outOff, type_size),
-        n_dims, mode, n_orig_ctx,
+        n_dims, mode, n_ctx_orig,
         freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
         nb00, nb01, nb02, nb03,
         ne0,
@@ -1692,13 +1692,16 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
 #pragma message("      https://github.com/ggerganov/llama.cpp/pull/7225")
                         GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
 
+#pragma message("TODO: update rope NORM mode to match NEOX mode")
+#pragma message("      https://github.com/ggerganov/llama.cpp/pull/7634")
+
                         GGML_ASSERT(ne10 == ne02);
                         GGML_ASSERT(src0t == dstt);
                         // const int n_past = ((int32_t *) dst->op_params)[0];
                         const int n_dims     = ((int32_t *) dst->op_params)[1];
                         const int mode       = ((int32_t *) dst->op_params)[2];
                         // skip 3, n_ctx used in GLM RoPE, unimplemented in Vulkan
-                        const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
+                        const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
 
                         float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
                         memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
@@ -1708,7 +1711,7 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
                         memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
                         memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
                         ggml_vk_rope(
-                            seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_orig_ctx,
+                            seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_ctx_orig,
                             freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
                             ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3
                         );
index fddc44f78d8afc5198e795625ff768a80e784821..946f11813dcf994c3790dd31b3e47378556949cc 100644 (file)
@@ -172,8 +172,10 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
-    GGML_METAL_KERNEL_TYPE_ROPE_F32,
-    GGML_METAL_KERNEL_TYPE_ROPE_F16,
+    GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
+    GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
+    GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
+    GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
     GGML_METAL_KERNEL_TYPE_IM2COL_F16,
     GGML_METAL_KERNEL_TYPE_IM2COL_F32,
     GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
@@ -626,8 +628,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,           mul_mm_id_iq1_m_f32,            ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,          mul_mm_id_iq4_nl_f32,           ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,          mul_mm_id_iq4_xs_f32,           ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32,                      rope_f32,                       true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16,                      rope_f16,                       true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,                 rope_norm_f32,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,                 rope_norm_f16,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,                 rope_neox_f32,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,                 rope_neox_f16,                  true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16,                    im2col_f16,                     true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32,                    im2col_f32,                     true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32,                   upscale_f32,                    true);
@@ -2285,7 +2289,7 @@ static enum ggml_status ggml_metal_graph_compute(
                         const int n_dims     = ((int32_t *) dst->op_params)[1];
                         const int mode       = ((int32_t *) dst->op_params)[2];
                         // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
-                        const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
+                        const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
 
                         float freq_base;
                         float freq_scale;
@@ -2302,22 +2306,23 @@ static enum ggml_status ggml_metal_graph_compute(
                         memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
 
                         const bool is_neox = mode & 2;
-                        const bool is_glm  = mode & 4;
 
-                        GGML_ASSERT(!is_glm && "GLM RoPE not implemented in Metal");
+                        id<MTLComputePipelineState> pipeline = nil;
 
                         if (!is_neox) {
-                            GGML_ASSERT(id_src2 == nil && "TODO: freq_factors not implemented for !is_neox");
+                            switch (src0->type) {
+                                case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
+                                case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
+                                default: GGML_ASSERT(false);
+                            };
+                        } else {
+                            switch (src0->type) {
+                                case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
+                                case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
+                                default: GGML_ASSERT(false);
+                            };
                         }
 
-                        id<MTLComputePipelineState> pipeline = nil;
-
-                        switch (src0->type) {
-                            case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline; break;
-                            case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline; break;
-                            default: GGML_ASSERT(false);
-                        };
-
                         [encoder setComputePipelineState:pipeline];
                         [encoder setBuffer:id_src0     offset:offs_src0        atIndex:0];
                         [encoder setBuffer:id_src1     offset:offs_src1        atIndex:1];
@@ -2345,14 +2350,13 @@ static enum ggml_status ggml_metal_graph_compute(
                         [encoder setBytes:&nb3         length:sizeof(uint64_t) atIndex:19];
                         [encoder setBytes:&n_past      length:sizeof(     int) atIndex:20];
                         [encoder setBytes:&n_dims      length:sizeof(     int) atIndex:21];
-                        [encoder setBytes:&mode        length:sizeof(     int) atIndex:22];
-                        [encoder setBytes:&n_orig_ctx  length:sizeof(     int) atIndex:23];
-                        [encoder setBytes:&freq_base   length:sizeof(   float) atIndex:24];
-                        [encoder setBytes:&freq_scale  length:sizeof(   float) atIndex:25];
-                        [encoder setBytes:&ext_factor  length:sizeof(   float) atIndex:26];
-                        [encoder setBytes:&attn_factor length:sizeof(   float) atIndex:27];
-                        [encoder setBytes:&beta_fast   length:sizeof(   float) atIndex:28];
-                        [encoder setBytes:&beta_slow   length:sizeof(   float) atIndex:29];
+                        [encoder setBytes:&n_ctx_orig  length:sizeof(     int) atIndex:22];
+                        [encoder setBytes:&freq_base   length:sizeof(   float) atIndex:23];
+                        [encoder setBytes:&freq_scale  length:sizeof(   float) atIndex:24];
+                        [encoder setBytes:&ext_factor  length:sizeof(   float) atIndex:25];
+                        [encoder setBytes:&attn_factor length:sizeof(   float) atIndex:26];
+                        [encoder setBytes:&beta_fast   length:sizeof(   float) atIndex:27];
+                        [encoder setBytes:&beta_slow   length:sizeof(   float) atIndex:28];
 
                         [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                     } break;
index 0cb85e1a5bad4daeabe6e1de00656402556e9f8b..e2796fd601281166213e81575fe56b54c97f2c72 100644 (file)
@@ -1654,8 +1654,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
 // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
 static void rope_yarn(
     float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
-    thread float * cos_theta, thread float * sin_theta
-) {
+    thread float * cos_theta, thread float * sin_theta) {
     // Get n-d rotational scaling corrected for extrapolation
     float theta_interp = freq_scale * theta_extrap;
     float theta = theta_interp;
@@ -1672,19 +1671,20 @@ static void rope_yarn(
 
 // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
 // `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
-static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
-    return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base));
+static float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
+    return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base));
 }
 
 static void rope_yarn_corr_dims(
-    int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
+    int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
 ) {
     // start and end correction dims
-    dims[0] = max(0.0f,         floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
-    dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
+    dims[0] = max(0.0f,         floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
+    dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
 }
 
-typedef void (rope_t)(
+template<typename T>
+kernel void kernel_rope_norm(
         device const    void * src0,
         device const int32_t * src1,
         device const   float * src2,
@@ -1707,8 +1707,7 @@ typedef void (rope_t)(
         constant    uint64_t & nb3,
         constant         int & n_past,
         constant         int & n_dims,
-        constant         int & mode,
-        constant         int & n_orig_ctx,
+        constant         int & n_ctx_orig,
         constant       float & freq_base,
         constant       float & freq_scale,
         constant       float & ext_factor,
@@ -1717,10 +1716,52 @@ typedef void (rope_t)(
         constant       float & beta_slow,
         uint  tiitg[[thread_index_in_threadgroup]],
         uint3 tptg[[threads_per_threadgroup]],
-        uint3 tgpig[[threadgroup_position_in_grid]]);
+        uint3 tgpig[[threadgroup_position_in_grid]]) {
+    const int64_t i3 = tgpig[2];
+    const int64_t i2 = tgpig[1];
+    const int64_t i1 = tgpig[0];
+
+    float corr_dims[2];
+    rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
+
+    device const int32_t * pos = src1;
+
+    const float theta_base = (float) pos[i2];
+    const float inv_ndims = -1.f/n_dims;
+
+    float cos_theta;
+    float sin_theta;
+
+    for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
+        if (i0 < n_dims) {
+            const int64_t ic = i0/2;
+
+            const float theta = theta_base * pow(freq_base, inv_ndims*i0);
+
+            const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
+
+            rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+
+            device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+            device       T * dst_data  = (device T *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+            const float x0 = src[0];
+            const float x1 = src[1];
+
+            dst_data[0] = x0*cos_theta - x1*sin_theta;
+            dst_data[1] = x0*sin_theta + x1*cos_theta;
+        } else {
+            device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+            device       T * dst_data  = (device T *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+            dst_data[0] = src[0];
+            dst_data[1] = src[1];
+        }
+    }
+}
 
 template<typename T>
-kernel void kernel_rope(
+kernel void kernel_rope_neox(
         device const    void * src0,
         device const int32_t * src1,
         device const   float * src2,
@@ -1743,8 +1784,7 @@ kernel void kernel_rope(
         constant    uint64_t & nb3,
         constant         int & n_past,
         constant         int & n_dims,
-        constant         int & mode,
-        constant         int & n_orig_ctx,
+        constant         int & n_ctx_orig,
         constant       float & freq_base,
         constant       float & freq_scale,
         constant       float & ext_factor,
@@ -1758,69 +1798,53 @@ kernel void kernel_rope(
     const int64_t i2 = tgpig[1];
     const int64_t i1 = tgpig[0];
 
-    const bool is_neox = mode & 2;
-
     float corr_dims[2];
-    rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
+    rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
 
     device const int32_t * pos = src1;
 
-    const int64_t p = pos[i2];
-
-    const float theta_base = (float)p;
+    const float theta_base = (float) pos[i2];
     const float inv_ndims = -1.f/n_dims;
 
-    if (!is_neox) {
-        for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
-            const float theta = theta_base * pow(freq_base, inv_ndims*i0);
-
-            float cos_theta, sin_theta;
-            rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
-
-            device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-            device       T * dst_data  = (device T *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
-
-            const T x0 = src[0];
-            const T x1 = src[1];
+    float cos_theta;
+    float sin_theta;
 
-            dst_data[0] = x0*cos_theta - x1*sin_theta;
-            dst_data[1] = x0*sin_theta + x1*cos_theta;
-        }
-    } else {
-        for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
-            if (ic < n_dims) {
-                const int64_t i0 = ic/2;
+    for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
+        if (i0 < n_dims) {
+            const int64_t ic = i0/2;
 
-                const float freq_factor = src2 != src0 ? src2[i0] : 1.0f;
-
-                const float theta = theta_base * pow(freq_base, inv_ndims*ic);
+            const float theta = theta_base * pow(freq_base, inv_ndims*i0);
 
-                float cos_theta, sin_theta;
-                rope_yarn(theta/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
+            const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
 
-                device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                device       T * dst_data  = (device T *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+            rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
 
-                const float x0 = src[0];
-                const float x1 = src[n_dims/2];
+            device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
+            device       T * dst_data  = (device T *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
 
-                dst_data[0]        = x0*cos_theta - x1*sin_theta;
-                dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
-            } else {
-                const int64_t i0 = ic;
+            const float x0 = src[0];
+            const float x1 = src[n_dims/2];
 
-                device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                device       T * dst_data  = (device T *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+            dst_data[0]        = x0*cos_theta - x1*sin_theta;
+            dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
+        } else {
+            device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+            device       T * dst_data  = (device T *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 
-                dst_data[0] = src[0];
-                dst_data[1] = src[1];
-            }
+            dst_data[0] = src[0];
+            dst_data[1] = src[1];
         }
     }
 }
 
-template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
-template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
+typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
+typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
+
+template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
+template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
+
+template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
+template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
 
 typedef void (im2col_t)(
         device const float * x,
index 5cd97e4ff98df55983c43be14025dea07c609ea8..3ff76474d7e72e7373e564b7de830169bd92a77e 100644 (file)
@@ -8928,49 +8928,6 @@ static void rope_neox(
     dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
 }
 
-static void rope_glm_f32(
-    const float * x, float * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
-    int n_ctx
-, const sycl::nd_item<3> &item_ct1) {
-    const int col = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                    item_ct1.get_local_id(2);
-    const int half_n_dims = ncols/4;
-
-    if (col >= half_n_dims) {
-        return;
-    }
-
-    const int row = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
-                    item_ct1.get_local_id(1);
-    const int i = row*ncols + col;
-    const int i2 = row/p_delta_rows;
-
-    const float col_theta_scale = dpct::pow(freq_base, -2.0f * col / ncols);
-     // FIXME: this is likely wrong
-    const int p = pos != nullptr ? pos[i2] : 0;
-
-    const float theta = sycl::min(p, n_ctx - 2) * freq_scale * col_theta_scale;
-    const float sin_theta = sycl::sin((float)theta);
-    const float cos_theta = sycl::cos((float)theta);
-
-    const float x0 = x[i + 0];
-    const float x1 = x[i + half_n_dims];
-
-    dst[i + 0]           = x0*cos_theta - x1*sin_theta;
-    dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;
-
-    const float block_theta =
-        ((float)sycl::max(p - n_ctx - 2, 0)) * col_theta_scale;
-    const float sin_block_theta = sycl::sin((float)block_theta);
-    const float cos_block_theta = sycl::cos((float)block_theta);
-
-    const float x2 = x[i + half_n_dims * 2];
-    const float x3 = x[i + half_n_dims * 3];
-
-    dst[i + half_n_dims * 2] = x2*cos_block_theta - x3*sin_block_theta;
-    dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
-}
-
 static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
                            const sycl::nd_item<3> &item_ct1) {
     const int row = item_ct1.get_group(1);
@@ -12520,22 +12477,6 @@ static void rope_neox_sycl(const T *x, T *dst, int ncols, int n_dims, int nrows,
     }
 }
 
-static void rope_glm_f32_sycl(const float *x, float *dst, int ncols, int nrows,
-                              const int32_t *pos, float freq_scale,
-                              int p_delta_rows, float freq_base, int n_ctx,
-                              dpct::queue_ptr stream) {
-    GGML_ASSERT(ncols % 4 == 0);
-    const sycl::range<3> block_dims(1, 1, SYCL_ROPE_BLOCK_SIZE / 4);
-    const int num_blocks_x = (ncols + SYCL_ROPE_BLOCK_SIZE - 1) / SYCL_ROPE_BLOCK_SIZE;
-    const sycl::range<3> block_nums(1, nrows, num_blocks_x);
-    stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                         [=](sycl::nd_item<3> item_ct1) {
-                             rope_glm_f32(x, dst, ncols, pos, freq_scale,
-                                          p_delta_rows, freq_base, n_ctx,
-                                          item_ct1);
-                         });
-}
-
 static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
                               const int nrows, dpct::queue_ptr stream) {
     const sycl::range<3> block_dims(1, 1, WARP_SIZE);
@@ -14066,8 +14007,8 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
     //const int n_past      = ((int32_t *) dst->op_params)[0];
     const int n_dims      = ((int32_t *) dst->op_params)[1];
     const int mode        = ((int32_t *) dst->op_params)[2];
-    const int n_ctx       = ((int32_t *) dst->op_params)[3];
-    const int n_orig_ctx  = ((int32_t *) dst->op_params)[4];
+    //const int n_ctx       = ((int32_t *) dst->op_params)[3];
+    const int n_ctx_orig  = ((int32_t *) dst->op_params)[4];
 
     // RoPE alteration for extended context
     float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
@@ -14087,7 +14028,9 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
     }
 
     const bool is_neox = mode & 2;
-    const bool is_glm  = mode & 4;
+
+#pragma message("TODO: update rope NORM mode to match NEOX mode")
+#pragma message("      https://github.com/ggerganov/llama.cpp/pull/7634")
 
     if (is_neox) {
         pos = (const int32_t *) src1_dd;
@@ -14100,13 +14043,10 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
     }
 
     rope_corr_dims corr_dims;
-    ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
+    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
 
     // compute
-    if (is_glm) {
-        GGML_ASSERT(false);
-        rope_glm_f32_sycl(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, main_stream);
-    } else if (is_neox) {
+    if (is_neox) {
         if (src0->type == GGML_TYPE_F32) {
             rope_neox_sycl(
                 (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
index 5e12ea9dde4d7b53d217451b2c3e040f686625a9..e0c512c0dab0f5bc9ab504618945b54a117cb81b 100644 (file)
@@ -3898,11 +3898,6 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
         {
             const int mode = ((const int32_t *) dst->op_params)[2];
             const bool is_neox = mode & 2;
-            const bool is_glm  = mode & 4;
-
-            if (is_glm) {
-                return nullptr;
-            }
 
             if (is_neox) {
                 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
@@ -4401,7 +4396,7 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, con
     const int n_dims        = ((int32_t *) dst->op_params)[1];
     const int mode          = ((int32_t *) dst->op_params)[2];
     // const int n_ctx         = ((int32_t *) dst->op_params)[3];
-    const int n_orig_ctx    = ((int32_t *) dst->op_params)[4];
+    const int n_ctx_orig    = ((int32_t *) dst->op_params)[4];
     const float freq_base   = ((float *)   dst->op_params)[5];
     const float freq_scale  = ((float *)   dst->op_params)[6];
     const float ext_factor  = ((float *)   dst->op_params)[7];
@@ -4410,12 +4405,12 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, con
     const float beta_slow   = ((float *)   dst->op_params)[10];
 
     const bool is_neox = mode & 2;
-    const bool is_glm  = mode & 4;
 
-    GGML_ASSERT(!is_glm);
+#pragma message("TODO: update rope NORM mode to match NEOX mode")
+#pragma message("      https://github.com/ggerganov/llama.cpp/pull/7634")
 
     float corr_dims[2];
-    ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
+    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
 
     if (is_neox) {
         const float theta_scale = powf(freq_base, -2.0f/n_dims);
@@ -6485,9 +6480,8 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
         case GGML_OP_ROPE:
             {
                 const int mode = ((const int32_t *) op->op_params)[2];
-                const bool is_glm  = mode & 4;
 
-                return !is_glm;
+                return true;
             } break;
         case GGML_OP_NONE:
         case GGML_OP_RESHAPE:
@@ -6992,15 +6986,15 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
     } else if (tensor->op == GGML_OP_ROPE) {
         const int n_dims      = ((int32_t *) tensor->op_params)[1];
         const int mode        = ((int32_t *) tensor->op_params)[2];
-        const int n_ggml_ctx       = ((int32_t *) tensor->op_params)[3];
-        const int n_orig_ggml_ctx  = ((int32_t *) tensor->op_params)[4];
+        //const int n_ctx_ggml       = ((int32_t *) tensor->op_params)[3];
+        const int n_ctx_orig_ggml  = ((int32_t *) tensor->op_params)[4];
         float freq_base       = ((float *)   tensor->op_params)[5];
         float freq_scale      = ((float *)   tensor->op_params)[6];
         float ext_factor      = ((float *)   tensor->op_params)[7];
         float attn_factor     = ((float *)   tensor->op_params)[8];
         float beta_fast       = ((float *)   tensor->op_params)[9];
         float beta_slow       = ((float *)   tensor->op_params)[10];
-        tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ggml_ctx, n_orig_ggml_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
+        tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
     } else if (tensor->op == GGML_OP_UNARY) {
         switch (ggml_get_unary_op(tensor)) {
         case GGML_UNARY_OP_SILU:
index 11e5c34ab56adef2c6857ca927ea19e044fa0984..1fc77743bc7b945baa6441845735a73870949ada 100644 (file)
@@ -6250,16 +6250,13 @@ static struct ggml_tensor * ggml_rope_impl(
         struct ggml_tensor  * c,
         int                   n_dims,
         int                   mode,
-        int                   n_ctx,
-        int                   n_orig_ctx,
+        int                   n_ctx_orig,
         float                 freq_base,
         float                 freq_scale,
         float                 ext_factor,
         float                 attn_factor,
         float                 beta_fast,
         float                 beta_slow,
-        float                 xpos_base,
-        bool                  xpos_down,
         bool                  inplace) {
     GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
 
@@ -6280,15 +6277,13 @@ static struct ggml_tensor * ggml_rope_impl(
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx };
+    int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
     memcpy(params +  5, &freq_base,    sizeof(float));
     memcpy(params +  6, &freq_scale,   sizeof(float));
     memcpy(params +  7, &ext_factor,   sizeof(float));
     memcpy(params +  8, &attn_factor,  sizeof(float));
     memcpy(params +  9, &beta_fast,    sizeof(float));
     memcpy(params + 10, &beta_slow,    sizeof(float));
-    memcpy(params + 11, &xpos_base,    sizeof(float));
-    memcpy(params + 12, &xpos_down,    sizeof(bool));
     ggml_set_op_params(result, params, sizeof(params));
 
     result->op   = GGML_OP_ROPE;
@@ -6305,10 +6300,9 @@ struct ggml_tensor * ggml_rope(
         struct ggml_tensor  * a,
         struct ggml_tensor  * b,
         int                   n_dims,
-        int                   mode,
-        int                   n_ctx) {
+        int                   mode) {
     return ggml_rope_impl(
-        ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
+        ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, false
     );
 }
 
@@ -6317,10 +6311,9 @@ struct ggml_tensor * ggml_rope_inplace(
         struct ggml_tensor  * a,
         struct ggml_tensor  * b,
         int                   n_dims,
-        int                   mode,
-        int                   n_ctx) {
+        int                   mode) {
     return ggml_rope_impl(
-        ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
+        ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, true
     );
 }
 
@@ -6331,8 +6324,7 @@ struct ggml_tensor * ggml_rope_ext(
         struct ggml_tensor  * c,
         int                   n_dims,
         int                   mode,
-        int                   n_ctx,
-        int                   n_orig_ctx,
+        int                   n_ctx_orig,
         float                 freq_base,
         float                 freq_scale,
         float                 ext_factor,
@@ -6340,8 +6332,8 @@ struct ggml_tensor * ggml_rope_ext(
         float                 beta_fast,
         float                 beta_slow) {
     return ggml_rope_impl(
-        ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
-        ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
+        ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
+        ext_factor, attn_factor, beta_fast, beta_slow, false
     );
 }
 
@@ -6352,8 +6344,7 @@ struct ggml_tensor * ggml_rope_ext_inplace(
         struct ggml_tensor  * c,
         int                   n_dims,
         int                   mode,
-        int                   n_ctx,
-        int                   n_orig_ctx,
+        int                   n_ctx_orig,
         float                 freq_base,
         float                 freq_scale,
         float                 ext_factor,
@@ -6361,8 +6352,8 @@ struct ggml_tensor * ggml_rope_ext_inplace(
         float                 beta_fast,
         float                 beta_slow) {
     return ggml_rope_impl(
-        ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
-        ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
+        ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
+        ext_factor, attn_factor, beta_fast, beta_slow, true
     );
 }
 
@@ -6372,8 +6363,7 @@ struct ggml_tensor * ggml_rope_custom(
         struct ggml_tensor  * b,
         int                   n_dims,
         int                   mode,
-        int                   n_ctx,
-        int                   n_orig_ctx,
+        int                   n_ctx_orig,
         float                 freq_base,
         float                 freq_scale,
         float                 ext_factor,
@@ -6381,8 +6371,8 @@ struct ggml_tensor * ggml_rope_custom(
         float                 beta_fast,
         float                 beta_slow) {
     return ggml_rope_impl(
-        ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
-        ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
+        ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
+        ext_factor, attn_factor, beta_fast, beta_slow, false
     );
 }
 
@@ -6392,8 +6382,7 @@ struct ggml_tensor * ggml_rope_custom_inplace(
         struct ggml_tensor  * b,
         int                   n_dims,
         int                   mode,
-        int                   n_ctx,
-        int                   n_orig_ctx,
+        int                   n_ctx_orig,
         float                 freq_base,
         float                 freq_scale,
         float                 ext_factor,
@@ -6401,21 +6390,11 @@ struct ggml_tensor * ggml_rope_custom_inplace(
         float                 beta_fast,
         float                 beta_slow) {
     return ggml_rope_impl(
-        ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
-        ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
+        ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
+        ext_factor, attn_factor, beta_fast, beta_slow, true
     );
 }
 
-struct ggml_tensor * ggml_rope_xpos_inplace(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        struct ggml_tensor  * b,
-        int                   n_dims,
-        float                 base,
-        bool                  down) {
-    return ggml_rope_impl(ctx, a, b, NULL, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
-}
-
 // ggml_rope_back
 
 struct ggml_tensor * ggml_rope_back(
@@ -6425,16 +6404,13 @@ struct ggml_tensor * ggml_rope_back(
         struct ggml_tensor  * c,
         int                   n_dims,
         int                   mode,
-        int                   n_ctx,
-        int                   n_orig_ctx,
+        int                   n_ctx_orig,
         float                 freq_base,
         float                 freq_scale,
         float                 ext_factor,
         float                 attn_factor,
         float                 beta_fast,
-        float                 beta_slow,
-        float                 xpos_base,
-        bool                  xpos_down) {
+        float                 beta_slow) {
     GGML_ASSERT(ggml_is_vector(b));
     GGML_ASSERT(b->type == GGML_TYPE_I32);
     GGML_ASSERT(a->ne[2] == b->ne[0]);
@@ -6450,15 +6426,13 @@ struct ggml_tensor * ggml_rope_back(
 
     struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
 
-    int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx };
+    int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
     memcpy(params +  5, &freq_base,    sizeof(float));
     memcpy(params +  6, &freq_scale,   sizeof(float));
     memcpy(params +  7, &ext_factor,   sizeof(float));
     memcpy(params +  8, &attn_factor,  sizeof(float));
     memcpy(params +  9, &beta_fast,    sizeof(float));
     memcpy(params + 10, &beta_slow,    sizeof(float));
-    memcpy(params + 11, &xpos_base,    sizeof(float));
-    memcpy(params + 12, &xpos_down,    sizeof(bool));
     ggml_set_op_params(result, params, sizeof(params));
 
     result->op   = GGML_OP_ROPE_BACK;
@@ -14227,8 +14201,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
 // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
 static void rope_yarn(
     float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
-    float * cos_theta, float * sin_theta
-) {
+    float * cos_theta, float * sin_theta) {
     // Get n-d rotational scaling corrected for extrapolation
     float theta_interp = freq_scale * theta_extrap;
     float theta = theta_interp;
@@ -14245,18 +14218,19 @@ static void rope_yarn(
 
 // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
 // `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
-static float ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, float base) {
-    return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
+static float ggml_rope_yarn_corr_dim(int n_dims, int n_ctx_orig, float n_rot, float base) {
+    return n_dims * logf(n_ctx_orig / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
 }
 
 static void ggml_rope_cache_init(
-     float theta_base, float freq_scale, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
-     float * cache, float sin_sign, float theta_scale
-) {
+     float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
+     float * cache, float sin_sign, float theta_scale) {
+    // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
     float theta = theta_base;
     for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
+        const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
         rope_yarn(
-            theta, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
+            theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
         );
         cache[i0 + 1] *= sin_sign;
 
@@ -14265,11 +14239,11 @@ static void ggml_rope_cache_init(
 }
 
 GGML_CALL void ggml_rope_yarn_corr_dims(
-    int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
+    int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
 ) {
     // start and end correction dims
-    float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_fast, freq_base));
-    float end   =  ceilf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_slow, freq_base));
+    float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_fast, freq_base));
+    float end   =  ceilf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_slow, freq_base));
     dims[0] = MAX(0, start);
     dims[1] = MIN(n_dims - 1, end);
 }
@@ -14289,15 +14263,11 @@ static void ggml_compute_forward_rope_f32(
 
     float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
 
-    // these two only relevant for xPos RoPE:
-    float xpos_base;
-    bool  xpos_down;
-
     //const int n_past     = ((int32_t *) dst->op_params)[0];
     const int n_dims     = ((int32_t *) dst->op_params)[1];
     const int mode       = ((int32_t *) dst->op_params)[2];
-    const int n_ctx      = ((int32_t *) dst->op_params)[3];
-    const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
+    //const int n_ctx      = ((int32_t *) dst->op_params)[3];
+    const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
 
     memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
     memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
@@ -14305,8 +14275,6 @@ static void ggml_compute_forward_rope_f32(
     memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
     memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
     memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
-    memcpy(&xpos_base,   (int32_t *) dst->op_params + 11, sizeof(float));
-    memcpy(&xpos_down,   (int32_t *) dst->op_params + 12, sizeof(bool));
 
     GGML_TENSOR_UNARY_OP_LOCALS
 
@@ -14336,20 +14304,15 @@ static void ggml_compute_forward_rope_f32(
     const float theta_scale = powf(freq_base, -2.0f/n_dims);
 
     float corr_dims[2];
-    ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
+    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
 
     const bool is_neox = mode & 2;
-    const bool is_glm  = mode & 4;
 
     const float * freq_factors = NULL;
-    if (is_neox) {
-        if (src2 != NULL) {
-            GGML_ASSERT(src2->type == GGML_TYPE_F32);
-            GGML_ASSERT(src2->ne[0] >= n_dims / 2);
-            freq_factors = (const float *) src2->data;
-        }
-    } else {
-        GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
+    if (src2 != NULL) {
+        GGML_ASSERT(src2->type == GGML_TYPE_F32);
+        GGML_ASSERT(src2->ne[0] >= n_dims / 2);
+        freq_factors = (const float *) src2->data;
     }
 
     // backward process uses inverse rotation by cos and sin.
@@ -14364,94 +14327,50 @@ static void ggml_compute_forward_rope_f32(
             const int64_t p = pos[i2];
 
             float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
-            if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox
-                ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
-            }
+            ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
 
             for (int64_t i1 = 0; i1 < ne1; i1++) {
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
 
-                float theta_base = (float)p;
-
-                if (is_glm) {
-                    theta_base = MIN(p, n_ctx - 2);
-                    float block_theta = MAX(p - (n_ctx - 2), 0);
-                    for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
-                        const float cos_theta = cosf(theta_base);
-                        const float sin_theta = sinf(theta_base) * sin_sign;
-                        const float cos_block_theta = cosf(block_theta);
-                        const float sin_block_theta = sinf(block_theta) * sin_sign;
-
-                        theta_base  *= theta_scale;
-                        block_theta *= theta_scale;
-
-                        const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                              float * dst_data  = (float *)((char *)  dst->data +  i3*nb3 + i2*nb2  + i1*nb1  + i0*nb0);
-
-                        const float x0 = src[0];
-                        const float x1 = src[n_dims/2];
-                        const float x2 = src[n_dims];
-                        const float x3 = src[n_dims/2*3];
-
-                        dst_data[0]          = x0*cos_theta - x1*sin_theta;
-                        dst_data[n_dims/2]   = x0*sin_theta + x1*cos_theta;
-                        dst_data[n_dims]     = x2*cos_block_theta - x3*sin_block_theta;
-                        dst_data[n_dims/2*3] = x2*sin_block_theta + x3*cos_block_theta;
-                    }
-                } else if (!is_neox) {
-                    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
+                if (!is_neox) {
+                    for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
                         const float cos_theta = cache[i0 + 0];
                         const float sin_theta = cache[i0 + 1];
 
-                        // zeta scaling for xPos only:
-                        float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
-                        if (xpos_down) zeta = 1.0f / zeta;
-
                         const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
                               float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 
                         const float x0 = src[0];
                         const float x1 = src[1];
 
-                        dst_data[0] = x0*cos_theta*zeta - x1*sin_theta*zeta;
-                        dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta;
+                        dst_data[0] = x0*cos_theta - x1*sin_theta;
+                        dst_data[1] = x0*sin_theta + x1*cos_theta;
                     }
                 } else {
-                    // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
-                    for (int64_t ic = 0; ic < ne0; ic += 2) {
-                        if (ic < n_dims) {
-                            const int64_t i0 = ic/2;
-
-                            const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f;
-
-                            float cos_theta, sin_theta;
-                            rope_yarn(
-                                theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor,
-                                &cos_theta, &sin_theta
-                            );
+                    for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
+                        const int64_t ic = i0/2;
 
-                            sin_theta  *= sin_sign;
-                            theta_base *= theta_scale;
+                        const float cos_theta = cache[i0 + 0];
+                        const float sin_theta = cache[i0 + 1];
 
-                            const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                                  float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+                        const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
+                        float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
 
-                            const float x0 = src[0];
-                            const float x1 = src[n_dims/2];
+                        const float x0 = src[0];
+                        const float x1 = src[n_dims/2];
 
-                            dst_data[0]        = x0*cos_theta - x1*sin_theta;
-                            dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
-                        } else {
-                            const int64_t i0 = ic;
+                        dst_data[0]        = x0*cos_theta - x1*sin_theta;
+                        dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
+                    }
+                }
 
-                            const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                                  float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+                for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
+                    const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                    float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 
-                            dst_data[0] = src[0];
-                            dst_data[1] = src[1];
-                        }
-                    }
+                    dst_data[0] = src[0];
+                    dst_data[1] = src[1];
                 }
             }
         }
@@ -14477,8 +14396,8 @@ static void ggml_compute_forward_rope_f16(
     //const int n_past     = ((int32_t *) dst->op_params)[0];
     const int n_dims     = ((int32_t *) dst->op_params)[1];
     const int mode       = ((int32_t *) dst->op_params)[2];
-    const int n_ctx      = ((int32_t *) dst->op_params)[3];
-    const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
+    //const int n_ctx      = ((int32_t *) dst->op_params)[3];
+    const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
     memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
     memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
     memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
@@ -14514,20 +14433,15 @@ static void ggml_compute_forward_rope_f16(
     const float theta_scale = powf(freq_base, -2.0f/n_dims);
 
     float corr_dims[2];
-    ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
+    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
 
     const bool is_neox = mode & 2;
-    const bool is_glm  = mode & 4;
 
     const float * freq_factors = NULL;
-    if (is_neox) {
-        if (src2 != NULL) {
-            GGML_ASSERT(src2->type == GGML_TYPE_F32);
-            GGML_ASSERT(src2->ne[0] >= n_dims / 2);
-            freq_factors = (const float *) src2->data;
-        }
-    } else {
-        GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
+    if (src2 != NULL) {
+        GGML_ASSERT(src2->type == GGML_TYPE_F32);
+        GGML_ASSERT(src2->ne[0] >= n_dims / 2);
+        freq_factors = (const float *) src2->data;
     }
 
     // backward process uses inverse rotation by cos and sin.
@@ -14542,43 +14456,14 @@ static void ggml_compute_forward_rope_f16(
             const int64_t p = pos[i2];
 
             float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
-            if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox
-                ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
-            }
+            ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
 
             for (int64_t i1 = 0; i1 < ne1; i1++) {
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
 
-                float theta_base = (float)p;
-
-                if (is_glm) {
-                    theta_base = MIN(p, n_ctx - 2);
-                    float block_theta = MAX(p - (n_ctx - 2), 0);
-                    for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
-                        const float cos_theta = cosf(theta_base);
-                        const float sin_theta = sinf(theta_base) * sin_sign;
-                        const float cos_block_theta = cosf(block_theta);
-                        const float sin_block_theta = sinf(block_theta) * sin_sign;
-
-                        theta_base  *= theta_scale;
-                        block_theta *= theta_scale;
-
-                        const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                              ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data +  i3*nb3 + i2*nb2  + i1*nb1  + i0*nb0);
-
-                        const float x0 = GGML_FP16_TO_FP32(src[0]);
-                        const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
-                        const float x2 = GGML_FP16_TO_FP32(src[n_dims]);
-                        const float x3 = GGML_FP16_TO_FP32(src[n_dims/2*3]);
-
-                        dst_data[0]          = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
-                        dst_data[n_dims/2]   = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
-                        dst_data[n_dims]     = GGML_FP32_TO_FP16(x2*cos_block_theta - x3*sin_block_theta);
-                        dst_data[n_dims/2*3] = GGML_FP32_TO_FP16(x2*sin_block_theta + x3*cos_block_theta);
-                    }
-                } else if (!is_neox) {
-                    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
+                if (!is_neox) {
+                    for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
                         const float cos_theta = cache[i0 + 0];
                         const float sin_theta = cache[i0 + 1];
 
@@ -14592,40 +14477,29 @@ static void ggml_compute_forward_rope_f16(
                         dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
                     }
                 } else {
-                    // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
-                    for (int64_t ic = 0; ic < ne0; ic += 2) {
-                        if (ic < n_dims) {
-                            const int64_t i0 = ic/2;
-
-                            const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f;
+                    for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
+                        const int64_t ic = i0/2;
 
-                            float cos_theta, sin_theta;
-                            rope_yarn(
-                                theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor,
-                                &cos_theta, &sin_theta
-                            );
-
-                            sin_theta  *= sin_sign;
-                            theta_base *= theta_scale;
+                        const float cos_theta = cache[i0 + 0];
+                        const float sin_theta = cache[i0 + 1];
 
-                            const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                                  ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+                        const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
+                        ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
 
-                            const float x0 = GGML_FP16_TO_FP32(src[0]);
-                            const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
+                        const float x0 = GGML_FP16_TO_FP32(src[0]);
+                        const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
 
-                            dst_data[0]        = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
-                            dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
-                        } else {
-                            const int64_t i0 = ic;
+                        dst_data[0]        = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
+                        dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
+                    }
+                }
 
-                            const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                                  ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+                for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
+                    const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                    ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 
-                            dst_data[0] = src[0];
-                            dst_data[1] = src[1];
-                        }
-                    }
+                    dst_data[0] = src[0];
+                    dst_data[1] = src[1];
                 }
             }
         }
@@ -18327,9 +18201,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     //const int n_past = ((int32_t *) tensor->op_params)[0];
                     const int n_dims     = ((int32_t *) tensor->op_params)[1];
                     const int mode       = ((int32_t *) tensor->op_params)[2];
-                    const int n_ctx      = ((int32_t *) tensor->op_params)[3];
-                    const int n_orig_ctx = ((int32_t *) tensor->op_params)[4];
-                    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down;
+                    //const int n_ctx      = ((int32_t *) tensor->op_params)[3];
+                    const int n_ctx_orig = ((int32_t *) tensor->op_params)[4];
+                    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
 
                     memcpy(&freq_base,   (int32_t *) tensor->op_params +  5, sizeof(float));
                     memcpy(&freq_scale,  (int32_t *) tensor->op_params +  6, sizeof(float));
@@ -18337,8 +18211,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     memcpy(&attn_factor, (int32_t *) tensor->op_params +  8, sizeof(float));
                     memcpy(&beta_fast,   (int32_t *) tensor->op_params +  9, sizeof(float));
                     memcpy(&beta_slow,   (int32_t *) tensor->op_params + 10, sizeof(float));
-                    memcpy(&xpos_base,   (int32_t *) tensor->op_params + 11, sizeof(float));
-                    memcpy(&xpos_down,   (int32_t *) tensor->op_params + 12, sizeof(bool));
 
                     src0->grad = ggml_add_or_set(ctx,
                             src0->grad,
@@ -18348,16 +18220,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 src2,
                                 n_dims,
                                 mode,
-                                n_ctx,
-                                n_orig_ctx,
+                                n_ctx_orig,
                                 freq_base,
                                 freq_scale,
                                 ext_factor,
                                 attn_factor,
                                 beta_fast,
-                                beta_slow,
-                                xpos_base,
-                                xpos_down),
+                                beta_slow),
                             zero_table);
                 }
             } break;
@@ -18367,9 +18236,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     //const int n_past = ((int32_t *) tensor->op_params)[0];
                     const int n_dims     = ((int32_t *) tensor->op_params)[1];
                     const int mode       = ((int32_t *) tensor->op_params)[2];
-                    const int n_ctx      = ((int32_t *) tensor->op_params)[3];
-                    const int n_orig_ctx = ((int32_t *) tensor->op_params)[4];
-                    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down;
+                    //const int n_ctx      = ((int32_t *) tensor->op_params)[3];
+                    const int n_ctx_orig = ((int32_t *) tensor->op_params)[4];
+                    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
 
                     memcpy(&freq_base,   (int32_t *) tensor->op_params +  5, sizeof(float));
                     memcpy(&freq_scale,  (int32_t *) tensor->op_params +  6, sizeof(float));
@@ -18377,8 +18246,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     memcpy(&attn_factor, (int32_t *) tensor->op_params +  8, sizeof(float));
                     memcpy(&beta_fast,   (int32_t *) tensor->op_params +  9, sizeof(float));
                     memcpy(&beta_slow,   (int32_t *) tensor->op_params + 10, sizeof(float));
-                    memcpy(&xpos_base,   (int32_t *) tensor->op_params + 11, sizeof(float));
-                    memcpy(&xpos_down,   (int32_t *) tensor->op_params + 12, sizeof(bool));
 
                     src0->grad = ggml_add_or_set(ctx,
                             src0->grad,
@@ -18388,16 +18255,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 src2,
                                 n_dims,
                                 mode,
-                                n_ctx,
-                                n_orig_ctx,
+                                n_ctx_orig,
                                 freq_base,
                                 freq_scale,
                                 ext_factor,
                                 attn_factor,
                                 beta_fast,
                                 beta_slow,
-                                xpos_base,
-                                xpos_down,
                                 false),
                             zero_table);
                 }
index 8dc90a45dd8726b5ec672ce4ca167f2e4d4c5201..ce406a8af867a65875fce92bc5fa09d40a3f3da5 100644 (file)
@@ -1141,7 +1141,7 @@ struct test_rope : public test_case {
     const std::array<int64_t, 4> ne_a;
     int n_dims;
     int mode;
-    int n_ctx;
+    int n_ctx; // used to generate positions
     float fs; // freq_scale
     float ef; // ext_factor
     float af; // attn_factor
@@ -1168,7 +1168,7 @@ struct test_rope : public test_case {
         }
         ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2]);
         ggml_tensor * freq = ff ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_dims/2) : nullptr;
-        ggml_tensor * out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, n_ctx, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
+        ggml_tensor * out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
         return out;
     }
 
@@ -1615,7 +1615,7 @@ struct llama_hparams {
 
     // cparams
     static constexpr uint32_t n_ctx = 512; // user-specified context size
-    static constexpr uint32_t n_orig_ctx = n_ctx;
+    static constexpr uint32_t n_ctx_orig = n_ctx;
 
     // batch
     int32_t n_tokens;
@@ -1806,13 +1806,13 @@ struct test_llama : public test_llm {
 
                 Qcur = ggml_rope_ext(
                     ctx, ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head,    hp.n_tokens), inp_pos, nullptr,
-                    hp.n_rot, 0, 0, hp.n_orig_ctx, freq_base, freq_scale,
+                    hp.n_rot, 0, hp.n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
 
                 Kcur = ggml_rope_ext(
                     ctx, ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens), inp_pos, nullptr,
-                    hp.n_rot, 0, 0, hp.n_orig_ctx, freq_base, freq_scale,
+                    hp.n_rot, 0, hp.n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
 
@@ -1931,12 +1931,12 @@ struct test_falcon : public test_llm {
 
                 // using mode = 2 for neox mode
                 Qcur = ggml_rope_ext(
-                    ctx, Qcur, inp_pos, nullptr, hp.n_rot, 2, 0, hp.n_orig_ctx,
+                    ctx, Qcur, inp_pos, nullptr, hp.n_rot, 2, hp.n_ctx_orig,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );
 
                 Kcur = ggml_rope_ext(
-                    ctx, Kcur, inp_pos, nullptr, hp.n_rot, 2, 0, hp.n_orig_ctx,
+                    ctx, Kcur, inp_pos, nullptr, hp.n_rot, 2, hp.n_ctx_orig,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );
 
@@ -2236,15 +2236,15 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
                 for (float ef : { 0.0f, 0.7465f }) {
                     for (float af : { 1.0f, 1.4245f }) {
                         for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
-                            // TODO: ff not supported yet for !neox
-                            test_cases.emplace_back(new test_rope(type, {128,  32, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 7B
-                            if (all) {
-                                test_cases.emplace_back(new test_rope(type, {128,  40, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 13B
-                                test_cases.emplace_back(new test_rope(type, {128,  52, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 30B
-                                test_cases.emplace_back(new test_rope(type, {128,  64, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 65B
-                            }
-
                             for (bool ff : {false, true}) { // freq_factors
+                                test_cases.emplace_back(new test_rope(type, {128,  32, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 7B
+
+                                if (all) {
+                                    test_cases.emplace_back(new test_rope(type, {128,  40, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 13B
+                                    test_cases.emplace_back(new test_rope(type, {128,  52, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 30B
+                                    test_cases.emplace_back(new test_rope(type, {128,  64, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 65B
+                                }
+
                                 if (all) {
                                     test_cases.emplace_back(new test_rope(type, { 64,   1, 10, 1},  64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B)
                                     test_cases.emplace_back(new test_rope(type, { 64,  71, 10, 1},  64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B)
@@ -2256,6 +2256,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
                                 test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1},  64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 40B)
                             }
                         }
+
                         all = false;
                     }
                 }
index 21ca43be3a963bdc5dba3fdba4074d5597f3d96a..a353276459b2d144abd609b318d7e60558fbf4c7 100644 (file)
@@ -1465,7 +1465,7 @@ int main(int argc, const char ** argv) {
                             continue;
                         }
 
-                        struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode, 0));
+                        struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode));
 
                         GGML_PRINT_DEBUG("rope f32: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
                         check_gradient("rope f32", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY);
@@ -1505,7 +1505,7 @@ int main(int argc, const char ** argv) {
                             continue;
                         }
 
-                        struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode, 0));
+                        struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode));
 
                         GGML_PRINT_DEBUG("rope f16: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
                         check_gradient("rope f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY);