From: JJJYmmm Date: Thu, 30 Oct 2025 15:19:14 +0000 (+0800) Subject: model: add support for qwen3vl series (llama/16780) X-Git-Tag: upstream/0.9.4.185~56 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=fd33b4bb05d99c0784892a8f4e38d4c7a4d873e3;p=pkg%2Fggml%2Fsources%2Fggml model: add support for qwen3vl series (llama/16780) * support qwen3vl series. Co-authored-by: Thireus ☠ Co-authored-by: yairpatch Co-authored-by: LETS-BEE * bugfix: fix the arch check for qwen3vl-moe. * use build_ffn * optimize deepstack structure * optimize deepstack feature saving * Revert "optimize deepstack feature saving" for temporal fix This reverts commit f321b9fdf13e59527408152e73b1071e19a87e71. * code clean * use fused qkv in clip * clean up / rm is_deepstack_layers for simplification * add test model * move test model to "big" section * fix imrope check * remove trailing whitespace * fix rope fail * metal : add imrope support * add imrope support for sycl * vulkan: add imrope w/o check * fix vulkan * webgpu: add imrope w/o check * Update gguf-py/gguf/tensor_mapping.py Co-authored-by: Sigbjørn Skjæret * fix tensor mapping --------- Co-authored-by: Thireus ☠ Co-authored-by: yairpatch Co-authored-by: LETS-BEE Co-authored-by: Xuan Son Nguyen Co-authored-by: Georgi Gerganov Co-authored-by: Sigbjørn Skjæret --- diff --git a/include/ggml.h b/include/ggml.h index d948b00c..2311cdab 100644 --- a/include/ggml.h +++ b/include/ggml.h @@ -242,6 +242,7 @@ #define GGML_ROPE_TYPE_NEOX 2 #define GGML_ROPE_TYPE_MROPE 8 #define GGML_ROPE_TYPE_VISION 24 +#define GGML_ROPE_TYPE_IMROPE 40 // binary: 101000 #define GGML_MROPE_SECTIONS 4 diff --git a/src/ggml-cpu/ops.cpp b/src/ggml-cpu/ops.cpp index c17ab102..f66d36ff 100644 --- a/src/ggml-cpu/ops.cpp +++ b/src/ggml-cpu/ops.cpp @@ -5474,7 +5474,7 @@ static void ggml_rope_cache_init( } static void ggml_mrope_cache_init( - float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects, + float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool is_imrope, bool indep_sects, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale, float * cache, float sin_sign, float theta_scale) { // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py @@ -5509,14 +5509,26 @@ static void ggml_mrope_cache_init( } float theta = theta_t; - if (sector >= sections[0] && sector < sec_w) { - theta = theta_h; - } - else if (sector >= sec_w && sector < sec_w + sections[2]) { - theta = theta_w; - } - else if (sector >= sec_w + sections[2]) { - theta = theta_e; + if (is_imrope) { // qwen3vl apply interleaved mrope + if (sector % 3 == 1 && sector < 3 * sections[1]) { + theta = theta_h; + } else if (sector % 3 == 2 && sector < 3 * sections[2]) { + theta = theta_w; + } else if (sector % 3 == 0 && sector < 3 * sections[0]) { + theta = theta_t; + } else { + theta = theta_e; + } + } else { + if (sector >= sections[0] && sector < sec_w) { + theta = theta_h; + } + else if (sector >= sec_w && sector < sec_w + sections[2]) { + theta = theta_w; + } + else if (sector >= sec_w + sections[2]) { + theta = theta_e; + } } rope_yarn( @@ -5589,6 +5601,7 @@ static void ggml_compute_forward_rope_f32( const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding + const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope const bool is_vision = mode == GGML_ROPE_TYPE_VISION; if (is_mrope) { @@ -5627,7 +5640,7 @@ static void ggml_compute_forward_rope_f32( const int64_t p_w = pos[i2 + ne2 * 2]; const int64_t p_e = pos[i2 + ne2 * 3]; ggml_mrope_cache_init( - p_t, p_h, p_w, p_e, sections, is_vision, + p_t, p_h, p_w, p_e, sections, is_imrope, is_vision, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); } @@ -5775,6 +5788,7 @@ static void ggml_compute_forward_rope_f16( const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; + const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; const bool is_vision = mode == GGML_ROPE_TYPE_VISION; if (is_mrope) { @@ -5813,7 +5827,7 @@ static void ggml_compute_forward_rope_f16( const int64_t p_w = pos[i2 + ne2 * 2]; const int64_t p_e = pos[i2 + ne2 * 3]; ggml_mrope_cache_init( - p_t, p_h, p_w, p_e, sections, is_vision, + p_t, p_h, p_w, p_e, sections, is_imrope, is_vision, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); } diff --git a/src/ggml-cuda/rope.cu b/src/ggml-cuda/rope.cu index d058504c..78ed7f51 100644 --- a/src/ggml-cuda/rope.cu +++ b/src/ggml-cuda/rope.cu @@ -125,7 +125,7 @@ template static __global__ void rope_multi( const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections) { + const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections, const bool is_imrope) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); if (i0 >= ne0) { @@ -152,17 +152,29 @@ static __global__ void rope_multi( const int sector = (i0 / 2) % sect_dims; float theta_base = 0.0; - if (sector < sections.v[0]) { - theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); - } - else if (sector >= sections.v[0] && sector < sec_w) { - theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f); - } - else if (sector >= sec_w && sector < sec_w + sections.v[2]) { - theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f); - } - else if (sector >= sec_w + sections.v[2]) { - theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f); + if (is_imrope) { + if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h + theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f); + } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w + theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f); + } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t + theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); + } else { + theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f); + } + } else { + if (sector < sections.v[0]) { + theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); + } + else if (sector >= sections.v[0] && sector < sec_w) { + theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f); + } + else if (sector >= sec_w && sector < sec_w + sections.v[2]) { + theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f); + } + else if (sector >= sec_w + sections.v[2]) { + theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f); + } } const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; @@ -276,7 +288,7 @@ template static void rope_multi_cuda( const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr, const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) { + const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, const bool is_imrope, cudaStream_t stream) { GGML_ASSERT(ne0 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); @@ -287,11 +299,11 @@ static void rope_multi_cuda( if (freq_factors == nullptr) { rope_multi<<>>( x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, - attn_factor, corr_dims, theta_scale, freq_factors, sections); + attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope); } else { rope_multi<<>>( x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, - attn_factor, corr_dims, theta_scale, freq_factors, sections); + attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope); } } @@ -369,6 +381,7 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; + const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; const bool is_vision = mode == GGML_ROPE_TYPE_VISION; if (is_mrope) { @@ -406,11 +419,11 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) if (src0->type == GGML_TYPE_F32) { rope_multi_cuda( (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, - freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream); } else if (src0->type == GGML_TYPE_F16) { rope_multi_cuda( (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, - freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream); } else { GGML_ABORT("fatal error"); } diff --git a/src/ggml-metal/ggml-metal-device.cpp b/src/ggml-metal/ggml-metal-device.cpp index 75811634..1a3c7873 100644 --- a/src/ggml-metal/ggml-metal-device.cpp +++ b/src/ggml-metal/ggml-metal-device.cpp @@ -1332,11 +1332,12 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope(ggml_metal_library_t const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; + const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; const bool is_vision = mode == GGML_ROPE_TYPE_VISION; if (is_neox) { snprintf(base, 256, "kernel_rope_neox_%s", ggml_type_name(op->src[0]->type)); - } else if (is_mrope && !is_vision) { + } else if ((is_mrope || is_imrope) && !is_vision) { GGML_ASSERT(op->src[1]->ne[0]*4 >= op->src[0]->ne[2]); // need at least 4 pos per token snprintf(base, 256, "kernel_rope_multi_%s", ggml_type_name(op->src[0]->type)); } else if (is_vision) { @@ -1346,14 +1347,20 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope(ggml_metal_library_t snprintf(base, 256, "kernel_rope_norm_%s", ggml_type_name(op->src[0]->type)); } - snprintf(name, 256, "%s", base); + snprintf(name, 256, "%s_imrope=%d", base, is_imrope ? 1 : 0); ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); if (res) { return res; } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_bool(cv, is_imrope, FC_ROPE + 0); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); return res; } diff --git a/src/ggml-metal/ggml-metal-impl.h b/src/ggml-metal/ggml-metal-impl.h index 96f43d26..7a878a65 100644 --- a/src/ggml-metal/ggml-metal-impl.h +++ b/src/ggml-metal/ggml-metal-impl.h @@ -76,6 +76,7 @@ #define FC_FLASH_ATTN_EXT_VEC_REDUCE 500 #define FC_MUL_MV 600 #define FC_MUL_MM 700 +#define FC_ROPE 800 // op-specific constants #define OP_FLASH_ATTN_EXT_NQPTG 8 diff --git a/src/ggml-metal/ggml-metal.metal b/src/ggml-metal/ggml-metal.metal index 2c2f0141..fa839a1d 100644 --- a/src/ggml-metal/ggml-metal.metal +++ b/src/ggml-metal/ggml-metal.metal @@ -3709,6 +3709,8 @@ template [[host_name("kernel_mul_mv_bf16_f32_short")]] kernel mul_mv_t_t_short_ template [[host_name("kernel_mul_mv_bf16_bf16_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short; #endif +constant bool FC_rope_is_imrope [[function_constant(FC_ROPE + 0)]]; + static float rope_yarn_ramp(const float low, const float high, const int i0) { const float y = (i0 / 2 - low) / max(0.001f, high - low); return 1.0f - min(1.0f, max(0.0f, y)); @@ -3889,14 +3891,26 @@ kernel void kernel_rope_multi( const int sector = ic % sect_dims; float theta_base; - if (sector < args.sect_0) { - theta_base = (float) pos[i2]; - } else if (sector < sec_w01) { - theta_base = (float) pos[i2 + args.ne02]; - } else if (sector < sec_w012) { - theta_base = (float) pos[i2 + args.ne02 * 2]; + if (FC_rope_is_imrope) { + if (sector % 3 == 1 && sector < 3 * args.sect_1) { // h + theta_base = (float) pos[i2 + args.ne02 * 1]; + } else if (sector % 3 == 2 && sector < 3 * args.sect_2) { // w + theta_base = (float) pos[i2 + args.ne02 * 2]; + } else if (sector % 3 == 0 && sector < 3 * args.sect_0) { // t + theta_base = (float) pos[i2 + args.ne02 * 0]; + } else { // e + theta_base = (float) pos[i2 + args.ne02 * 3]; + } } else { - theta_base = (float) pos[i2 + args.ne02 * 3]; + if (sector < args.sect_0) { + theta_base = (float) pos[i2]; + } else if (sector < sec_w01) { + theta_base = (float) pos[i2 + args.ne02 * 1]; + } else if (sector < sec_w012) { + theta_base = (float) pos[i2 + args.ne02 * 2]; + } else { + theta_base = (float) pos[i2 + args.ne02 * 3]; + } } // end of mrope diff --git a/src/ggml-sycl/rope.cpp b/src/ggml-sycl/rope.cpp index a3ab703d..69140b19 100644 --- a/src/ggml-sycl/rope.cpp +++ b/src/ggml-sycl/rope.cpp @@ -119,7 +119,7 @@ static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections, - const sycl::nd_item<3> & item_ct1) { + const bool is_imrope, const sycl::nd_item<3> & item_ct1) { // get index pos const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1)); if (i0 >= ne0) { @@ -143,17 +143,29 @@ static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const float theta_base = 0.0; - if (sector < sections.v[0]) { - theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f); - } - else if (sector >= sections.v[0] && sector < sec_w) { - theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f); - } - else if (sector >= sec_w && sector < sec_w + sections.v[2]) { - theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f); - } - else if (sector >= sec_w + sections.v[2]) { - theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f); + if (is_imrope) { + if (sector % 3 == 1 && sector < 3 * sections.v[1]) { + theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f); + } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { + theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f); + } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { + theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f); + } else { + theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f); + } + } else { + if (sector < sections.v[0]) { + theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f); + } + else if (sector >= sections.v[0] && sector < sec_w) { + theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f); + } + else if (sector >= sec_w && sector < sec_w + sections.v[2]) { + theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f); + } + else if (sector >= sec_w + sections.v[2]) { + theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f); + } } const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f; @@ -281,7 +293,7 @@ static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1, const size_t s2, const int n_dims, const int nr, const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors, - const mrope_sections sections, queue_ptr stream) { + const mrope_sections sections, const bool is_imrope, queue_ptr stream) { GGML_ASSERT(ne0 % 2 == 0); const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE)); @@ -297,12 +309,12 @@ static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1, if (freq_factors == nullptr) { stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) { rope_multi(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, - corr_dims, theta_scale, freq_factors, sections, item_ct1); + corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1); }); } else { stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) { rope_multi(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, - corr_dims, theta_scale, freq_factors, sections, item_ct1); + corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1); }); } } @@ -381,6 +393,7 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; + const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; const bool is_vision = mode == GGML_ROPE_TYPE_VISION; if (is_mrope) { @@ -422,11 +435,11 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) if (dst->src[0]->type == GGML_TYPE_F16) { rope_multi_sycl((const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, - freq_factors, sections, main_stream); + freq_factors, sections, is_imrope, main_stream); } else if (dst->src[0]->type == GGML_TYPE_F32) { rope_multi_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, - main_stream); + is_imrope, main_stream); } else { GGML_ABORT("Fatal error: Tensor type unsupported!"); } diff --git a/src/ggml-vulkan/ggml-vulkan.cpp b/src/ggml-vulkan/ggml-vulkan.cpp index d0976519..b61879aa 100644 --- a/src/ggml-vulkan/ggml-vulkan.cpp +++ b/src/ggml-vulkan/ggml-vulkan.cpp @@ -1056,6 +1056,7 @@ struct vk_op_rope_push_constants { uint32_t s1; uint32_t s2; int32_t sections[4]; + uint32_t is_imrope; uint32_t is_back; uint32_t set_rows_stride; }; @@ -9927,6 +9928,8 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons memcpy(sections, (int32_t *) dst->op_params + 11, sizeof(int)*4); } + const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; + float corr_dims[2]; ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); @@ -9948,7 +9951,7 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, src2 != nullptr, (uint32_t)src0->ne[2], s1, s2, - { sections[0], sections[1], sections[2], sections[3] }, backprop, set_rows_stride, + { sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride, }, dryrun); } diff --git a/src/ggml-vulkan/vulkan-shaders/rope_head.glsl b/src/ggml-vulkan/vulkan-shaders/rope_head.glsl index 0eda186c..fa2bb333 100644 --- a/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +++ b/src/ggml-vulkan/vulkan-shaders/rope_head.glsl @@ -27,6 +27,7 @@ layout (push_constant) uniform parameter { uint s1; uint s2; int sections[4]; + uint is_imrope; uint is_back; uint set_rows_stride; } p; diff --git a/src/ggml-vulkan/vulkan-shaders/rope_multi.comp b/src/ggml-vulkan/vulkan-shaders/rope_multi.comp index 111286b4..54aabcf2 100644 --- a/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +++ b/src/ggml-vulkan/vulkan-shaders/rope_multi.comp @@ -32,17 +32,29 @@ void main() { const uint sector = (i0 / 2) % sect_dims; float theta_base = 0.0; - if (sector < p.sections[0]) { - theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f); - } - else if (sector >= p.sections[0] && sector < sec_w) { - theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f); - } - else if (sector >= sec_w && sector < sec_w + p.sections[2]) { - theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f); - } - else if (sector >= sec_w + p.sections[2]) { - theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f); + if (p.is_imrope != 0) { + if (sector % 3 == 1 && sector < 3 * p.sections[1]) { + theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f); + } else if (sector % 3 == 2 && sector < 3 * p.sections[2]) { + theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f); + } else if (sector % 3 == 0 && sector < 3 * p.sections[0]) { + theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f); + } else { + theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f); + } + } else { + if (sector < p.sections[0]) { + theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f); + } + else if (sector >= p.sections[0] && sector < sec_w) { + theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f); + } + else if (sector >= sec_w && sector < sec_w + p.sections[2]) { + theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f); + } + else if (sector >= sec_w + p.sections[2]) { + theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f); + } } const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; diff --git a/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl b/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl index 9a6ff411..84dc8dbf 100644 --- a/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +++ b/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl @@ -221,6 +221,7 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let is_neox = bool(params.mode & 2); let is_mrope = bool(params.mode & 8); + let is_imrope = params.mode == 40; let is_vision = params.mode == 24; var i = gid.x * 2; // start index for this thread @@ -248,24 +249,36 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let sec_w = params.sections1 + params.sections0; let sec_e = params.sections2 + sec_w; let sector = (i0 / 2) % sect_dims; - if (sector >= params.sections0 && sector < sec_w) { - theta_base_mult = 1; - if (is_vision) { - theta_scale_pwr = sector - params.sections0; - } - } else if (sector >= sec_w && sector < sec_e) { - theta_base_mult = 2; - if (is_vision) { - theta_scale_pwr = sector - sec_w; - } - } else if (sector >= sec_e) { - if (is_vision) { - theta_scale_pwr = sector - sec_e; - theta_scale_pwr = (i0 / 2) % sec_e; - } - theta_base_mult = 3; - } else if (is_vision) { - theta_scale_pwr = sector; + if (is_imrope) { + if (sector % 3 == 1 && sector < 3 * params.sections1) { + theta_base_mult = 1; + } else if (sector % 3 == 2 && sector < 3 * params.sections2) { + theta_base_mult = 2; + } else if (sector % 3 == 0 && sector < 3 * params.sections0) { + theta_base_mult = 0; + } else { + theta_base_mult = 3; + } + } else { + if (sector >= params.sections0 && sector < sec_w) { + theta_base_mult = 1; + if (is_vision) { + theta_scale_pwr = sector - params.sections0; + } + } else if (sector >= sec_w && sector < sec_e) { + theta_base_mult = 2; + if (is_vision) { + theta_scale_pwr = sector - sec_w; + } + } else if (sector >= sec_e) { + if (is_vision) { + theta_scale_pwr = sector - sec_e; + theta_scale_pwr = (i0 / 2) % sec_e; + } + theta_base_mult = 3; + } else if (is_vision) { + theta_scale_pwr = sector; + } } } let theta_base = f32(src1[params.offset_src1 + i2 + params.ne2 * theta_base_mult]) * pow(params.theta_scale, f32(theta_scale_pwr)); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 4b1304c2..92361d6f 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7076,7 +7076,12 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 7B) test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 20, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 32, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); + test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 128, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,imrope (qwen3vl 2B) + test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 128, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,imrope (qwen3vl 7B) + test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 20, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw)); + test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 32, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw)); test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT) + test_cases.emplace_back(new test_rope(type, {128, 16, 2, 1}, 128, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen3vl) } test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B) @@ -7092,7 +7097,7 @@ static std::vector> make_test_cases_eval() { // single inplace test per type/mode/ff for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { - for (int mode : {GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX, GGML_ROPE_TYPE_MROPE, GGML_ROPE_TYPE_VISION}) { + for (int mode : {GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX, GGML_ROPE_TYPE_MROPE, GGML_ROPE_TYPE_IMROPE, GGML_ROPE_TYPE_VISION}) { for (bool ff : {false, true}) { test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, mode, 512, 1.4245f, 0.7465f, 1.4245f, ff, 0, true, true)); }