]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
cuda : fix rope with partial rotation and non-cont src (llama/14580)
authorGeorgi Gerganov <redacted>
Tue, 8 Jul 2025 07:15:21 +0000 (10:15 +0300)
committerGeorgi Gerganov <redacted>
Sat, 12 Jul 2025 16:23:56 +0000 (19:23 +0300)
* cuda : fix rope non-cont

ggml-ci

* cont : fix multi-rope + add test

ggml-ci

* sycl : try fix

ggml-ci

* cont : fix sycl + clean-up cuda

ggml-ci

ggml/src/ggml-cuda/rope.cu
ggml/src/ggml-sycl/rope.cpp

index 18f691b2d3103fce13cd5cd60f522798157d91e6..d058504cd6cc036d690cb2cfea00e36f8e80630c 100644 (file)
@@ -50,21 +50,19 @@ static __global__ void rope_norm(
 
     const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
 
-    if (i0 >= n_dims) {
-        const int i = row_dst*ne0 + i0;
-
-        dst[i + 0] = x[i + 0];
-        dst[i + 1] = x[i + 1];
-
-        return;
-    }
-
     const int row_x     = row_dst % ne1;
     const int channel_x = row_dst / ne1;
 
     const int idst = row_dst*ne0 + i0;
     const int ix   = channel_x*s2 + row_x*s1 + i0;
 
+    if (i0 >= n_dims) {
+        dst[idst + 0] = x[ix + 0];
+        dst[idst + 1] = x[ix + 1];
+
+        return;
+    }
+
     const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
 
     const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@@ -94,21 +92,19 @@ static __global__ void rope_neox(
 
     const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
 
-    if (i0 >= n_dims) {
-        const int i = row_dst*ne0 + i0;
-
-        dst[i + 0] = x[i + 0];
-        dst[i + 1] = x[i + 1];
-
-        return;
-    }
-
     const int row_x     = row_dst % ne1;
     const int channel_x = row_dst / ne1;
 
     const int idst = row_dst*ne0 + i0/2;
     const int ix   = channel_x*s2 + row_x*s1 + i0/2;
 
+    if (i0 >= n_dims) {
+        dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
+        dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
+
+        return;
+    }
+
     const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
 
     const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@@ -138,21 +134,19 @@ static __global__ void rope_multi(
 
     const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
 
-    if (i0 >= n_dims) {
-        const int i = row_dst*ne0 + i0;
-
-        dst[i + 0] = x[i + 0];
-        dst[i + 1] = x[i + 1];
-
-        return;
-    }
-
     const int row_x     = row_dst % ne1;
     const int channel_x = row_dst / ne1;
 
     const int idst = row_dst*ne0 + i0/2;
     const int ix   = channel_x*s2 + row_x*s1 + i0/2;
 
+    if (i0 >= n_dims) {
+        dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
+        dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
+
+        return;
+    }
+
     const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
     const int sec_w = sections.v[1] + sections.v[0];
     const int sector = (i0 / 2) % sect_dims;
index e44c6b6ef8f42bd0d867b1e416b25ad94d0302ac..1b60226dcd531ed798cc1499c593260833e6c096 100644 (file)
@@ -47,18 +47,17 @@ static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const
 
     const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
 
-    if (i0 >= n_dims) {
-        const int i = row * ne0 + i0;
-        *reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
-        return;
-    }
-
     const int row0     = row % ne1;
     const int channel0 = row / ne1;
 
     const int i  = row * ne0 + i0;
     const int i2 = channel0 * s2 + row0 * s1 + i0;
 
+    if (i0 >= n_dims) {
+        *reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2);
+        return;
+    }
+
     const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
 
     const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
@@ -88,18 +87,17 @@ static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const
 
     const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
 
-    if (i0 >= n_dims) {
-        const int i = row * ne0 + i0;
-        *reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
-        return;
-    }
-
     const int row0     = row % ne1;
     const int channel0 = row / ne1;
 
     const int i  = row * ne0 + i0 / 2;
     const int i2 = channel0 * s2 + row0 * s1 + i0 / 2;
 
+    if (i0 >= n_dims) {
+        *reinterpret_cast<sycl::vec<T, 2> *>(dst + i + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2 + i0 / 2);
+        return;
+    }
+
     const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
 
     const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
@@ -129,17 +127,16 @@ static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const
     }
     const int    row_dst   = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
 
-    if (i0 >= n_dims) {
-        const int i = row_dst*ne0 + i0;
-        *reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
-        return;
-    }
-
     const int    row_x     = row_dst % ne1;
     const int    channel_x = row_dst / ne1;
     const int    idst      = (row_dst * ne0) + (i0 / 2);
     const size_t ix        = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
 
+    if (i0 >= n_dims) {
+        *reinterpret_cast<sycl::vec<T, 2> *>(dst + idst + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i0 / 2 + ix);
+        return;
+    }
+
     const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
     const int sec_w = sections.v[1] + sections.v[0];
     const int sector = (i0 / 2) % sect_dims;