]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
Fix garbled output with REPACK at high thread counts (llama/16956)
authorNoah <redacted>
Tue, 4 Nov 2025 05:04:59 +0000 (05:04 +0000)
committerGeorgi Gerganov <redacted>
Sun, 9 Nov 2025 16:30:22 +0000 (18:30 +0200)
* Fix garbled output with REPACK at high thread counts

Fixed a race condition in the REPACK matrix multiplication code that caused garbled output when using 26+ threads (model-dependent threshold). The issue occurred because with high thread counts, the code forced chunk count to equal thread count, creating many small chunks. After aligning these chunks to NB_COLS boundaries, adjacent chunks could overlap, causing data corruption and race conditions. The fix enforces minimum chunk sizes based on NB_COLS and caps maximum chunk count to prevent creating too many tiny chunks, ensuring proper alignment without overlaps.

* Update ggml/src/ggml-cpu/repack.cpp

Co-authored-by: Georgi Gerganov <redacted>
* Update ggml/src/ggml-cpu/repack.cpp

Co-authored-by: Georgi Gerganov <redacted>
---------

Co-authored-by: Georgi Gerganov <redacted>
src/ggml-cpu/repack.cpp

index 8da1e0e9245b6fee8c36b4805478dbcd96b8cfea..8421c84ce0942e47b450df16bb4ff43ba3b2c36b 100644 (file)
@@ -1678,10 +1678,24 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
         int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
         int64_t nchunk     = (nr + chunk_size - 1) / chunk_size;
 
+        // Ensure minimum chunk size to avoid alignment issues with high thread counts
+        // Minimum chunk size should be at least NB_COLS to prevent overlapping chunks after alignment
+        const int64_t min_chunk_size = NB_COLS;
+        if (nchunk > 0 && (nr / nchunk) < min_chunk_size && nr >= min_chunk_size) {
+            nchunk = (nr + min_chunk_size - 1) / min_chunk_size;
+        }
+
         if (nth == 1 || nchunk < nth || disable_chunking) {
             nchunk = nth;
         }
 
+        // Ensure nchunk doesn't exceed the number of rows divided by minimum chunk size
+        // This prevents creating too many tiny chunks that could overlap after alignment
+        const int64_t max_nchunk = (nr + min_chunk_size - 1) / min_chunk_size;
+        if (nchunk > max_nchunk) {
+            nchunk = max_nchunk;
+        }
+
         if (ith == 0) {
             // Every thread starts at ith, so the first unprocessed chunk is nth.  This save a bit of coordination right at the start.
             ggml_threadpool_chunk_set(params->threadpool, nth);
@@ -1695,8 +1709,15 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
         while (current_chunk < nchunk) {
             int64_t src0_start = (current_chunk * ne01) / nchunk;
             int64_t src0_end   = ((current_chunk + 1) * ne01) / nchunk;
+
+            // Align boundaries to NB_COLS - round up to ensure all data is included
+            // The chunk size limiting above ensures chunks are large enough to prevent overlaps
             src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
             src0_end   = (src0_end   % NB_COLS) ? src0_end   + NB_COLS - (src0_end   % NB_COLS) : src0_end;
+            if (src0_end > ne01) {
+                src0_end = ne01;
+            }
+
             if (src0_start >= src0_end) {
                 break;
             }
@@ -1808,8 +1829,12 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
             int64_t src0_cur_start = (ith * ne01) / nth;
             int64_t src0_cur_end   = ((ith + 1) * ne01) / nth;
 
+            // Align boundaries to NB_COLS - round up to ensure all data is included
             src0_cur_start = (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
             src0_cur_end   = (src0_cur_end   % NB_COLS) ? src0_cur_end   + NB_COLS - (src0_cur_end   % NB_COLS) : src0_cur_end;
+            if (src0_cur_end > ne01) {
+                src0_cur_end = ne01;
+            }
 
             if (src0_cur_start >= src0_cur_end) {
                 return;