// Prefer going as small as num_topk_pipelines - 3 for perf reasons.
// But if K is larger, then we need a larger workgroup
- uint32_t max_pipeline = num_topk_pipelines - 3;
+ uint32_t max_pipeline = num_topk_pipelines - 1;
+ uint32_t preferred_pipeline = std::max(num_topk_pipelines - 3, (uint32_t)log2f(float(k)) + 2);
+ max_pipeline = std::min(preferred_pipeline, max_pipeline);
uint32_t min_pipeline = (uint32_t)log2f(float(k)) + 1;
// require full subgroup
min_pipeline = std::max(min_pipeline, ctx->device->subgroup_size_log2);