From: Jeff Bolz Date: Fri, 26 Dec 2025 15:53:46 +0000 (-0600) Subject: vulkan: handle rope with large number of rows (llama/18306) X-Git-Tag: upstream/1.8.3~85 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=67473fef57eaf072671892be62384bea373ece21;p=pkg%2Fggml%2Fsources%2Fwhisper.cpp vulkan: handle rope with large number of rows (llama/18306) --- diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 1459b260..e7ce518f 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1192,6 +1192,7 @@ struct vk_op_diag_mask_push_constants { struct vk_op_rope_push_constants { uint32_t rope_mode; uint32_t ncols; + uint32_t nrows; uint32_t n_dims; float freq_scale; uint32_t p_delta_rows; @@ -9090,10 +9091,20 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co elements = { num_groups * (uint32_t)src0->ne[3], 1, 1 }; } break; case GGML_OP_DIAG_MASK_INF: - case GGML_OP_ROPE: - case GGML_OP_ROPE_BACK: elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 }; break; + case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: + { + uint32_t nrows = (uint32_t)ggml_nrows(src0); + uint32_t z = 1; + if (nrows > ctx->device->properties.limits.maxComputeWorkGroupCount[0]) { + z = CEIL_DIV(nrows, 32768); + nrows = 32768; + } + elements = { nrows, (uint32_t)ne00, z }; + + } break; case GGML_OP_GET_ROWS: elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) }; elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]); @@ -10021,7 +10032,7 @@ static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor * uint32_t nb02 = src0->nb[2] / ggml_type_size(src0->type); vk_op_rope_push_constants rope { - (uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], + (uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, has_ff, (uint32_t)src0->ne[2], nb01, nb02, { sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride, diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp index 7c1fb1cd..f7587468 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp @@ -6,6 +6,9 @@ void main() { const uint i0 = 2*gl_GlobalInvocationID.y; // i1 is actually i2*nb2+i1, but the rows are contiguous - const uint i1 = gl_GlobalInvocationID.x; + const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; + if (i1 >= pc.nrows) { + return; + } rope_multi(i0, i1, pc); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp index 68f00c18..acb8ed78 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp @@ -6,6 +6,9 @@ void main() { const uint i0 = 2*gl_GlobalInvocationID.y; // i1 is actually i2*nb2+i1, but the rows are contiguous - const uint i1 = gl_GlobalInvocationID.x; + const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; + if (i1 >= pc.nrows) { + return; + } rope_neox(i0, i1, pc); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp index 28a939ec..0033cdb2 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp @@ -6,6 +6,9 @@ void main() { const uint i0 = 2*gl_GlobalInvocationID.y; // i1 is actually i2*nb2+i1, but the rows are contiguous - const uint i1 = gl_GlobalInvocationID.x; + const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; + if (i1 >= pc.nrows) { + return; + } rope_norm(i0, i1, pc); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl index 82f39cee..939cf3c5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl @@ -6,6 +6,7 @@ struct rope_params { uint rope_mode; uint ncols; + uint nrows; uint n_dims; float freq_scale; uint p_delta_rows; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp index ea1e0fdb..d93800b5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp @@ -6,6 +6,9 @@ void main() { const uint i0 = 2*gl_GlobalInvocationID.y; // i1 is actually i2*nb2+i1, but the rows are contiguous - const uint i1 = gl_GlobalInvocationID.x; + const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; + if (i1 >= pc.nrows) { + return; + } rope_vision(i0, i1, pc); }