int row = tid / load_cols;
int col = tid % load_cols;
#pragma unroll
- for (int idx = tid; idx < total_elems; idx += split_d_inner) {
+ for (int idx = 0; idx < total_elems; idx += split_d_inner) {
if (row < (int)split_d_inner) {
smem[row * n_cols + col] = x_block[row * stride_x + col];
}
col += split_d_inner;
row += col / load_cols;
col = col % load_cols;
+ if (idx >= total_elems - tid - split_d_inner) {
+ break;
+ }
}
__syncthreads();