vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
vk_pipeline pipeline_acc_f32;
- vk_pipeline pipeline_add_f32, pipeline_add_f16_f32_f16;
- vk_pipeline pipeline_mul_f32;
- vk_pipeline pipeline_div_f32;
+ vk_pipeline pipeline_add_f32, pipeline_add_f32_norepeat;
+ vk_pipeline pipeline_add_f16_f32_f16, pipeline_add_f16_f32_f16_norepeat;
+ vk_pipeline pipeline_mul_f32, pipeline_mul_f32_norepeat;
+ vk_pipeline pipeline_div_f32, pipeline_div_f32_norepeat;
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
vk_pipeline pipeline_upscale_f32;
vk_pipeline pipeline_scale_f32;
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_add_f32_norepeat, "add_f32_norepeat", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16_norepeat, "add_f16_f32_f16_norepeat", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_mul_f32_norepeat, "mul_f32_norepeat", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_div_f32_norepeat, "div_f32_norepeat", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
return nullptr;
case GGML_OP_ADD:
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- return ctx->device->pipeline_add_f32;
+ return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f32_norepeat : ctx->device->pipeline_add_f32;
}
if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
- return ctx->device->pipeline_add_f16_f32_f16;
+ return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f16_f32_f16_norepeat : ctx->device->pipeline_add_f16_f32_f16;
}
return nullptr;
case GGML_OP_MUL:
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- return ctx->device->pipeline_mul_f32;
+ return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_f32_norepeat : ctx->device->pipeline_mul_f32;
}
return nullptr;
case GGML_OP_DIV:
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- return ctx->device->pipeline_div_f32;
+ return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_div_f32_norepeat : ctx->device->pipeline_div_f32;
}
return nullptr;
case GGML_OP_CONCAT:
#include "types.comp"
#include "generic_binary_head.comp"
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
void main() {
const uint idx = gl_GlobalInvocationID.x;
if (idx >= p.ne) {
const uint oy = (src1_i - (oz * p.nb02)) / p.nb01;
const uint ox = src1_i % p.nb01;
+ uint i00, i01, i02, i03;
+ get_indices(idx, i00, i01, i02, i03);
+
if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) {
- data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) + FLOAT_TYPE(data_b[ox + oy * p.ne10 + oz * p.ne10 * p.ne11]));
+ data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[ox + oy * p.ne10 + oz * p.ne10 * p.ne11]));
} else {
- data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]));
+ data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]));
}
}
#version 450
+#extension GL_EXT_shader_16bit_storage : require
+
#include "types.comp"
#include "generic_binary_head.comp"
+const uint num_threads = 256;
+
+layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
+
void main() {
- const uint idx = get_idx();
+ uint idx = get_idx();
- if (idx >= p.ne) {
- return;
- }
+ // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
+ const uint num_iter = 2;
- data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) + FLOAT_TYPE(data_b[src1_idx(idx)]));
+ [[unroll]] for (uint i = 0; i < num_iter; ++i) {
+ if (idx >= p.ne) {
+ continue;
+ }
+ uint i00, i01, i02, i03;
+ get_indices(idx, i00, i01, i02, i03);
+
+ data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[src1_idx(i00, i01, i02, i03)]));
+
+ idx += num_threads;
+ }
}
#include "types.comp"
#include "generic_binary_head.comp"
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
void main() {
const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
const int dim = p.param3;
#include "types.comp"
#include "generic_binary_head.comp"
+const uint num_threads = 256;
+
+layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
+
void main() {
- const uint idx = get_idx();
+ uint idx = get_idx();
- if (idx >= p.ne) {
- return;
- }
+ // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
+ const uint num_iter = 2;
- data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) / FLOAT_TYPE(data_b[src1_idx(idx)]));
+ [[unroll]] for (uint i = 0; i < num_iter; ++i) {
+ if (idx >= p.ne) {
+ continue;
+ }
+ uint i00, i01, i02, i03;
+ get_indices(idx, i00, i01, i02, i03);
+
+ data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]) / FLOAT_TYPE(data_b[src1_idx(i00, i01, i02, i03)]));
+
+ idx += num_threads;
+ }
}
#extension GL_EXT_shader_16bit_storage : require
+#extension GL_EXT_control_flow_attributes : require
layout (push_constant) uniform parameter
{
float param1; float param2; int param3;
} p;
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
+// true if src0/src1 are the same shape and the indices can be reused without additional modulus
+layout(constant_id = 0) const bool norepeat = false;
+
uint get_idx() {
return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
}
-uint src0_idx(uint idx) {
- const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
- const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
- const uint i02 = (idx - i03_offset) / (p.ne01*p.ne00);
- const uint i02_offset = i02*p.ne01*p.ne00;
- const uint i01 = (idx - i03_offset - i02_offset) / p.ne00;
- const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
- return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
+// mod and div are expensive and coordinates/dimensions are often power of 2 or equal to 1
+uint fastmod(uint a, uint b) {
+ if ((b & (b-1)) == 0) {
+ return a & (b-1);
+ }
+ return a % b;
}
-uint src1_idx(uint idx) {
- const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
+uint fastdiv(uint a, uint b) {
+ return (a < b) ? 0 : (a / b);
+}
+
+void get_indices(uint idx, out uint i00, out uint i01, out uint i02, out uint i03) {
+ i03 = fastdiv(idx, (p.ne02*p.ne01*p.ne00));
const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
- const uint i02 = (idx - i03_offset) / (p.ne01*p.ne00);
+ i02 = fastdiv((idx - i03_offset), (p.ne01*p.ne00));
const uint i02_offset = i02*p.ne01*p.ne00;
- const uint i01 = (idx - i03_offset - i02_offset) / p.ne00;
- const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
+ i01 = (idx - i03_offset - i02_offset) / p.ne00;
+ i00 = idx - i03_offset - i02_offset - i01*p.ne00;
+}
+
+uint src0_idx(uint i00, uint i01, uint i02, uint i03) {
+ return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
+}
- return (i03 % p.ne13)*p.nb13 + (i02 % p.ne12)*p.nb12 + (i01 % p.ne11)*p.nb11 + (i00 % p.ne10)*p.nb10;
+uint src1_idx(uint i00, uint i01, uint i02, uint i03) {
+ if (norepeat) {
+ return i03*p.nb13 + i02*p.nb12 + i01*p.nb11 + i00*p.nb10;
+ } else {
+ return fastmod(i03, p.ne13)*p.nb13 + fastmod(i02, p.ne12)*p.nb12 + fastmod(i01, p.ne11)*p.nb11 + fastmod(i00, p.ne10)*p.nb10;
+ }
}
-uint dst_idx(uint idx) {
- const uint i23 = idx / (p.ne22*p.ne21*p.ne20);
- const uint i23_offset = i23 * p.ne22*p.ne21*p.ne20;
- const uint i22 = (idx - i23_offset) / (p.ne21*p.ne20);
- const uint i22_offset = i22*p.ne21*p.ne20;
- const uint i21 = (idx - i23_offset - i22_offset) / p.ne20;
- const uint i20 = idx - i23_offset - i22_offset - i21*p.ne20;
- return i23*p.nb23 + i22*p.nb22 + i21*p.nb21 + i20*p.nb20;
+uint dst_idx(uint i00, uint i01, uint i02, uint i03) {
+ return i03*p.nb23 + i02*p.nb22 + i01*p.nb21 + i00*p.nb20;
}
#include "types.comp"
#include "generic_binary_head.comp"
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
void main() {
const uint i00 = gl_GlobalInvocationID.x;
const uint i10 = gl_GlobalInvocationID.y;
#include "generic_binary_head.comp"
#include "dequant_funcs.comp"
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
void main() {
const uint i00 = (gl_GlobalInvocationID.x)*2;
const uint i10 = gl_GlobalInvocationID.y;
#include "types.comp"
#include "generic_binary_head.comp"
+const uint num_threads = 256;
+
+layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
+
void main() {
- const uint idx = get_idx();
+ uint idx = get_idx();
- if (idx >= p.ne) {
- return;
- }
+ // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
+ const uint num_iter = 2;
- data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) * FLOAT_TYPE(data_b[src1_idx(idx)]));
+ [[unroll]] for (uint i = 0; i < num_iter; ++i) {
+ if (idx >= p.ne) {
+ continue;
+ }
+ uint i00, i01, i02, i03;
+ get_indices(idx, i00, i01, i02, i03);
+
+ data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]) * FLOAT_TYPE(data_b[src1_idx(i00, i01, i02, i03)]));
+
+ idx += num_threads;
+ }
}