]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
model: add support for qwen3vl series (llama/16780)
authorJJJYmmm <redacted>
Thu, 30 Oct 2025 15:19:14 +0000 (23:19 +0800)
committerGeorgi Gerganov <redacted>
Sat, 1 Nov 2025 07:41:35 +0000 (09:41 +0200)
* support qwen3vl series.

Co-authored-by: Thireus ☠ <redacted>
Co-authored-by: yairpatch <redacted>
Co-authored-by: LETS-BEE <redacted>
* bugfix: fix the arch check for qwen3vl-moe.

* use build_ffn

* optimize deepstack structure

* optimize deepstack feature saving

* Revert "optimize deepstack feature saving" for temporal fix

This reverts commit f321b9fdf13e59527408152e73b1071e19a87e71.

* code clean

* use fused qkv in clip

* clean up / rm is_deepstack_layers for simplification

* add test model

* move test model to "big" section

* fix imrope check

* remove trailing whitespace

* fix rope fail

* metal : add imrope support

* add imrope support for sycl

* vulkan: add imrope w/o check

* fix vulkan

* webgpu: add imrope w/o check

* Update gguf-py/gguf/tensor_mapping.py

Co-authored-by: Sigbjørn Skjæret <redacted>
* fix tensor mapping

---------

Co-authored-by: Thireus ☠ <redacted>
Co-authored-by: yairpatch <redacted>
Co-authored-by: LETS-BEE <redacted>
Co-authored-by: Xuan Son Nguyen <redacted>
Co-authored-by: Georgi Gerganov <redacted>
Co-authored-by: Sigbjørn Skjæret <redacted>
12 files changed:
include/ggml.h
src/ggml-cpu/ops.cpp
src/ggml-cuda/rope.cu
src/ggml-metal/ggml-metal-device.cpp
src/ggml-metal/ggml-metal-impl.h
src/ggml-metal/ggml-metal.metal
src/ggml-sycl/rope.cpp
src/ggml-vulkan/ggml-vulkan.cpp
src/ggml-vulkan/vulkan-shaders/rope_head.glsl
src/ggml-vulkan/vulkan-shaders/rope_multi.comp
src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl
tests/test-backend-ops.cpp

index d948b00cc7f3048258b358ff32410d2e20204fb6..2311cdabe3ba4ad85009f688b5fa1ccbb160e924 100644 (file)
 #define GGML_ROPE_TYPE_NEOX   2
 #define GGML_ROPE_TYPE_MROPE  8
 #define GGML_ROPE_TYPE_VISION 24
+#define GGML_ROPE_TYPE_IMROPE 40 // binary: 101000
 
 #define GGML_MROPE_SECTIONS   4
 
index c17ab10245d58240d11e114aed31fb359dc6236c..f66d36ff62c035313027129217988ce0a7450553 100644 (file)
@@ -5474,7 +5474,7 @@ static void ggml_rope_cache_init(
 }
 
 static void ggml_mrope_cache_init(
-     float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects,
+     float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool is_imrope, bool indep_sects,
      float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
      float * cache, float sin_sign, float theta_scale) {
     // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
@@ -5509,14 +5509,26 @@ static void ggml_mrope_cache_init(
         }
 
         float theta = theta_t;
-        if (sector >= sections[0] && sector < sec_w) {
-            theta = theta_h;
-        }
-        else if (sector >= sec_w && sector < sec_w + sections[2]) {
-            theta = theta_w;
-        }
-        else if (sector >= sec_w + sections[2]) {
-            theta = theta_e;
+        if (is_imrope) { // qwen3vl apply interleaved mrope
+            if (sector % 3 == 1 && sector < 3 * sections[1]) {
+                theta = theta_h;
+            } else if (sector % 3 == 2 && sector < 3 * sections[2]) {
+                theta = theta_w;
+            } else if (sector % 3 == 0 && sector < 3 * sections[0]) {
+                theta = theta_t;
+            } else {
+                theta = theta_e;
+            }
+        } else {
+            if (sector >= sections[0] && sector < sec_w) {
+                theta = theta_h;
+            }
+            else if (sector >= sec_w && sector < sec_w + sections[2]) {
+                theta = theta_w;
+            }
+            else if (sector >= sec_w + sections[2]) {
+                theta = theta_e;
+            }
         }
 
         rope_yarn(
@@ -5589,6 +5601,7 @@ static void ggml_compute_forward_rope_f32(
 
     const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
     const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;  // ggml_rope_multi, multimodal rotary position embedding
+    const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
     const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
 
     if (is_mrope) {
@@ -5627,7 +5640,7 @@ static void ggml_compute_forward_rope_f32(
                 const int64_t p_w = pos[i2 + ne2 * 2];
                 const int64_t p_e = pos[i2 + ne2 * 3];
                 ggml_mrope_cache_init(
-                    p_t, p_h, p_w, p_e, sections, is_vision,
+                    p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
                     freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
             }
 
@@ -5775,6 +5788,7 @@ static void ggml_compute_forward_rope_f16(
 
     const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
     const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
+    const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
     const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
 
     if (is_mrope) {
@@ -5813,7 +5827,7 @@ static void ggml_compute_forward_rope_f16(
                 const int64_t p_w = pos[i2 + ne2 * 2];
                 const int64_t p_e = pos[i2 + ne2 * 3];
                 ggml_mrope_cache_init(
-                    p_t, p_h, p_w, p_e, sections, is_vision,
+                    p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
                     freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
             }
 
index d058504cd6cc036d690cb2cfea00e36f8e80630c..78ed7f519abb9bacc6ed022c8b31a8cd38f2b5c6 100644 (file)
@@ -125,7 +125,7 @@ template<bool forward, bool has_ff, typename T>
 static __global__ void rope_multi(
         const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
         const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
-        const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections) {
+        const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections, const bool is_imrope) {
     const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
     if (i0 >= ne0) {
@@ -152,17 +152,29 @@ static __global__ void rope_multi(
     const int sector = (i0 / 2) % sect_dims;
 
     float theta_base = 0.0;
-    if (sector < sections.v[0]) {
-        theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
-    }
-    else if (sector >= sections.v[0] && sector < sec_w) {
-        theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
-    }
-    else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
-        theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
-    }
-    else if (sector >= sec_w + sections.v[2]) {
-        theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
+    if (is_imrope) {
+        if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h
+            theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
+        } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w
+            theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
+        } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t
+            theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
+        } else {
+            theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
+        }
+    } else {
+        if (sector < sections.v[0]) {
+            theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
+        }
+        else if (sector >= sections.v[0] && sector < sec_w) {
+            theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
+        }
+        else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
+            theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
+        }
+        else if (sector >= sec_w + sections.v[2]) {
+            theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
+        }
     }
 
     const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@@ -276,7 +288,7 @@ template<bool forward, typename T>
 static void rope_multi_cuda(
         const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
         const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
-        const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) {
+        const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, const bool is_imrope, cudaStream_t stream) {
     GGML_ASSERT(ne0 % 2 == 0);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
     const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
@@ -287,11 +299,11 @@ static void rope_multi_cuda(
     if (freq_factors == nullptr) {
         rope_multi<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
             x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
-            attn_factor, corr_dims, theta_scale, freq_factors, sections);
+            attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
     } else {
         rope_multi<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
             x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
-            attn_factor, corr_dims, theta_scale, freq_factors, sections);
+            attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
     }
 }
 
@@ -369,6 +381,7 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
 
     const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
     const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
+    const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
     const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
 
     if (is_mrope) {
@@ -406,11 +419,11 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
         if (src0->type == GGML_TYPE_F32) {
             rope_multi_cuda<forward>(
                 (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
-                freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
+                freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream);
         } else if (src0->type == GGML_TYPE_F16) {
             rope_multi_cuda<forward>(
                 (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
-                freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
+                freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream);
         } else {
             GGML_ABORT("fatal error");
         }
index 75811634227b352987f8b75f37836d52ea79f8cd..1a3c7873b745c4eaf75cca18647c48760ac9ec39 100644 (file)
@@ -1332,11 +1332,12 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope(ggml_metal_library_t
 
     const bool is_neox   = mode & GGML_ROPE_TYPE_NEOX;
     const bool is_mrope  = mode & GGML_ROPE_TYPE_MROPE;
+    const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
     const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
 
     if (is_neox) {
         snprintf(base, 256, "kernel_rope_neox_%s", ggml_type_name(op->src[0]->type));
-    } else if (is_mrope && !is_vision) {
+    } else if ((is_mrope || is_imrope) && !is_vision) {
         GGML_ASSERT(op->src[1]->ne[0]*4 >= op->src[0]->ne[2]); // need at least 4 pos per token
         snprintf(base, 256, "kernel_rope_multi_%s", ggml_type_name(op->src[0]->type));
     } else if (is_vision) {
@@ -1346,14 +1347,20 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope(ggml_metal_library_t
         snprintf(base, 256, "kernel_rope_norm_%s", ggml_type_name(op->src[0]->type));
     }
 
-    snprintf(name, 256, "%s", base);
+    snprintf(name, 256, "%s_imrope=%d", base, is_imrope ? 1 : 0);
 
     ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
     if (res) {
         return res;
     }
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+    ggml_metal_cv_t cv = ggml_metal_cv_init();
+
+    ggml_metal_cv_set_bool(cv, is_imrope, FC_ROPE + 0);
+
+    res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+
+    ggml_metal_cv_free(cv);
 
     return res;
 }
index 96f43d260a3c33dc539b4a6e16f259f3b66e2bf3..7a878a657bc124478d299c4a6a0f7862c3e68174 100644 (file)
@@ -76,6 +76,7 @@
 #define FC_FLASH_ATTN_EXT_VEC_REDUCE   500
 #define FC_MUL_MV                      600
 #define FC_MUL_MM                      700
+#define FC_ROPE                        800
 
 // op-specific constants
 #define OP_FLASH_ATTN_EXT_NQPTG 8
index 2c2f0141514ca252a86f8b02818f90620015d887..fa839a1df6e304c8f85adb7e392e9fbd7c1d8d37 100644 (file)
@@ -3709,6 +3709,8 @@ template [[host_name("kernel_mul_mv_bf16_f32_short")]]  kernel mul_mv_t_t_short_
 template [[host_name("kernel_mul_mv_bf16_bf16_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<bfloat, bfloat>;
 #endif
 
+constant bool FC_rope_is_imrope [[function_constant(FC_ROPE + 0)]];
+
 static float rope_yarn_ramp(const float low, const float high, const int i0) {
     const float y = (i0 / 2 - low) / max(0.001f, high - low);
     return 1.0f - min(1.0f, max(0.0f, y));
@@ -3889,14 +3891,26 @@ kernel void kernel_rope_multi(
             const int sector    = ic % sect_dims;
 
             float theta_base;
-            if (sector < args.sect_0) {
-                theta_base = (float) pos[i2];
-            } else if (sector < sec_w01) {
-                theta_base = (float) pos[i2 + args.ne02];
-            } else if (sector < sec_w012) {
-                theta_base = (float) pos[i2 + args.ne02 * 2];
+            if (FC_rope_is_imrope) {
+                if (sector % 3 == 1 && sector < 3 * args.sect_1) { // h
+                    theta_base = (float) pos[i2 + args.ne02 * 1];
+                } else if (sector % 3 == 2 && sector < 3 * args.sect_2) { // w
+                    theta_base = (float) pos[i2 + args.ne02 * 2];
+                } else if (sector % 3 == 0 && sector < 3 * args.sect_0) { // t
+                    theta_base = (float) pos[i2 + args.ne02 * 0];
+                } else { // e
+                    theta_base = (float) pos[i2 + args.ne02 * 3];
+                }
             } else {
-                theta_base = (float) pos[i2 + args.ne02 * 3];
+                if (sector < args.sect_0) {
+                    theta_base = (float) pos[i2];
+                } else if (sector < sec_w01) {
+                    theta_base = (float) pos[i2 + args.ne02 * 1];
+                } else if (sector < sec_w012) {
+                    theta_base = (float) pos[i2 + args.ne02 * 2];
+                } else {
+                    theta_base = (float) pos[i2 + args.ne02 * 3];
+                }
             }
             // end of mrope
 
index a3ab703d1f08896ab33f05fa21ba9518d2fae31c..69140b19a4c072c0eb2eb20114aea5b1fa35910a 100644 (file)
@@ -119,7 +119,7 @@ static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const
                         const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
                         const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
                         const float theta_scale, const float * freq_factors, const mrope_sections sections,
-                        const sycl::nd_item<3> & item_ct1) {
+                        const bool is_imrope, const sycl::nd_item<3> & item_ct1) {
     // get index pos
     const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1));
     if (i0 >= ne0) {
@@ -143,17 +143,29 @@ static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const
 
 
     float theta_base = 0.0;
-    if (sector < sections.v[0]) {
-        theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
-    }
-    else if (sector >= sections.v[0] && sector < sec_w) {
-        theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
-    }
-    else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
-        theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
-    }
-    else if (sector >= sec_w + sections.v[2]) {
-        theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
+    if (is_imrope) {
+        if (sector % 3 == 1 && sector < 3 * sections.v[1]) {
+            theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
+        } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) {
+            theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
+        } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) {
+            theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
+        } else {
+            theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
+        }
+    } else {
+        if (sector < sections.v[0]) {
+            theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
+        }
+        else if (sector >= sections.v[0] && sector < sec_w) {
+            theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
+        }
+        else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
+            theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
+        }
+        else if (sector >= sec_w + sections.v[2]) {
+            theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
+        }
     }
 
     const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
@@ -281,7 +293,7 @@ static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1,
                              const size_t s2, const int n_dims, const int nr, const int32_t * pos,
                              const float freq_scale, const float freq_base, const float ext_factor,
                              const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors,
-                             const mrope_sections sections, queue_ptr stream) {
+                             const mrope_sections sections, const bool is_imrope, queue_ptr stream) {
     GGML_ASSERT(ne0 % 2 == 0);
     const sycl::range<3>    block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
     const int               n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
@@ -297,12 +309,12 @@ static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1,
     if (freq_factors == nullptr) {
         stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
             rope_multi<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
-                                  corr_dims, theta_scale, freq_factors, sections, item_ct1);
+                                  corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1);
         });
     } else {
         stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
             rope_multi<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
-                                 corr_dims, theta_scale, freq_factors, sections, item_ct1);
+                                 corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1);
         });
     }
 }
@@ -381,6 +393,7 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
 
     const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
     const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
+    const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
     const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
 
     if (is_mrope) {
@@ -422,11 +435,11 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
         if (dst->src[0]->type == GGML_TYPE_F16) {
             rope_multi_sycl((const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, ne01, ne02, s01,
                 s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
-                freq_factors, sections, main_stream);
+                freq_factors, sections, is_imrope, main_stream);
         } else if (dst->src[0]->type == GGML_TYPE_F32) {
             rope_multi_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
                              nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
-                             main_stream);
+                             is_imrope, main_stream);
         } else {
             GGML_ABORT("Fatal error: Tensor type unsupported!");
         }
index d0976519f263feb5b72e6d69658eab77c6ec580e..b61879aa5d312351859b091705158f6c73bd0f7f 100644 (file)
@@ -1056,6 +1056,7 @@ struct vk_op_rope_push_constants {
     uint32_t s1;
     uint32_t s2;
     int32_t sections[4];
+    uint32_t is_imrope;
     uint32_t is_back;
     uint32_t set_rows_stride;
 };
@@ -9927,6 +9928,8 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
         memcpy(sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
     }
 
+    const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
+
     float corr_dims[2];
     ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
 
@@ -9948,7 +9951,7 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
         (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
         freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
         src2 != nullptr, (uint32_t)src0->ne[2], s1, s2,
-        { sections[0], sections[1], sections[2], sections[3] }, backprop, set_rows_stride,
+        { sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride,
     }, dryrun);
 }
 
index 0eda186c8a37c8ab3bde645526cf2b70a06aead1..fa2bb33394cb2df081c31d1723cce3a23e212740 100644 (file)
@@ -27,6 +27,7 @@ layout (push_constant) uniform parameter {
     uint s1;
     uint s2;
     int sections[4];
+    uint is_imrope;
     uint is_back;
     uint set_rows_stride;
 } p;
index 111286b4988c3f76aa2a5eabfa20eb31802b8483..54aabcf22283893470cd75c56fda2245f57a7d24 100644 (file)
@@ -32,17 +32,29 @@ void main() {
     const uint sector = (i0 / 2) % sect_dims;
 
     float theta_base = 0.0;
-    if (sector < p.sections[0]) {
-        theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
-    }
-    else if (sector >= p.sections[0] && sector < sec_w) {
-        theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
-    }
-    else if (sector >= sec_w && sector < sec_w + p.sections[2]) {
-        theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
-    }
-    else if (sector >= sec_w + p.sections[2]) {
-        theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
+    if (p.is_imrope != 0) {
+        if (sector % 3 == 1 && sector < 3 * p.sections[1]) {
+            theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
+        } else if (sector % 3 == 2 && sector < 3 * p.sections[2]) {
+            theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
+        } else if (sector % 3 == 0 && sector < 3 * p.sections[0]) {
+            theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
+        } else {
+            theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
+        }
+    } else {
+        if (sector < p.sections[0]) {
+            theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
+        }
+        else if (sector >= p.sections[0] && sector < sec_w) {
+            theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
+        }
+        else if (sector >= sec_w && sector < sec_w + p.sections[2]) {
+            theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
+        }
+        else if (sector >= sec_w + p.sections[2]) {
+            theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
+        }
     }
 
     const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
index 9a6ff41128b6d7815da555f070d003e74170d3f0..84dc8dbff61debb250af27dd59b617412f9c1026 100644 (file)
@@ -221,6 +221,7 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
 
     let is_neox = bool(params.mode & 2);
     let is_mrope = bool(params.mode & 8);
+    let is_imrope = params.mode == 40;
     let is_vision = params.mode == 24;
 
     var i = gid.x * 2; // start index for this thread
@@ -248,24 +249,36 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
         let sec_w = params.sections1 + params.sections0;
         let sec_e = params.sections2 + sec_w;
         let sector = (i0 / 2) % sect_dims;
-        if (sector >= params.sections0 && sector < sec_w) {
-            theta_base_mult = 1;
-            if (is_vision) {
-                theta_scale_pwr = sector - params.sections0;
-            }
-        } else if (sector >= sec_w && sector < sec_e) {
-            theta_base_mult = 2;
-            if (is_vision) {
-                theta_scale_pwr = sector - sec_w;
-            }
-        } else if (sector >= sec_e) {
-            if (is_vision) {
-                theta_scale_pwr = sector - sec_e;
-                theta_scale_pwr = (i0 / 2) % sec_e;
-            }
-            theta_base_mult = 3;
-        } else if (is_vision) {
-            theta_scale_pwr = sector;
+        if (is_imrope) {
+          if (sector % 3 == 1 && sector < 3 * params.sections1) {
+              theta_base_mult = 1;
+          } else if (sector % 3 == 2 && sector < 3 * params.sections2) {
+              theta_base_mult = 2;
+          } else if (sector % 3 == 0 && sector < 3 * params.sections0) {
+              theta_base_mult = 0;
+          } else {
+              theta_base_mult = 3;
+          }
+        } else {
+          if (sector >= params.sections0 && sector < sec_w) {
+              theta_base_mult = 1;
+              if (is_vision) {
+                  theta_scale_pwr = sector - params.sections0;
+              }
+          } else if (sector >= sec_w && sector < sec_e) {
+              theta_base_mult = 2;
+              if (is_vision) {
+                  theta_scale_pwr = sector - sec_w;
+              }
+          } else if (sector >= sec_e) {
+              if (is_vision) {
+                  theta_scale_pwr = sector - sec_e;
+                  theta_scale_pwr = (i0 / 2) % sec_e;
+              }
+              theta_base_mult = 3;
+          } else if (is_vision) {
+              theta_scale_pwr = sector;
+          }
         }
     }
     let theta_base = f32(src1[params.offset_src1 + i2 + params.ne2 * theta_base_mult]) * pow(params.theta_scale, f32(theta_scale_pwr));
index 4b1304c2b891cc549656759d4c8753d6148a3b84..92361d6f0f4d7763c19fe298cf55dbd5957db3db 100644 (file)
@@ -7076,7 +7076,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
                                     test_cases.emplace_back(new test_rope(type, {128,  28, 2, 1}, 128, GGML_ROPE_TYPE_MROPE,  512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 7B)
                                     test_cases.emplace_back(new test_rope(type, {128,  12, 2, 1},  20, GGML_ROPE_TYPE_MROPE,  512, fs, ef, af, ff, v, fw));
                                     test_cases.emplace_back(new test_rope(type, {128,  28, 2, 1},  32, GGML_ROPE_TYPE_MROPE,  512, fs, ef, af, ff, v, fw));
+                                    test_cases.emplace_back(new test_rope(type, {128,  12, 2, 1}, 128, GGML_ROPE_TYPE_IMROPE,  512, fs, ef, af, ff, v, fw)); // rope_multi,imrope (qwen3vl 2B)
+                                    test_cases.emplace_back(new test_rope(type, {128,  28, 2, 1}, 128, GGML_ROPE_TYPE_IMROPE,  512, fs, ef, af, ff, v, fw)); // rope_multi,imrope (qwen3vl 7B)
+                                    test_cases.emplace_back(new test_rope(type, {128,  12, 2, 1},  20, GGML_ROPE_TYPE_IMROPE,  512, fs, ef, af, ff, v, fw));
+                                    test_cases.emplace_back(new test_rope(type, {128,  28, 2, 1},  32, GGML_ROPE_TYPE_IMROPE,  512, fs, ef, af, ff, v, fw));
                                     test_cases.emplace_back(new test_rope(type, { 80,  16, 2, 1},  80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT)
+                                    test_cases.emplace_back(new test_rope(type, {128,  16, 2, 1}, 128, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen3vl)
                                 }
 
                                 test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1},  64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)
@@ -7092,7 +7097,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
 
     // single inplace test per type/mode/ff
     for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
-        for (int mode : {GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX, GGML_ROPE_TYPE_MROPE, GGML_ROPE_TYPE_VISION}) {
+        for (int mode : {GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX, GGML_ROPE_TYPE_MROPE, GGML_ROPE_TYPE_IMROPE, GGML_ROPE_TYPE_VISION}) {
             for (bool ff : {false, true}) {
                 test_cases.emplace_back(new test_rope(type, {128,  32, 2, 1}, 128, mode, 512, 1.4245f, 0.7465f, 1.4245f, ff, 0, true, true));
             }