static void ggml_vk_destroy_buffer(vk_buffer& buf);
static constexpr uint32_t mul_mat_vec_max_cols = 8;
+static constexpr uint32_t p021_max_gqa_ratio = 8;
enum vk_device_architecture {
OTHER,
bool uma;
bool prefer_host_memory;
bool float_controls_rte_fp16;
+ bool subgroup_add;
bool subgroup_size_control;
uint32_t subgroup_min_size;
vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT];
- vk_pipeline pipeline_mul_mat_vec_p021_f16_f32;
+ vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio];
vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
+ for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
+ if (device->subgroup_add && device->subgroup_require_full_support) {
+ ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_subgroup_add_len, mul_mat_vec_p021_f16_f32_subgroup_add_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true, true);
+ } else {
+ ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true);
+ }
+ }
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
vk::PhysicalDeviceDriverProperties driver_props;
vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props;
+ vk::PhysicalDeviceVulkan11Properties vk11_props;
vk::PhysicalDeviceVulkan12Properties vk12_props;
vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
props2.pNext = &props3;
props3.pNext = &subgroup_props;
subgroup_props.pNext = &driver_props;
- driver_props.pNext = &vk12_props;
+ driver_props.pNext = &vk11_props;
+ vk11_props.pNext = &vk12_props;
VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_props;
}
device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16;
+ device->subgroup_add = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
+ (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic);
+
const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
const uint64_t d_sz = sizeof(float) * d_ne;
+ // With grouped query attention there are > 1 Q matrices per K, V matrix.
+ uint32_t gqa_ratio = (uint32_t)ne12 / (uint32_t)ne02;
+ if (gqa_ratio > 8 || gqa_ratio == 0 || ne12 != ne02 * gqa_ratio) {
+ gqa_ratio = 1;
+ }
+
if (dryrun) {
// Request descriptor sets
- ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, 1);
+ ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], 1);
return;
}
// compute
const std::array<uint32_t, 6> pc = { (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
+
+ uint32_t workgroups_z = (uint32_t)ne12;
+ // When gqa_ratio > 1, each invocation does multiple rows and we can launch fewer workgroups
+ if (gqa_ratio > 1) {
+ workgroups_z /= gqa_ratio;
+ }
+
ggml_vk_sync_buffers(subctx);
- ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
+ ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, workgroups_z });
}
static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
+layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
+layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
+
layout (push_constant) uniform parameter
{
uint ncols_x;
const uint idst = channel*nrows_dst + row_dst;
- tmp[tid] = 0.0f;
+ FLOAT_TYPE temp = 0.0f;
- for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) {
- const uint col_x = col_x0 + tid;
+ // Detect alignment for vector loads
+ bool is_aligned = (p.ncols_x % 4) == 0 && (p.row_stride_x % 4) == 0 && (p.channel_stride_x % 4) == 0;
- if (col_x >= p.ncols_x) {
- break;
- }
+ for (uint col_x0 = 0; col_x0 < p.ncols_x;) {
+
+ // Unroll 2x and do vec4 loads if aligned
+ const uint unroll_count = 2;
+ if (col_x0 + unroll_count * 4 * BLOCK_SIZE <= p.ncols_x && is_aligned) {
+ [[unroll]] for (uint i = 0; i < unroll_count; ++i) {
+ const uint col_x = col_x0 + 4*tid;
+
+ const uint row_y = col_x;
+
+ const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
+ const uint iy = channel*nrows_y + row_y;
+
+ const vec4 av4 = vec4(data_a_v4[ix / 4]);
+ const vec4 bv4 = vec4(data_b_v4[iy / 4]);
+
+ temp += dot(av4, bv4);
+
+ col_x0 += 4*BLOCK_SIZE;
+ }
+ // do vec4 loads if aligned
+ } else if (col_x0 + 4*BLOCK_SIZE <= p.ncols_x && is_aligned) {
+ const uint col_x = col_x0 + 4*tid;
- const uint row_y = col_x;
+ const uint row_y = col_x;
- const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
- const uint iy = channel*nrows_y + row_y;
+ const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
+ const uint iy = channel*nrows_y + row_y;
- const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
+ const vec4 av4 = vec4(data_a_v4[ix / 4]);
+ const vec4 bv4 = vec4(data_b_v4[iy / 4]);
- tmp[tid] = fma(xi, FLOAT_TYPE(data_b[iy]), tmp[tid]);
+ temp += dot(av4, bv4);
+
+ col_x0 += 4*BLOCK_SIZE;
+ } else {
+ const uint col_x = col_x0 + tid;
+ if (col_x >= p.ncols_x) {
+ break;
+ }
+
+ const uint row_y = col_x;
+
+ const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
+ const uint iy = channel*nrows_y + row_y;
+
+ const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
+
+ temp = fma(xi, FLOAT_TYPE(data_b[iy]), temp);
+ col_x0 += BLOCK_SIZE;
+ }
}
+ tmp[tid] = temp;
+
// sum up partial sums and write back result
barrier();
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : require
+#if USE_SUBGROUP_ADD
+#extension GL_KHR_shader_subgroup_arithmetic : enable
+#endif
-#define BLOCK_SIZE 32
#define FLOAT_TYPE float
-layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
+layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
+layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
+
+layout(constant_id = 0) const int BLOCK_SIZE = 32;
+// gqa_ratio is in the range [1,8]
+layout(constant_id = 1) const uint gqa_ratio = 1;
+
layout (push_constant) uniform parameter
{
uint ncols_x;
uint d_offset;
} p;
-shared FLOAT_TYPE tmp[BLOCK_SIZE];
+#if !USE_SUBGROUP_ADD
+shared FLOAT_TYPE tmp[8][BLOCK_SIZE];
+#endif
void main() {
const uint tid = gl_LocalInvocationID.x;
const uint row_x = gl_GlobalInvocationID.y;
- const uint channel = gl_GlobalInvocationID.z;
- const uint channel_x = channel / (p.nchannels_y / p.nchannels_x);
+
+ uint channel, channel_x;
+
+ // When gqa_ratio > 1, each invocation does multiple rows.
+ // The row in the A matrix is starting from channel / gqa_ratio and the
+ // rows in the B matrix are [channel, channel+gqa_ratio).
+ // When gpa_ratio is 1, each invocation does one row.
+ if (gqa_ratio > 1) {
+ channel_x = gl_GlobalInvocationID.z;
+ channel = channel_x * gqa_ratio;
+ } else {
+ channel = gl_GlobalInvocationID.z;
+ channel_x = channel / (p.nchannels_y / p.nchannels_x);;
+ }
const uint nrows_y = p.ncols_x;
const uint nrows_dst = p.nrows_x;
const uint row_dst = row_x;
- tmp[tid] = FLOAT_TYPE(0.0f);
+ FLOAT_TYPE temp[8];
+ [[unroll]] for (uint i = 0; i < 8; ++i) {
+ temp[i] = FLOAT_TYPE(0.0f);
+ }
+
+ // Detect alignment for vector loads
+ bool is_aligned = (p.ncols_x % 4) == 0 && (p.nchannels_x % 4) == 0 && (nrows_y % 4) == 0;
for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) {
- const uint col_x = col_x0 + tid;
- if (col_x >= p.ncols_x) {
- break;
- }
+ // Use vec4 loads if aligned
+ if (col_x0 + 4*BLOCK_SIZE <= p.ncols_x && is_aligned) {
- // x is transposed and permuted
- const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x;
- const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
+ uint col_x = col_x0 + 4*tid;
+ const uint row_y = col_x;
- const uint row_y = col_x;
+ // x is transposed and permuted
+ const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x;
+ const vec4 av4 = vec4(data_a_v4[ix / 4]);
- // y is not transposed but permuted
- const uint iy = channel*nrows_y + row_y;
+ [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
+ // y is not transposed but permuted
+ const uint iy = (channel + c)*nrows_y + row_y;
- tmp[tid] = fma(xi, FLOAT_TYPE(data_b[iy]), tmp[tid]);
- }
+ vec4 bv4 = data_b_v4[iy / 4];
+ temp[c] += dot(av4, bv4);
+ }
+
+ col_x0 += 3*BLOCK_SIZE;
+ } else {
+ const uint col_x = col_x0 + tid;
+
+ if (col_x >= p.ncols_x) {
+ break;
+ }
- // dst is not transposed and not permuted
- const uint idst = channel*nrows_dst + row_dst;
+ // x is transposed and permuted
+ const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x;
+ const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
+ const uint row_y = col_x;
+
+ [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
+ // y is not transposed but permuted
+ const uint iy = (channel + c)*nrows_y + row_y;
+
+ temp[c] = fma(xi, FLOAT_TYPE(data_b[iy]), temp[c]);
+ }
+ }
+ }
+
+#if USE_SUBGROUP_ADD
+ // reduce vec4 at a time
+ vec4 t = vec4(temp[0], temp[1], temp[2], temp[3]);
+ t = subgroupAdd(t);
+ temp[0] = t[0];
+ temp[1] = t[1];
+ temp[2] = t[2];
+ temp[3] = t[3];
+ if (gqa_ratio > 4) {
+ t = vec4(temp[4], temp[5], temp[6], temp[7]);
+ t = subgroupAdd(t);
+ temp[4] = t[0];
+ temp[5] = t[1];
+ temp[6] = t[2];
+ temp[7] = t[3];
+ }
+#else
+ [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
+ tmp[c][tid] = temp[c];
+ }
// sum up partial sums and write back result
barrier();
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) {
- tmp[tid] += tmp[tid + s];
+ [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
+ temp[c] += tmp[c][tid + s];
+ tmp[c][tid] = temp[c];
+ }
}
barrier();
}
+ [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
+ temp[c] = tmp[c][tid];
+ }
+#endif
if (tid == 0) {
- dst[idst] = tmp[0];
+ [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
+ // dst is not transposed and not permuted
+ const uint idst = (channel + c)*nrows_dst + row_dst;
+ dst[idst] = temp[c];
+ }
}
}
}
}
- string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
- string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
+ string_to_spv("mul_mat_vec_p021_f16_f32_subgroup_add", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}});
+ string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}});
+ string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}});
// Norms
string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
const std::array<int64_t, 2> bs; // dims 3 and 4
const std::array<int64_t, 2> nr; // repeat in dims 3 and 4
const std::array<int64_t, 4> per; // permutation of dimensions
+ const bool v; // whether a is a non-contiguous view
std::string vars() override {
- return VARS_TO_STR8(type_a, type_b, m, n, k, bs, nr, per);
+ return VARS_TO_STR9(type_a, type_b, m, n, k, bs, nr, per, v);
}
double max_nmse_err() override {
int64_t m = 32, int64_t n = 32, int64_t k = 32,
std::array<int64_t, 2> bs = {10, 10},
std::array<int64_t, 2> nr = {2, 2},
- std::array<int64_t, 4> per = {0, 1, 2, 3})
- : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), per(per) {}
+ std::array<int64_t, 4> per = {0, 1, 2, 3},
+ bool v = false)
+ : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), per(per), v(v) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
// C^T = A * B^T: (k, m) * (k, n) => (m, n)
const int npermuted = (per[0] != 0) + (per[1] != 1) + (per[2] != 2) + (per[3] != 3);
if (npermuted > 0) {
GGML_ASSERT(npermuted == 2);
+ GGML_ASSERT(!v); // not handled
GGML_ASSERT(!ggml_is_quantized(type_a) || per[0] == 0);
GGML_ASSERT(!ggml_is_quantized(type_b) || per[0] == 0);
ggml_set_name(a, "a_permuted");
ggml_set_name(b, "b_permuted");
} else {
- a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]);
+
+ if (v) {
+ a = ggml_new_tensor_4d(ctx, type_a, k*2, m, bs[0], bs[1]);
+ a = ggml_view_4d(ctx, a, k, m, bs[0], bs[1], a->nb[1], a->nb[2], a->nb[3], 0);
+ } else {
+ a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]);
+ }
b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
if (!ggml_is_quantized(type_a)) {
if (bs[1] == 1 && nr[1] == 1) {
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 45, 128, { 8, 1}, {4, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 45, 64, { 8, 1}, {4, 1}));
+ for (auto bs : {1,2,4,8}) {
+ for (auto nr : {1,4}) {
+ for (uint32_t m = 0; m < 2; ++m) {
+ for (uint32_t k = 0; k < 2; ++k) {
+ test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056 + m, 1, 128 + k, {bs, 1}, {nr, 1}, {0, 2, 1, 3}));
+ test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128 + m, 1, 1056 + k, {bs, 1}, {nr, 1}, {0, 1, 2, 3}, true));
+ }
+ }
+ }
+ }
+
// sycl backend will limit task global_range < MAX_INT
// test case for f16-type-convert-to-fp32 kernel with large k under fp32 compute dtype (occurs in stable-diffusion)
// however this case needs to alloc more memory which may fail in some devices (Intel Arc770, etc.)
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32000, 512, 1, 1}));
+ test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3}));
+ test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, true));
+
for (int bs : {1, 2, 3, 4, 5, 8, 512}) {
for (ggml_type type_a : all_types) {
for (ggml_type type_b : {GGML_TYPE_F32}) {