]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
opencl: support imrope (llama/16914)
authorlhez <redacted>
Mon, 3 Nov 2025 19:47:57 +0000 (11:47 -0800)
committerGeorgi Gerganov <redacted>
Sun, 9 Nov 2025 16:30:22 +0000 (18:30 +0200)
* opencl: support imrope

* opencl: fix whitespace

src/ggml-opencl/ggml-opencl.cpp
src/ggml-opencl/kernels/rope.cl

index 93a3600b63f07f2bf8c0e5df0679b1bccc962d12..3dc4d03550931367b44834762b64258a93b0a533 100644 (file)
@@ -8399,6 +8399,7 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const
     const bool is_neox = mode & 2;
     const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
     const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
+    const int  is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
 
     if (is_mrope) {
         GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
@@ -8489,9 +8490,14 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const
     CL_CHECK(clSetKernelArg(kernel, 30, sizeof(float),    &attn_factor));
     CL_CHECK(clSetKernelArg(kernel, 31, sizeof(float),    &beta_fast));
     CL_CHECK(clSetKernelArg(kernel, 32, sizeof(float),    &beta_slow));
+    // both mrope and vision kernels have sections
     if (is_mrope || is_vision) {
         CL_CHECK(clSetKernelArg(kernel, 33, sizeof(int32_t)*4, &sections));
     }
+    // only mrope has is_imrope
+    if (is_mrope && !is_vision) {
+        CL_CHECK(clSetKernelArg(kernel, 34, sizeof(int), &is_imrope));
+    }
 
     size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
     size_t local_work_size[] = {(size_t)nth, 1, 1};
index 0247730c0365f71b28833d51ad83aa34d22f4e13..82f4cd87407d7d1436e10c485943a974fa0b5a5b 100644 (file)
@@ -392,7 +392,8 @@ kernel void kernel_rope_multi_f32(
         float attn_factor,
         float beta_fast,
         float beta_slow,
-        int4 sections
+        int4 sections,
+        int  is_imrope
 ) {
     src0 = (global void*)((global char*)src0 + offset0);
     src1 = (global int*)((global char*)src1 + offset1);
@@ -419,17 +420,29 @@ kernel void kernel_rope_multi_f32(
             const int sector = (i0 / 2) % sect_dims;
             float theta_base = 0.0f;
 
-            if (sector < sections.s0) {
-                theta_base = pos[i2];
-            }
-            else if (sector >= sections.s0 && sector < sec_w) {
-                theta_base = pos[i2 + ne2 * 1];
-            }
-            else if (sector >= sec_w && sector < sec_w + sections.s2) {
-                theta_base = pos[i2 + ne2 * 2];
-            }
-            else if (sector >= sec_w + sections.s2) {
-                theta_base = pos[i2 + ne2 * 3];
+            if (is_imrope) {
+                if (sector % 3 == 1 && sector < 3 * sections.s1) { // h
+                    theta_base = (float) pos[i2 + ne02 * 1];
+                } else if (sector % 3 == 2 && sector < 3 * sections.s2) { // w
+                    theta_base = (float) pos[i2 + ne02 * 2];
+                } else if (sector % 3 == 0 && sector < 3 * sections.s0) { // t
+                    theta_base = (float) pos[i2 + ne02 * 0];
+                } else { // e
+                    theta_base = (float) pos[i2 + ne02 * 3];
+                }
+            } else {
+                if (sector < sections.s0) {
+                    theta_base = pos[i2];
+                }
+                else if (sector >= sections.s0 && sector < sec_w) {
+                    theta_base = pos[i2 + ne2 * 1];
+                }
+                else if (sector >= sec_w && sector < sec_w + sections.s2) {
+                    theta_base = pos[i2 + ne2 * 2];
+                }
+                else if (sector >= sec_w + sections.s2) {
+                    theta_base = pos[i2 + ne2 * 3];
+                }
             }
 
             const float theta = theta_base * pow(freq_base, inv_ndims*i0);
@@ -490,7 +503,8 @@ kernel void kernel_rope_multi_f16(
         float attn_factor,
         float beta_fast,
         float beta_slow,
-        int4 sections
+        int4 sections,
+        int  is_imrope
 ) {
     src0 = (global void*)((global char*)src0 + offset0);
     src1 = (global int*)((global char*)src1 + offset1);
@@ -517,17 +531,29 @@ kernel void kernel_rope_multi_f16(
             const int sector = (i0 / 2) % sect_dims;
             float theta_base = 0.0f;
 
-            if (sector < sections.s0) {
-                theta_base = pos[i2];
-            }
-            else if (sector >= sections.s0 && sector < sec_w) {
-                theta_base = pos[i2 + ne2 * 1];
-            }
-            else if (sector >= sec_w && sector < sec_w + sections.s2) {
-                theta_base = pos[i2 + ne2 * 2];
-            }
-            else if (sector >= sec_w + sections.s2) {
-                theta_base = pos[i2 + ne2 * 3];
+            if (is_imrope) {
+                if (sector % 3 == 1 && sector < 3 * sections.s1) { // h
+                    theta_base = (float) pos[i2 + ne02 * 1];
+                } else if (sector % 3 == 2 && sector < 3 * sections.s2) { // w
+                    theta_base = (float) pos[i2 + ne02 * 2];
+                } else if (sector % 3 == 0 && sector < 3 * sections.s0) { // t
+                    theta_base = (float) pos[i2 + ne02 * 0];
+                } else { // e
+                    theta_base = (float) pos[i2 + ne02 * 3];
+                }
+            } else {
+                if (sector < sections.s0) {
+                    theta_base = pos[i2];
+                }
+                else if (sector >= sections.s0 && sector < sec_w) {
+                    theta_base = pos[i2 + ne2 * 1];
+                }
+                else if (sector >= sec_w && sector < sec_w + sections.s2) {
+                    theta_base = pos[i2 + ne2 * 2];
+                }
+                else if (sector >= sec_w + sections.s2) {
+                    theta_base = pos[i2 + ne2 * 3];
+                }
             }
 
             const float theta = theta_base * pow(freq_base, inv_ndims*i0);