struct vk_op_rope_push_constants {
uint32_t rope_mode;
- uint32_t ncols;
uint32_t nrows;
uint32_t n_dims;
float freq_scale;
- uint32_t p_delta_rows;
float freq_base;
float ext_factor;
float attn_factor;
float corr_dims[2];
float theta_scale;
uint32_t has_ff;
- uint32_t ne02;
- uint32_t s1;
- uint32_t s2;
int32_t sections[4];
uint32_t is_imrope;
uint32_t is_back;
uint32_t set_rows_stride;
+ uint32_t ne00;
+ uint32_t ne01;
+ uint32_t ne02;
+ uint32_t nb01;
+ uint32_t nb02;
+ uint32_t nb03;
+ uint32_t nb11;
+ uint32_t nb12;
+ uint32_t nb13;
};
+static_assert(sizeof(vk_op_rope_push_constants) <= 128, "sizeof(vk_op_rope_push_constants) must be <= 128");
// For fused rms_norm+mul+rope(+view+set_rows)
struct vk_op_rms_norm_mul_rope_push_constants {
uint32_t nb01 = src0->nb[1] / ggml_type_size(src0->type);
uint32_t nb02 = src0->nb[2] / ggml_type_size(src0->type);
+ uint32_t nb03 = src0->nb[3] / ggml_type_size(src0->type);
+
+ uint32_t nb11 = dst->nb[1] / ggml_type_size(dst->type);
+ uint32_t nb12 = dst->nb[2] / ggml_type_size(dst->type);
+ uint32_t nb13 = dst->nb[3] / ggml_type_size(dst->type);
vk_op_rope_push_constants rope {
- (uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
- freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
- has_ff, (uint32_t)src0->ne[2], nb01, nb02,
+ (uint32_t)mode, (uint32_t)ggml_nrows(src0), (uint32_t)n_dims, freq_scale,
+ freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, has_ff,
{ sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride,
+
+ (uint32_t)src0->ne[0],
+ (uint32_t)src0->ne[1],
+ (uint32_t)src0->ne[2],
+ nb01, nb02, nb03,
+ nb11, nb12, nb13,
};
return rope;
case GGML_OP_REPEAT_BACK:
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_ROPE:
+ return ggml_is_contiguous_rows(op) && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_ROPE_BACK:
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
return 1.0f - min(1.0f, max(0.0f, y));
}
-uint rope_a_coord(const uint i0, const uint i01, const uint i02, rope_params p) {
+uint rope_a_coord(const uint i0, const uint i01, const uint i02, const uint i03, rope_params p) {
#if RMS_NORM_ROPE_FUSION
// Per-row offset in shared memory
const uint ix = i0;
#else
- const uint ix = i02*p.nb02 + i01*p.nb01 + i0;
+ const uint ix = i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i0;
#endif
return ix;
}
sin_theta = sin(theta) * mscale;
}
-void rope_norm(const uint i0, const uint i1, rope_params p) {
- uint ne0 = p.ncols;
- uint ne1 = p.p_delta_rows;
-
- if (i0 >= ne0) {
+void rope_norm(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {
+ if (i0 >= p.ne00) {
return;
}
- // i1 is actually i2*nb2+i1, but the rows are contiguous
- const uint i01 = i1 % ne1;
- const uint i02 = i1 / ne1;
-
- uint idst = i1*ne0 + i0;
- const uint ix = rope_a_coord(i0, i01, i02, p);
+ uint idst = i0 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
+ const uint ix = rope_a_coord(i0, i1, i2, i3, p);
// Fusion optimization: ROPE + VIEW + SET_ROWS.
// The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.
if (p.set_rows_stride != 0) {
- idst = i01*ne0 + i0;
- idst += rope_data_i[i02].x * p.set_rows_stride;
+ idst = i1*p.nb11 + i0;
+ idst += rope_data_i[i2].x * p.set_rows_stride;
}
if (i0 >= p.n_dims) {
return;
}
- const float theta_base = rope_data_pos[i02] * pow(p.theta_scale, i0/2.0f);
+ const float theta_base = rope_data_pos[i2] * pow(p.theta_scale, i0/2.0f);
const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
rope_data_d[idst + 1] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
}
-void rope_neox(const uint i0, const uint i1, rope_params p) {
- uint ne0 = p.ncols;
- uint ne1 = p.p_delta_rows;
-
- if (i0 >= ne0) {
+void rope_neox(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {
+ if (i0 >= p.ne00) {
return;
}
- const uint i01 = i1 % ne1;
- const uint i02 = i1 / ne1;
-
- uint idst = i1*ne0 + i0/2;
- const uint ix = rope_a_coord(i0/2, i01, i02, p);
+ uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
+ const uint ix = rope_a_coord(i0/2, i1, i2, i3, p);
// Fusion optimization: ROPE + VIEW + SET_ROWS.
// The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.
if (p.set_rows_stride != 0) {
- idst = i01*ne0 + i0/2;
- idst += rope_data_i[i02].x * p.set_rows_stride;
+ idst = i1*p.nb11 + i0/2;
+ idst += rope_data_i[i2].x * p.set_rows_stride;
}
if (i0 >= p.n_dims) {
return;
}
- const float theta_base = rope_data_pos[i02] * pow(p.theta_scale, i0/2.0f);
+ const float theta_base = rope_data_pos[i2] * pow(p.theta_scale, i0/2.0f);
const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
}
-void rope_multi(const uint i0, const uint i1, rope_params p) {
- uint ne0 = p.ncols;
- uint ne1 = p.p_delta_rows;
- uint ne2 = p.ne02;
-
- if (i0 >= ne0) {
+void rope_multi(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {
+ if (i0 >= p.ne00) {
return;
}
- const uint i01 = i1 % ne1;
- const uint i02 = i1 / ne1;
-
- uint idst = i1*ne0 + i0/2;
- const uint ix = rope_a_coord(i0/2, i01, i02, p);
+ uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
+ const uint ix = rope_a_coord(i0/2, i1, i2, i3, p);
// Fusion optimization: ROPE + VIEW + SET_ROWS.
// The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.
if (p.set_rows_stride != 0) {
- idst = i01*ne0 + i0/2;
- idst += rope_data_i[i02].x * p.set_rows_stride;
+ idst = i1*p.nb11 + i0/2;
+ idst += rope_data_i[i2].x * p.set_rows_stride;
}
if (i0 >= p.n_dims) {
float theta_base = 0.0;
if (p.is_imrope != 0) {
if (sector % 3 == 1 && sector < 3 * p.sections[1]) {
- theta_base = rope_data_pos[i02 + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
+ theta_base = rope_data_pos[i2 + p.ne02 * 1]*pow(p.theta_scale, i0/2.0f);
} else if (sector % 3 == 2 && sector < 3 * p.sections[2]) {
- theta_base = rope_data_pos[i02 + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
+ theta_base = rope_data_pos[i2 + p.ne02 * 2]*pow(p.theta_scale, i0/2.0f);
} else if (sector % 3 == 0 && sector < 3 * p.sections[0]) {
- theta_base = rope_data_pos[i02]*pow(p.theta_scale, i0/2.0f);
+ theta_base = rope_data_pos[i2]*pow(p.theta_scale, i0/2.0f);
} else {
- theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
+ theta_base = rope_data_pos[i2 + p.ne02 * 3]*pow(p.theta_scale, i0/2.0f);
}
} else {
if (sector < p.sections[0]) {
- theta_base = rope_data_pos[i02]*pow(p.theta_scale, i0/2.0f);
+ theta_base = rope_data_pos[i2]*pow(p.theta_scale, i0/2.0f);
}
else if (sector >= p.sections[0] && sector < sec_w) {
- theta_base = rope_data_pos[i02 + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
+ theta_base = rope_data_pos[i2 + p.ne02 * 1]*pow(p.theta_scale, i0/2.0f);
}
else if (sector >= sec_w && sector < sec_w + p.sections[2]) {
- theta_base = rope_data_pos[i02 + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
+ theta_base = rope_data_pos[i2 + p.ne02 * 2]*pow(p.theta_scale, i0/2.0f);
}
else if (sector >= sec_w + p.sections[2]) {
- theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
+ theta_base = rope_data_pos[i2 + p.ne02 * 3]*pow(p.theta_scale, i0/2.0f);
}
}
rope_data_d[idst + p.n_dims/2] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
}
-void rope_vision(const uint i0, const uint i1, rope_params p) {
- uint ne0 = p.ncols;
- uint ne1 = p.p_delta_rows;
- uint ne2 = p.ne02;
-
- if (i0 >= ne0) {
+void rope_vision(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {
+ if (i0 >= p.ne00) {
return;
}
- const uint i01 = i1 % ne1;
- const uint i02 = i1 / ne1;
-
- const uint idst = i1*ne0 + i0/2;
- const uint ix = rope_a_coord(i0/2, i01, i02, p);
+ const uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
+ const uint ix = rope_a_coord(i0/2, i1, i2, i3, p);
const int sect_dims = p.sections[0] + p.sections[1];
const int sec_w = p.sections[1] + p.sections[0];
float theta_base = 0.0;
if (sector < p.sections[0]) {
const uint p0 = sector;
- theta_base = rope_data_pos[i02]*pow(p.theta_scale, p0);
+ theta_base = rope_data_pos[i2]*pow(p.theta_scale, p0);
}
else if (sector >= p.sections[0] && sector < sec_w) {
const uint p0 = sector - p.sections[0];
- theta_base = rope_data_pos[i02 + ne2]*pow(p.theta_scale, p0);
+ theta_base = rope_data_pos[i2 + p.ne02]*pow(p.theta_scale, p0);
}
const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;