sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
}
#else
- ACC_TYPE sums[WMITER * TM * WNITER * TN];
+ ACC_TYPE_VEC2 sums[WMITER * TM * WNITER * TN/2];
FLOAT_TYPE_VEC2 cache_a[WMITER * TM];
- FLOAT_TYPE_VEC2 cache_b[TN];
+ FLOAT_TYPE_VEC2 cache_b;
- [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
- sums[i] = ACC_TYPE(0.0f);
+ [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) {
+ sums[i] = ACC_TYPE_VEC2(0.0f, 0.0f);
}
#endif
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i];
}
}
- [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
- [[unroll]] for (uint j = 0; j < TN; j++) {
- cache_b[j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i];
- }
- [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
- [[unroll]] for (uint cc = 0; cc < TN; cc++) {
- [[unroll]] for (uint cr = 0; cr < TM; cr++) {
- const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
- sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr].x), ACC_TYPE(cache_b[cc].x), fma(ACC_TYPE(cache_a[wsir * TM + cr].y), ACC_TYPE(cache_b[cc].y), sums[sums_idx]));
+ [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
+ [[unroll]] for (uint cc = 0; cc < TN; cc++) {
+ cache_b = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + i];
+
+ [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
+ [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
+ // [WNITER][TN][WMITER][TM / 2] -> [wsic][cc][wsir][cr]
+ const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;
+ sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y), sums[sums_idx].x));
+ sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y), sums[sums_idx].y));
}
}
}
}
+
}
#endif
}
}
#else
- [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
- sums[i] = clamp(sums[i], -ACC_TYPE_MAX, ACC_TYPE_MAX);
+ [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) {
+ sums[i].x = clamp(sums[i].x, -ACC_TYPE_MAX, ACC_TYPE_MAX);
+ sums[i].y = clamp(sums[i].y, -ACC_TYPE_MAX, ACC_TYPE_MAX);
}
#endif
#endif
const u16vec2 row_idx = row_ids[row_i - ic * BN];
#endif // MUL_MAT_ID
- [[unroll]] for (uint cr = 0; cr < TM; cr++) {
+ [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
+ const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;
#ifdef MUL_MAT_ID
- if (dr_warp + cr < p.M) {
- data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
+ if (dr_warp + 2 * cr < p.M) {
+ data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x);
+ }
+ if (dr_warp + 2 * cr + 1 < p.M) {
+ data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y);
}
#else
- if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
- data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
+ if (dr_warp + 2 * cr < p.M && dc_warp + cc < p.N) {
+ data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x);
+ }
+ if (dr_warp + 2 * cr + 1 < p.M && dc_warp + cc < p.N) {
+ data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y);
}
#endif // MUL_MAT_ID
}