]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
Update SYCL-Rope op and Refactor (llama/8157)
authorzhentaoyu <redacted>
Mon, 1 Jul 2024 11:39:06 +0000 (19:39 +0800)
committerGeorgi Gerganov <redacted>
Mon, 8 Jul 2024 10:03:28 +0000 (13:03 +0300)
* align with rope.cu and move sycl-op to a single file

src/ggml-sycl.cpp
src/ggml-sycl/backend.hpp
src/ggml-sycl/rope.cpp [new file with mode: 0644]
src/ggml-sycl/rope.hpp [new file with mode: 0644]

index 4a668a2c34d3ead78dc4011c0347d5eb59f1404e..30d8a5b33b61335cdbf6fa6e02a3a7191708a378 100644 (file)
@@ -978,114 +978,6 @@ static void cpy_f32_q(const char * cx, char * cdst, const int ne,
     cpy_blck(cx + x_offset, cdst + dst_offset);
 }
 
-static float rope_yarn_ramp(const float low, const float high, const int i0) {
-    const float y = (i0 / 2 - low) / sycl::max(0.001f, high - low);
-    return 1.0f - sycl::min(1.0f, sycl::max(0.0f, y));
-}
-
-struct rope_corr_dims {
-    float v[4];
-};
-
-// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
-// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
-static void rope_yarn(
-    float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
-    float * cos_theta, float * sin_theta
-) {
-    // Get n-d rotational scaling corrected for extrapolation
-    float theta_interp = freq_scale * theta_extrap;
-    float theta = theta_interp;
-    if (ext_factor != 0.0f) {
-        float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
-        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
-
-        // Get n-d magnitude scaling corrected for interpolation
-        mscale *= 1.0f + 0.1f * sycl::log(1.0f / freq_scale);
-    }
-    *cos_theta = sycl::cos(theta) * mscale;
-    *sin_theta = sycl::sin(theta) * mscale;
-}
-
-// rope == RoPE == rotary positional embedding
-template<typename T, bool has_pos>
-static void rope(
-    const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
-    float ext_factor, float attn_factor, rope_corr_dims corr_dims
-,
-    const sycl::nd_item<3> &item_ct1) {
-    const int col = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
-                         item_ct1.get_local_id(1));
-
-    if (col >= ncols) {
-        return;
-    }
-
-    const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                    item_ct1.get_local_id(2);
-    const int i = row*ncols + col;
-    const int i2 = row/p_delta_rows;
-
-    const int p = has_pos ? pos[i2] : 0;
-    const float theta_base = p * dpct::pow(freq_base, -float(col) / ncols);
-
-    float cos_theta, sin_theta;
-    rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta);
-
-    const float x0 = x[i + 0];
-    const float x1 = x[i + 1];
-
-    dst[i + 0] = x0*cos_theta - x1*sin_theta;
-    dst[i + 1] = x0*sin_theta + x1*cos_theta;
-}
-
-template<typename T, bool has_pos, bool has_freq_facs>
-static void rope_neox(
-    const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims,
-    const float * freq_factors, const sycl::nd_item<3> &item_ct1) {
-    const int col = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
-                         item_ct1.get_local_id(1));
-
-    if (col >= ncols) {
-        return;
-    }
-
-    const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                    item_ct1.get_local_id(2);
-    const int ib = col / n_dims;
-    const int ic = col % n_dims;
-
-    if (ib > 0) {
-        const int i = row*ncols + ib*n_dims + ic;
-
-        dst[i + 0] = x[i + 0];
-        dst[i + 1] = x[i + 1];
-
-        return;
-    }
-
-    const int i  = row*ncols + ib*n_dims + ic/2;
-    const int i2 = row/p_delta_rows;
-
-    float cur_rot = inv_ndims * ic - ib;
-
-    const int p = has_pos ? pos[i2] : 0;
-    const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;
-
-    const float theta_base =
-        p * freq_scale * dpct::pow(theta_scale, col / 2.0f)/freq_factor;
-
-    float cos_theta, sin_theta;
-    rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
-
-    const float x0 = x[i + 0];
-    const float x1 = x[i + n_dims/2];
-
-    dst[i + 0]        = x0*cos_theta - x1*sin_theta;
-    dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
-}
-
 static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
                            const sycl::nd_item<3> &item_ct1) {
     const int row = item_ct1.get_group(1);
@@ -2241,110 +2133,6 @@ static void clamp_f32_sycl(const float *x, float *dst, const float min,
         });
 }
 
-template <typename T>
-static void rope_sycl(const T *x, T *dst, int ncols, int nrows,
-                      const int32_t *pos, float freq_scale, int p_delta_rows,
-                      float freq_base, float ext_factor, float attn_factor,
-                      rope_corr_dims corr_dims, queue_ptr stream) {
-    GGML_ASSERT(ncols % 2 == 0);
-    const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
-    const int num_blocks_x = (ncols + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
-    const sycl::range<3> block_nums(1, num_blocks_x, nrows);
-    if (pos == nullptr) {
-        /*
-        DPCT1049:40: The work-group size passed to the SYCL kernel may exceed
-        the limit. To get the device limit, query
-        info::device::max_work_group_size. Adjust the work-group size if needed.
-        */
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
-
-        stream->parallel_for(
-            sycl::nd_range<3>(block_nums * block_dims, block_dims),
-            [=](sycl::nd_item<3> item_ct1) {
-                rope<T, false>(x, dst, ncols, pos, freq_scale, p_delta_rows,
-                               freq_base, ext_factor, attn_factor, corr_dims,
-                               item_ct1);
-            });
-    } else {
-        /*
-        DPCT1049:41: The work-group size passed to the SYCL kernel may exceed
-        the limit. To get the device limit, query
-        info::device::max_work_group_size. Adjust the work-group size if needed.
-        */
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
-
-        stream->parallel_for(
-            sycl::nd_range<3>(block_nums * block_dims, block_dims),
-            [=](sycl::nd_item<3> item_ct1) {
-                rope<T, true>(x, dst, ncols, pos, freq_scale, p_delta_rows,
-                              freq_base, ext_factor, attn_factor, corr_dims,
-                              item_ct1);
-            });
-    }
-}
-
-template <typename T>
-static void rope_neox_sycl(const T *x, T *dst, int ncols, int n_dims, int nrows,
-                           const int32_t *pos, float freq_scale,
-                           int p_delta_rows, float freq_base, float ext_factor,
-                           float attn_factor, rope_corr_dims corr_dims,
-                           const float * freq_factors, queue_ptr stream) {
-    GGML_ASSERT(ncols % 2 == 0);
-    const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
-    const int num_blocks_x = (ncols + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
-    const sycl::range<3> block_nums(1, num_blocks_x, nrows);
-
-    const float theta_scale = powf(freq_base, -2.0f/n_dims);
-    const float inv_ndims = -1.0f / n_dims;
-
-    if (pos == nullptr) {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
-        if (freq_factors == nullptr) {
-            stream->parallel_for(
-                sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                [=](sycl::nd_item<3> item_ct1) {
-                    rope_neox<T, false, false>(x, dst, ncols, n_dims, pos, freq_scale,
-                                        p_delta_rows, ext_factor, attn_factor,
-                                        corr_dims, theta_scale, inv_ndims, freq_factors,
-                                        item_ct1);
-                });
-        } else {
-            stream->parallel_for(
-                sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                [=](sycl::nd_item<3> item_ct1) {
-                    rope_neox<T, false, true>(x, dst, ncols, n_dims, pos, freq_scale,
-                                        p_delta_rows, ext_factor, attn_factor,
-                                        corr_dims, theta_scale, inv_ndims, freq_factors,
-                                        item_ct1);
-                });
-        }
-    } else {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
-
-        if (freq_factors == nullptr) {
-            stream->parallel_for(
-                sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                [=](sycl::nd_item<3> item_ct1) {
-                    rope_neox<T, true, false>(x, dst, ncols, n_dims, pos, freq_scale,
-                                       p_delta_rows, ext_factor, attn_factor,
-                                       corr_dims, theta_scale, inv_ndims, freq_factors, item_ct1);
-                });
-        } else {
-            stream->parallel_for(
-                sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                [=](sycl::nd_item<3> item_ct1) {
-                    rope_neox<T, true, true>(x, dst, ncols, n_dims, pos, freq_scale,
-                                       p_delta_rows, ext_factor, attn_factor,
-                                       corr_dims, theta_scale, inv_ndims, freq_factors, item_ct1);
-                });
-        }
-    }
-}
-
 static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
                               const int nrows, queue_ptr stream) {
     const sycl::range<3> block_dims(1, 1, WARP_SIZE);
@@ -3461,97 +3249,6 @@ catch (sycl::exception const &exc) {
   std::exit(1);
 }
 
-inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
-                              ggml_tensor *dst, const float *src0_dd,
-                              const float *src1_dd, float *dst_dd,
-                              const queue_ptr &main_stream) {
-    const ggml_tensor * src2 = dst->src[2];
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
-    GGML_ASSERT(src0->type == dst->type);
-
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne2 = dst->ne[2];
-    const int64_t nrows = ggml_nrows(src0);
-
-    //const int n_past      = ((int32_t *) dst->op_params)[0];
-    const int n_dims      = ((int32_t *) dst->op_params)[1];
-    const int mode        = ((int32_t *) dst->op_params)[2];
-    //const int n_ctx       = ((int32_t *) dst->op_params)[3];
-    const int n_ctx_orig  = ((int32_t *) dst->op_params)[4];
-
-    // RoPE alteration for extended context
-    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
-    memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
-    memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
-    memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
-    memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
-    memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
-    memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
-
-    const float * freq_factors = nullptr;
-    const int32_t * pos = nullptr;
-    if ((mode & 1) == 0) {
-        GGML_ASSERT(src1->type == GGML_TYPE_I32);
-        GGML_ASSERT(src1->ne[0] == ne2);
-        pos = (const int32_t *) src1_dd;
-    }
-
-    const bool is_neox = mode & 2;
-
-#pragma message("TODO: update rope NORM mode to match NEOX mode")
-#pragma message("      https://github.com/ggerganov/llama.cpp/pull/7634")
-
-    if (is_neox) {
-        pos = (const int32_t *) src1_dd;
-
-        if (src2 != nullptr) {
-            freq_factors = (const float *) src2->data;
-        }
-    } else {
-        GGML_ASSERT(src2 == nullptr && "TODO: freq_factors not implemented for !is_neox");
-    }
-
-    rope_corr_dims corr_dims;
-    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
-
-    // compute
-    if (is_neox) {
-        if (src0->type == GGML_TYPE_F32) {
-            rope_neox_sycl(
-                (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, freq_factors, main_stream
-            );
-        } else if (src0->type == GGML_TYPE_F16) {
-            rope_neox_sycl((const sycl::half *)src0_dd, (sycl::half *)dst_dd,
-                           ne00, n_dims, nrows, pos, freq_scale, ne01,
-                           freq_base, ext_factor, attn_factor, corr_dims,
-                           freq_factors, main_stream);
-        } else {
-            GGML_ASSERT(false);
-        }
-    } else {
-        if (src0->type == GGML_TYPE_F32) {
-            rope_sycl(
-                (const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, main_stream
-            );
-        } else if (src0->type == GGML_TYPE_F16) {
-            rope_sycl((const sycl::half *)src0_dd, (sycl::half *)dst_dd, ne00,
-                      nrows, pos, freq_scale, ne01, freq_base, ext_factor,
-                      attn_factor, corr_dims, main_stream);
-        } else {
-            GGML_ASSERT(false);
-        }
-    }
-
-    (void) src1;
-    (void) dst;
-    (void) src1_dd;
-}
-
 static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
                                 const ggml_tensor *src1, ggml_tensor *dst,
                                 const float *src0_dd, const float *src1_dd,
@@ -6241,7 +5938,9 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
         case GGML_OP_CONT:
         case GGML_OP_DIAG_MASK_INF:
         case GGML_OP_SOFT_MAX:
+            return true;
         case GGML_OP_ROPE:
+            return ggml_is_contiguous(op->src[0]);
         case GGML_OP_IM2COL:
         case GGML_OP_POOL_2D:
         case GGML_OP_SUM_ROWS:
index 2d37e271f90508825d6d7fdae56df4927b0b00b2..d5a63cd710cc34ecad121ecadf38e88f34ad9eb2 100644 (file)
@@ -19,5 +19,6 @@
 #include "dmmv.hpp"
 #include "mmq.hpp"
 #include "mmvq.hpp"
+#include "rope.hpp"
 
 #endif // GGML_SYCL_BACKEND_HPP
diff --git a/src/ggml-sycl/rope.cpp b/src/ggml-sycl/rope.cpp
new file mode 100644 (file)
index 0000000..eabf169
--- /dev/null
@@ -0,0 +1,275 @@
+#include "rope.hpp"
+
+struct rope_corr_dims {
+    float v[2];
+};
+
+static float rope_yarn_ramp(const float low, const float high, const int i0) {
+    const float y = (i0 / 2 - low) / sycl::max(0.001f, high - low);
+    return 1.0f - sycl::min(1.0f, sycl::max(0.0f, y));
+}
+
+// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
+// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
+static void rope_yarn(
+    float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
+    float * cos_theta, float * sin_theta) {
+    // Get n-d rotational scaling corrected for extrapolation
+    float theta_interp = freq_scale * theta_extrap;
+    float theta = theta_interp;
+    if (ext_factor != 0.0f) {
+        float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
+        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
+
+        // Get n-d magnitude scaling corrected for interpolation
+        mscale *= 1.0f + 0.1f * sycl::log(1.0f / freq_scale);
+    }
+    *cos_theta = sycl::cos(theta) * mscale;
+    *sin_theta = sycl::sin(theta) * mscale;
+}
+
+template<typename T, bool has_ff>
+static void rope_norm(
+    const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors,
+    const sycl::nd_item<3> &item_ct1) {
+    const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
+                         item_ct1.get_local_id(1));
+
+    if (i0 >= ne0) {
+        return;
+    }
+
+    const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                    item_ct1.get_local_id(2);
+
+    if (i0 >= n_dims) {
+        const int i = row*ne0 + i0;
+
+        dst[i + 0] = x[i + 0];
+        dst[i + 1] = x[i + 1];
+
+        return;
+    }
+
+    const int i = row*ne0 + i0;
+    const int i2 = row/p_delta_rows;
+
+    const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
+
+    const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
+
+    float cos_theta;
+    float sin_theta;
+
+    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+
+    const float x0 = x[i + 0];
+    const float x1 = x[i + 1];
+
+    dst[i + 0] = x0*cos_theta - x1*sin_theta;
+    dst[i + 1] = x0*sin_theta + x1*cos_theta;
+}
+
+template<typename T, bool has_ff>
+static void rope_neox(
+    const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors,
+    const sycl::nd_item<3> &item_ct1) {
+    const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
+                         item_ct1.get_local_id(1));
+
+    if (i0 >= ne0) {
+        return;
+    }
+
+    const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                    item_ct1.get_local_id(2);
+
+    if (i0 >= n_dims) {
+        const int i = row*ne0 + i0;
+
+        dst[i + 0] = x[i + 0];
+        dst[i + 1] = x[i + 1];
+
+        return;
+    }
+
+    const int i  = row*ne0 + i0/2;
+    const int i2 = row/p_delta_rows;
+
+    const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
+
+    const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
+
+    float cos_theta;
+    float sin_theta;
+
+    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+
+    const float x0 = x[i + 0];
+    const float x1 = x[i + n_dims/2];
+
+    dst[i + 0]        = x0*cos_theta - x1*sin_theta;
+    dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
+}
+
+template <typename T>
+static void rope_norm_sycl(
+    const T *x, T *dst, int ne0, int n_dims, int nr, const int32_t *pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
+    GGML_ASSERT(ne0 % 2 == 0);
+    const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
+    const int num_blocks_x = (ne0 + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
+    const sycl::range<3> block_nums(1, num_blocks_x, nr);
+
+    const float theta_scale = powf(freq_base, -2.0f/n_dims);
+
+    dpct::has_capability_or_fail(stream->get_device(),
+                                     {sycl::aspect::fp16});
+
+    if (freq_factors == nullptr) {
+        /*
+        DPCT1049:40: The work-group size passed to the SYCL kernel may exceed
+        the limit. To get the device limit, query
+        info::device::max_work_group_size. Adjust the work-group size if needed.
+        */
+        stream->parallel_for(
+            sycl::nd_range<3>(block_nums * block_dims, block_dims),
+            [=](sycl::nd_item<3> item_ct1) {
+                rope_norm<T, false>(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows,
+                               ext_factor, attn_factor, corr_dims, theta_scale, freq_factors,
+                               item_ct1);
+            });
+    } else {
+        /*
+        DPCT1049:41: The work-group size passed to the SYCL kernel may exceed
+        the limit. To get the device limit, query
+        info::device::max_work_group_size. Adjust the work-group size if needed.
+        */
+        stream->parallel_for(
+            sycl::nd_range<3>(block_nums * block_dims, block_dims),
+            [=](sycl::nd_item<3> item_ct1) {
+                rope_norm<T, true>(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows,
+                              ext_factor, attn_factor, corr_dims, theta_scale, freq_factors,
+                              item_ct1);
+            });
+    }
+}
+
+template <typename T>
+static void rope_neox_sycl(
+    const T *x, T *dst, int ne0, int n_dims, int nr, const int32_t *pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
+    GGML_ASSERT(ne0 % 2 == 0);
+    const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
+    const int num_blocks_x = (ne0 + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
+    const sycl::range<3> block_nums(1, num_blocks_x, nr);
+
+    const float theta_scale = powf(freq_base, -2.0f/n_dims);
+
+    dpct::has_capability_or_fail(stream->get_device(),
+                                    {sycl::aspect::fp16});
+
+    if (freq_factors == nullptr) {
+        stream->parallel_for(
+            sycl::nd_range<3>(block_nums * block_dims, block_dims),
+            [=](sycl::nd_item<3> item_ct1) {
+                rope_neox<T, false>(x, dst, ne0, n_dims, pos, freq_scale,
+                                    p_delta_rows, ext_factor, attn_factor,
+                                    corr_dims, theta_scale, freq_factors,
+                                    item_ct1);
+            });
+    } else {
+        stream->parallel_for(
+            sycl::nd_range<3>(block_nums * block_dims, block_dims),
+            [=](sycl::nd_item<3> item_ct1) {
+                rope_neox<T, true>(x, dst, ne0, n_dims, pos, freq_scale,
+                                    p_delta_rows, ext_factor, attn_factor,
+                                    corr_dims, theta_scale, freq_factors,
+                                    item_ct1);
+            });
+    }
+}
+
+void ggml_sycl_op_rope(
+    ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
+    const float *src0_dd, const float *src1_dd, float *dst_dd, const queue_ptr &main_stream) {
+    const ggml_tensor * src2 = dst->src[2];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
+    GGML_ASSERT(src0->type == dst->type);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t nr = ggml_nrows(src0);
+
+    //const int n_past      = ((int32_t *) dst->op_params)[0];
+    const int n_dims      = ((int32_t *) dst->op_params)[1];
+    const int mode        = ((int32_t *) dst->op_params)[2];
+    //const int n_ctx       = ((int32_t *) dst->op_params)[3];
+    const int n_ctx_orig  = ((int32_t *) dst->op_params)[4];
+
+    // RoPE alteration for extended context
+    float freq_base;
+    float freq_scale;
+    float ext_factor;
+    float attn_factor;
+    float beta_fast;
+    float beta_slow;
+
+    memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
+    memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
+    memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
+    memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
+    memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
+    memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
+
+    const bool is_neox = mode & 2;
+
+    const int32_t * pos = (const int32_t *) src1_dd;
+
+    const float * freq_factors = nullptr;
+    if (src2 != nullptr) {
+        freq_factors = (const float *) src2->data;
+    }
+
+    rope_corr_dims corr_dims;
+    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
+
+    // compute
+    if (is_neox) {
+        if (src0->type == GGML_TYPE_F32) {
+            rope_neox_sycl(
+                (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+                attn_factor, corr_dims, freq_factors, main_stream
+            );
+        } else if (src0->type == GGML_TYPE_F16) {
+            rope_neox_sycl(
+                (const sycl::half *)src0_dd, (sycl::half *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+                attn_factor, corr_dims, freq_factors, main_stream
+            );
+        } else {
+            GGML_ASSERT(false);
+        }
+    } else {
+        if (src0->type == GGML_TYPE_F32) {
+            rope_norm_sycl(
+                (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+                attn_factor, corr_dims, freq_factors, main_stream
+            );
+        } else if (src0->type == GGML_TYPE_F16) {
+            rope_norm_sycl(
+                (const sycl::half *)src0_dd, (sycl::half *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+                attn_factor, corr_dims, freq_factors, main_stream
+            );
+        } else {
+            GGML_ASSERT(false);
+        }
+    }
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
diff --git a/src/ggml-sycl/rope.hpp b/src/ggml-sycl/rope.hpp
new file mode 100644 (file)
index 0000000..00354c3
--- /dev/null
@@ -0,0 +1,22 @@
+//
+// MIT license
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#ifndef GGML_SYCL_ROPE_HPP
+#define GGML_SYCL_ROPE_HPP
+
+#include "common.hpp"
+
+void ggml_sycl_op_rope(
+    ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
+    const float *src0_dd, const float *src1_dd, float *dst_dd, const queue_ptr &main_stream);
+
+#endif // GGML_SYCL_ROPE_HPP