uint32_t orig_ncols;
uint32_t ncols_input;
uint32_t ncols_output;
+ uint32_t k;
uint32_t nrows;
uint32_t first_pass;
uint32_t last_pass;
timings[name.str()].push_back(time);
return;
}
+ if (node->op == GGML_OP_TOP_K) {
+ std::stringstream name;
+ name << ggml_op_name(node->op) <<
+ " K=" << node->ne[0] <<
+ " (" << node->src[0]->ne[0] << "," << node->src[0]->ne[1] << "," << node->src[0]->ne[2] << "," << node->src[0]->ne[3] << ")";
+ timings[name.str()].push_back(time);
+ return;
+ }
timings[ggml_op_name(node->op)].push_back(time);
}
private:
uint32_t nrows = ggml_nrows(src0);
uint32_t k = dst->ne[0];
- vk_op_topk_push_constants pc { ncols, ncols, k, nrows, 0, 0 };
+ vk_op_topk_push_constants pc { ncols, ncols, ncols, k, nrows, 0, 0 };
- // Reserve space for ivec2 per element, double buffered
- const size_t dbl_buf_size = size_t{ncols} * nrows * 2 * sizeof(int);
- const size_t x_sz = dbl_buf_size * 2;
- uint32_t dbl_buf_index = 0;
-
- if (ctx->prealloc_size_x < x_sz) {
- ctx->prealloc_size_x = x_sz;
- ggml_vk_preallocate_buffers(ctx, subctx);
- }
if (ctx->prealloc_x_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
// largest elements. Repeat until we have the top K elements.
// Need to do at least one iteration to write out the results.
bool done_one_iter = false;
+ uint32_t dbl_buf_index = 0;
+ size_t dbl_buf_size;
while (num_elements > k || !done_one_iter) {
- done_one_iter = true;
// Prefer going as small as num_topk_pipelines - 3 for perf reasons.
// But if K is larger, then we need a larger workgroup
// Number of elements remaining after this pass
uint32_t num_dst_elements = (num_elements / pipeline->wg_denoms[0]) * k + std::min(k, num_elements % pipeline->wg_denoms[0]);
+ pc2.ncols_output = num_dst_elements;
+
+ if (!done_one_iter) {
+ // Reserve space for ivec2 per element, double buffered
+ // K per workgroup per row
+ dbl_buf_size = num_dst_elements * nrows * 2 * sizeof(int);
+ dbl_buf_size = ROUNDUP_POW2(dbl_buf_size, ctx->device->properties.limits.minStorageBufferOffsetAlignment);
+ const size_t x_sz = dbl_buf_size * 2;
+
+ if (ctx->prealloc_size_x < x_sz) {
+ ctx->prealloc_size_x = x_sz;
+ ggml_vk_preallocate_buffers(ctx, subctx);
+ }
+ }
+
vk_subbuffer src_buf;
vk_subbuffer dst_buf;
if (num_elements > k) {
ggml_vk_sync_buffers(ctx, subctx);
}
+ done_one_iter = true;
}
ctx->prealloc_x_need_sync = true;
}
uint orig_ncols;
uint ncols_input;
uint ncols_output;
+ uint k;
uint nrows;
uint first_pass;
uint last_pass;
const uint row_offset = row * p.ncols_input;
dst_row[col] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x]));
} else {
- const uint row_offset = row * p.orig_ncols;
+ const uint row_offset = row * p.ncols_input;
dst_row[col] = data_s[row_offset + gl_GlobalInvocationID.x];
}
} else {
}
barrier();
- if (p.ncols_output == 1) {
+ if (p.k == 1) {
// Fast path for single output - just do a max reduction
[[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) {
if (col < s) {
}
}
- if (col < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) {
+ if (col < p.k) {
if (p.last_pass != 0) {
- const uint row_offset = row * p.ncols_output;
- data_d[row_offset + col] = dst_row[col].x;
+ if (gl_GlobalInvocationID.x < p.ncols_input) {
+ const uint row_offset = row * p.k;
+ data_d[row_offset + col] = dst_row[col].x;
+ }
} else {
- const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output;
- data_t[row_offset + col] = dst_row[col];
+ if (gl_WorkGroupID.x * p.k + col < p.ncols_output) {
+ const uint row_offset = row * p.ncols_output + gl_WorkGroupID.x * p.k;
+ data_t[row_offset + col] = dst_row[col];
+ }
}
}
}
uint orig_ncols;
uint ncols_input;
uint ncols_output;
+ uint k;
uint nrows;
uint first_pass;
uint last_pass;
const uint row_offset = row * p.ncols_input;
dst_row[tid] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x]));
} else {
- const uint row_offset = row * p.orig_ncols;
+ const uint row_offset = row * p.ncols_input;
dst_row[tid] = data_s[row_offset + gl_GlobalInvocationID.x];
}
} else {
}
barrier();
- if (p.ncols_output == 1) {
+ if (p.k == 1) {
// Fast path for single output - just do a max reduction
[[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) {
if (tid < s) {
uint range_max = 0xFF800000;
// How many are above the current range, and how many we need to find.
uint total = 0;
- uint limit = min(p.ncols_output, p.ncols_input - gl_WorkGroupID.x * BLOCK_SIZE);
+ uint limit = min(p.k, p.ncols_input - gl_WorkGroupID.x * BLOCK_SIZE);
while (mask != 0) {
barrier();
range_max = range_min + ((min_idx + 1) << shift);
range_min = range_min + (min_idx << shift);
- if (total == p.ncols_output) {
+ if (total == p.k) {
break;
}
total -= counts[min_idx];
barrier();
}
- if (tid < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) {
+ if (tid < p.k) {
if (p.last_pass != 0) {
- const uint row_offset = row * p.ncols_output;
- data_d[row_offset + tid] = dst_row[tid].x;
+ if (gl_GlobalInvocationID.x < p.ncols_input) {
+ const uint row_offset = row * p.k;
+ data_d[row_offset + tid] = dst_row[tid].x;
+ }
} else {
- const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output;
- data_t[row_offset + tid] = dst_row[tid];
+ if (gl_WorkGroupID.x * p.k + tid < p.ncols_output) {
+ const uint row_offset = row * p.ncols_output + gl_WorkGroupID.x * p.k;
+ data_t[row_offset + tid] = dst_row[tid];
+ }
}
}
}