From: Noah Date: Tue, 4 Nov 2025 05:04:59 +0000 (+0000) Subject: Fix garbled output with REPACK at high thread counts (llama/16956) X-Git-Tag: upstream/0.9.4.185~33 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=d7c5e5ac16985acf33ec145ba0a08c98579e8927;p=pkg%2Fggml%2Fsources%2Fggml Fix garbled output with REPACK at high thread counts (llama/16956) * Fix garbled output with REPACK at high thread counts Fixed a race condition in the REPACK matrix multiplication code that caused garbled output when using 26+ threads (model-dependent threshold). The issue occurred because with high thread counts, the code forced chunk count to equal thread count, creating many small chunks. After aligning these chunks to NB_COLS boundaries, adjacent chunks could overlap, causing data corruption and race conditions. The fix enforces minimum chunk sizes based on NB_COLS and caps maximum chunk count to prevent creating too many tiny chunks, ensuring proper alignment without overlaps. * Update ggml/src/ggml-cpu/repack.cpp Co-authored-by: Georgi Gerganov * Update ggml/src/ggml-cpu/repack.cpp Co-authored-by: Georgi Gerganov --------- Co-authored-by: Georgi Gerganov --- diff --git a/src/ggml-cpu/repack.cpp b/src/ggml-cpu/repack.cpp index 8da1e0e9..8421c84c 100644 --- a/src/ggml-cpu/repack.cpp +++ b/src/ggml-cpu/repack.cpp @@ -1678,10 +1678,24 @@ template 0 && (nr / nchunk) < min_chunk_size && nr >= min_chunk_size) { + nchunk = (nr + min_chunk_size - 1) / min_chunk_size; + } + if (nth == 1 || nchunk < nth || disable_chunking) { nchunk = nth; } + // Ensure nchunk doesn't exceed the number of rows divided by minimum chunk size + // This prevents creating too many tiny chunks that could overlap after alignment + const int64_t max_nchunk = (nr + min_chunk_size - 1) / min_chunk_size; + if (nchunk > max_nchunk) { + nchunk = max_nchunk; + } + if (ith == 0) { // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. ggml_threadpool_chunk_set(params->threadpool, nth); @@ -1695,8 +1709,15 @@ template ne01) { + src0_end = ne01; + } + if (src0_start >= src0_end) { break; } @@ -1808,8 +1829,12 @@ template ne01) { + src0_cur_end = ne01; + } if (src0_cur_start >= src0_cur_end) { return;