struct vk_op_rope_push_constants {
uint32_t rope_mode;
uint32_t ncols;
+ uint32_t nrows;
uint32_t n_dims;
float freq_scale;
uint32_t p_delta_rows;
elements = { num_groups * (uint32_t)src0->ne[3], 1, 1 };
} break;
case GGML_OP_DIAG_MASK_INF:
- case GGML_OP_ROPE:
- case GGML_OP_ROPE_BACK:
elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 };
break;
+ case GGML_OP_ROPE:
+ case GGML_OP_ROPE_BACK:
+ {
+ uint32_t nrows = (uint32_t)ggml_nrows(src0);
+ uint32_t z = 1;
+ if (nrows > ctx->device->properties.limits.maxComputeWorkGroupCount[0]) {
+ z = CEIL_DIV(nrows, 32768);
+ nrows = 32768;
+ }
+ elements = { nrows, (uint32_t)ne00, z };
+
+ } break;
case GGML_OP_GET_ROWS:
elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) };
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
uint32_t nb02 = src0->nb[2] / ggml_type_size(src0->type);
vk_op_rope_push_constants rope {
- (uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
+ (uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
has_ff, (uint32_t)src0->ne[2], nb01, nb02,
{ sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride,