int unroll_count = 4;
uint unrolled_iters = num_iters & ~(unroll_count - 1);
+#if K_PER_ITER == 2
+ // If the K dimension is odd, we need lastiter==true on the last iteration
+ // so OOB is computed correctly. Skip some unrolling to make that happen.
+ if ((p.ncols & 1) != 0 &&
+ unrolled_iters == num_iters &&
+ unrolled_iters > 0) {
+ unrolled_iters -= unroll_count;
+ }
+#endif
+
uint i = 0;
while (i < unrolled_iters) {
// Manually partially unroll the loop
i++;
}
}
+
unroll_count = 2;
unrolled_iters = num_iters & ~(unroll_count - 1);
+
+#if K_PER_ITER == 2
+ if ((p.ncols & 1) != 0 &&
+ unrolled_iters == num_iters &&
+ unrolled_iters > 0) {
+ unrolled_iters -= unroll_count;
+ }
+#endif
+
while (i < unrolled_iters) {
// Manually partially unroll the loop
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {