]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
SYCL: Add non-contiguous support in ROPE (#12993)
authorAkarshan Biswas <redacted>
Mon, 21 Apr 2025 13:43:30 +0000 (19:13 +0530)
committerGitHub <redacted>
Mon, 21 Apr 2025 13:43:30 +0000 (19:13 +0530)
ggml-ci

ggml/src/ggml-sycl/ggml-sycl.cpp
ggml/src/ggml-sycl/rope.cpp
ggml/src/ggml-sycl/rope.hpp

index 1de34c96298b9ee57236a82d8f264e2ec71b97f2..8081a77b74f673fb92a9d51ab9282a84134228fe 100644 (file)
@@ -3168,11 +3168,6 @@ static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor
     ggml_sycl_op_diag_mask_inf(ctx, dst);
 }
 
-static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
-    GGML_ASSERT(ggml_is_contiguous(dst->src[0])); // TODO: this restriction is temporary until non-cont support is implemented
-    ggml_sycl_op_rope(ctx, dst);
-}
-
 static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     ggml_sycl_op_pool2d(ctx, dst);
 }
@@ -4002,7 +3997,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
                 if (mode == GGML_ROPE_TYPE_MROPE) {
                     return false;
                 }
-                return ggml_is_contiguous(op->src[0]);
+                return true;
             }
         case GGML_OP_IM2COL:
             return true;
index 80e050f2414964236122c9faab4a3680c44c3d3d..4e276d3b62e42b20b048334a4e864714ad97e13b 100644 (file)
@@ -34,23 +34,21 @@ static void rope_yarn(
     *sin_theta = sycl::sin(theta) * mscale;
 }
 
-template<typename T, bool has_ff>
-static void rope_norm(
-    const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors,
-    const sycl::nd_item<3> &item_ct1) {
-    const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
-                         item_ct1.get_local_id(1));
+template <typename T, bool has_ff>
+static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
+                      const int32_t * pos, float freq_scale, float ext_factor, float attn_factor,
+                      const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors,
+                      const sycl::nd_item<3> & item_ct1) {
+    const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1));
 
     if (i0 >= ne0) {
         return;
     }
 
-    const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                    item_ct1.get_local_id(2);
+    const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
 
     if (i0 >= n_dims) {
-        const int i = row*ne0 + i0;
+        const int i = row * ne0 + i0;
 
         dst[i + 0] = x[i + 0];
         dst[i + 1] = x[i + 1];
@@ -58,42 +56,43 @@ static void rope_norm(
         return;
     }
 
-    const int i = row*ne0 + i0;
-    const int i2 = row/p_delta_rows;
+    const int row0     = row % ne1;
+    const int channel0 = row / ne1;
+
+    const int i  = row * ne0 + i0;
+    const int i2 = channel0 * s2 + row0 * s1 + i0;
 
-    const float theta_base = pos[i2] * sycl::pow(theta_scale, i0 / 2.0f);
+    const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
 
-    const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
+    const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
 
     float cos_theta;
     float sin_theta;
 
-    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+    rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
 
-    const float x0 = x[i + 0];
-    const float x1 = x[i + 1];
+    const float x0 = x[i2 + 0];
+    const float x1 = x[i2 + 1];
 
-    dst[i + 0] = x0*cos_theta - x1*sin_theta;
-    dst[i + 1] = x0*sin_theta + x1*cos_theta;
+    dst[i + 0] = x0 * cos_theta - x1 * sin_theta;
+    dst[i + 1] = x0 * sin_theta + x1 * cos_theta;
 }
 
-template<typename T, bool has_ff>
-static void rope_neox(
-    const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors,
-    const sycl::nd_item<3> &item_ct1) {
-    const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
-                         item_ct1.get_local_id(1));
+template <typename T, bool has_ff>
+static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
+                      const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
+                      const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors,
+                      const sycl::nd_item<3> & item_ct1) {
+    const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1));
 
     if (i0 >= ne0) {
         return;
     }
 
-    const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                    item_ct1.get_local_id(2);
+    const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
 
     if (i0 >= n_dims) {
-        const int i = row*ne0 + i0;
+        const int i = row * ne0 + i0;
 
         dst[i + 0] = x[i + 0];
         dst[i + 1] = x[i + 1];
@@ -101,23 +100,26 @@ static void rope_neox(
         return;
     }
 
-    const int i  = row*ne0 + i0/2;
-    const int i2 = row/p_delta_rows;
+    const int row0     = row % ne1;
+    const int channel0 = row / ne1;
+
+    const int i  = row * ne0 + i0 / 2;
+    const int i2 = channel0 * s2 + row0 * s1 + i0 / 2;
 
-    const float theta_base = pos[i2] * sycl::pow(theta_scale, i0 / 2.0f);
+    const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
 
-    const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
+    const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
 
     float cos_theta;
     float sin_theta;
 
-    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+    rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
 
-    const float x0 = x[i + 0];
-    const float x1 = x[i + n_dims/2];
+    const float x0 = x[i2 + 0];
+    const float x1 = x[i2 + n_dims / 2];
 
-    dst[i + 0]        = x0*cos_theta - x1*sin_theta;
-    dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
+    dst[i + 0]          = x0 * cos_theta - x1 * sin_theta;
+    dst[i + n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
 }
 
 template <typename T, bool has_ff>
@@ -163,18 +165,18 @@ static void rope_vision(const T * x, T * dst, const int ne0, const int ne1, cons
 }
 
 template <typename T>
-static void rope_norm_sycl(
-    const T *x, T *dst, int ne0, int n_dims, int nr, const int32_t *pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
+static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2,
+                           const int n_dims, int nr, const int32_t * pos, const float freq_scale, const float freq_base,
+                           const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
+                           const float * freq_factors, queue_ptr stream) {
     GGML_ASSERT(ne0 % 2 == 0);
     const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
-    const int num_blocks_x = (ne0 + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
+    const int            num_blocks_x = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
     const sycl::range<3> block_nums(1, num_blocks_x, nr);
 
-    const float theta_scale = powf(freq_base, -2.0f/n_dims);
+    const float theta_scale = powf(freq_base, -2.0f / n_dims);
 
-    dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
+    dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
 
     if (freq_factors == nullptr) {
         /*
@@ -182,61 +184,47 @@ static void rope_norm_sycl(
         the limit. To get the device limit, query
         info::device::max_work_group_size. Adjust the work-group size if needed.
         */
-        stream->parallel_for(
-            sycl::nd_range<3>(block_nums * block_dims, block_dims),
-            [=](sycl::nd_item<3> item_ct1) {
-                rope_norm<T, false>(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows,
-                               ext_factor, attn_factor, corr_dims, theta_scale, freq_factors,
-                               item_ct1);
-            });
+        stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
+            rope_norm<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
+                                theta_scale, freq_factors, item_ct1);
+        });
     } else {
         /*
         DPCT1049:41: The work-group size passed to the SYCL kernel may exceed
         the limit. To get the device limit, query
         info::device::max_work_group_size. Adjust the work-group size if needed.
         */
-        stream->parallel_for(
-            sycl::nd_range<3>(block_nums * block_dims, block_dims),
-            [=](sycl::nd_item<3> item_ct1) {
-                rope_norm<T, true>(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows,
-                              ext_factor, attn_factor, corr_dims, theta_scale, freq_factors,
-                              item_ct1);
-            });
+        stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
+            rope_norm<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
+                               theta_scale, freq_factors, item_ct1);
+        });
     }
 }
 
 template <typename T>
-static void rope_neox_sycl(
-    const T *x, T *dst, int ne0, int n_dims, int nr, const int32_t *pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
+static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2,
+                           const int n_dims, const int nr, const int32_t * pos, const float freq_scale,
+                           const float freq_base, const float ext_factor, const float attn_factor,
+                           const rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
     GGML_ASSERT(ne0 % 2 == 0);
     const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
-    const int num_blocks_x = (ne0 + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
+    const int            num_blocks_x = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
     const sycl::range<3> block_nums(1, num_blocks_x, nr);
 
-    const float theta_scale = powf(freq_base, -2.0f/n_dims);
+    const float theta_scale = powf(freq_base, -2.0f / n_dims);
 
-    dpct::has_capability_or_fail(stream->get_device(),
-                                    {sycl::aspect::fp16});
+    dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
 
     if (freq_factors == nullptr) {
-        stream->parallel_for(
-            sycl::nd_range<3>(block_nums * block_dims, block_dims),
-            [=](sycl::nd_item<3> item_ct1) {
-                rope_neox<T, false>(x, dst, ne0, n_dims, pos, freq_scale,
-                                    p_delta_rows, ext_factor, attn_factor,
-                                    corr_dims, theta_scale, freq_factors,
-                                    item_ct1);
-            });
+        stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
+            rope_neox<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
+                                theta_scale, freq_factors, item_ct1);
+        });
     } else {
-        stream->parallel_for(
-            sycl::nd_range<3>(block_nums * block_dims, block_dims),
-            [=](sycl::nd_item<3> item_ct1) {
-                rope_neox<T, true>(x, dst, ne0, n_dims, pos, freq_scale,
-                                    p_delta_rows, ext_factor, attn_factor,
-                                    corr_dims, theta_scale, freq_factors,
-                                    item_ct1);
-            });
+        stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
+            rope_neox<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
+                               theta_scale, freq_factors, item_ct1);
+        });
     }
 }
 
@@ -272,7 +260,7 @@ static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1,
     }
 }
 
-void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
+inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
 
     GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
     GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
@@ -329,43 +317,46 @@ void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
     if (is_neox) {
         GGML_SYCL_DEBUG("%s: neox path\n", __func__);
         if (dst->src[0]->type == GGML_TYPE_F32) {
-            rope_neox_sycl(
-                (const float *)dst->src[0]->data, (float *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, freq_factors, main_stream
-            );
+            rope_neox_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr,
+                           pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream);
         } else if (dst->src[0]->type == GGML_TYPE_F16) {
-            rope_neox_sycl(
-                (const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, freq_factors, main_stream
-            );
+            rope_neox_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02,
+                           n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
+                           main_stream);
         } else {
             GGML_ABORT("fatal error");
         }
     } else if (is_vision) {
         GGML_SYCL_DEBUG("%s: vision path\n", __func__);
         if (dst->src[0]->type == GGML_TYPE_F16) {
-            rope_vision_sycl((const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
-                freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, main_stream);
+            rope_vision_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, ne02, s01,
+                             s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
+                             freq_factors, sections, main_stream);
         } else if (dst->src[0]->type == GGML_TYPE_F32) {
-            rope_vision_sycl((const float *) dst->src[0]->data, (float *)dst->data, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
-                freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, main_stream);
+            rope_vision_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
+                             nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
+                             main_stream);
         } else {
             GGML_ABORT("Fatal error: Tensor type unsupported!");
         }
     } else {
         GGML_SYCL_DEBUG("%s: norm path\n", __func__);
         if (dst->src[0]->type == GGML_TYPE_F32) {
-            rope_norm_sycl(
-                (const float *)dst->src[0]->data, (float *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, freq_factors, main_stream
-            );
+            rope_norm_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr,
+                           pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream);
         } else if (dst->src[0]->type == GGML_TYPE_F16) {
-            rope_norm_sycl(
-                (const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, freq_factors, main_stream
-            );
+            rope_norm_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02,
+                           n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
+                           main_stream);
         } else {
             GGML_ABORT("fatal error");
         }
     }
 }
+
+void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_rope(ctx, dst);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
index a399bddb8a07b76003f8cab3e55bb58fccfb2cb9..8c7141aac5c9b1fed0d9ff573ec78f614b7a400e 100644 (file)
@@ -15,6 +15,6 @@
 
 #include "common.hpp"
 
-void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst);
+void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst);
 
 #endif // GGML_SYCL_ROPE_HPP