- Launch an appropriate number of invocations (next larger power of two).
32 invocations is common and the barrier is much cheaper there.
- Specialize for "needs bounds checking" vs not.
- Make the code less branchy and [[unroll]] the loops. In the final code,
I see no branches inside the main loop (only predicated stores) when
needs_bounds_check is false.
- Always sort ascending, then apply the ascending vs descending option when
doing the final stores to memory.
- Copy the values into shared memory, makes them slightly cheaper to access.
CONV_SHAPE_COUNT,
};
+static constexpr uint32_t num_argsort_pipelines = 11;
+static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
+
struct vk_device_struct {
std::recursive_mutex mutex;
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
- vk_pipeline pipeline_argsort_f32;
+ vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
vk_pipeline pipeline_sum_rows_f32;
vk_pipeline pipeline_argmax_f32;
vk_pipeline pipeline_count_equal_i32;
struct vk_op_argsort_push_constants {
uint32_t ncols;
- uint32_t ncols_pad;
int32_t order;
};
ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
}
- ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
+ for (uint32_t i = 0; i < num_argsort_pipelines; ++i) {
+ ggml_vk_create_pipeline(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1u<<i, 1, 1}, {1u<<i, i}, 1, true);
+ }
ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
}
case GGML_OP_ARGSORT:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
- return ctx->device->pipeline_argsort_f32;
+ uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
+ return ctx->device->pipeline_argsort_f32[idx];
}
return nullptr;
case GGML_OP_SUM:
uint32_t ncols = src0->ne[0];
- uint32_t ncols_pad = 1;
- while (ncols_pad < ncols) {
- ncols_pad *= 2;
- }
-
- GGML_ASSERT(ncols_pad <= 1024);
-
ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGSORT, {
ncols,
- ncols_pad,
op_params[0],
}, dryrun);
}
case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_SGD:
return op->src[0]->type == GGML_TYPE_F32;
+ case GGML_OP_ARGSORT:
+ return op->ne[0] <= max_argsort_cols;
case GGML_OP_UPSCALE:
case GGML_OP_ACC:
case GGML_OP_CONCAT:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
case GGML_OP_SOFT_MAX_BACK:
- case GGML_OP_ARGSORT:
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
case GGML_OP_ARGMAX:
#version 450
+#extension GL_EXT_control_flow_attributes : enable
#include "types.comp"
-#define BLOCK_SIZE 1024
+layout(constant_id = 0) const int BLOCK_SIZE = 1024;
+layout(constant_id = 1) const int BLOCK_SIZE_LOG2 = 10;
#define ASC 0
-layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) buffer D {int data_d[];};
layout (push_constant) uniform parameter {
uint ncols;
- uint ncols_pad;
uint order;
} p;
shared int dst_row[BLOCK_SIZE];
+shared A_TYPE a_sh[BLOCK_SIZE];
void swap(uint idx0, uint idx1) {
int tmp = dst_row[idx0];
dst_row[idx1] = tmp;
}
-void main() {
+void argsort(bool needs_bounds_check) {
// bitonic sort
const int col = int(gl_LocalInvocationID.x);
const uint row = gl_WorkGroupID.y;
const uint row_offset = row * p.ncols;
// initialize indices
- if (col < p.ncols_pad) {
- dst_row[col] = col;
- }
+ dst_row[col] = col;
+ a_sh[col] = data_a[row_offset + col];
barrier();
- for (uint k = 2; k <= p.ncols_pad; k *= 2) {
- for (uint j = k / 2; j > 0; j /= 2) {
- const uint ixj = col ^ j;
- if (col < p.ncols_pad && ixj > col) {
- if ((col & k) == 0) {
- if (dst_row[col] >= p.ncols ||
- (dst_row[ixj] < p.ncols && (p.order == ASC ?
- data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]] :
- data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]]))
- ) {
- swap(col, ixj);
- }
- } else {
- if (dst_row[ixj] >= p.ncols ||
- (dst_row[col] < p.ncols && (p.order == ASC ?
- data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]] :
- data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]]))
- ) {
- swap(col, ixj);
- }
- }
+ uint num_outer_loop_iters = BLOCK_SIZE_LOG2;
+ [[unroll]] for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) {
+ uint num_inner_loop_iters = outer_idx + 1;
+ [[unroll]] for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) {
+ const int ixj = int(col ^ j);
+
+ int idx_0 = (col & k) == 0 ? col : ixj;
+ int idx_1 = (col & k) == 0 ? ixj : col;
+
+ int sh_idx_0 = dst_row[idx_0];
+ int sh_idx_1 = dst_row[idx_1];
+ bool idx_0_oob = needs_bounds_check ? sh_idx_0 >= p.ncols : false;
+ bool idx_1_oob = needs_bounds_check ? sh_idx_1 >= p.ncols : false;
+
+ if ((idx_0_oob ||
+ (!idx_1_oob && a_sh[sh_idx_0] > a_sh[sh_idx_1])) && (ixj > col)) {
+ swap(idx_0, idx_1);
}
+
barrier();
}
}
if (col < p.ncols) {
- data_d[row_offset + col] = dst_row[col];
+ if (p.order == ASC) {
+ data_d[row_offset + col] = dst_row[col];
+ } else {
+ data_d[row_offset + p.ncols - col - 1] = dst_row[col];
+ }
+ }
+}
+
+void main() {
+ if (p.ncols == BLOCK_SIZE) {
+ argsort(false);
+ } else {
+ argsort(true);
}
}
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order));
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen
+ test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 1, 1, 1}, order));
}
for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR}) {