}
};
+struct vk_solve_tri_pipeline_state {
+ vk_solve_tri_pipeline_state(uint32_t N, uint32_t K)
+ : N(N), K(K) {}
+
+ uint32_t N, K;
+
+ bool operator<(const vk_solve_tri_pipeline_state &b) const {
+ return std::tie(N, K) <
+ std::tie(b.N, b.K);
+ }
+};
+
enum shader_reduction_mode {
SHADER_REDUCTION_MODE_SHMEM,
SHADER_REDUCTION_MODE_HYBRID,
vk_pipeline pipeline_cumsum_f32;
vk_pipeline pipeline_argmax_f32;
vk_pipeline pipeline_count_equal_i32;
+ std::map<vk_solve_tri_pipeline_state, vk_pipeline> pipeline_solve_tri_f32;
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
vk_pipeline pipeline_im2col_3d_f32, pipeline_im2col_3d_f32_f16;
vk_pipeline pipeline_timestep_embedding_f32;
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
+ for (auto &s : device->pipeline_solve_tri_f32) {
+ const vk_solve_tri_pipeline_state &state = s.first;
+ ggml_vk_create_pipeline(
+ device, s.second, "solve_tri_f32",
+ solve_tri_f32_len, solve_tri_f32_data, "main", 3,
+ sizeof(vk_op_binary_push_constants), {1, 1, 1}, { 0, state.N, state.K }, 1, true);
+ }
+
#define IM2COL(bda) \
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32 ## bda ## _len, im2col_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32 ## bda ## _len, im2col_3d_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
return ctx->device->pipeline_cumsum_f32;
}
return nullptr;
+ case GGML_OP_SOLVE_TRI:
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+
+ vk_solve_tri_pipeline_state solve_tri_pipeline_state(src0->ne[0], src1->ne[0]);
+
+ vk_pipeline pipeline = nullptr;
+
+ {
+ std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
+ auto it = ctx->device->pipeline_solve_tri_f32.find(solve_tri_pipeline_state);
+ if (it != ctx->device->pipeline_solve_tri_f32.end()) {
+ pipeline = it->second;
+ } else {
+ ctx->device->pipeline_solve_tri_f32[solve_tri_pipeline_state] = pipeline = std::make_shared<vk_pipeline_struct>();
+ }
+ }
+
+ return pipeline;
+ }
+ return nullptr;
case GGML_OP_ARGMAX:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
return ctx->device->pipeline_argmax_f32;
elements = { nr, 1, 1 };
}
} break;
+ case GGML_OP_SOLVE_TRI:
+ {
+ uint32_t nr = (uint32_t)(ne02 * ne03);
+ if (nr > 262144) {
+ elements = { 512, 512, CEIL_DIV(nr, 262144) };
+ } else if (nr > 512) {
+ elements = { 512, CEIL_DIV(nr, 512), 1 };
+ } else {
+ elements = { nr, 1, 1 };
+ }
+ }
+ break;
case GGML_OP_RMS_NORM:
if (ctx->do_add_rms_partials) {
// Run one element per thread, 128 threads per workgroup
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
}
+static void ggml_vk_solve_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
+ const uint32_t src1_type_size = ggml_type_size(src1->type);
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
+
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SOLVE_TRI, {
+ (uint32_t)ggml_nelements(src0),
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+ (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
+ 0,
+ 0.0f, 0.0f, 0,
+ });
+}
+
static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const int32_t s0 = dst->op_params[0];
const int32_t s1 = dst->op_params[1];
case GGML_OP_COUNT_EQUAL:
ggml_vk_count_equal(ctx, compute_ctx, src0, src1, node);
+ break;
+ case GGML_OP_SOLVE_TRI:
+ ggml_vk_solve_tri(ctx, compute_ctx, src0, src1, node);
+
break;
case GGML_OP_IM2COL:
ggml_vk_im2col(ctx, compute_ctx, src0, src1, node);
}
return false;
}
+ case GGML_OP_SOLVE_TRI:
+ {
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
+ const vk_device& device = ggml_vk_get_device(ctx->device);
+
+ if (op->type != GGML_TYPE_F32 || op->src[0]->type != GGML_TYPE_F32) {
+ return false;
+ }
+ const uint32_t N = op->src[0]->ne[0];
+ const uint32_t K = op->src[1]->ne[0];
+ // K dimension limited to workgroup size
+ if (K > 128) {
+ return false;
+ }
+ if (N * N * sizeof(float) + N * K * sizeof(float) > device->properties.limits.maxComputeSharedMemorySize) {
+ return false;
+ }
+ return true;
+ }
case GGML_OP_ARGMAX:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_COUNT_EQUAL:
tensor_clone = ggml_argmax(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_COUNT_EQUAL) {
tensor_clone = ggml_count_equal(ggml_ctx, src_clone[0], src_clone[1]);
+ } else if (tensor->op == GGML_OP_SOLVE_TRI) {
+ tensor_clone = ggml_solve_tri(ggml_ctx, src_clone[0], src_clone[1], true, true, false);
} else if (tensor->op == GGML_OP_IM2COL) {
const int32_t s0 = tensor->op_params[0];
const int32_t s1 = tensor->op_params[1];
--- /dev/null
+#version 450
+
+#include "types.glsl"
+#include "generic_binary_head.glsl"
+
+layout (constant_id = 1) const uint N = 64;
+layout (constant_id = 2) const uint K = 32;
+
+layout(local_size_x = 128, local_size_y = 1, local_size_z = 1) in;
+
+uint a_base, b_base, x_base;
+
+FLOAT_TYPE get_a(uint r, uint c) {
+ return FLOAT_TYPE(data_a[a_base + r * p.nb01 + c * p.nb00]);
+}
+
+FLOAT_TYPE get_b(uint r, uint c) {
+ return FLOAT_TYPE(data_b[b_base + r * p.nb11 + c * p.nb10]);
+}
+
+void store_x(uint r, uint c, FLOAT_TYPE v) {
+ data_d[x_base + r * p.nb21 + c * p.nb20] = D_TYPE(v);
+}
+
+shared FLOAT_TYPE shA[N * N];
+shared FLOAT_TYPE shB[N * K];
+
+void main() {
+ const uint batch = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
+ const uint tid = gl_LocalInvocationID.x;
+
+ if (batch >= p.ne02 * p.ne03) {
+ return;
+ }
+
+ const uint i3 = batch / p.ne22;
+ const uint i2 = batch % p.ne22;
+ a_base = get_aoffset() + i2 * p.nb02 + i3 * p.nb03;
+ b_base = get_boffset() + i2 * p.nb12 + i3 * p.nb13;
+ x_base = get_doffset() + i2 * p.nb22 + i3 * p.nb23;
+
+ // Load the A matrix into shA
+ [[unroll]] for (uint i = 0; i < N * N; i += gl_WorkGroupSize.x) {
+ uint idx = i + tid;
+ if (((N * N) % gl_WorkGroupSize.x == 0) || idx < N * N) {
+ shA[idx] = get_a(idx / N, idx % N);
+ }
+ }
+ // Load the B matrix into shB
+ [[unroll]] for (uint i = 0; i < N * K; i += gl_WorkGroupSize.x) {
+ uint idx = i + tid;
+ if (((N * K) % gl_WorkGroupSize.x == 0) || idx < N * K) {
+ shB[idx] = get_b(idx / K, idx % K);
+ }
+ }
+ barrier();
+
+ FLOAT_TYPE X[N];
+ // Each thread solves one column
+ if (tid < K) {
+ [[unroll]] for (int r = 0; r < N; ++r) {
+ FLOAT_TYPE b = shB[r * K + tid];
+ // Compute x[r,c] = (b[r,c] - sum(a[r,c]*x[c])) / a[r,r]
+ [[unroll]] for (int c = 0; c < r; ++c) {
+ b -= shA[r * N + c] * X[c];
+ }
+ FLOAT_TYPE x = b / shA[r * N + r];
+ X[r] = x;
+ store_x(r, tid, x);
+ }
+ }
+}