vk_pipeline pipeline_topk_f32[num_topk_pipelines];
vk_pipeline pipeline_sum_rows_f32;
vk_pipeline pipeline_cumsum_f32;
+ vk_pipeline pipeline_cumsum_small_f32;
+ vk_pipeline pipeline_cumsum_multipass1_f32;
+ vk_pipeline pipeline_cumsum_multipass2_f32;
vk_pipeline pipeline_argmax_f32;
vk_pipeline pipeline_count_equal_i32;
std::map<vk_solve_tri_pipeline_state, vk_pipeline> pipeline_solve_tri_f32;
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 128, device->subgroup_size }, 1, true, true, device->subgroup_size);
+ const uint32_t cumsum_elem_per_thread = (device->vendor_id == VK_VENDOR_ID_AMD || device->vendor_id == VK_VENDOR_ID_INTEL) ? 2 : 4;
+ ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 256, device->subgroup_size, cumsum_elem_per_thread }, 1, true, true, device->subgroup_size);
+ ggml_vk_create_pipeline(device, device->pipeline_cumsum_small_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 128, device->subgroup_size, 1 }, 1, true, true, device->subgroup_size);
+ ggml_vk_create_pipeline(device, device->pipeline_cumsum_multipass1_f32, "cumsum_multipass1_f32", cumsum_multipass1_f32_len, cumsum_multipass1_f32_data, "main", 3, sizeof(vk_op_sum_rows_push_constants), {256, 1, 1}, { 256, device->subgroup_size }, 1, true, true, device->subgroup_size);
+ ggml_vk_create_pipeline(device, device->pipeline_cumsum_multipass2_f32, "cumsum_multipass2_f32", cumsum_multipass2_f32_len, cumsum_multipass2_f32_data, "main", 3, sizeof(vk_op_sum_rows_push_constants), {256, 1, 1}, { 256, device->subgroup_size }, 1, true, true, device->subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
return nullptr;
case GGML_OP_CUMSUM:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- return ctx->device->pipeline_cumsum_f32;
+ if (src0->ne[0] <= 512) {
+ return ctx->device->pipeline_cumsum_small_f32;
+ } else {
+ return ctx->device->pipeline_cumsum_f32;
+ }
}
return nullptr;
case GGML_OP_SOLVE_TRI:
}
static void ggml_vk_cumsum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
- vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CUMSUM, p);
+ vk_op_sum_rows_push_constants pc = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
+ // Use the single pass shader when the rows are small or there are enough rows to fill the GPU.
+ // For fewer, larger rows, use the multipass shader to spread each row across SMs.
+ if (dst->ne[0] <= 4096 || ggml_nrows(dst) >= ctx->device->shader_core_count) {
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CUMSUM, pc);
+ return;
+ }
+
+ // First pass computes partial sums within a block, and stores the last partial
+ // to the temp buffer. Second pass sums the block partials from the temp buffer
+ // and adds that to the result of the first pass.
+ vk_pipeline pipeline1 = ctx->device->pipeline_cumsum_multipass1_f32;
+ vk_pipeline pipeline2 = ctx->device->pipeline_cumsum_multipass2_f32;
+ GGML_ASSERT(pipeline1 != nullptr && pipeline2 != nullptr);
+
+ ggml_pipeline_request_descriptor_sets(ctx, pipeline1, 1);
+ ggml_pipeline_request_descriptor_sets(ctx, pipeline2, 1);
+
+ std::array<uint32_t, 3> elements;
+
+ elements[0] = dst->ne[0];
+ elements[1] = (uint32_t)ggml_nrows(dst);
+ elements[2] = 1;
+
+ size_t temp_size = sizeof(float) * elements[0] * ggml_nrows(dst);
+
+ if (ctx->prealloc_size_split_k < temp_size) {
+ ctx->prealloc_size_split_k = temp_size;
+ ggml_vk_preallocate_buffers(ctx, subctx);
+ }
+
+ vk_subbuffer src_buf = ggml_vk_tensor_subbuffer(ctx, src0);
+ vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
+ vk_subbuffer temp_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);
+
+ if (ctx->prealloc_split_k_need_sync) {
+ ggml_vk_sync_buffers(ctx, subctx);
+ }
+
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline1, {src_buf, dst_buf, temp_buf}, pc, elements);
+ ggml_vk_sync_buffers(ctx, subctx);
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline2, {src_buf, dst_buf, temp_buf}, pc, elements);
+
+ ctx->prealloc_split_k_need_sync = true;
}
static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
layout (constant_id = 0) const uint BLOCK_SIZE = 128;
layout (constant_id = 1) const uint SUBGROUP_SIZE = 32;
+layout (constant_id = 2) const uint ELEM_PER_THREAD = 4;
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
last_sum = 0;
}
- uint col = tid;
- uint num_iter = CEIL_DIV(p.n_cols, BLOCK_SIZE);
+ uint col = tid * ELEM_PER_THREAD;
+ uint num_iter = CEIL_DIV(p.n_cols, BLOCK_SIZE * ELEM_PER_THREAD);
for (int i = 0; i < num_iter; ++i) {
- FLOAT_TYPE v = 0;
- if (col < p.n_cols) {
- v = FLOAT_TYPE(data_a[src_idx + col]);
+ FLOAT_TYPE v[ELEM_PER_THREAD];
+ FLOAT_TYPE thread_sum = 0;
+ [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
+ if (col + j < p.n_cols) {
+ thread_sum += FLOAT_TYPE(data_a[src_idx + col + j]);
+ }
+ v[j] = thread_sum;
}
- v = subgroupInclusiveAdd(v);
+ thread_sum = subgroupExclusiveAdd(thread_sum);
+ [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
+ v[j] += thread_sum;
+ }
// Store the largest partial sum for each subgroup, then add the partials for all
// lower subgroups and the final partial sum from the previous iteration.
if (gl_SubgroupInvocationID == SUBGROUP_SIZE - 1) {
- partial[subgroup_id] = v;
+ partial[subgroup_id] = v[ELEM_PER_THREAD - 1];
}
barrier();
- for (int j = 0; j < subgroup_id; ++j) {
- v += partial[j];
+ for (int s = 0; s < subgroup_id; ++s) {
+ [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
+ v[j] += partial[s];
+ }
+ }
+ [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
+ v[j] += last_sum;
}
- v += last_sum;
barrier();
if (tid == BLOCK_SIZE - 1) {
- last_sum = v;
+ last_sum = v[ELEM_PER_THREAD - 1];
}
- if (col < p.n_cols) {
- data_d[dst_idx + col] = D_TYPE(v);
+ [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
+ if (col + j < p.n_cols) {
+ data_d[dst_idx + col + j] = D_TYPE(v[j]);
+ }
}
- col += BLOCK_SIZE;
+ col += BLOCK_SIZE * ELEM_PER_THREAD;
}
}
--- /dev/null
+#version 450
+
+#include "types.glsl"
+#include "sum_rows.glsl"
+
+#extension GL_EXT_control_flow_attributes : enable
+#extension GL_KHR_shader_subgroup_arithmetic : enable
+#extension GL_KHR_shader_subgroup_basic : enable
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+layout (binding = 2) writeonly buffer T {D_TYPE data_t[];};
+
+layout (constant_id = 0) const uint BLOCK_SIZE = 128;
+layout (constant_id = 1) const uint SUBGROUP_SIZE = 32;
+
+#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
+
+shared FLOAT_TYPE partial[BLOCK_SIZE / SUBGROUP_SIZE];
+
+void main() {
+ const uint row = gl_WorkGroupID.y;
+ const uint tid = gl_LocalInvocationID.x;
+ const uint col = gl_GlobalInvocationID.x;
+
+ const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L);
+ const uint i03_offset = i03 * p.ne01*p.ne02;
+ const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L);
+ const uint i01 = row - i03_offset - i02*p.ne01;
+
+ const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03;
+ const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13;
+
+ uint subgroup_id = tid / SUBGROUP_SIZE;
+
+ FLOAT_TYPE v = 0;
+ if (col < p.n_cols) {
+ v = FLOAT_TYPE(data_a[src_idx + col]);
+ }
+ v = subgroupInclusiveAdd(v);
+
+ // Store the largest partial sum for each subgroup, then add the partials for all
+ // lower subgroups and the final partial sum from the previous iteration.
+ if (gl_SubgroupInvocationID == SUBGROUP_SIZE - 1) {
+ partial[subgroup_id] = v;
+ }
+ barrier();
+ for (int j = 0; j < subgroup_id; ++j) {
+ v += partial[j];
+ }
+ barrier();
+ if (tid == BLOCK_SIZE - 1) {
+ data_t[gl_WorkGroupID.x + gl_NumWorkGroups.x * row] = v;
+ }
+ if (col < p.n_cols) {
+ data_d[dst_idx + col] = D_TYPE(v);
+ }
+}
--- /dev/null
+#version 450
+
+#include "types.glsl"
+#include "sum_rows.glsl"
+
+#extension GL_EXT_control_flow_attributes : enable
+#extension GL_KHR_shader_subgroup_arithmetic : enable
+#extension GL_KHR_shader_subgroup_basic : enable
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+layout (binding = 1) buffer D {D_TYPE data_d[];};
+layout (binding = 2) readonly buffer T {D_TYPE data_t[];};
+
+layout (constant_id = 0) const uint BLOCK_SIZE = 128;
+layout (constant_id = 1) const uint SUBGROUP_SIZE = 32;
+
+#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
+
+shared FLOAT_TYPE temp[BLOCK_SIZE / SUBGROUP_SIZE];
+
+void main() {
+ const uint row = gl_WorkGroupID.y;
+ const uint tid = gl_LocalInvocationID.x;
+
+ const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L);
+ const uint i03_offset = i03 * p.ne01*p.ne02;
+ const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L);
+ const uint i01 = row - i03_offset - i02*p.ne01;
+
+ const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03;
+ const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13;
+
+ const uint col = gl_GlobalInvocationID.x;
+
+ float v = 0;
+ // prefetch value we're adding to
+ if (col < p.n_cols) {
+ v = data_d[dst_idx + col];
+ }
+
+ // compute the sum of all previous blocks
+ uint c = tid;
+ float sum = 0;
+ while (c < gl_WorkGroupID.x) {
+ sum += data_t[c + gl_NumWorkGroups.x * row];
+ c += BLOCK_SIZE;
+ }
+
+ sum = subgroupAdd(sum);
+ if (gl_SubgroupInvocationID == 0) {
+ temp[gl_SubgroupID] = sum;
+ }
+ barrier();
+ sum = 0;
+ [[unroll]] for (uint s = 0; s < BLOCK_SIZE / SUBGROUP_SIZE; ++s) {
+ sum += temp[s];
+ }
+
+ // Add the sum to what the first pass computed
+ if (col < p.n_cols) {
+ data_d[dst_idx + col] = v + sum;
+ }
+}
+