From: lhez Date: Mon, 3 Nov 2025 19:47:57 +0000 (-0800) Subject: opencl: support imrope (llama/16914) X-Git-Tag: upstream/0.9.4.185~35 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=3ea395a7786210d370797d1c779e6156f461b940;p=pkg%2Fggml%2Fsources%2Fggml opencl: support imrope (llama/16914) * opencl: support imrope * opencl: fix whitespace --- diff --git a/src/ggml-opencl/ggml-opencl.cpp b/src/ggml-opencl/ggml-opencl.cpp index 93a3600b..3dc4d035 100644 --- a/src/ggml-opencl/ggml-opencl.cpp +++ b/src/ggml-opencl/ggml-opencl.cpp @@ -8399,6 +8399,7 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const const bool is_neox = mode & 2; const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; const bool is_vision = mode == GGML_ROPE_TYPE_VISION; + const int is_imrope = mode == GGML_ROPE_TYPE_IMROPE; if (is_mrope) { GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0); @@ -8489,9 +8490,14 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const CL_CHECK(clSetKernelArg(kernel, 30, sizeof(float), &attn_factor)); CL_CHECK(clSetKernelArg(kernel, 31, sizeof(float), &beta_fast)); CL_CHECK(clSetKernelArg(kernel, 32, sizeof(float), &beta_slow)); + // both mrope and vision kernels have sections if (is_mrope || is_vision) { CL_CHECK(clSetKernelArg(kernel, 33, sizeof(int32_t)*4, §ions)); } + // only mrope has is_imrope + if (is_mrope && !is_vision) { + CL_CHECK(clSetKernelArg(kernel, 34, sizeof(int), &is_imrope)); + } size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; size_t local_work_size[] = {(size_t)nth, 1, 1}; diff --git a/src/ggml-opencl/kernels/rope.cl b/src/ggml-opencl/kernels/rope.cl index 0247730c..82f4cd87 100644 --- a/src/ggml-opencl/kernels/rope.cl +++ b/src/ggml-opencl/kernels/rope.cl @@ -392,7 +392,8 @@ kernel void kernel_rope_multi_f32( float attn_factor, float beta_fast, float beta_slow, - int4 sections + int4 sections, + int is_imrope ) { src0 = (global void*)((global char*)src0 + offset0); src1 = (global int*)((global char*)src1 + offset1); @@ -419,17 +420,29 @@ kernel void kernel_rope_multi_f32( const int sector = (i0 / 2) % sect_dims; float theta_base = 0.0f; - if (sector < sections.s0) { - theta_base = pos[i2]; - } - else if (sector >= sections.s0 && sector < sec_w) { - theta_base = pos[i2 + ne2 * 1]; - } - else if (sector >= sec_w && sector < sec_w + sections.s2) { - theta_base = pos[i2 + ne2 * 2]; - } - else if (sector >= sec_w + sections.s2) { - theta_base = pos[i2 + ne2 * 3]; + if (is_imrope) { + if (sector % 3 == 1 && sector < 3 * sections.s1) { // h + theta_base = (float) pos[i2 + ne02 * 1]; + } else if (sector % 3 == 2 && sector < 3 * sections.s2) { // w + theta_base = (float) pos[i2 + ne02 * 2]; + } else if (sector % 3 == 0 && sector < 3 * sections.s0) { // t + theta_base = (float) pos[i2 + ne02 * 0]; + } else { // e + theta_base = (float) pos[i2 + ne02 * 3]; + } + } else { + if (sector < sections.s0) { + theta_base = pos[i2]; + } + else if (sector >= sections.s0 && sector < sec_w) { + theta_base = pos[i2 + ne2 * 1]; + } + else if (sector >= sec_w && sector < sec_w + sections.s2) { + theta_base = pos[i2 + ne2 * 2]; + } + else if (sector >= sec_w + sections.s2) { + theta_base = pos[i2 + ne2 * 3]; + } } const float theta = theta_base * pow(freq_base, inv_ndims*i0); @@ -490,7 +503,8 @@ kernel void kernel_rope_multi_f16( float attn_factor, float beta_fast, float beta_slow, - int4 sections + int4 sections, + int is_imrope ) { src0 = (global void*)((global char*)src0 + offset0); src1 = (global int*)((global char*)src1 + offset1); @@ -517,17 +531,29 @@ kernel void kernel_rope_multi_f16( const int sector = (i0 / 2) % sect_dims; float theta_base = 0.0f; - if (sector < sections.s0) { - theta_base = pos[i2]; - } - else if (sector >= sections.s0 && sector < sec_w) { - theta_base = pos[i2 + ne2 * 1]; - } - else if (sector >= sec_w && sector < sec_w + sections.s2) { - theta_base = pos[i2 + ne2 * 2]; - } - else if (sector >= sec_w + sections.s2) { - theta_base = pos[i2 + ne2 * 3]; + if (is_imrope) { + if (sector % 3 == 1 && sector < 3 * sections.s1) { // h + theta_base = (float) pos[i2 + ne02 * 1]; + } else if (sector % 3 == 2 && sector < 3 * sections.s2) { // w + theta_base = (float) pos[i2 + ne02 * 2]; + } else if (sector % 3 == 0 && sector < 3 * sections.s0) { // t + theta_base = (float) pos[i2 + ne02 * 0]; + } else { // e + theta_base = (float) pos[i2 + ne02 * 3]; + } + } else { + if (sector < sections.s0) { + theta_base = pos[i2]; + } + else if (sector >= sections.s0 && sector < sec_w) { + theta_base = pos[i2 + ne2 * 1]; + } + else if (sector >= sec_w && sector < sec_w + sections.s2) { + theta_base = pos[i2 + ne2 * 2]; + } + else if (sector >= sec_w + sections.s2) { + theta_base = pos[i2 + ne2 * 3]; + } } const float theta = theta_base * pow(freq_base, inv_ndims*i0);