uint32_t nary_shmem = 2 * sizeof(int) * BLOCK_SIZE +
sizeof(int) * device->subgroup_size +
2 * sizeof(int) +
- (BLOCK_SIZE / device->subgroup_size) * sizeof(int);
+ 2 * (BLOCK_SIZE / device->subgroup_size) * sizeof(int);
if (device->subgroup_arithmetic && device->subgroup_require_full_support && device->subgroup_shuffle && device->subgroup_ballot &&
nary_shmem <= device->properties.limits.maxComputeSharedMemorySize) {
ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_nary_search_f32_len, topk_nary_search_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, device->subgroup_size, device->subgroup_size_log2}, 1, true, true, device->subgroup_size);
shared int sh_min_idx;
shared uint sh_total;
shared uint offset_partials[BLOCK_SIZE / SUBGROUP_SIZE];
+shared uint eq_min_partials[BLOCK_SIZE / SUBGROUP_SIZE];
// Map float values to uint such that comparisons still work.
// Positive values set the high bit, negative values are inverted.
// We need to compact these values to the start of the dst_row array.
// Have each subgroup count how many items it'll store, so other
// subgroups can compute their base offset.
- bool top = f2ui(intBitsToFloat(v.y)) >= range_min;
- uvec4 b = subgroupBallot(top);
- uint bit_count = subgroupBallotBitCount(b);
- if ((tid % SUBGROUP_SIZE) == 0) {
- offset_partials[tid / SUBGROUP_SIZE] = bit_count;
- }
- barrier();
+ // Values strictly greater than range_min must be stored. For values equal
+ // to range_min, there can be ties and it's possible we'll need to store
+ // an arbitrary subset of them.
+ // If total == p.k, have a fast path where we don't need to handle ties.
+ if (total == p.k) {
+ bool top = f2ui(intBitsToFloat(v.y)) >= range_min;
+ uvec4 b = subgroupBallot(top);
+ uint bit_count = subgroupBallotBitCount(b);
+ if ((tid % SUBGROUP_SIZE) == 0) {
+ offset_partials[tid / SUBGROUP_SIZE] = bit_count;
+ }
+ barrier();
- uint out_idx = 0;
- [[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) {
- if (i < tid / SUBGROUP_SIZE) {
- out_idx += offset_partials[i];
+ uint out_idx = 0;
+ [[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) {
+ if (i < tid / SUBGROUP_SIZE) {
+ out_idx += offset_partials[i];
+ }
}
- }
- uint bit_count_ex = subgroupBallotExclusiveBitCount(b);
- if (top) {
- // TODO: Copy directly to the output?
- dst_row[out_idx + bit_count_ex] = v;
+ uint bit_count_ex = subgroupBallotExclusiveBitCount(b);
+ if (top) {
+ // TODO: Copy directly to the output?
+ dst_row[out_idx + bit_count_ex] = v;
+ }
+ } else {
+ bool top = f2ui(intBitsToFloat(v.y)) > range_min;
+ bool eq_min = f2ui(intBitsToFloat(v.y)) == range_min;
+ uvec4 b_top = subgroupBallot(top);
+ uvec4 b_eq_min = subgroupBallot(eq_min);
+ uint bit_count_top = subgroupBallotBitCount(b_top);
+ uint bit_count_eq_min = subgroupBallotBitCount(b_eq_min);
+ if ((tid % SUBGROUP_SIZE) == 0) {
+ offset_partials[tid / SUBGROUP_SIZE] = bit_count_top;
+ eq_min_partials[tid / SUBGROUP_SIZE] = bit_count_eq_min;
+ }
+ barrier();
+
+ uint out_idx = 0;
+ uint eq_min_base = 0;
+ uint eq_min_idx = 0;
+ [[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) {
+ if (i < tid / SUBGROUP_SIZE) {
+ out_idx += offset_partials[i];
+ eq_min_idx += eq_min_partials[i];
+ }
+ eq_min_base += offset_partials[i];
+ }
+ // range_min values are stored at the end
+ eq_min_idx += eq_min_base;
+
+ uint bit_count_ex_top = subgroupBallotExclusiveBitCount(b_top);
+ uint bit_count_ex_eq_min = subgroupBallotExclusiveBitCount(b_eq_min);
+ if (top) {
+ // TODO: Copy directly to the output?
+ dst_row[out_idx + bit_count_ex_top] = v;
+ }
+ if (eq_min && eq_min_idx + bit_count_ex_eq_min < p.k) {
+ dst_row[eq_min_idx + bit_count_ex_eq_min] = v;
+ }
}
barrier();