int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
+ // Ensure minimum chunk size to avoid alignment issues with high thread counts
+ // Minimum chunk size should be at least NB_COLS to prevent overlapping chunks after alignment
+ const int64_t min_chunk_size = NB_COLS;
+ if (nchunk > 0 && (nr / nchunk) < min_chunk_size && nr >= min_chunk_size) {
+ nchunk = (nr + min_chunk_size - 1) / min_chunk_size;
+ }
+
if (nth == 1 || nchunk < nth || disable_chunking) {
nchunk = nth;
}
+ // Ensure nchunk doesn't exceed the number of rows divided by minimum chunk size
+ // This prevents creating too many tiny chunks that could overlap after alignment
+ const int64_t max_nchunk = (nr + min_chunk_size - 1) / min_chunk_size;
+ if (nchunk > max_nchunk) {
+ nchunk = max_nchunk;
+ }
+
if (ith == 0) {
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
ggml_threadpool_chunk_set(params->threadpool, nth);
while (current_chunk < nchunk) {
int64_t src0_start = (current_chunk * ne01) / nchunk;
int64_t src0_end = ((current_chunk + 1) * ne01) / nchunk;
+
+ // Align boundaries to NB_COLS - round up to ensure all data is included
+ // The chunk size limiting above ensures chunks are large enough to prevent overlaps
src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
+ if (src0_end > ne01) {
+ src0_end = ne01;
+ }
+
if (src0_start >= src0_end) {
break;
}
int64_t src0_cur_start = (ith * ne01) / nth;
int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
+ // Align boundaries to NB_COLS - round up to ensure all data is included
src0_cur_start = (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
+ if (src0_cur_end > ne01) {
+ src0_cur_end = ne01;
+ }
if (src0_cur_start >= src0_cur_end) {
return;