]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
fix op rope, add rope_back (llama/20293)
authorNeo Zhang <redacted>
Wed, 11 Mar 2026 01:53:34 +0000 (09:53 +0800)
committerGeorgi Gerganov <redacted>
Mon, 16 Mar 2026 11:10:15 +0000 (13:10 +0200)
ggml/src/ggml-sycl/convert.hpp
ggml/src/ggml-sycl/ggml-sycl.cpp
ggml/src/ggml-sycl/rope.cpp
ggml/src/ggml-sycl/rope.hpp

index f93bd0df7d7e40240814f61456a256df2288524c..6e621f2154d2bf2d3c8612d612a525d66030bc6b 100644 (file)
@@ -39,6 +39,11 @@ template<typename dst_t, typename src_t>
         return sycl::ext::oneapi::bfloat16(float(x));
     } else if constexpr (std::is_same_v<src_t, sycl::ext::oneapi::bfloat16>) {
         return static_cast<float>(x);
+    } else if constexpr (std::is_same_v<src_t, sycl::float2> && std::is_same_v<dst_t, sycl::half2>) {
+        return x.template convert<sycl::half, sycl::rounding_mode::rte>();
+    } else if constexpr (std::is_same_v<src_t, sycl::float2> &&
+                         std::is_same_v<dst_t, sycl::vec<sycl::ext::oneapi::bfloat16, 2>>) {
+        return {x.x, x.y};
     } else if constexpr(std::is_same_v<dst_t, int32_t>) {
         return int32_t(x);
     } else {
@@ -46,4 +51,5 @@ template<typename dst_t, typename src_t>
     }
 }
 
+
 #endif  // GGML_SYCL_CONVERT_HPP
index 66dfc4532c0017e8d20a7bf0ec80b8df41315a9f..f887061b279b5c0f6aecf092cb8681e5617e221a 100644 (file)
@@ -4145,6 +4145,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
         case GGML_OP_ROPE:
             ggml_sycl_rope(ctx, dst);
             break;
+        case GGML_OP_ROPE_BACK:
+            ggml_sycl_rope_back(ctx, dst);
+            break;
         case GGML_OP_IM2COL:
             ggml_sycl_im2col(ctx, dst);
             break;
@@ -4851,6 +4854,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
             return max_bias == 0.0f;
         }
         case GGML_OP_ROPE:
+        case GGML_OP_ROPE_BACK:
         case GGML_OP_IM2COL:
             return true;
         case GGML_OP_UPSCALE:
index aeaa58b95b3281c5906182385321814da6e4ebe1..9d83a1e9fa09b640c292d8ee5f4ab9359a77d981 100644 (file)
@@ -1,4 +1,5 @@
 #include "rope.hpp"
+#include "convert.hpp"
 #include "ggml-sycl/common.hpp"
 #include "ggml.h"
 
@@ -15,366 +16,489 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
     return 1.0f - sycl::min(1.0f, sycl::max(0.0f, y));
 }
 
-// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
-// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
-static void rope_yarn(
-    float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
-    float * cos_theta, float * sin_theta) {
-    // Get n-d rotational scaling corrected for extrapolation
+template <bool forward>
+static void rope_yarn(const float theta_extrap, const float freq_scale,
+                      const rope_corr_dims corr_dims, const int64_t i0,
+                      const float ext_factor, float mscale, float &cos_theta,
+                      float &sin_theta) {
     float theta_interp = freq_scale * theta_extrap;
     float theta = theta_interp;
     if (ext_factor != 0.0f) {
-        float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
+        float ramp_mix =
+            rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
         theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
 
-        // Get n-d magnitude scaling corrected for interpolation
         mscale *= 1.0f + 0.1f * sycl::log(1.0f / freq_scale);
     }
-    *cos_theta = sycl::cos(theta) * mscale;
-    *sin_theta = sycl::sin(theta) * mscale;
+    cos_theta = sycl::cos(theta) * mscale;
+    sin_theta = sycl::sin(theta) * mscale;
+    if (!forward) {
+        sin_theta *= -1.0f;
+    }
 }
 
-template <typename T, bool has_ff>
-static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
-                      const int32_t * pos, float freq_scale, float ext_factor, float attn_factor,
-                      const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors,
-                      const sycl::nd_item<3> & item_ct1) {
-    const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1));
-
-    if (i0 >= ne0) {
+template <bool forward, bool has_ff, typename T, typename D>
+static void rope_norm(const T *x, D *dst, const int ne00, const int ne01,
+                      const int ne02, const int s01, const int s02,
+                      const int s03, const int s1, const int s2, const int s3,
+                      const int n_dims, const int32_t *pos,
+                      const float freq_scale, const float ext_factor,
+                      const float attn_factor, const rope_corr_dims corr_dims,
+                      const float theta_scale, const float *freq_factors,
+                      const int64_t *row_indices, const int set_rows_stride) {
+    auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+    const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
+                        item_ct1.get_local_id(1));
+
+    if (i0 >= ne00) {
         return;
     }
 
-    const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
+    const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                        item_ct1.get_local_id(2);
 
-    const int row0     = row % ne1;
-    const int channel0 = row / ne1;
+    const uint32_t i3 = row_dst / (ne01 * ne02);
+    const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
+    const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
 
-    const int i  = row * ne0 + i0;
-    const int i2 = channel0 * s2 + row0 * s1 + i0;
+    int idst = i0 + i1 * s1 + i2 * s2 + i3 * s3;
+    const int ix = i0 + i1 * s01 + i2 * s02 + i3 * s03;
+
+    if (set_rows_stride != 0) {
+        idst = i1 * s1 + i0;
+        idst += row_indices[i2] * set_rows_stride;
+    }
 
+    const auto &store_coaelsced = [&](float x0, float x1) {
+        if constexpr (std::is_same_v<float, D>) {
+            sycl::float2 v = sycl::float2(x0, x1);
+            ggml_sycl_memcpy_1<8>(dst + idst, &v);
+        } else if constexpr (std::is_same_v<sycl::half, D>) {
+            sycl::half2 v = sycl::half2(x0, x1);
+            ggml_sycl_memcpy_1<4>(dst + idst, &v);
+        }
+    };
     if (i0 >= n_dims) {
-        *reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2);
+        store_coaelsced(x[ix + 0], x[ix + 1]);
         return;
     }
 
-    const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
+    const float theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
 
     const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
 
     float cos_theta;
     float sin_theta;
 
-    rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+    rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,
+                       ext_factor, attn_factor, cos_theta, sin_theta);
 
-    const float x0 = x[i2 + 0];
-    const float x1 = x[i2 + 1];
+    const float x0 = x[ix + 0];
+    const float x1 = x[ix + 1];
 
-    dst[i + 0] = x0 * cos_theta - x1 * sin_theta;
-    dst[i + 1] = x0 * sin_theta + x1 * cos_theta;
+    store_coaelsced(x0 * cos_theta - x1 * sin_theta,
+                    x0 * sin_theta + x1 * cos_theta);
 }
 
-template <typename T, bool has_ff>
-static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
-                      const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
-                      const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors,
-                      const sycl::nd_item<3> & item_ct1) {
-    const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1));
-
-    if (i0 >= ne0) {
+template <bool forward, bool has_ff, typename T, typename D>
+static void rope_neox(const T *x, D *dst, const int ne00, const int ne01,
+                      const int ne02, const int s01, const int s02,
+                      const int s03, const int s1, const int s2, const int s3,
+                      const int n_dims, const int32_t *pos,
+                      const float freq_scale, const float ext_factor,
+                      const float attn_factor, const rope_corr_dims corr_dims,
+                      const float theta_scale, const float *freq_factors,
+                      const int64_t *row_indices, const int set_rows_stride) {
+    auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+    const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
+                        item_ct1.get_local_id(1));
+
+    if (i0 >= ne00) {
         return;
     }
 
-    const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
+    const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                        item_ct1.get_local_id(2);
 
-    const int row0     = row % ne1;
-    const int channel0 = row / ne1;
+    const uint32_t i3 = row_dst / (ne01 * ne02);
+    const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
+    const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
 
-    const int i  = row * ne0 + i0 / 2;
-    const int i2 = channel0 * s2 + row0 * s1 + i0 / 2;
+    int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
+    const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
+
+    if (set_rows_stride != 0) {
+        idst = i1 * s1 + i0 / 2;
+        idst += row_indices[i2] * set_rows_stride;
+    }
 
     if (i0 >= n_dims) {
-        *reinterpret_cast<sycl::vec<T, 2> *>(dst + i + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2 + i0 / 2);
+        dst[idst + i0 / 2 + 0] = ggml_sycl_cast<D>(x[ix + i0 / 2 + 0]);
+        dst[idst + i0 / 2 + 1] = ggml_sycl_cast<D>(x[ix + i0 / 2 + 1]);
+
         return;
     }
 
-    const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
+    const float theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
 
     const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
 
     float cos_theta;
     float sin_theta;
 
-    rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+    rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,
+                       ext_factor, attn_factor, cos_theta, sin_theta);
 
-    const float x0 = x[i2 + 0];
-    const float x1 = x[i2 + n_dims / 2];
+    const float x0 = x[ix + 0];
+    const float x1 = x[ix + n_dims / 2];
 
-    dst[i + 0]          = x0 * cos_theta - x1 * sin_theta;
-    dst[i + n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
+    dst[idst + 0] = ggml_sycl_cast<D>(x0 * cos_theta - x1 * sin_theta);
+    dst[idst + n_dims / 2] = ggml_sycl_cast<D>(x0 * sin_theta + x1 * cos_theta);
 }
 
-template <typename T, bool has_ff>
-static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
-                        const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
-                        const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
-                        const float theta_scale, const float * freq_factors, const mrope_sections sections,
-                        const bool is_imrope, const sycl::nd_item<3> & item_ct1) {
-    // get index pos
-    const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1));
-    if (i0 >= ne0) {
+template <bool forward, bool has_ff, typename T>
+static void rope_multi(const T *x, T *dst, const int ne00, const int ne01,
+                       const int ne02, const int s01, const int s02,
+                       const int s03, const int s1, const int s2, const int s3,
+                       const int n_dims, const int32_t *pos,
+                       const float freq_scale, const float ext_factor,
+                       const float attn_factor, const rope_corr_dims corr_dims,
+                       const float theta_scale, const float *freq_factors,
+                       const mrope_sections sections, const bool is_imrope) {
+    auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+    const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
+                        item_ct1.get_local_id(1));
+
+    if (i0 >= ne00) {
         return;
     }
-    const int    row_dst   = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
 
-    const int    row_x     = row_dst % ne1;
-    const int    channel_x = row_dst / ne1;
-    const int    idst      = (row_dst * ne0) + (i0 / 2);
-    const size_t ix        = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
+    const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                        item_ct1.get_local_id(2);
+
+    const uint32_t i3 = row_dst / (ne01 * ne02);
+    const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
+    const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
+
+    int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
+    const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
 
     if (i0 >= n_dims) {
-        *reinterpret_cast<sycl::vec<T, 2> *>(dst + idst + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i0 / 2 + ix);
+        dst[idst + i0 / 2 + 0] = x[ix + i0 / 2 + 0];
+        dst[idst + i0 / 2 + 1] = x[ix + i0 / 2 + 1];
+
         return;
     }
 
-    const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
+    const int sect_dims =
+        sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
     const int sec_w = sections.v[1] + sections.v[0];
     const int sector = (i0 / 2) % sect_dims;
 
-
     float theta_base = 0.0;
     if (is_imrope) {
-        if (sector % 3 == 1 && sector < 3 * sections.v[1]) {
-            theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
-        } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) {
-            theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
-        } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) {
-            theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
+        if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h
+            theta_base = pos[i2 + ne02 * 1] * dpct::pow(theta_scale, i0 / 2.0f);
+        } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w
+            theta_base = pos[i2 + ne02 * 2] * dpct::pow(theta_scale, i0 / 2.0f);
+        } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t
+            theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
         } else {
-            theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
+            theta_base = pos[i2 + ne02 * 3] * dpct::pow(theta_scale, i0 / 2.0f);
         }
     } else {
         if (sector < sections.v[0]) {
-            theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
-        }
-        else if (sector >= sections.v[0] && sector < sec_w) {
-            theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
-        }
-        else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
-            theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
-        }
-        else if (sector >= sec_w + sections.v[2]) {
-            theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
+            theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
+        } else if (sector >= sections.v[0] && sector < sec_w) {
+            theta_base = pos[i2 + ne02 * 1] * dpct::pow(theta_scale, i0 / 2.0f);
+        } else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
+            theta_base = pos[i2 + ne02 * 2] * dpct::pow(theta_scale, i0 / 2.0f);
+        } else if (sector >= sec_w + sections.v[2]) {
+            theta_base = pos[i2 + ne02 * 3] * dpct::pow(theta_scale, i0 / 2.0f);
         }
     }
 
     const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
-    float       cos_theta;
-    float       sin_theta;
-    rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
-    const float x0 = x[ix + 0];
-    const float x1 = x[ix + n_dims/2];
 
-    // store results in dst
-    dst[idst + 0]      = x0 * cos_theta - x1 * sin_theta;
-    dst[idst + n_dims/2] = x0 * sin_theta + x1 * cos_theta;
-}
+    float cos_theta;
+    float sin_theta;
 
+    rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,
+                       ext_factor, attn_factor, cos_theta, sin_theta);
 
+    const float x0 = x[ix + 0];
+    const float x1 = x[ix + n_dims / 2];
+
+    dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
+    dst[idst + n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
+}
 
-template <typename T, bool has_ff>
-static void rope_vision(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
-                        const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
-                        const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
-                        const float theta_scale, const float * freq_factors, const mrope_sections sections,
-                        const sycl::nd_item<3> & item_ct1) {
-    // get index pos
-    const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1));
-    if (i0 >= ne0) {
+template <bool forward, bool has_ff, typename T>
+static void rope_vision(const T *x, T *dst, const int ne00, const int ne01,
+                        const int ne02, const int s01, const int s02,
+                        const int s03, const int s1, const int s2, const int s3,
+                        const int n_dims, const int32_t *pos,
+                        const float freq_scale, const float ext_factor,
+                        const float attn_factor, const rope_corr_dims corr_dims,
+                        const float theta_scale, const float *freq_factors,
+                        const mrope_sections sections) {
+    auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+    const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
+                        item_ct1.get_local_id(1));
+
+    if (i0 >= ne00) {
         return;
     }
-    const int    row_dst   = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
-    const int    row_x     = row_dst % ne1;
-    const int    channel_x = row_dst / ne1;
-    const int    idst      = (row_dst * ne0) + (i0 / 2);
-    const size_t ix        = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
+
+    const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                        item_ct1.get_local_id(2);
+
+    const uint32_t i3 = row_dst / (ne01 * ne02);
+    const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
+    const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
+
+    int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
+    const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
 
     const int sect_dims = sections.v[0] + sections.v[1];
-    const int sector    = (i0 / 2) % sect_dims;
+    const int sec_w = sections.v[1] + sections.v[0];
+    const int sector = (i0 / 2) % sect_dims;
 
-    float theta_base = 0.0f;
+    float theta_base = 0.0;
     if (sector < sections.v[0]) {
         const int p = sector;
-        theta_base  = pos[channel_x] * sycl::pow(theta_scale, (float) p);
-    } else {
+        theta_base = pos[i2] * dpct::pow(theta_scale, p);
+    } else if (sector >= sections.v[0] && sector < sec_w) {
         const int p = sector - sections.v[0];
-        theta_base  = pos[channel_x + ne2] * sycl::pow(theta_scale, (float) p);
+        theta_base = pos[i2 + ne02] * dpct::pow(theta_scale, p);
     }
 
     const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
-    float       cos_theta;
-    float       sin_theta;
-    rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+
+    float cos_theta;
+    float sin_theta;
+
+    rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,
+                       ext_factor, attn_factor, cos_theta, sin_theta);
+
     const float x0 = x[ix + 0];
     const float x1 = x[ix + n_dims];
 
-    // store results in dst
-    dst[idst + 0]      = x0 * cos_theta - x1 * sin_theta;
+    dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
     dst[idst + n_dims] = x0 * sin_theta + x1 * cos_theta;
 }
 
-template <typename T>
-static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2,
-                           const int n_dims, int nr, const int32_t * pos, const float freq_scale, const float freq_base,
-                           const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
-                           const float * freq_factors, queue_ptr stream) {
-    GGML_ASSERT(ne0 % 2 == 0);
-    const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
-    const int            num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
-    const sycl::range<3> block_nums(1, num_blocks_x, nr);
+template <bool forward, typename T, typename D>
+static void
+rope_norm_sycl(const T *x, D *dst, const int ne00, const int ne01,
+               const int ne02, const int s01, const int s02, const int s03,
+               const int s1, const int s2, const int s3, const int n_dims,
+               const int nr, const int32_t *pos, const float freq_scale,
+               const float freq_base, const float ext_factor,
+               const float attn_factor, const rope_corr_dims corr_dims,
+               const float *freq_factors, const int64_t *row_indices,
+               const int set_rows_stride, dpct::queue_ptr stream) {
+    GGML_ASSERT(ne00 % 2 == 0);
+    const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
+    const int n_blocks_x =
+        (ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
+    const dpct::dim3 block_nums(nr, n_blocks_x, 1);
 
     const float theta_scale = powf(freq_base, -2.0f / n_dims);
 
-    dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
-
     if (freq_factors == nullptr) {
-        /*
-        DPCT1049:40: The work-group size passed to the SYCL kernel may exceed
-        the limit. To get the device limit, query
-        info::device::max_work_group_size. Adjust the work-group size if needed.
-        */
-        stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
-            rope_norm<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
-                                theta_scale, freq_factors, item_ct1);
-        });
+        stream->parallel_for(
+            sycl::nd_range<3>(block_nums * block_dims, block_dims),
+            [=](sycl::nd_item<3> item_ct1) {
+                GGML_UNUSED(item_ct1);
+                rope_norm<forward, false>(
+                    x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
+                    pos, freq_scale, ext_factor, attn_factor, corr_dims,
+                    theta_scale, freq_factors, row_indices, set_rows_stride);
+            });
     } else {
-        /*
-        DPCT1049:41: The work-group size passed to the SYCL kernel may exceed
-        the limit. To get the device limit, query
-        info::device::max_work_group_size. Adjust the work-group size if needed.
-        */
-        stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
-            rope_norm<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
-                               theta_scale, freq_factors, item_ct1);
-        });
+        stream->parallel_for(
+            sycl::nd_range<3>(block_nums * block_dims, block_dims),
+            [=](sycl::nd_item<3> item_ct1) {
+                GGML_UNUSED(item_ct1);
+                rope_norm<forward, true>(
+                    x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
+                    pos, freq_scale, ext_factor, attn_factor, corr_dims,
+                    theta_scale, freq_factors, row_indices, set_rows_stride);
+            });
     }
 }
 
-template <typename T>
-static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2,
-                           const int n_dims, const int nr, const int32_t * pos, const float freq_scale,
-                           const float freq_base, const float ext_factor, const float attn_factor,
-                           const rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
-    GGML_ASSERT(ne0 % 2 == 0);
-    const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
-    const int            num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
-    const sycl::range<3> block_nums(1, num_blocks_x, nr);
+template <bool forward, typename T, typename D>
+static void
+rope_neox_sycl(const T *x, D *dst, const int ne00, const int ne01,
+               const int ne02, const int s01, const int s02, const int s03,
+               const int s1, const int s2, const int s3, const int n_dims,
+               const int nr, const int32_t *pos, const float freq_scale,
+               const float freq_base, const float ext_factor,
+               const float attn_factor, const rope_corr_dims corr_dims,
+               const float *freq_factors, const int64_t *row_indices,
+               const int set_rows_stride, dpct::queue_ptr stream) {
+    GGML_ASSERT(ne00 % 2 == 0);
+    const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
+    const int n_blocks_x =
+        (ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
+    const dpct::dim3 block_nums(nr, n_blocks_x, 1);
 
     const float theta_scale = powf(freq_base, -2.0f / n_dims);
 
-    dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
-
     if (freq_factors == nullptr) {
-        stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
-            rope_neox<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
-                                theta_scale, freq_factors, item_ct1);
-        });
+        stream->parallel_for(
+            sycl::nd_range<3>(block_nums * block_dims, block_dims),
+            [=](sycl::nd_item<3> item_ct1) {
+                GGML_UNUSED(item_ct1);
+                rope_neox<forward, false>(
+                    x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
+                    pos, freq_scale, ext_factor, attn_factor, corr_dims,
+                    theta_scale, freq_factors, row_indices, set_rows_stride);
+            });
     } else {
-        stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
-            rope_neox<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
-                               theta_scale, freq_factors, item_ct1);
-        });
+        stream->parallel_for(
+            sycl::nd_range<3>(block_nums * block_dims, block_dims),
+            [=](sycl::nd_item<3> item_ct1) {
+                GGML_UNUSED(item_ct1);
+                rope_neox<forward, true>(
+                    x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
+                    pos, freq_scale, ext_factor, attn_factor, corr_dims,
+                    theta_scale, freq_factors, row_indices, set_rows_stride);
+            });
     }
 }
 
-template <typename T>
-static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
-                             const size_t s2, const int n_dims, const int nr, const int32_t * pos,
-                             const float freq_scale, const float freq_base, const float ext_factor,
-                             const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors,
-                             const mrope_sections sections, const bool is_imrope, queue_ptr stream) {
-    GGML_ASSERT(ne0 % 2 == 0);
-    const sycl::range<3>    block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
-    const int               n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
-    const sycl::range<3>    grid_dims(1, n_blocks_y, nr);
-    const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims);
-
-    const float theta_scale = std::pow(freq_base, -2.0f / n_dims);
-    // Add FP16 capability check if T could be sycl::half
-    if constexpr (std::is_same_v<T, sycl::half>) {
-        dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
-    }
-    // launch kernel
+template <bool forward, typename T>
+static void
+rope_multi_sycl(const T *x, T *dst, const int ne00, const int ne01,
+                const int ne02, const int s01, const int s02, const int s03,
+                const int s1, const int s2, const int s3, const int n_dims,
+                const int nr, const int32_t *pos, const float freq_scale,
+                const float freq_base, const float ext_factor,
+                const float attn_factor, const rope_corr_dims corr_dims,
+                const float *freq_factors, const mrope_sections sections,
+                const bool is_imrope, dpct::queue_ptr stream) {
+    GGML_ASSERT(ne00 % 2 == 0);
+    const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
+    const int n_blocks_x =
+        (ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
+    const dpct::dim3 block_nums(nr, n_blocks_x, 1);
+
+    const float theta_scale = powf(freq_base, -2.0f / n_dims);
+
     if (freq_factors == nullptr) {
-        stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
-            rope_multi<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
-                                  corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1);
-        });
+        stream->parallel_for(
+            sycl::nd_range<3>(block_nums * block_dims, block_dims),
+            [=](sycl::nd_item<3> item_ct1) {
+                GGML_UNUSED(item_ct1);
+                rope_multi<forward, false, T>(
+                    x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
+                    pos, freq_scale, ext_factor, attn_factor, corr_dims,
+                    theta_scale, freq_factors, sections, is_imrope);
+            });
     } else {
-        stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
-            rope_multi<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
-                                 corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1);
-        });
+        stream->parallel_for(
+            sycl::nd_range<3>(block_nums * block_dims, block_dims),
+            [=](sycl::nd_item<3> item_ct1) {
+                GGML_UNUSED(item_ct1);
+                rope_multi<forward, true, T>(
+                    x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
+                    pos, freq_scale, ext_factor, attn_factor, corr_dims,
+                    theta_scale, freq_factors, sections, is_imrope);
+            });
     }
 }
 
+template <bool forward, typename T>
+static void
+rope_vision_sycl(const T *x, T *dst, const int ne00, const int ne01,
+                 const int ne02, const int s01, const int s02, const int s03,
+                 const int s1, const int s2, const int s3, const int n_dims,
+                 const int nr, const int32_t *pos, const float freq_scale,
+                 const float freq_base, const float ext_factor,
+                 const float attn_factor, const rope_corr_dims corr_dims,
+                 const float *freq_factors, const mrope_sections sections,
+                 dpct::queue_ptr stream) {
+    GGML_ASSERT(ne00 % 2 == 0);
+    const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
+    const int n_blocks_x =
+        (ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
+    const dpct::dim3 block_nums(nr, n_blocks_x, 1);
 
+    const float theta_scale = powf(freq_base, -2.0f / n_dims);
 
-
-// rope vision
-template <typename T>
-static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
-                             const size_t s2, const int n_dims, const int nr, const int32_t * pos,
-                             const float freq_scale, const float freq_base, const float ext_factor,
-                             const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors,
-                             const mrope_sections sections, queue_ptr stream) {
-    GGML_ASSERT(ne0 % 2 == 0);
-    const sycl::range<3>    block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
-    const int               n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
-    const sycl::range<3>    grid_dims(1, n_blocks_y, nr);
-    const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims);
-
-    const float theta_scale = std::pow(freq_base, -2.0f / n_dims);
-    // Add FP16 capability check if T could be sycl::half
-    if constexpr (std::is_same_v<T, sycl::half>) {
-        dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
-    }
-    // launch kernel
     if (freq_factors == nullptr) {
-        stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
-            rope_vision<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
-                                  corr_dims, theta_scale, freq_factors, sections, item_ct1);
-        });
+        stream->parallel_for(
+            sycl::nd_range<3>(block_nums * block_dims, block_dims),
+            [=](sycl::nd_item<3> item_ct1) {
+                GGML_UNUSED(item_ct1);
+                rope_vision<forward, false, T>(
+                    x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
+                    pos, freq_scale, ext_factor, attn_factor, corr_dims,
+                    theta_scale, freq_factors, sections);
+            });
     } else {
-        stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
-            rope_vision<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
-                                 corr_dims, theta_scale, freq_factors, sections, item_ct1);
-        });
+        stream->parallel_for(
+            sycl::nd_range<3>(block_nums * block_dims, block_dims),
+            [=](sycl::nd_item<3> item_ct1) {
+                GGML_UNUSED(item_ct1);
+                rope_vision<forward, true, T>(
+                    x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
+                    pos, freq_scale, ext_factor, attn_factor, corr_dims,
+                    theta_scale, freq_factors, sections);
+            });
     }
 }
 
-inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
+template <bool forward>
+void ggml_sycl_op_rope_impl(ggml_backend_sycl_context &ctx, ggml_tensor *dst,
+                            const ggml_tensor *set_rows = nullptr) {
+    const ggml_tensor *src0 = dst->src[0];
+    const ggml_tensor *src1 = dst->src[1];
+    const ggml_tensor *src2 = dst->src[2];
+
+    const float *src0_d = (const float *)src0->data;
+    const float *src1_d = (const float *)src1->data;
+
+    void *dst_d = dst->data;
+    const int64_t *row_indices = nullptr;
+    ggml_type dst_type = dst->type;
+    int set_rows_stride = 0;
+
+    if (set_rows != nullptr) {
+        GGML_ASSERT(forward);
+        dst_d = set_rows->data;
+        row_indices = (const int64_t *)set_rows->src[1]->data;
+        dst_type = set_rows->type;
+        set_rows_stride = set_rows->nb[1] / ggml_type_size(set_rows->type);
+    }
+    dpct::queue_ptr stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
+    GGML_ASSERT(src0->type == dst->type ||
+                (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16));
 
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
-    GGML_ASSERT(dst->src[0]->type == dst->type);
-    const int64_t ne00 = dst->src[0]->ne[0]; // head dims
-    const int64_t ne01 = dst->src[0]->ne[1]; // num heads
-    const int64_t ne02 = dst->src[0]->ne[2]; // num heads
-    const int64_t nr = ggml_nrows(dst->src[0]);
+    const int64_t ne00 = src0->ne[0]; // head dims
+    const int64_t ne01 = src0->ne[1]; // num heads
+    const int64_t ne02 = src0->ne[2]; // num heads
+    const int64_t nr = ggml_nrows(src0);
 
-    const size_t s01 = dst->src[0]->nb[1] / ggml_type_size(dst->src[0]->type);
-    const size_t s02 = dst->src[0]->nb[2] / ggml_type_size(dst->src[0]->type);
+    const size_t s01 = src0->nb[1] / ggml_type_size(src0->type);
+    const size_t s02 = src0->nb[2] / ggml_type_size(src0->type);
+    const size_t s03 = src0->nb[3] / ggml_type_size(src0->type);
 
+    const size_t s1 = dst->nb[1] / ggml_type_size(dst->type);
+    const size_t s2 = dst->nb[2] / ggml_type_size(dst->type);
+    const size_t s3 = dst->nb[3] / ggml_type_size(dst->type);
 
-    //const int n_past      = ((int32_t *) dst->op_params)[0];
-    const int n_dims      = ((int32_t *) dst->op_params)[1];
-    const int mode        = ((int32_t *) dst->op_params)[2];
-    //const int n_ctx       = ((int32_t *) dst->op_params)[3];
-    const int n_ctx_orig  = ((int32_t *) dst->op_params)[4];
+    const int n_dims = ((int32_t *)dst->op_params)[1];
+    const int mode = ((int32_t *)dst->op_params)[2];
+    const int n_ctx_orig = ((int32_t *)dst->op_params)[4];
     mrope_sections sections;
 
-    // RoPE alteration for extended context
     float freq_base;
     float freq_scale;
     float ext_factor;
@@ -382,13 +506,13 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
     float beta_fast;
     float beta_slow;
 
-    memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
-    memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
-    memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
-    memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
-    memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
-    memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
-    memcpy(&sections.v,  (int32_t *) dst->op_params + 11, sizeof(int)*4);
+    memcpy(&freq_base, (int32_t *)dst->op_params + 5, sizeof(float));
+    memcpy(&freq_scale, (int32_t *)dst->op_params + 6, sizeof(float));
+    memcpy(&ext_factor, (int32_t *)dst->op_params + 7, sizeof(float));
+    memcpy(&attn_factor, (int32_t *)dst->op_params + 8, sizeof(float));
+    memcpy(&beta_fast, (int32_t *)dst->op_params + 9, sizeof(float));
+    memcpy(&beta_slow, (int32_t *)dst->op_params + 10, sizeof(float));
+    memcpy(&sections.v, (int32_t *)dst->op_params + 11, sizeof(int) * 4);
 
     const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
     const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
@@ -396,82 +520,122 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
     const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
 
     if (is_mrope) {
-        GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || sections.v[2] > 0);
+        GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 ||
+                    sections.v[2] > 0);
     }
 
     if (is_vision) {
-        GGML_ASSERT(n_dims == ne00/2);
+        GGML_ASSERT(n_dims == ne00 / 2);
     }
 
-    const int32_t * pos = (const int32_t *) dst->src[1]->data;
+    const int32_t *pos = (const int32_t *)src1_d;
 
-    const float * freq_factors = nullptr;
-    if (dst->src[2] != nullptr) {
-        freq_factors = (const float *) dst->src[2]->data;
+    const float *freq_factors = nullptr;
+    if (src2 != nullptr) {
+        freq_factors = (const float *)src2->data;
     }
 
     rope_corr_dims corr_dims;
-    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
-
-    dpct::queue_ptr main_stream = ctx.stream();
-    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
+    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast,
+                             beta_slow, corr_dims.v);
 
     // compute
     if (is_neox) {
         GGML_SYCL_DEBUG("%s: neox path\n", __func__);
-        if (dst->src[0]->type == GGML_TYPE_F32) {
-            rope_neox_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr,
-                           pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream);
-        } else if (dst->src[0]->type == GGML_TYPE_F16) {
-            rope_neox_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02,
-                           n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
-                           main_stream);
+        if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
+            rope_neox_sycl<forward, float, float>(
+                (const float *)src0_d, (float *)dst_d, ne00, ne01, ne02, s01,
+                s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
+                ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
+                set_rows_stride, stream);
+        } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
+            rope_neox_sycl<forward, float, sycl::half>(
+                (const float *)src0_d, (sycl::half *)dst_d, ne00, ne01, ne02,
+                s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
+                freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
+                row_indices, set_rows_stride, stream);
+        } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
+            rope_neox_sycl<forward, sycl::half, sycl::half>(
+                (const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
+                ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
+                freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
+                row_indices, set_rows_stride, stream);
         } else {
-            GGML_ABORT("fatal error");
+            GGML_ABORT("Fatal error: Tensor type unsupported!");
         }
     } else if (is_mrope && !is_vision) {
         GGML_SYCL_DEBUG("%s: mrope path\n", __func__);
-        if (dst->src[0]->type == GGML_TYPE_F16) {
-            rope_multi_sycl((const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, ne01, ne02, s01,
-                s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
-                freq_factors, sections, is_imrope, main_stream);
-        } else if (dst->src[0]->type == GGML_TYPE_F32) {
-            rope_multi_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
-                             nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
-                             is_imrope, main_stream);
+        if (src0->type == GGML_TYPE_F32) {
+            rope_multi_sycl<forward>((const float *)src0_d, (float *)dst_d,
+                                     ne00, ne01, ne02, s01, s02, s03, s1, s2,
+                                     s3, n_dims, nr, pos, freq_scale, freq_base,
+                                     ext_factor, attn_factor, corr_dims,
+                                     freq_factors, sections, is_imrope, stream);
+        } else if (src0->type == GGML_TYPE_F16) {
+            rope_multi_sycl<forward>(
+                (const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
+                ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
+                freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
+                sections, is_imrope, stream);
         } else {
             GGML_ABORT("Fatal error: Tensor type unsupported!");
         }
     } else if (is_vision) {
         GGML_SYCL_DEBUG("%s: vision path\n", __func__);
-        if (dst->src[0]->type == GGML_TYPE_F16) {
-            rope_vision_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, ne02, s01,
-                             s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
-                             freq_factors, sections, main_stream);
-        } else if (dst->src[0]->type == GGML_TYPE_F32) {
-            rope_vision_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
-                             nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
-                             main_stream);
+        if (src0->type == GGML_TYPE_F32) {
+            rope_vision_sycl<forward>(
+                (const float *)src0_d, (float *)dst_d, ne00, ne01, ne02, s01,
+                s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
+                ext_factor, attn_factor, corr_dims, freq_factors, sections,
+                stream);
+        } else if (src0->type == GGML_TYPE_F16) {
+            rope_vision_sycl<forward>(
+                (const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
+                ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
+                freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
+                sections, stream);
         } else {
             GGML_ABORT("Fatal error: Tensor type unsupported!");
         }
     } else {
         GGML_SYCL_DEBUG("%s: norm path\n", __func__);
-        if (dst->src[0]->type == GGML_TYPE_F32) {
-            rope_norm_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr,
-                           pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream);
-        } else if (dst->src[0]->type == GGML_TYPE_F16) {
-            rope_norm_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02,
-                           n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
-                           main_stream);
+        if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
+            rope_norm_sycl<forward, float, float>(
+                (const float *)src0_d, (float *)dst_d, ne00, ne01, ne02, s01,
+                s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
+                ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
+                set_rows_stride, stream);
+        } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
+            rope_norm_sycl<forward, float, sycl::half>(
+                (const float *)src0_d, (sycl::half *)dst_d, ne00, ne01, ne02,
+                s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
+                freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
+                row_indices, set_rows_stride, stream);
+        } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
+            rope_norm_sycl<forward, sycl::half, sycl::half>(
+                (const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
+                ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
+                freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
+                row_indices, set_rows_stride, stream);
         } else {
-            GGML_ABORT("fatal error");
+            GGML_ABORT("Fatal error: Tensor type unsupported!");
         }
     }
 }
 
-void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+void ggml_sycl_rope(ggml_backend_sycl_context &ctx, ggml_tensor *dst) {
     scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3);
-    ggml_sycl_op_rope(ctx, dst);
+
+    ggml_sycl_op_rope_impl<true>(ctx, dst);
 }
 
+void ggml_sycl_rope_back(ggml_backend_sycl_context &ctx, ggml_tensor *dst) {
+    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3);
+    ggml_sycl_op_rope_impl<false>(ctx, dst);
+}
+
+void ggml_sycl_rope_fused(ggml_backend_sycl_context &ctx, ggml_tensor *rope,
+                          ggml_tensor *set_rows) {
+    scope_op_debug_print scope_dbg_print(__func__, rope, /*num_src=*/3);
+    ggml_sycl_op_rope_impl<true>(ctx, rope, set_rows);
+}
index 8c7141aac5c9b1fed0d9ff573ec78f614b7a400e..b95a585808be42c20fe5221f1db5d32c663feee8 100644 (file)
 
 #include "common.hpp"
 
+#define SYCL_ROPE_BLOCK_SIZE 256
+
 void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst);
 
+void ggml_sycl_rope_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+
+void ggml_sycl_rope_fused(ggml_backend_sycl_context & ctx, ggml_tensor * dst, ggml_tensor * set_rows);
+
 #endif // GGML_SYCL_ROPE_HPP