const bool is_neox = mode & 2;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
+ const int is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
if (is_mrope) {
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
CL_CHECK(clSetKernelArg(kernel, 30, sizeof(float), &attn_factor));
CL_CHECK(clSetKernelArg(kernel, 31, sizeof(float), &beta_fast));
CL_CHECK(clSetKernelArg(kernel, 32, sizeof(float), &beta_slow));
+ // both mrope and vision kernels have sections
if (is_mrope || is_vision) {
CL_CHECK(clSetKernelArg(kernel, 33, sizeof(int32_t)*4, §ions));
}
+ // only mrope has is_imrope
+ if (is_mrope && !is_vision) {
+ CL_CHECK(clSetKernelArg(kernel, 34, sizeof(int), &is_imrope));
+ }
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
size_t local_work_size[] = {(size_t)nth, 1, 1};
float attn_factor,
float beta_fast,
float beta_slow,
- int4 sections
+ int4 sections,
+ int is_imrope
) {
src0 = (global void*)((global char*)src0 + offset0);
src1 = (global int*)((global char*)src1 + offset1);
const int sector = (i0 / 2) % sect_dims;
float theta_base = 0.0f;
- if (sector < sections.s0) {
- theta_base = pos[i2];
- }
- else if (sector >= sections.s0 && sector < sec_w) {
- theta_base = pos[i2 + ne2 * 1];
- }
- else if (sector >= sec_w && sector < sec_w + sections.s2) {
- theta_base = pos[i2 + ne2 * 2];
- }
- else if (sector >= sec_w + sections.s2) {
- theta_base = pos[i2 + ne2 * 3];
+ if (is_imrope) {
+ if (sector % 3 == 1 && sector < 3 * sections.s1) { // h
+ theta_base = (float) pos[i2 + ne02 * 1];
+ } else if (sector % 3 == 2 && sector < 3 * sections.s2) { // w
+ theta_base = (float) pos[i2 + ne02 * 2];
+ } else if (sector % 3 == 0 && sector < 3 * sections.s0) { // t
+ theta_base = (float) pos[i2 + ne02 * 0];
+ } else { // e
+ theta_base = (float) pos[i2 + ne02 * 3];
+ }
+ } else {
+ if (sector < sections.s0) {
+ theta_base = pos[i2];
+ }
+ else if (sector >= sections.s0 && sector < sec_w) {
+ theta_base = pos[i2 + ne2 * 1];
+ }
+ else if (sector >= sec_w && sector < sec_w + sections.s2) {
+ theta_base = pos[i2 + ne2 * 2];
+ }
+ else if (sector >= sec_w + sections.s2) {
+ theta_base = pos[i2 + ne2 * 3];
+ }
}
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
float attn_factor,
float beta_fast,
float beta_slow,
- int4 sections
+ int4 sections,
+ int is_imrope
) {
src0 = (global void*)((global char*)src0 + offset0);
src1 = (global int*)((global char*)src1 + offset1);
const int sector = (i0 / 2) % sect_dims;
float theta_base = 0.0f;
- if (sector < sections.s0) {
- theta_base = pos[i2];
- }
- else if (sector >= sections.s0 && sector < sec_w) {
- theta_base = pos[i2 + ne2 * 1];
- }
- else if (sector >= sec_w && sector < sec_w + sections.s2) {
- theta_base = pos[i2 + ne2 * 2];
- }
- else if (sector >= sec_w + sections.s2) {
- theta_base = pos[i2 + ne2 * 3];
+ if (is_imrope) {
+ if (sector % 3 == 1 && sector < 3 * sections.s1) { // h
+ theta_base = (float) pos[i2 + ne02 * 1];
+ } else if (sector % 3 == 2 && sector < 3 * sections.s2) { // w
+ theta_base = (float) pos[i2 + ne02 * 2];
+ } else if (sector % 3 == 0 && sector < 3 * sections.s0) { // t
+ theta_base = (float) pos[i2 + ne02 * 0];
+ } else { // e
+ theta_base = (float) pos[i2 + ne02 * 3];
+ }
+ } else {
+ if (sector < sections.s0) {
+ theta_base = pos[i2];
+ }
+ else if (sector >= sections.s0 && sector < sec_w) {
+ theta_base = pos[i2 + ne2 * 1];
+ }
+ else if (sector >= sec_w && sector < sec_w + sections.s2) {
+ theta_base = pos[i2 + ne2 * 2];
+ }
+ else if (sector >= sec_w + sections.s2) {
+ theta_base = pos[i2 + ne2 * 3];
+ }
}
const float theta = theta_base * pow(freq_base, inv_ndims*i0);