struct vk_op_argsort_push_constants {
uint32_t ncols;
+ uint32_t nrows;
int32_t order;
};
break;
case GGML_OP_ARGSORT:
elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 };
+ elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
break;
case GGML_OP_IM2COL:
{
int32_t * op_params = (int32_t *)dst->op_params;
uint32_t ncols = src0->ne[0];
+ uint32_t nrows = ggml_nrows(src0);
ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGSORT, {
ncols,
+ nrows,
op_params[0],
}, dryrun);
}
layout (push_constant) uniform parameter {
uint ncols;
+ uint nrows;
uint order;
} p;
dst_row[idx1] = tmp;
}
-void argsort(bool needs_bounds_check) {
+void argsort(bool needs_bounds_check, const uint row) {
// bitonic sort
const int col = int(gl_LocalInvocationID.x);
- const uint row = gl_WorkGroupID.y;
const uint row_offset = row * p.ncols;
void main() {
if (p.ncols == BLOCK_SIZE) {
- argsort(false);
+ uint row = gl_WorkGroupID.y;
+ while (row < p.nrows) {
+ argsort(false, row);
+ row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
+ }
} else {
- argsort(true);
+ uint row = gl_WorkGroupID.y;
+ while (row < p.nrows) {
+ argsort(true, row);
+ row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
+ }
}
}