]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
cpu: introduce chunking for repack matmuls and enable matmul-id chunking on ARM64...
authorMax Krasnyansky <redacted>
Thu, 30 Oct 2025 16:06:13 +0000 (09:06 -0700)
committerGeorgi Gerganov <redacted>
Sat, 1 Nov 2025 07:41:35 +0000 (09:41 +0200)
Very similar implementation to the flash-attention chunking, with similar benefits.

src/ggml-cpu/ggml-cpu.c
src/ggml-cpu/repack.cpp

index 9ec485cfa2ff7543502769136c52dd3890eb2381..b5466dd703d1d14fbb1e950cb06b11388c83ef8b 100644 (file)
@@ -1613,13 +1613,8 @@ static void ggml_compute_forward_mul_mat_id(
             chunk_size = 64;
         }
 
-#if defined(__aarch64__)
-        // disable for ARM
-        const bool disable_chunking = true;
-#else
         // disable for NUMA
         const bool disable_chunking = ggml_is_numa();
-#endif // defined(__aarch64__)
 
         int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;
         int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;
index f531d21e23224370687b09e268119008efb64a18..8da1e0e9245b6fee8c36b4805478dbcd96b8cfea 100644 (file)
@@ -1600,6 +1600,32 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
         return false;
     }
 
+    void forward_mul_mat_one_chunk(ggml_compute_params * params, ggml_tensor * op, int64_t src0_start, int64_t src0_end) {
+        const ggml_tensor * src0 = op->src[0];
+        const ggml_tensor * src1 = op->src[1];
+        ggml_tensor *       dst  = op;
+
+        GGML_TENSOR_BINARY_OP_LOCALS
+
+        const void * src1_wdata      = params->wdata;
+        const size_t src1_col_stride = ggml_row_size(PARAM_TYPE, ne10);
+
+        // If there are more than three rows in src1, use gemm; otherwise, use gemv.
+        if (ne11 > 3) {
+            gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
+                    (float *) ((char *) dst->data) + src0_start, ne01,
+                    (const char *) src0->data + src0_start * nb01,
+                    (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
+        }
+        for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
+            gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
+                    (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
+                    (const char *) src0->data + src0_start * nb01,
+                    (const char *) src1_wdata + (src1_col_stride * iter), 1,
+                    src0_end - src0_start);
+        }
+    }
+
     void forward_mul_mat(ggml_compute_params * params, ggml_tensor * op) {
         const ggml_tensor * src0 = op->src[0];
         const ggml_tensor * src1 = op->src[1];
@@ -1643,31 +1669,41 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
             from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
         }
 
-        ggml_barrier(params->threadpool);
+        // disable for NUMA
+        const bool disable_chunking = ggml_is_numa();
 
-        const void * src1_wdata      = params->wdata;
-        const size_t src1_col_stride = ggml_row_size(PARAM_TYPE, ne10);
-        int64_t      src0_start      = (ith * ne01) / nth;
-        int64_t      src0_end        = ((ith + 1) * ne01) / nth;
-        src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
-        src0_end   = (src0_end   % NB_COLS) ? src0_end   + NB_COLS - (src0_end   % NB_COLS) : src0_end;
-        if (src0_start >= src0_end) {
-            return;
+        // 4x chunks per thread
+        int64_t nr = ggml_nrows(op->src[0]);
+        int nth_scaled = nth * 4;
+        int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
+        int64_t nchunk     = (nr + chunk_size - 1) / chunk_size;
+
+        if (nth == 1 || nchunk < nth || disable_chunking) {
+            nchunk = nth;
         }
 
-        // If there are more than three rows in src1, use gemm; otherwise, use gemv.
-        if (ne11 > 3) {
-            gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
-                    (float *) ((char *) dst->data) + src0_start, ne01,
-                    (const char *) src0->data + src0_start * nb01,
-                    (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
+        if (ith == 0) {
+            // Every thread starts at ith, so the first unprocessed chunk is nth.  This save a bit of coordination right at the start.
+            ggml_threadpool_chunk_set(params->threadpool, nth);
         }
-        for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
-            gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
-                    (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
-                    (const char *) src0->data + src0_start * nb01,
-                    (const char *) src1_wdata + (src1_col_stride * iter), 1,
-                    src0_end - src0_start);
+
+        ggml_barrier(params->threadpool);
+
+        // The first chunk comes from our thread_id, the rest will get auto-assigned.
+        int current_chunk = ith;
+
+        while (current_chunk < nchunk) {
+            int64_t src0_start = (current_chunk * ne01) / nchunk;
+            int64_t src0_end   = ((current_chunk + 1) * ne01) / nchunk;
+            src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
+            src0_end   = (src0_end   % NB_COLS) ? src0_end   + NB_COLS - (src0_end   % NB_COLS) : src0_end;
+            if (src0_start >= src0_end) {
+                break;
+            }
+
+            forward_mul_mat_one_chunk(params, dst, src0_start, src0_end);
+
+            current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
         }
     }