#define GGML_ROPE_TYPE_NEOX 2
#define GGML_ROPE_TYPE_MROPE 8
#define GGML_ROPE_TYPE_VISION 24
+#define GGML_ROPE_TYPE_IMROPE 40 // binary: 101000
#define GGML_MROPE_SECTIONS 4
}
static void ggml_mrope_cache_init(
- float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects,
+ float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool is_imrope, bool indep_sects,
float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
float * cache, float sin_sign, float theta_scale) {
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
}
float theta = theta_t;
- if (sector >= sections[0] && sector < sec_w) {
- theta = theta_h;
- }
- else if (sector >= sec_w && sector < sec_w + sections[2]) {
- theta = theta_w;
- }
- else if (sector >= sec_w + sections[2]) {
- theta = theta_e;
+ if (is_imrope) { // qwen3vl apply interleaved mrope
+ if (sector % 3 == 1 && sector < 3 * sections[1]) {
+ theta = theta_h;
+ } else if (sector % 3 == 2 && sector < 3 * sections[2]) {
+ theta = theta_w;
+ } else if (sector % 3 == 0 && sector < 3 * sections[0]) {
+ theta = theta_t;
+ } else {
+ theta = theta_e;
+ }
+ } else {
+ if (sector >= sections[0] && sector < sec_w) {
+ theta = theta_h;
+ }
+ else if (sector >= sec_w && sector < sec_w + sections[2]) {
+ theta = theta_w;
+ }
+ else if (sector >= sec_w + sections[2]) {
+ theta = theta_e;
+ }
}
rope_yarn(
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding
+ const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_mrope) {
const int64_t p_w = pos[i2 + ne2 * 2];
const int64_t p_e = pos[i2 + ne2 * 3];
ggml_mrope_cache_init(
- p_t, p_h, p_w, p_e, sections, is_vision,
+ p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
}
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
+ const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_mrope) {
const int64_t p_w = pos[i2 + ne2 * 2];
const int64_t p_e = pos[i2 + ne2 * 3];
ggml_mrope_cache_init(
- p_t, p_h, p_w, p_e, sections, is_vision,
+ p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
}
static __global__ void rope_multi(
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
- const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections) {
+ const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections, const bool is_imrope) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (i0 >= ne0) {
const int sector = (i0 / 2) % sect_dims;
float theta_base = 0.0;
- if (sector < sections.v[0]) {
- theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
- }
- else if (sector >= sections.v[0] && sector < sec_w) {
- theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
- }
- else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
- theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
- }
- else if (sector >= sec_w + sections.v[2]) {
- theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
+ if (is_imrope) {
+ if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h
+ theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
+ } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w
+ theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
+ } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t
+ theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
+ } else {
+ theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
+ }
+ } else {
+ if (sector < sections.v[0]) {
+ theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
+ }
+ else if (sector >= sections.v[0] && sector < sec_w) {
+ theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
+ }
+ else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
+ theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
+ }
+ else if (sector >= sec_w + sections.v[2]) {
+ theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
+ }
}
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
static void rope_multi_cuda(
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
- const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) {
+ const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, const bool is_imrope, cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
if (freq_factors == nullptr) {
rope_multi<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
- attn_factor, corr_dims, theta_scale, freq_factors, sections);
+ attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
} else {
rope_multi<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
- attn_factor, corr_dims, theta_scale, freq_factors, sections);
+ attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
}
}
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
+ const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_mrope) {
if (src0->type == GGML_TYPE_F32) {
rope_multi_cuda<forward>(
(const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
- freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
+ freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream);
} else if (src0->type == GGML_TYPE_F16) {
rope_multi_cuda<forward>(
(const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
- freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
+ freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream);
} else {
GGML_ABORT("fatal error");
}
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
+ const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_neox) {
snprintf(base, 256, "kernel_rope_neox_%s", ggml_type_name(op->src[0]->type));
- } else if (is_mrope && !is_vision) {
+ } else if ((is_mrope || is_imrope) && !is_vision) {
GGML_ASSERT(op->src[1]->ne[0]*4 >= op->src[0]->ne[2]); // need at least 4 pos per token
snprintf(base, 256, "kernel_rope_multi_%s", ggml_type_name(op->src[0]->type));
} else if (is_vision) {
snprintf(base, 256, "kernel_rope_norm_%s", ggml_type_name(op->src[0]->type));
}
- snprintf(name, 256, "%s", base);
+ snprintf(name, 256, "%s_imrope=%d", base, is_imrope ? 1 : 0);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+ ggml_metal_cv_t cv = ggml_metal_cv_init();
+
+ ggml_metal_cv_set_bool(cv, is_imrope, FC_ROPE + 0);
+
+ res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+
+ ggml_metal_cv_free(cv);
return res;
}
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 500
#define FC_MUL_MV 600
#define FC_MUL_MM 700
+#define FC_ROPE 800
// op-specific constants
#define OP_FLASH_ATTN_EXT_NQPTG 8
template [[host_name("kernel_mul_mv_bf16_bf16_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<bfloat, bfloat>;
#endif
+constant bool FC_rope_is_imrope [[function_constant(FC_ROPE + 0)]];
+
static float rope_yarn_ramp(const float low, const float high, const int i0) {
const float y = (i0 / 2 - low) / max(0.001f, high - low);
return 1.0f - min(1.0f, max(0.0f, y));
const int sector = ic % sect_dims;
float theta_base;
- if (sector < args.sect_0) {
- theta_base = (float) pos[i2];
- } else if (sector < sec_w01) {
- theta_base = (float) pos[i2 + args.ne02];
- } else if (sector < sec_w012) {
- theta_base = (float) pos[i2 + args.ne02 * 2];
+ if (FC_rope_is_imrope) {
+ if (sector % 3 == 1 && sector < 3 * args.sect_1) { // h
+ theta_base = (float) pos[i2 + args.ne02 * 1];
+ } else if (sector % 3 == 2 && sector < 3 * args.sect_2) { // w
+ theta_base = (float) pos[i2 + args.ne02 * 2];
+ } else if (sector % 3 == 0 && sector < 3 * args.sect_0) { // t
+ theta_base = (float) pos[i2 + args.ne02 * 0];
+ } else { // e
+ theta_base = (float) pos[i2 + args.ne02 * 3];
+ }
} else {
- theta_base = (float) pos[i2 + args.ne02 * 3];
+ if (sector < args.sect_0) {
+ theta_base = (float) pos[i2];
+ } else if (sector < sec_w01) {
+ theta_base = (float) pos[i2 + args.ne02 * 1];
+ } else if (sector < sec_w012) {
+ theta_base = (float) pos[i2 + args.ne02 * 2];
+ } else {
+ theta_base = (float) pos[i2 + args.ne02 * 3];
+ }
}
// end of mrope
const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
const float theta_scale, const float * freq_factors, const mrope_sections sections,
- const sycl::nd_item<3> & item_ct1) {
+ const bool is_imrope, const sycl::nd_item<3> & item_ct1) {
// get index pos
const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1));
if (i0 >= ne0) {
float theta_base = 0.0;
- if (sector < sections.v[0]) {
- theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
- }
- else if (sector >= sections.v[0] && sector < sec_w) {
- theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
- }
- else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
- theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
- }
- else if (sector >= sec_w + sections.v[2]) {
- theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
+ if (is_imrope) {
+ if (sector % 3 == 1 && sector < 3 * sections.v[1]) {
+ theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
+ } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) {
+ theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
+ } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) {
+ theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
+ } else {
+ theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
+ }
+ } else {
+ if (sector < sections.v[0]) {
+ theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
+ }
+ else if (sector >= sections.v[0] && sector < sec_w) {
+ theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
+ }
+ else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
+ theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
+ }
+ else if (sector >= sec_w + sections.v[2]) {
+ theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
+ }
}
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
const size_t s2, const int n_dims, const int nr, const int32_t * pos,
const float freq_scale, const float freq_base, const float ext_factor,
const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors,
- const mrope_sections sections, queue_ptr stream) {
+ const mrope_sections sections, const bool is_imrope, queue_ptr stream) {
GGML_ASSERT(ne0 % 2 == 0);
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
if (freq_factors == nullptr) {
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
rope_multi<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
- corr_dims, theta_scale, freq_factors, sections, item_ct1);
+ corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1);
});
} else {
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
rope_multi<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
- corr_dims, theta_scale, freq_factors, sections, item_ct1);
+ corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1);
});
}
}
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
+ const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_mrope) {
if (dst->src[0]->type == GGML_TYPE_F16) {
rope_multi_sycl((const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, ne01, ne02, s01,
s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
- freq_factors, sections, main_stream);
+ freq_factors, sections, is_imrope, main_stream);
} else if (dst->src[0]->type == GGML_TYPE_F32) {
rope_multi_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
- main_stream);
+ is_imrope, main_stream);
} else {
GGML_ABORT("Fatal error: Tensor type unsupported!");
}
uint32_t s1;
uint32_t s2;
int32_t sections[4];
+ uint32_t is_imrope;
uint32_t is_back;
uint32_t set_rows_stride;
};
memcpy(sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
}
+ const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
+
float corr_dims[2];
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
(uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
src2 != nullptr, (uint32_t)src0->ne[2], s1, s2,
- { sections[0], sections[1], sections[2], sections[3] }, backprop, set_rows_stride,
+ { sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride,
}, dryrun);
}
uint s1;
uint s2;
int sections[4];
+ uint is_imrope;
uint is_back;
uint set_rows_stride;
} p;
const uint sector = (i0 / 2) % sect_dims;
float theta_base = 0.0;
- if (sector < p.sections[0]) {
- theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
- }
- else if (sector >= p.sections[0] && sector < sec_w) {
- theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
- }
- else if (sector >= sec_w && sector < sec_w + p.sections[2]) {
- theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
- }
- else if (sector >= sec_w + p.sections[2]) {
- theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
+ if (p.is_imrope != 0) {
+ if (sector % 3 == 1 && sector < 3 * p.sections[1]) {
+ theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
+ } else if (sector % 3 == 2 && sector < 3 * p.sections[2]) {
+ theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
+ } else if (sector % 3 == 0 && sector < 3 * p.sections[0]) {
+ theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
+ } else {
+ theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
+ }
+ } else {
+ if (sector < p.sections[0]) {
+ theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
+ }
+ else if (sector >= p.sections[0] && sector < sec_w) {
+ theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
+ }
+ else if (sector >= sec_w && sector < sec_w + p.sections[2]) {
+ theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
+ }
+ else if (sector >= sec_w + p.sections[2]) {
+ theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
+ }
}
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
let is_neox = bool(params.mode & 2);
let is_mrope = bool(params.mode & 8);
+ let is_imrope = params.mode == 40;
let is_vision = params.mode == 24;
var i = gid.x * 2; // start index for this thread
let sec_w = params.sections1 + params.sections0;
let sec_e = params.sections2 + sec_w;
let sector = (i0 / 2) % sect_dims;
- if (sector >= params.sections0 && sector < sec_w) {
- theta_base_mult = 1;
- if (is_vision) {
- theta_scale_pwr = sector - params.sections0;
- }
- } else if (sector >= sec_w && sector < sec_e) {
- theta_base_mult = 2;
- if (is_vision) {
- theta_scale_pwr = sector - sec_w;
- }
- } else if (sector >= sec_e) {
- if (is_vision) {
- theta_scale_pwr = sector - sec_e;
- theta_scale_pwr = (i0 / 2) % sec_e;
- }
- theta_base_mult = 3;
- } else if (is_vision) {
- theta_scale_pwr = sector;
+ if (is_imrope) {
+ if (sector % 3 == 1 && sector < 3 * params.sections1) {
+ theta_base_mult = 1;
+ } else if (sector % 3 == 2 && sector < 3 * params.sections2) {
+ theta_base_mult = 2;
+ } else if (sector % 3 == 0 && sector < 3 * params.sections0) {
+ theta_base_mult = 0;
+ } else {
+ theta_base_mult = 3;
+ }
+ } else {
+ if (sector >= params.sections0 && sector < sec_w) {
+ theta_base_mult = 1;
+ if (is_vision) {
+ theta_scale_pwr = sector - params.sections0;
+ }
+ } else if (sector >= sec_w && sector < sec_e) {
+ theta_base_mult = 2;
+ if (is_vision) {
+ theta_scale_pwr = sector - sec_w;
+ }
+ } else if (sector >= sec_e) {
+ if (is_vision) {
+ theta_scale_pwr = sector - sec_e;
+ theta_scale_pwr = (i0 / 2) % sec_e;
+ }
+ theta_base_mult = 3;
+ } else if (is_vision) {
+ theta_scale_pwr = sector;
+ }
}
}
let theta_base = f32(src1[params.offset_src1 + i2 + params.ne2 * theta_base_mult]) * pow(params.theta_scale, f32(theta_scale_pwr));
test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 7B)
test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 20, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw));
test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 32, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw));
+ test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 128, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,imrope (qwen3vl 2B)
+ test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 128, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,imrope (qwen3vl 7B)
+ test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 20, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw));
+ test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 32, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw));
test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT)
+ test_cases.emplace_back(new test_rope(type, {128, 16, 2, 1}, 128, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen3vl)
}
test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)
// single inplace test per type/mode/ff
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
- for (int mode : {GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX, GGML_ROPE_TYPE_MROPE, GGML_ROPE_TYPE_VISION}) {
+ for (int mode : {GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX, GGML_ROPE_TYPE_MROPE, GGML_ROPE_TYPE_IMROPE, GGML_ROPE_TYPE_VISION}) {
for (bool ff : {false, true}) {
test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, mode, 512, 1.4245f, 0.7465f, 1.4245f, ff, 0, true, true));
}