vk_pipeline pipeline_rwkv_wkv6_f32;
vk_pipeline pipeline_rwkv_wkv7_f32;
vk_pipeline pipeline_opt_step_adamw_f32;
+ vk_pipeline pipeline_conv2d_f32;
vk_pipeline pipeline_conv2d_dw_whcn_f32;
vk_pipeline pipeline_conv2d_dw_cwhn_f32;
uint32_t H;
};
+struct vk_op_conv2d_push_constants {
+ uint32_t Cout;
+ uint32_t Cin;
+ uint32_t N;
+
+ uint32_t KW;
+ uint32_t KH;
+ uint32_t W;
+ uint32_t H;
+ uint32_t OW;
+ uint32_t OH;
+
+ uint32_t s0;
+ uint32_t s1;
+ uint32_t p0;
+ uint32_t p1;
+ uint32_t d0;
+ uint32_t d1;
+
+ uint32_t nb01;
+ uint32_t nb02;
+ uint32_t nb03;
+
+ uint32_t nb11;
+ uint32_t nb12;
+ uint32_t nb13;
+
+ uint32_t nb1;
+ uint32_t nb2;
+ uint32_t nb3;
+};
+
struct vk_op_conv2d_dw_push_constants {
uint32_t ne;
uint32_t batches;
#endif // GGML_VULKAN_MEMORY_DEBUG
class vk_perf_logger {
-public:
+ public:
void print_timings() {
+ if (timings.empty()) {
+ return;
+ }
+ uint64_t total_all_op_times = 0;
std::cerr << "----------------\nVulkan Timings:" << std::endl;
- for (const auto& t : timings) {
- uint64_t total = 0;
- for (const auto& time : t.second) {
- total += time;
+ for (const auto & t : timings) {
+ uint64_t total_op_times = 0;
+ for (const auto & time : t.second) {
+ total_op_times += time;
+ }
+ std::cerr << t.first << ": " << t.second.size() << " x " << (total_op_times / t.second.size() / 1000.0)
+ << " us";
+
+ // If we have as many flops entries as timing entries for the op, then compute and log the flops/S.
+ auto it = flops.find(t.first);
+ if (it != flops.end() && (it->second).size() == t.second.size()) {
+ uint64_t total_op_flops = 0;
+ for (const auto & elem : it->second) {
+ total_op_flops += elem;
+ }
+ std::cerr << " ("
+ << (double(total_op_flops) / (1000.0 * 1000.0 * 1000.0)) /
+ (double(total_op_times) / (1000.0 * 1000.0 * 1000.0))
+ << " GFLOPS/s)";
}
- std::cerr << t.first << ": " << t.second.size() << " x " << (total / t.second.size() / 1000.0) << " us" << std::endl;
+
+ total_all_op_times += total_op_times;
+
+ std::cerr << std::endl;
+ }
+
+ if (timings.size() > 0) {
+ std::cerr << "Total time: " << total_all_op_times / 1000.0 << " us." << std::endl;
}
timings.clear();
+ flops.clear();
}
void log_timing(const ggml_tensor * node, uint64_t time) {
return;
}
if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
- const uint64_t m = node->src[0]->ne[1];
- const uint64_t n = node->src[1]->ne[1];
- const uint64_t k = node->src[1]->ne[0];
- std::string name = ggml_op_name(node->op);
+ const uint64_t m = node->src[0]->ne[1];
+ const uint64_t n = node->src[1]->ne[1];
+ const uint64_t k = node->src[1]->ne[0];
+ std::string name = ggml_op_name(node->op);
if (n == 1) {
name += "_VEC m=" + std::to_string(m) + " k=" + std::to_string(k);
} else {
name += " m=" + std::to_string(m) + " n=" + std::to_string(n) + " k=" + std::to_string(k);
}
timings[name].push_back(time);
+ flops[name].push_back(m * n * (k + (k - 1)));
+ return;
+ }
+ if (node->op == GGML_OP_CONV_2D) {
+ std::string name = ggml_op_name(node->op);
+ ggml_tensor * knl = node->src[0];
+ uint64_t OW = node->ne[0];
+ uint64_t OH = node->ne[1];
+ uint64_t N = node->ne[3];
+ uint64_t Cout = node->ne[2];
+ uint64_t KW = knl->ne[0];
+ uint64_t KH = knl->ne[1];
+ uint64_t Cin = knl->ne[2];
+ // KxCRS @ CRSxNPQ = KxNPQ -> M=K, K=CRS, N=NPQ
+ uint64_t size_M = Cout;
+ uint64_t size_K = Cin * KW * KH;
+ uint64_t size_N = N * OW * OH;
+ uint64_t n_flops = size_M * size_N * (size_K + (size_K - 1));
+ name += " M=Cout=" + std::to_string(size_M) + ", K=Cin*KW*KH=" + std::to_string(size_K) +
+ ", N=N*OW*OH=" + std::to_string(size_N);
+ flops[name].push_back(n_flops);
+ timings[name].push_back(time);
return;
}
timings[ggml_op_name(node->op)].push_back(time);
}
-private:
+ private:
std::map<std::string, std::vector<uint64_t>> timings;
+ std::map<std::string, std::vector<uint64_t>> flops;
};
struct ggml_backend_vk_context {
}
compile_count++;
}
+
compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint,
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
};
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
+ // conv2d
+ uint32_t conv2d_WG_SIZE = 256;
+ uint32_t conv2d_BS_K = 128;
+ uint32_t conv2d_BS_CRS = 16;
+ uint32_t use_collectives = 0; // Enables subgroup ops for preventing the re-calculation of indices.
+ if (device->subgroup_shuffle &&
+ device->vendor_id != VK_VENDOR_ID_INTEL) { // Do not enable collectives on Intel, see PR 14316
+ use_collectives = 1;
+ conv2d_BS_CRS = std::min(
+ device->subgroup_size,
+ conv2d_BS_CRS); // CRS block size should be capped at sugroup size for correctness when shuffle is used.
+ }
+ uint32_t conv2d_BS_NPQ = 128;
+ uint32_t conv2d_TS_K = 8;
+ uint32_t conv2d_shmem_req =
+ (conv2d_BS_K * (conv2d_BS_CRS + 1) + conv2d_BS_CRS * (conv2d_BS_NPQ + 1)) * sizeof(float);
+ if (device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req) {
+ conv2d_BS_CRS = 8;
+ if (use_collectives) {
+ conv2d_BS_CRS = std::min(device->subgroup_size, conv2d_BS_CRS);
+ }
+ }
+
+ if (use_collectives) {
+ ggml_vk_create_pipeline(
+ device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
+ sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
+ { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true, true);
+ } else {
+ ggml_vk_create_pipeline(
+ device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
+ sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
+ { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true,
+ false);
+ }
+
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
return ctx->device->pipeline_leaky_relu_f32;
}
return nullptr;
+ case GGML_OP_CONV_2D:
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
+ ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
+ return ctx->device->pipeline_conv2d_f32;
+ }
+ return nullptr;
case GGML_OP_CONV_2D_DW:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
if (ggml_is_contiguous(src1)) {
const uint32_t OW = dst->ne[0];
elements = { N * OC * OH * OW, 1, 1};
} break;
+ case GGML_OP_CONV_2D:
+ {
+ // src0 - kernel: [KW, KH, Cin, Cout]
+ // src1 - input: [W, H, Cin, N]
+ // dst - result: [OW, OH, Cout, N]
+
+ // Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d)
+ auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
+ return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
+ };
+ // parallelize in {OW/BS_K, OH/BS_NPQ, 1}
+ int64_t W = src1->ne[0];
+ int64_t H = src1->ne[1];
+ int64_t KW = src0->ne[0];
+ int64_t KH = src0->ne[1];
+ int64_t Cout = src0->ne[3];
+ int64_t N = src1->ne[3];
+ int64_t OH = calc_conv_output_size(H, KH, dst->op_params[1], dst->op_params[3], dst->op_params[5]);
+ int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], dst->op_params[2], dst->op_params[4]);
+ int64_t NPQ = N * OW * OH;
+
+ // Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups
+ elements = { static_cast<uint32_t>(Cout), static_cast<uint32_t>(NPQ), 1 };
+ }
+ break;
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_DIV:
}, dryrun);
}
+static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0,
+ const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ GGML_ASSERT(nb00 == sizeof(float));
+ GGML_ASSERT(nb10 == sizeof(float));
+ GGML_ASSERT(nb0 == sizeof(float));
+
+ vk_op_conv2d_push_constants p{};
+ p.Cout = static_cast<uint32_t>(ne03);
+ p.Cin = static_cast<uint32_t>(ne02);
+ p.N = static_cast<uint32_t>(ne13);
+
+ p.KW = static_cast<uint32_t>(ne00);
+ p.KH = static_cast<uint32_t>(ne01);
+ p.W = static_cast<uint32_t>(ne10);
+ p.H = static_cast<uint32_t>(ne11);
+ p.OW = static_cast<uint32_t>(ne0);
+ p.OH = static_cast<uint32_t>(ne1);
+
+ p.s0 = static_cast<uint32_t>(dst->op_params[0]);
+ p.s1 = static_cast<uint32_t>(dst->op_params[1]);
+ p.p0 = static_cast<uint32_t>(dst->op_params[2]);
+ p.p1 = static_cast<uint32_t>(dst->op_params[3]);
+ p.d0 = static_cast<uint32_t>(dst->op_params[4]);
+ p.d1 = static_cast<uint32_t>(dst->op_params[5]);
+
+ p.nb01 = static_cast<uint32_t>(nb01 / nb00);
+ p.nb02 = static_cast<uint32_t>(nb02 / nb00);
+ p.nb03 = static_cast<uint32_t>(nb03 / nb00);
+
+ p.nb11 = static_cast<uint32_t>(nb11 / nb10);
+ p.nb12 = static_cast<uint32_t>(nb12 / nb10);
+ p.nb13 = static_cast<uint32_t>(nb13 / nb10);
+
+ p.nb1 = static_cast<uint32_t>(nb1 / nb0);
+ p.nb2 = static_cast<uint32_t>(nb2 / nb0);
+ p.nb3 = static_cast<uint32_t>(nb3 / nb0);
+
+ GGML_ASSERT(ne03 == ne2);
+ GGML_ASSERT(ne02 == ne12);
+
+ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D, std::move(p), dryrun);
+}
+
static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
vk_op_conv2d_dw_push_constants p{};
p.ne = ggml_nelements(dst);
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_CONV_TRANSPOSE_1D:
case GGML_OP_POOL_2D:
+ case GGML_OP_CONV_2D:
case GGML_OP_CONV_2D_DW:
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_CONV_TRANSPOSE_1D:
case GGML_OP_POOL_2D:
+ case GGML_OP_CONV_2D:
case GGML_OP_CONV_2D_DW:
case GGML_OP_LEAKY_RELU:
{
case GGML_OP_POOL_2D:
ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun);
+ break;
+ case GGML_OP_CONV_2D:
+ ggml_vk_conv_2d(ctx, compute_ctx, src0, src1, node, dryrun);
+
break;
case GGML_OP_CONV_2D_DW:
ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node, dryrun);
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_CONV_TRANSPOSE_1D:
case GGML_OP_POOL_2D:
+ case GGML_OP_CONV_2D:
case GGML_OP_CONV_2D_DW:
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
+ } else if (cgraph->nodes[i]->op == GGML_OP_CONV_2D) {
+ // Return CRSxNPQxsizeof(*) to account as many bytes as mul_mat has in im2col->mul_mat mode.
+ auto CRS_size =
+ cgraph->nodes[i]->src[0]->ne[0] * cgraph->nodes[i]->src[0]->ne[1] * cgraph->nodes[i]->src[0]->ne[2];
+ auto NPQ_size = cgraph->nodes[i]->ne[0] * cgraph->nodes[i]->ne[1] * cgraph->nodes[i]->ne[3];
+ total_mat_mul_bytes += NPQ_size * CRS_size * ggml_type_size(cgraph->nodes[i]->type);
}
i += ctx->num_additional_fused_ops;
ctx->num_additional_fused_ops = 0;
return true;
case GGML_OP_CONV_TRANSPOSE_1D:
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
+ case GGML_OP_CONV_2D:
+ {
+ // Op is disabled for Apple because it segfaults at pipeline create time on MoltenVK
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
+ const vk_device& device = ggml_vk_get_device(ctx->device);
+ bool is_Apple = ggml_vk_get_device(ctx->device)->vendor_id == VK_VENDOR_ID_APPLE;
+ // Channel-contiguous format is not supported yet.
+ return (op->src[0]->type == GGML_TYPE_F32 &&
+ op->src[1]->type == GGML_TYPE_F32 &&
+ op->type == GGML_TYPE_F32 &&
+ ggml_is_contiguous(op->src[0]) &&
+ ggml_is_contiguous(op->src[1]) &&
+ ggml_is_contiguous(op)) && !is_Apple;
+ }
default:
return false;
}
const int32_t p1 = tensor->op_params[6];
tensor_clone = ggml_pool_2d(ggml_ctx, src_clone[0], op, k0, k1, s0, s1, p0, p1);
+ } else if (tensor->op == GGML_OP_CONV_2D) {
+ const int32_t s0 = tensor->op_params[0];
+ const int32_t s1 = tensor->op_params[1];
+ const int32_t p0 = tensor->op_params[2];
+ const int32_t p1 = tensor->op_params[3];
+ const int32_t d0 = tensor->op_params[4];
+ const int32_t d1 = tensor->op_params[5];
+ tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1);
} else if (tensor->op == GGML_OP_LEAKY_RELU) {
const float * op_params = (const float *)tensor->op_params;
tensor_clone = ggml_leaky_relu(ggml_ctx, src_clone[0], op_params[0], false);
--- /dev/null
+#version 450
+
+#ifdef USE_COLLECTIVES
+# extension GL_KHR_shader_subgroup_shuffle : enable
+#endif
+
+#include "types.comp"
+
+// Make spec constant
+#define SHMEM_PAD 0
+
+// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j
+layout(binding = 0) readonly buffer A {
+ A_TYPE knl_data[];
+}; // src0 - kernel: [KW, KH, Cin, Cout]
+
+layout(binding = 1) readonly buffer B {
+ B_TYPE src_data[];
+}; // src1 - input: [W, H, Cin, N] -- channel_first format
+
+layout(binding = 2) writeonly buffer D {
+ D_TYPE dst_data[];
+}; // dst - result: [OW, OH, Cout, N]
+
+layout(push_constant) uniform parameter {
+ // I/O channels, batch size
+ uint32_t Cout;
+ uint32_t Cin;
+ uint32_t N;
+
+ // Tensor spatial sizes: kernel, input, output
+ uint32_t KW;
+ uint32_t KH;
+ uint32_t W;
+ uint32_t H;
+ uint32_t OW;
+ uint32_t OH;
+
+ // Parameters: stride, padding, dilation - 0=y, 1=x
+ uint32_t s0;
+ uint32_t s1;
+ uint32_t p0;
+ uint32_t p1;
+ uint32_t d0;
+ uint32_t d1;
+
+ // Strides in elements
+ uint32_t nb01;
+ uint32_t nb02;
+ uint32_t nb03;
+
+ uint32_t nb11;
+ uint32_t nb12;
+ uint32_t nb13;
+
+ uint32_t nb1;
+ uint32_t nb2;
+ uint32_t nb3;
+}
+
+p;
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+// Blocktile sizes
+layout(constant_id = 1) const uint BS_K = 128;
+layout(constant_id = 2) const uint BS_CRS = 16;
+layout(constant_id = 3) const uint BS_NPQ = 128;
+// Thread-tile sizes
+layout(constant_id = 4) const uint TS_K = 8;
+layout(constant_id = 5) const uint use_collectives = 1;
+
+uint32_t tid = gl_LocalInvocationID.x;
+const uint32_t WG_SIZE = gl_WorkGroupSize.x;
+
+uint splitWork(uint work_size, uint block_size) {
+ return (block_size + work_size - 1) / block_size;
+}
+
+uint32_t K = p.Cout;
+uint32_t CRS = p.Cin * p.KH * p.KW;
+uint32_t NPQ = p.N * p.OH * p.OW;
+
+uint32_t n_elems_out = K * NPQ;
+
+// Number of blocktiles per input
+uint32_t NB_CRS = splitWork(CRS, BS_CRS);
+
+const uint32_t Ash_stride = BS_CRS + SHMEM_PAD;
+const uint32_t Bsh_stride = BS_NPQ + SHMEM_PAD;
+
+const uint32_t Ash_numel = BS_K * BS_CRS;
+const uint32_t Bsh_numel = BS_CRS * BS_NPQ;
+
+const uint32_t Ash_len = BS_K * Ash_stride;
+const uint32_t Bsh_len = BS_CRS * Bsh_stride;
+
+shared float Ash[Ash_len]; // K x CRS
+shared float Bsh[Bsh_len]; // CRS x NPQ
+
+// Threadtile sizes
+const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K;
+
+// Number of threadtiles per blocktile
+const uint32_t NT_K = BS_K / TS_K;
+const uint32_t NT_NPQ = BS_NPQ / TS_NPQ;
+
+float regA[TS_K];
+float regB[TS_NPQ];
+float regC[TS_K][TS_NPQ];
+
+/*
+Compute
+KxCRS @ CRSxNPQ = K x NPQ
+K=Cout
+C=Cin
+R,S=KH,KW
+P,Q=OH,OW
+*/
+
+uint32_t B_idx_K = gl_WorkGroupID.x;
+uint32_t B_idx_NPQ = gl_WorkGroupID.y;
+
+uint32_t T_y = tid / NT_NPQ;
+uint32_t T_x = tid % NT_NPQ;
+
+uint32_t Ar = tid / BS_CRS;
+uint32_t Ac = tid % BS_CRS;
+const uint32_t ArpWg = WG_SIZE / BS_CRS;
+
+uint32_t Br = tid / BS_NPQ;
+uint32_t Bc = tid % BS_NPQ;
+const uint32_t BrpWg = WG_SIZE / BS_NPQ;
+
+void main() {
+ for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
+ for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
+ regC[T_ly][T_lx] = 0.0;
+ }
+ }
+ /* Advance block in CRS dim */
+ for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) {
+ uint32_t CRS_idx_a;
+ uint32_t Cin_idx_a;
+ uint32_t KH_idx_a;
+ uint32_t KW_idx_a;
+
+#ifdef USE_COLLECTIVES
+ uint32_t cached_CRS_idx;
+ uint32_t cached_Cin_idx;
+ uint32_t cached_KH_idx;
+ uint32_t cached_KW_idx;
+ if (use_collectives == 1) {
+ cached_CRS_idx = B_idx_CRS * BS_CRS + gl_SubgroupInvocationID;
+ cached_Cin_idx = cached_CRS_idx / (p.KW * p.KH);
+ uint32_t cached_CRS_remainder = (cached_CRS_idx - cached_Cin_idx * p.KW * p.KH);
+ cached_KH_idx = cached_CRS_remainder / p.KW;
+ cached_KW_idx = cached_CRS_remainder - cached_KH_idx * p.KW;
+
+ CRS_idx_a = subgroupShuffle(cached_CRS_idx, Ac);
+ Cin_idx_a = subgroupShuffle(cached_Cin_idx, Ac);
+ KH_idx_a = subgroupShuffle(cached_KH_idx, Ac);
+ KW_idx_a = subgroupShuffle(cached_KW_idx, Ac);
+ } else {
+ CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A)
+ Cin_idx_a = CRS_idx_a / (p.KW * p.KH);
+ uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
+ KH_idx_a = CRS_remainder / p.KW;
+ KW_idx_a = CRS_remainder - KH_idx_a * p.KW;
+ }
+#else
+ CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A)
+ Cin_idx_a = CRS_idx_a / (p.KW * p.KH);
+ CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
+ KH_idx_a = CRS_remainder / p.KW;
+ KW_idx_a = CRS_remainder - KH_idx_a * p.KW;
+#endif
+
+ /* Load kernel to A_block: (BS_K x BS_CRS)*/
+ for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) {
+ uint32_t B_ly = r_offset + Ar;
+ uint32_t B_lx = Ac;
+ uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/
+ uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + Cin_idx_a * p.nb02 + K_idx * p.nb03, K * CRS - 1);
+ float val = knl_data[knl_idx];
+ if (K_idx >= K || CRS_idx_a >= CRS) {
+ val = 0.0;
+ }
+ Ash[B_ly * Ash_stride + B_lx] = val;
+ }
+ /* Load input to B_block: (BS_CRS x BS_NPQ) */
+ for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) {
+ uint32_t B_ly = r_offset + Br; /* Row index of B block */
+ uint32_t B_lx = Bc;
+ uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */
+ uint32_t N_idx = NPQ_idx / (p.OH * p.OW);
+ uint32_t NPQ_remainder = NPQ_idx - N_idx * p.OH * p.OW;
+ uint32_t OH_idx = NPQ_remainder / p.OW;
+ uint32_t OW_idx = NPQ_remainder - OH_idx * p.OW;
+
+ uint32_t CRS_idx_b;
+ uint32_t Cin_idx_b;
+ uint32_t KH_idx_b;
+ uint32_t KW_idx_b;
+#ifdef USE_COLLECTIVES
+ if (use_collectives == 1) {
+ CRS_idx_b = subgroupShuffle(cached_CRS_idx, r_offset + Br);
+ Cin_idx_b = subgroupShuffle(cached_Cin_idx, r_offset + Br);
+ KH_idx_b = subgroupShuffle(cached_KH_idx, r_offset + Br);
+ KW_idx_b = subgroupShuffle(cached_KW_idx, r_offset + Br);
+ } else {
+ CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
+ Cin_idx_b = CRS_idx_b / (p.KW * p.KH);
+ uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
+ KH_idx_b = CRS_remainder / p.KW;
+ KW_idx_b = CRS_remainder - KH_idx_b * p.KW;
+ }
+#else
+ CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
+ Cin_idx_b = CRS_idx_b / (p.KW * p.KH);
+ uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
+ KH_idx_b = CRS_remainder / p.KW;
+ KW_idx_b = CRS_remainder - KH_idx_b * p.KW;
+#endif
+
+ uint32_t H_idx = OH_idx * p.s1 + KH_idx_b * p.d1 - p.p1;
+ uint32_t W_idx = OW_idx * p.s0 + KW_idx_b * p.d0 - p.p0;
+ uint32_t src_idx =
+ min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1);
+ float val = src_data[src_idx];
+ if (CRS_idx_b >= CRS || NPQ_idx >= NPQ || H_idx < 0 || H_idx >= p.H || W_idx < 0 || W_idx >= p.W) {
+ val = 0.0;
+ }
+ Bsh[B_ly * Bsh_stride + B_lx] = val;
+ }
+ barrier();
+ for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) {
+ for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
+ regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx];
+ }
+ for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
+ regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx];
+ }
+ for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
+ for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
+ regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]);
+ }
+ }
+ }
+ barrier();
+ }
+ /* Save C* */
+ for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
+ for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
+ uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;
+ uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;
+ uint32_t N_idx = NPQ_idx / (p.OH * p.OW);
+ uint32_t OH_idx = (NPQ_idx - N_idx * p.OH * p.OW) / p.OW;
+ uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW;
+ uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3;
+ if (K_idx < K && NPQ_idx < NPQ) {
+ dst_data[dst_idx] = regC[T_ly][T_lx];
+ }
+ }
+ }
+}