AMD_RDNA2,
AMD_RDNA3,
INTEL_XE2,
+ NVIDIA_PRE_TURING,
};
// HSK x HSV
// https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2025-0/intel-xe-gpu-architecture.html
return vk_device_architecture::INTEL_XE2;
}
+ } else if (props.vendorID == VK_VENDOR_ID_NVIDIA) {
+ const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();
+
+ bool cooperative_matrix = false;
+
+ // Detect "pre-turing" based on lack of coopmat support.
+ for (const auto& properties : ext_props) {
+ if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0) {
+ cooperative_matrix = true;
+ break;
+ }
+ }
+
+ if (!cooperative_matrix) {
+ return vk_device_architecture::NVIDIA_PRE_TURING;
+ }
}
return vk_device_architecture::OTHER;
}
+enum vk_conv_shapes {
+ CONV_SHAPE_128x128,
+ CONV_SHAPE_64x32,
+ CONV_SHAPE_32x256,
+ CONV_SHAPE_COUNT,
+};
+
struct vk_device_struct {
std::recursive_mutex mutex;
vk_pipeline pipeline_rwkv_wkv6_f32;
vk_pipeline pipeline_rwkv_wkv7_f32;
vk_pipeline pipeline_opt_step_adamw_f32;
- vk_pipeline pipeline_conv2d_f32;
- vk_pipeline pipeline_conv2d_f16_f32;
+ vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
+ vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
vk_pipeline pipeline_conv2d_dw_whcn_f32;
vk_pipeline pipeline_conv2d_dw_cwhn_f32;
uint32_t nb1;
uint32_t nb2;
uint32_t nb3;
+
+ // init_fastdiv_values constants for dividing by KW, KW*KH, OW, OW*OH
+ uint32_t KWmp; uint32_t KWL;
+ uint32_t KWKHmp; uint32_t KWKHL;
+ uint32_t OWmp; uint32_t OWL;
+ uint32_t OWOHmp; uint32_t OWOHL;
};
+template <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) {
+ // Compute magic values to divide by KW, KW*KH, OW, OW*OH
+ init_fastdiv_values(p.KW, p.KWmp, p.KWL);
+ init_fastdiv_values(p.KW*p.KH, p.KWKHmp, p.KWKHL);
+ init_fastdiv_values(p.OW, p.OWmp, p.OWL);
+ init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
+}
+
struct vk_op_conv2d_dw_push_constants {
uint32_t ne;
uint32_t batches;
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
// conv2d
- uint32_t conv2d_WG_SIZE = 256;
- uint32_t conv2d_BS_K = 128;
- uint32_t conv2d_BS_CRS = 16;
- uint32_t use_collectives = 0; // Enables subgroup ops for preventing the re-calculation of indices.
- if (device->subgroup_shuffle &&
- device->vendor_id != VK_VENDOR_ID_INTEL) { // Do not enable collectives on Intel, see PR 14316
- use_collectives = 1;
- conv2d_BS_CRS = std::min(
- device->subgroup_size,
- conv2d_BS_CRS); // CRS block size should be capped at sugroup size for correctness when shuffle is used.
- }
- uint32_t conv2d_BS_NPQ = 128;
- uint32_t conv2d_TS_K = 8;
- uint32_t conv2d_shmem_req =
- (conv2d_BS_K * (conv2d_BS_CRS + 1) + conv2d_BS_CRS * (conv2d_BS_NPQ + 1)) * sizeof(float);
- if (device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req) {
- conv2d_BS_CRS = 8;
- if (use_collectives) {
- conv2d_BS_CRS = std::min(device->subgroup_size, conv2d_BS_CRS);
- }
- }
-
- if (use_collectives) {
- ggml_vk_create_pipeline(
- device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
- sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
- { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true, true);
- ggml_vk_create_pipeline(
- device, device->pipeline_conv2d_f16_f32, "conv2d_f16_f32", conv2d_f16_f32_len, conv2d_f16_f32_data, "main", 3,
- sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
- { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true, true);
- } else {
- ggml_vk_create_pipeline(
- device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
- sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
- { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true,
- false);
- ggml_vk_create_pipeline(
- device, device->pipeline_conv2d_f16_f32, "conv2d_f16_f32", conv2d_f16_f32_len, conv2d_f16_f32_data, "main", 3,
- sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
- { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true,
- false);
+ for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
+ uint32_t conv2d_WG_SIZE = 256;
+ uint32_t conv2d_BS_K = 128;
+ uint32_t conv2d_BS_CRS = 16;
+ uint32_t use_collectives = 0; // Enables subgroup ops for preventing the re-calculation of indices.
+ uint32_t conv2d_BS_NPQ = 128;
+ uint32_t conv2d_TS_K = 8;
+ uint32_t conv2d_SHMEM_PAD = 4;
+ bool conv2d_UNROLL = true;
+
+ if (device->vendor_id == VK_VENDOR_ID_INTEL) {
+ conv2d_SHMEM_PAD = 0;
+ conv2d_UNROLL = false;
+ } else if (device->vendor_id == VK_VENDOR_ID_AMD) {
+ conv2d_SHMEM_PAD = device->architecture == vk_device_architecture::AMD_GCN ? 1 : 4;
+ }
+
+ switch (s) {
+ default:
+ case CONV_SHAPE_128x128:
+ conv2d_BS_K = 128;
+ conv2d_BS_NPQ = 128;
+ conv2d_BS_CRS = 16;
+ if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != vk_device_architecture::AMD_GCN) {
+ conv2d_UNROLL = false;
+ }
+ break;
+ case CONV_SHAPE_64x32:
+ conv2d_BS_K = 64;
+ conv2d_BS_NPQ = 32;
+ conv2d_BS_CRS = 32;
+ conv2d_TS_K = 4;
+ break;
+ case CONV_SHAPE_32x256:
+ conv2d_BS_K = 32;
+ conv2d_BS_NPQ = 256;
+ conv2d_BS_CRS = 16;
+ break;
+ }
+
+ // Use collectives on pre-Turing NVIDIA GPUs and GCN AMD cards, which had slower integer math.
+ bool allow_collectives_nv = device->vendor_id != VK_VENDOR_ID_NVIDIA ||
+ device->architecture == vk_device_architecture::NVIDIA_PRE_TURING;
+ bool allow_collectives_amd = device->vendor_id != VK_VENDOR_ID_AMD ||
+ device->architecture == vk_device_architecture::AMD_GCN;
+
+ if (device->subgroup_shuffle &&
+ device->vendor_id != VK_VENDOR_ID_INTEL && // Do not enable collectives on Intel, see PR 14316.
+ allow_collectives_nv &&
+ allow_collectives_amd) {
+ use_collectives = 1;
+ conv2d_BS_CRS = std::min(
+ device->subgroup_size,
+ conv2d_BS_CRS); // CRS block size should be capped at subgroup size for correctness when shuffle is used.
+ }
+
+ uint32_t conv2d_shmem_req =
+ (conv2d_BS_K * (conv2d_BS_CRS + conv2d_SHMEM_PAD) + conv2d_BS_CRS * (conv2d_BS_NPQ + conv2d_SHMEM_PAD)) * sizeof(float);
+ if (device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req) {
+ conv2d_BS_CRS = 8;
+ if (use_collectives) {
+ conv2d_BS_CRS = std::min(device->subgroup_size, conv2d_BS_CRS);
+ }
+ }
+
+ std::array<uint32_t, 3> wg_denoms = { conv2d_BS_K, conv2d_BS_NPQ, 1 };
+ std::vector<uint32_t> spec_constants = { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD };
+
+ if (conv2d_UNROLL) {
+ ggml_vk_create_pipeline(
+ device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_unroll_len, conv2d_f32_unroll_data, "main", 3,
+ sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
+ ggml_vk_create_pipeline(
+ device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_unroll_len, conv2d_f16_f32_unroll_data, "main", 3,
+ sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
+ } else {
+ ggml_vk_create_pipeline(
+ device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
+ sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
+ ggml_vk_create_pipeline(
+ device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_len, conv2d_f16_f32_data, "main", 3,
+ sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
+ }
}
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
}
}
+static std::array<uint32_t, 3> ggml_vk_get_conv_elements(const ggml_tensor *dst) {
+ const ggml_tensor *src0 = dst->src[0];
+ const ggml_tensor *src1 = dst->src[1];
+
+ // src0 - kernel: [KW, KH, Cin, Cout]
+ // src1 - input: [W, H, Cin, N]
+ // dst - result: [OW, OH, Cout, N]
+
+ // Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d)
+ auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
+ return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
+ };
+ // parallelize in {OW/BS_K, OH/BS_NPQ, 1}
+ int64_t W = src1->ne[0];
+ int64_t H = src1->ne[1];
+ int64_t KW = src0->ne[0];
+ int64_t KH = src0->ne[1];
+ int64_t Cout = src0->ne[3];
+ int64_t N = src1->ne[3];
+ int64_t OH = calc_conv_output_size(H, KH, dst->op_params[1], dst->op_params[3], dst->op_params[5]);
+ int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], dst->op_params[2], dst->op_params[4]);
+ int64_t NPQ = N * OW * OH;
+
+ // Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups
+ std::array<uint32_t, 3> elements = { static_cast<uint32_t>(Cout), static_cast<uint32_t>(NPQ), 1 };
+ return elements;
+}
+
static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
switch (op) {
case GGML_OP_GET_ROWS:
case GGML_OP_CONV_2D:
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
+ auto elements = ggml_vk_get_conv_elements(dst);
+ vk_conv_shapes shape;
+
+ uint32_t tiles[CONV_SHAPE_COUNT];
+ for (uint32_t i = 0; i < CONV_SHAPE_COUNT; ++i) {
+ tiles[i] = CEIL_DIV(elements[0], ctx->device->pipeline_conv2d_f32[i]->wg_denoms[0]) * CEIL_DIV(elements[1], ctx->device->pipeline_conv2d_f32[i]->wg_denoms[1]);
+ }
+
+ // We can't query number of shader cores on Intel, use 32 as a placeholder
+ // so small convolutions will still choose a smaller tile.
+ const uint32_t shader_core_count = ctx->device->shader_core_count > 0 ? ctx->device->shader_core_count : 32;
+
+ if (elements[0] > 64 && tiles[CONV_SHAPE_128x128] >= shader_core_count * 2) {
+ shape = CONV_SHAPE_128x128;
+ } else if (elements[0] <= 32 && tiles[CONV_SHAPE_32x256] >= shader_core_count * 2) {
+ shape = CONV_SHAPE_32x256;
+ } else {
+ shape = CONV_SHAPE_64x32;
+ }
+
if (src0->type == GGML_TYPE_F32) {
- return ctx->device->pipeline_conv2d_f32;
+ return ctx->device->pipeline_conv2d_f32[shape];
} else if (src0->type == GGML_TYPE_F16) {
- return ctx->device->pipeline_conv2d_f16_f32;
+ return ctx->device->pipeline_conv2d_f16_f32[shape];
}
}
return nullptr;
} break;
case GGML_OP_CONV_2D:
{
- // src0 - kernel: [KW, KH, Cin, Cout]
- // src1 - input: [W, H, Cin, N]
- // dst - result: [OW, OH, Cout, N]
-
- // Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d)
- auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
- return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
- };
- // parallelize in {OW/BS_K, OH/BS_NPQ, 1}
- int64_t W = src1->ne[0];
- int64_t H = src1->ne[1];
- int64_t KW = src0->ne[0];
- int64_t KH = src0->ne[1];
- int64_t Cout = src0->ne[3];
- int64_t N = src1->ne[3];
- int64_t OH = calc_conv_output_size(H, KH, dst->op_params[1], dst->op_params[3], dst->op_params[5]);
- int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], dst->op_params[2], dst->op_params[4]);
- int64_t NPQ = N * OW * OH;
-
- // Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups
- elements = { static_cast<uint32_t>(Cout), static_cast<uint32_t>(NPQ), 1 };
- }
- break;
+ elements = ggml_vk_get_conv_elements(dst);
+ } break;
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_DIV:
#version 450
+#extension GL_EXT_control_flow_attributes : enable
+
#ifdef USE_COLLECTIVES
# extension GL_KHR_shader_subgroup_shuffle : enable
#endif
#include "types.comp"
-// Make spec constant
-#define SHMEM_PAD 0
-
// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j
layout(binding = 0) readonly buffer A {
A_TYPE knl_data[];
uint32_t nb1;
uint32_t nb2;
uint32_t nb3;
+
+ // fastdiv helper values
+ uint32_t KWmp; uint32_t KWL;
+ uint32_t KWKHmp; uint32_t KWKHL;
+ uint32_t OWmp; uint32_t OWL;
+ uint32_t OWOHmp; uint32_t OWOHL;
}
p;
// Thread-tile sizes
layout(constant_id = 4) const uint TS_K = 8;
layout(constant_id = 5) const uint use_collectives = 1;
+layout(constant_id = 6) const uint SHMEM_PAD = 4;
uint32_t tid = gl_LocalInvocationID.x;
const uint32_t WG_SIZE = gl_WorkGroupSize.x;
uint32_t Bc = tid % BS_NPQ;
const uint32_t BrpWg = WG_SIZE / BS_NPQ;
+// see init_fastdiv_values in ggml-vulkan.cpp
+uint fastdiv(uint n, uint mp, uint L) {
+ uint msbs, lsbs;
+ // msbs = mulhi(n, mp)
+ umulExtended(n, mp, msbs, lsbs);
+ return (msbs + n) >> L;
+}
+
void main() {
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
uint32_t cached_KW_idx;
if (use_collectives == 1) {
cached_CRS_idx = B_idx_CRS * BS_CRS + gl_SubgroupInvocationID;
- cached_Cin_idx = cached_CRS_idx / (p.KW * p.KH);
+ cached_Cin_idx = fastdiv(cached_CRS_idx, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
uint32_t cached_CRS_remainder = (cached_CRS_idx - cached_Cin_idx * p.KW * p.KH);
- cached_KH_idx = cached_CRS_remainder / p.KW;
+ cached_KH_idx = fastdiv(cached_CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
cached_KW_idx = cached_CRS_remainder - cached_KH_idx * p.KW;
CRS_idx_a = subgroupShuffle(cached_CRS_idx, Ac);
KW_idx_a = subgroupShuffle(cached_KW_idx, Ac);
} else {
CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A)
- Cin_idx_a = CRS_idx_a / (p.KW * p.KH);
+ Cin_idx_a = fastdiv(CRS_idx_a, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
- KH_idx_a = CRS_remainder / p.KW;
+ KH_idx_a = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
KW_idx_a = CRS_remainder - KH_idx_a * p.KW;
}
#else
CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A)
- Cin_idx_a = CRS_idx_a / (p.KW * p.KH);
+ Cin_idx_a = fastdiv(CRS_idx_a, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); / (p.KW * p.KH);
CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
- KH_idx_a = CRS_remainder / p.KW;
+ KH_idx_a = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
KW_idx_a = CRS_remainder - KH_idx_a * p.KW;
#endif
Ash[B_ly * Ash_stride + B_lx] = val;
}
/* Load input to B_block: (BS_CRS x BS_NPQ) */
- for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) {
+ UNROLL for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) {
uint32_t B_ly = r_offset + Br; /* Row index of B block */
uint32_t B_lx = Bc;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */
- uint32_t N_idx = NPQ_idx / (p.OH * p.OW);
+ uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW;
uint32_t NPQ_remainder = NPQ_idx - N_idx * p.OH * p.OW;
- uint32_t OH_idx = NPQ_remainder / p.OW;
+ uint32_t OH_idx = fastdiv(NPQ_remainder, p.OWmp, p.OWL); // divide by p.OW;
uint32_t OW_idx = NPQ_remainder - OH_idx * p.OW;
uint32_t CRS_idx_b;
KW_idx_b = subgroupShuffle(cached_KW_idx, r_offset + Br);
} else {
CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
- Cin_idx_b = CRS_idx_b / (p.KW * p.KH);
+ Cin_idx_b = fastdiv(CRS_idx_b, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
- KH_idx_b = CRS_remainder / p.KW;
+ KH_idx_b = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
KW_idx_b = CRS_remainder - KH_idx_b * p.KW;
}
#else
CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
- Cin_idx_b = CRS_idx_b / (p.KW * p.KH);
+ Cin_idx_b = fastdiv(CRS_idx_b, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
- KH_idx_b = CRS_remainder / p.KW;
+ KH_idx_b = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
KW_idx_b = CRS_remainder - KH_idx_b * p.KW;
#endif
Bsh[B_ly * Bsh_stride + B_lx] = val;
}
barrier();
- for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) {
- for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
- regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx];
- }
- for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
- regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx];
- }
- for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
+ if (T_y * TS_K < K) {
+ UNROLL for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) {
+ for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
+ regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx];
+ }
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
- regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]);
+ regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx];
+ }
+ for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
+ for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
+ regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]);
+ }
}
}
}
barrier();
}
/* Save C* */
- for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
- for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
- uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;
- uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;
- uint32_t N_idx = NPQ_idx / (p.OH * p.OW);
- uint32_t OH_idx = (NPQ_idx - N_idx * p.OH * p.OW) / p.OW;
- uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW;
- uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3;
- if (K_idx < K && NPQ_idx < NPQ) {
- dst_data[dst_idx] = regC[T_ly][T_lx];
+ if (T_y * TS_K < K) {
+ for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
+ for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
+ uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;
+ uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;
+ uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW;
+ uint32_t OH_idx = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); // divide by p.OW;
+ uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW;
+ uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3;
+ if (K_idx < K && NPQ_idx < NPQ) {
+ dst_data[dst_idx] = regC[T_ly][T_lx];
+ }
}
}
}