vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
vk_pipeline pipeline_sum_rows_f32;
+ vk_pipeline pipeline_cumsum_f32;
vk_pipeline pipeline_argmax_f32;
vk_pipeline pipeline_count_equal_i32;
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 128, device->subgroup_size }, 1, true, true, device->subgroup_size);
+
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
#define IM2COL(bda) \
return ctx->device->pipeline_sum_rows_f32;
}
return nullptr;
+ case GGML_OP_CUMSUM:
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ return ctx->device->pipeline_cumsum_f32;
+ }
+ return nullptr;
case GGML_OP_ARGMAX:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
return ctx->device->pipeline_argmax_f32;
case GGML_OP_SOFT_MAX:
case GGML_OP_SOFT_MAX_BACK:
case GGML_OP_SUM_ROWS:
+ case GGML_OP_CUMSUM:
case GGML_OP_MEAN:
case GGML_OP_ARGMAX:
{
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_MEAN, p);
}
+static void ggml_vk_cumsum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+ vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CUMSUM, p);
+}
+
static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f });
}
case GGML_OP_SUM_ROWS:
ggml_vk_sum_rows(ctx, compute_ctx, src0, node);
+ break;
+ case GGML_OP_CUMSUM:
+ ggml_vk_cumsum(ctx, compute_ctx, src0, node);
+
break;
case GGML_OP_MEAN:
ggml_vk_mean(ctx, compute_ctx, src0, node);
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
+ case GGML_OP_CUMSUM:
+ {
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
+ auto device = ggml_vk_get_device(ctx->device);
+ if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
+ return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
+ }
+ return false;
+ }
case GGML_OP_ARGMAX:
case GGML_OP_COUNT_EQUAL:
case GGML_OP_IM2COL:
tensor_clone = ggml_sum(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_SUM_ROWS) {
tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]);
+ } else if (tensor->op == GGML_OP_CUMSUM) {
+ tensor_clone = ggml_cumsum(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_MEAN) {
tensor_clone = ggml_mean(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_ARGMAX) {
--- /dev/null
+#version 450
+
+#include "types.glsl"
+#include "sum_rows.glsl"
+
+#extension GL_EXT_control_flow_attributes : enable
+#extension GL_KHR_shader_subgroup_arithmetic : enable
+#extension GL_KHR_shader_subgroup_basic : enable
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+layout (constant_id = 0) const uint BLOCK_SIZE = 128;
+layout (constant_id = 1) const uint SUBGROUP_SIZE = 32;
+
+#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
+
+shared FLOAT_TYPE partial[BLOCK_SIZE / SUBGROUP_SIZE];
+shared FLOAT_TYPE last_sum;
+
+void main() {
+ const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
+ const uint tid = gl_LocalInvocationID.x;
+
+ const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L);
+ const uint i03_offset = i03 * p.ne01*p.ne02;
+ const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L);
+ const uint i01 = row - i03_offset - i02*p.ne01;
+
+ const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03;
+ const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13;
+
+ uint subgroup_id = tid / SUBGROUP_SIZE;
+
+ if (tid == 0) {
+ last_sum = 0;
+ }
+
+ uint col = tid;
+ uint num_iter = CEIL_DIV(p.n_cols, BLOCK_SIZE);
+ for (int i = 0; i < num_iter; ++i) {
+ FLOAT_TYPE v = 0;
+ if (col < p.n_cols) {
+ v = FLOAT_TYPE(data_a[src_idx + col]);
+ }
+ v = subgroupInclusiveAdd(v);
+
+ // Store the largest partial sum for each subgroup, then add the partials for all
+ // lower subgroups and the final partial sum from the previous iteration.
+ if (gl_SubgroupInvocationID == SUBGROUP_SIZE - 1) {
+ partial[subgroup_id] = v;
+ }
+ barrier();
+ for (int j = 0; j < subgroup_id; ++j) {
+ v += partial[j];
+ }
+ v += last_sum;
+ barrier();
+ if (tid == BLOCK_SIZE - 1) {
+ last_sum = v;
+ }
+ if (col < p.n_cols) {
+ data_d[dst_idx + col] = D_TYPE(v);
+ }
+ col += BLOCK_SIZE;
+ }
+}
string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
+ string_to_spv("cumsum_f32", "cumsum.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
for (std::string dim_str : {"", "_3d"}) {
for (bool bda : {false, true}) {