]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
cuda : fix rope with partial rotation and non-cont src (llama/14580)
authorGeorgi Gerganov <redacted>
Tue, 8 Jul 2025 07:15:21 +0000 (10:15 +0300)
committerGeorgi Gerganov <redacted>
Sat, 12 Jul 2025 13:05:00 +0000 (16:05 +0300)
* cuda : fix rope non-cont

ggml-ci

* cont : fix multi-rope + add test

ggml-ci

* sycl : try fix

ggml-ci

* cont : fix sycl + clean-up cuda

ggml-ci

src/ggml-cuda/rope.cu
src/ggml-sycl/rope.cpp
tests/test-backend-ops.cpp

index 18f691b2d3103fce13cd5cd60f522798157d91e6..d058504cd6cc036d690cb2cfea00e36f8e80630c 100644 (file)
@@ -50,21 +50,19 @@ static __global__ void rope_norm(
 
     const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
 
-    if (i0 >= n_dims) {
-        const int i = row_dst*ne0 + i0;
-
-        dst[i + 0] = x[i + 0];
-        dst[i + 1] = x[i + 1];
-
-        return;
-    }
-
     const int row_x     = row_dst % ne1;
     const int channel_x = row_dst / ne1;
 
     const int idst = row_dst*ne0 + i0;
     const int ix   = channel_x*s2 + row_x*s1 + i0;
 
+    if (i0 >= n_dims) {
+        dst[idst + 0] = x[ix + 0];
+        dst[idst + 1] = x[ix + 1];
+
+        return;
+    }
+
     const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
 
     const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@@ -94,21 +92,19 @@ static __global__ void rope_neox(
 
     const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
 
-    if (i0 >= n_dims) {
-        const int i = row_dst*ne0 + i0;
-
-        dst[i + 0] = x[i + 0];
-        dst[i + 1] = x[i + 1];
-
-        return;
-    }
-
     const int row_x     = row_dst % ne1;
     const int channel_x = row_dst / ne1;
 
     const int idst = row_dst*ne0 + i0/2;
     const int ix   = channel_x*s2 + row_x*s1 + i0/2;
 
+    if (i0 >= n_dims) {
+        dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
+        dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
+
+        return;
+    }
+
     const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
 
     const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@@ -138,21 +134,19 @@ static __global__ void rope_multi(
 
     const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
 
-    if (i0 >= n_dims) {
-        const int i = row_dst*ne0 + i0;
-
-        dst[i + 0] = x[i + 0];
-        dst[i + 1] = x[i + 1];
-
-        return;
-    }
-
     const int row_x     = row_dst % ne1;
     const int channel_x = row_dst / ne1;
 
     const int idst = row_dst*ne0 + i0/2;
     const int ix   = channel_x*s2 + row_x*s1 + i0/2;
 
+    if (i0 >= n_dims) {
+        dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
+        dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
+
+        return;
+    }
+
     const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
     const int sec_w = sections.v[1] + sections.v[0];
     const int sector = (i0 / 2) % sect_dims;
index e44c6b6ef8f42bd0d867b1e416b25ad94d0302ac..1b60226dcd531ed798cc1499c593260833e6c096 100644 (file)
@@ -47,18 +47,17 @@ static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const
 
     const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
 
-    if (i0 >= n_dims) {
-        const int i = row * ne0 + i0;
-        *reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
-        return;
-    }
-
     const int row0     = row % ne1;
     const int channel0 = row / ne1;
 
     const int i  = row * ne0 + i0;
     const int i2 = channel0 * s2 + row0 * s1 + i0;
 
+    if (i0 >= n_dims) {
+        *reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2);
+        return;
+    }
+
     const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
 
     const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
@@ -88,18 +87,17 @@ static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const
 
     const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
 
-    if (i0 >= n_dims) {
-        const int i = row * ne0 + i0;
-        *reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
-        return;
-    }
-
     const int row0     = row % ne1;
     const int channel0 = row / ne1;
 
     const int i  = row * ne0 + i0 / 2;
     const int i2 = channel0 * s2 + row0 * s1 + i0 / 2;
 
+    if (i0 >= n_dims) {
+        *reinterpret_cast<sycl::vec<T, 2> *>(dst + i + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2 + i0 / 2);
+        return;
+    }
+
     const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
 
     const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
@@ -129,17 +127,16 @@ static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const
     }
     const int    row_dst   = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
 
-    if (i0 >= n_dims) {
-        const int i = row_dst*ne0 + i0;
-        *reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
-        return;
-    }
-
     const int    row_x     = row_dst % ne1;
     const int    channel_x = row_dst / ne1;
     const int    idst      = (row_dst * ne0) + (i0 / 2);
     const size_t ix        = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
 
+    if (i0 >= n_dims) {
+        *reinterpret_cast<sycl::vec<T, 2> *>(dst + idst + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i0 / 2 + ix);
+        return;
+    }
+
     const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
     const int sec_w = sections.v[1] + sections.v[0];
     const int sector = (i0 / 2) % sect_dims;
index 652856a35d5e7b96a56ad964054b5c607fc79c37..b54bcc8a35e640b27b086bf340cb927b38c0675b 100644 (file)
@@ -5323,12 +5323,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     for (bool fw : {true, false}) { // fw == forward
         bool all = true;
 
-        for (float v : { 0, 1 }) {
-            for (float fs : { 1.0f, 1.4245f }) {
-                for (float ef : { 0.0f, 0.7465f }) {
-                    for (float af : { 1.0f, 1.4245f }) {
-                        for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
-                            for (bool ff : {false, true}) { // freq_factors
+        for (float fs : { 1.0f, 1.4245f }) {
+            for (float ef : { 0.0f, 0.7465f }) {
+                for (float af : { 1.0f, 1.4245f }) {
+                    for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
+                        for (bool ff : {false, true}) { // freq_factors
+                            for (float v : { 0, 1 }) {
                                 test_cases.emplace_back(new test_rope(type, {128,  32, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 7B
 
                                 if (all) {
@@ -5341,13 +5341,21 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
                                     test_cases.emplace_back(new test_rope(type, { 64,   1, 2, 1},  64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B)
                                     test_cases.emplace_back(new test_rope(type, { 64,  71, 2, 1},  64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B)
                                     test_cases.emplace_back(new test_rope(type, { 64,   8, 2, 1},  64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)
+
+                                    test_cases.emplace_back(new test_rope(type, { 80,  32, 2, 1},  20, 0, 512, fs, ef, af, ff, v, fw));
+                                    test_cases.emplace_back(new test_rope(type, { 80,  32, 2, 1},  32, 0, 512, fs, ef, af, ff, v, fw));
+                                    test_cases.emplace_back(new test_rope(type, { 80,  32, 4, 1},  32, 0, 512, fs, ef, af, ff, v, fw));
+
                                     test_cases.emplace_back(new test_rope(type, { 80,  32, 2, 1},  20, 2, 512, fs, ef, af, ff, v, fw)); // neox (stablelm)
                                     test_cases.emplace_back(new test_rope(type, { 80,  32, 2, 1},  32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
+                                    test_cases.emplace_back(new test_rope(type, { 80,  32, 4, 1},  32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
                                 }
 
                                 if (all) {
                                     test_cases.emplace_back(new test_rope(type, {128,  12, 2, 1}, 128, GGML_ROPE_TYPE_MROPE,  512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 2B)
                                     test_cases.emplace_back(new test_rope(type, {128,  28, 2, 1}, 128, GGML_ROPE_TYPE_MROPE,  512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 7B)
+                                    test_cases.emplace_back(new test_rope(type, {128,  12, 2, 1},  20, GGML_ROPE_TYPE_MROPE,  512, fs, ef, af, ff, v, fw));
+                                    test_cases.emplace_back(new test_rope(type, {128,  28, 2, 1},  32, GGML_ROPE_TYPE_MROPE,  512, fs, ef, af, ff, v, fw));
                                     test_cases.emplace_back(new test_rope(type, { 80,  16, 2, 1},  80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT)
                                 }