]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : fix AMX and add batched support (llama/19925)
authorAdrien Gallouët <redacted>
Thu, 26 Feb 2026 20:39:11 +0000 (21:39 +0100)
committerGeorgi Gerganov <redacted>
Fri, 27 Feb 2026 10:04:54 +0000 (12:04 +0200)
llama-perplexity -hf ggml-org/Qwen3-0.6B-GGUF:Q4_0 -f wikitext-2-raw/wiki.test.raw -c 2048 -b 2048 --chunks 2

before this commit:

```
perplexity: calculating perplexity over 2 chunks, n_ctx=2048, batch_size=2048, n_seq=1
perplexity: 2.31 seconds per pass - ETA 0.07 minutes
[1]17.3868,[2]22.2199,
Final estimate: PPL = 22.2199 +/- 1.59692

llama_perf_context_print:        load time =     878.56 ms
llama_perf_context_print: prompt eval time =    2037.82 ms /  4096 tokens (    0.50 ms per token,  2009.99 tokens per second)
llama_perf_context_print:        eval time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_perf_context_print:       total time =    6403.17 ms /  4097 tokens
llama_perf_context_print:    graphs reused =          0
llama_memory_breakdown_print: | memory breakdown [MiB] | total   free    self   model   context   compute    unaccounted |
llama_memory_breakdown_print: |   - Host               |                  845 =   318 +     224 +     302                |
llama_memory_breakdown_print: |   - CPU_REPACK         |                  288 =   288 +       0 +       0                |
llama_memory_breakdown_print: |   - AMX                |                   31 =    31 +       0 +       0                |
```

after this commit:

```
perplexity: calculating perplexity over 2 chunks, n_ctx=2048, batch_size=2048, n_seq=1
perplexity: 1.98 seconds per pass - ETA 0.05 minutes
[1]17.2005,[2]21.8220,
Final estimate: PPL = 21.8220 +/- 1.56485

llama_perf_context_print:        load time =     719.23 ms
llama_perf_context_print: prompt eval time =    1676.23 ms /  4096 tokens (    0.41 ms per token,  2443.58 tokens per second)
llama_perf_context_print:        eval time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_perf_context_print:       total time =    4258.74 ms /  4097 tokens
llama_perf_context_print:    graphs reused =          0
llama_memory_breakdown_print: | memory breakdown [MiB] | total   free    self   model   context   compute    unaccounted |
llama_memory_breakdown_print: |   - Host               |                  845 =   318 +     224 +     302                |
llama_memory_breakdown_print: |   - AMX                |                  319 =   319 +       0 +       0                |
```
(no more CPU_REPACK)

after this commit, disabling amx:

```
perplexity: calculating perplexity over 2 chunks, n_ctx=2048, batch_size=2048, n_seq=1
perplexity: 2.34 seconds per pass - ETA 0.07 minutes
[1]17.2005,[2]21.8220,
Final estimate: PPL = 21.8220 +/- 1.56485

llama_perf_context_print:        load time =     841.91 ms
llama_perf_context_print: prompt eval time =    2057.28 ms /  4096 tokens (    0.50 ms per token,  1990.98 tokens per second)
llama_perf_context_print:        eval time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_perf_context_print:       total time =    6454.51 ms /  4097 tokens
llama_perf_context_print:    graphs reused =          0
llama_memory_breakdown_print: | memory breakdown [MiB] | total   free    self   model   context   compute    unaccounted |
llama_memory_breakdown_print: |   - Host               |                  845 =   318 +     224 +     302                |
llama_memory_breakdown_print: |   - CPU_REPACK         |                  319 =   319 +       0 +       0                |
```
=> same perplexity.

Signed-off-by: Adrien Gallouët <redacted>
src/ggml-cpu/amx/amx.cpp
src/ggml-cpu/amx/mmq.cpp

index 895a57137537aa4257955eb9f3833efd48b75d89..9baf3e025e6cc48032a93c4dbea8b17a12f80dce 100644 (file)
@@ -141,27 +141,50 @@ static size_t ggml_backend_amx_buffer_type_get_alignment(ggml_backend_buffer_typ
 namespace ggml::cpu::amx {
 class extra_buffer_type : ggml::cpu::extra_buffer_type {
     bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
-        // handle only 2d gemm for now
-        auto is_contiguous_2d = [](const struct ggml_tensor * t) {
-            return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1;
-        };
-
-        if (op->op == GGML_OP_MUL_MAT && is_contiguous_2d(op->src[0]) &&  // src0 must be contiguous
-            is_contiguous_2d(op->src[1]) &&                               // src1 must be contiguous
-            op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_amx_buffer_type() &&
-            op->src[0]->ne[0] % (TILE_K * 2 * 32) == 0 && // TODO: not sure if correct (https://github.com/ggml-org/llama.cpp/pull/16315)
-            op->ne[0] % (TILE_N * 2) == 0 &&                              // out_features is 32x
-            (qtype_has_amx_kernels(op->src[0]->type) || (op->src[0]->type == GGML_TYPE_F16))) {
-            // src1 must be host buffer
-            if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
+        if (op->op != GGML_OP_MUL_MAT) {
+            return false;
+        }
+        auto * src0 = op->src[0];
+        auto * src1 = op->src[1];
+
+        if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) {
+            return false;
+        }
+        if (!src0->buffer || src0->buffer->buft != ggml_backend_amx_buffer_type()) {
+            return false;
+        }
+        if (src1->buffer && !ggml_backend_buft_is_host(src1->buffer->buft)) {
+            return false;
+        }
+        if (op->ne[0] % (TILE_N * 2)) {
+            return false;
+        }
+        int alignment;
+        switch (src0->type) {
+            case GGML_TYPE_Q4_0:
+            case GGML_TYPE_Q4_1:
+            case GGML_TYPE_Q8_0:
+                alignment = TILE_K;
+                break;
+            case GGML_TYPE_Q4_K:
+            case GGML_TYPE_Q5_K:
+            case GGML_TYPE_Q6_K:
+            case GGML_TYPE_IQ4_XS:
+                alignment = 256; // QK_K
+                break;
+            case GGML_TYPE_F16:
+                alignment = 16;
+                break;
+            default:
                 return false;
-            }
-            // src1 must be float32
-            if (op->src[1]->type == GGML_TYPE_F32) {
-                return true;
-            }
         }
-        return false;
+        if (src0->ne[0] % alignment) {
+            return false;
+        }
+        if (src1->type != GGML_TYPE_F32) {
+            return false;
+        }
+        return true;
     }
 
     ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
index 47c61b88164b8fc141134aaa564298e21a851adb..b5aca76633c6c5a258148b1715390cfe6c714fd5 100644 (file)
@@ -1,4 +1,3 @@
-
 #if defined(__GNUC__)
 #pragma GCC diagnostic ignored "-Wpedantic"
 #pragma GCC diagnostic ignored "-Wunused-local-typedefs"
@@ -202,35 +201,27 @@ struct tile_config_t{
 //    advanced-matrix-extensions-intrinsics-functions.html
 //
 
-#define TC_CONFIG_TILE(i, r, cb) tc.rows[i] = r; tc.colsb[i] = cb
-void ggml_tile_config_init(void) {
-    static thread_local bool is_first_time = true;
+inline void ggml_tile_config_init(void) {
+    static thread_local bool done = false;
 
-    if (!is_first_time) {
+    if (done) {
         return;
     }
 
-    static thread_local tile_config_t tc;
-    tile_config_t current_tc;
-    _tile_storeconfig(&current_tc);
-
-    // load only when config changes
-    if (tc.palette_id == 0 || (memcmp(&current_tc.colsb, &tc.colsb, sizeof(uint16_t) * 8) != 0 &&
-                               memcmp(&current_tc.rows, &tc.rows, sizeof(uint8_t) * 8) != 0)) {
-        tc.palette_id = 1;
-        tc.start_row = 0;
-        TC_CONFIG_TILE(TMM0, 8, 64);
-        TC_CONFIG_TILE(TMM1, 8, 64);
-        TC_CONFIG_TILE(TMM2, 16, 32);
-        TC_CONFIG_TILE(TMM3, 16, 32);
-        TC_CONFIG_TILE(TMM4, 16, 64);
-        TC_CONFIG_TILE(TMM5, 16, 64);
-        TC_CONFIG_TILE(TMM6, 16, 64);
-        TC_CONFIG_TILE(TMM7, 16, 64);
-        _tile_loadconfig(&tc);
-    }
-
-    is_first_time = false;
+    alignas(64) tile_config_t tc = {};
+    tc.palette_id = 1;
+    tc.start_row = 0;
+    tc.rows[0] = 8;   tc.colsb[0] = 64;
+    tc.rows[1] = 8;   tc.colsb[1] = 64;
+    tc.rows[2] = 16;  tc.colsb[2] = 32;
+    tc.rows[3] = 16;  tc.colsb[3] = 32;
+    tc.rows[4] = 16;  tc.colsb[4] = 64;
+    tc.rows[5] = 16;  tc.colsb[5] = 64;
+    tc.rows[6] = 16;  tc.colsb[6] = 64;
+    tc.rows[7] = 16;  tc.colsb[7] = 64;
+
+    _tile_loadconfig(&tc);
+    done = true;
 }
 
 // we need an extra 16 * 4B (TILE_N * int32_t) for each NB/KB block for compensation.
@@ -268,33 +259,6 @@ int get_row_size(int K) {
     return row_size;
 }
 
-// vectorized dtype conversion
-inline float FP16_TO_FP32(ggml_half val) {
-    __m256i v = _mm256_setr_epi16(
-        val, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
-    __m512 o = _mm512_cvtph_ps(v);
-    return _mm512_cvtss_f32(o);
-}
-
-inline __m512 FP16_TO_FP32_VEC(ggml_half val) {
-    __m256i v = _mm256_set1_epi16(val);
-    return _mm512_cvtph_ps(v);
-}
-
-// horizontal reduce
-inline float _mm512_reduce_max_ps(const __m512 x) {
-    __m512 v = x;
-    __m512 v1 = _mm512_shuffle_f32x4(v, v, 0x4E);
-    v = _mm512_max_ps(v, v1);
-    v1 = _mm512_shuffle_f32x4(v, v, 0xB1);
-    v = _mm512_max_ps(v, v1);
-    v1 = _mm512_shuffle_ps(v, v, 0x4E);
-    v = _mm512_max_ps(v, v1);
-    v1 = _mm512_shuffle_ps(v, v, 0xB1);
-    v = _mm512_max_ps(v, v1);
-    return _mm512_cvtss_f32(v);
-}
-
 // transpose utils
 #define SHUFFLE_EPI32(a, b, mask) \
     _mm256_castps_si256(_mm256_shuffle_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b), mask))
@@ -1370,9 +1334,9 @@ struct tinygemm_kernel_avx<float, ggml_fp16_t, float, BLOCK_M, BLOCK_N, BLOCK_K>
 
 #define LAUNCH_TINYGEMM_KERNEL_AVX(MB_SIZE, NB_SIZE)                                \
     tinygemm_kernel_avx<float, type, float, MB_SIZE, NB_SIZE, blck_size>::apply(    \
-        K, (const float *)src1->data + mb_start * K,                                \
-        (const type *)src0->data + nb_start * K,                                    \
-        (float *)dst->data + mb_start * ldc + nb_start, ldc);
+        K, (const float *)src1->data + src1_offset + mb_start * K,                  \
+        (const type *)src0->data + src0_offset + nb_start * K,                      \
+        (float *)dst->data + dst_offset + mb_start * ldc + nb_start, ldc)
 
 
 // re-organize in the format {NB, KB, TILE_SIZE}:
@@ -2019,11 +1983,11 @@ struct tinygemm_kernel_vnni<block_q8_K, block_iq4_xs, float, BLOCK_M, BLOCK_N, B
     }
 };
 
-#define LAUNCH_TINYGEMM_KERNEL_VNNI(NB_SIZE)                                         \
-    tinygemm_kernel_vnni<vec_dot_type, type, float, 1, NB_SIZE, blck_size>::apply(   \
-        KB, (const char *)wdata + 0 * row_size_A,                                    \
-        (const char *)src0->data + PACKED_INDEX(nb * kTilesN, 0, KB, TILE_SIZE),     \
-        (float *) dst->data + 0 * N + nb_start, ldc)
+#define LAUNCH_TINYGEMM_KERNEL_VNNI(NB_SIZE)                                                   \
+    tinygemm_kernel_vnni<vec_dot_type, type, float, 1, NB_SIZE, blck_size>::apply(             \
+        KB, wdata_batch,                                                                       \
+        (const char *)src0->data + src0_offset + PACKED_INDEX(nb * kTilesN, 0, KB, TILE_SIZE), \
+        (float *) dst->data + dst_offset + nb_start, ldc)
 
 template <typename TA, typename TB, typename TC, int BLOCK_K,
           typename std::enable_if<!is_type_qkk<TB>::value, int>::type = 0>
@@ -2079,7 +2043,7 @@ void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const v
         _tile_stored(TMM5, Tile5(C_pre), TILE_N * sizeof(int32_t));
 
         if (need_unpack) {
-            unpack_B<TB>(Tile1, B_blk0);
+            unpack_B<TB>(Tile1, B_blk1);
             _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK);
         } else {
             _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK);
@@ -2336,6 +2300,13 @@ void ggml_backend_amx_convert_weight(struct ggml_tensor * tensor, const void * d
     });
 }
 
+// ne2 is passed explicitly to help compiler optimize repeated calls
+inline int64_t ggml_batch_offset(const ggml_tensor * t, int64_t batch_idx, int64_t ne2) {
+    const int64_t i2 = batch_idx % ne2;
+    const int64_t i3 = batch_idx / ne2;
+    return i3 * t->nb[3] + i2 * t->nb[2];
+}
+
 size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst) {
     struct ggml_tensor * src0 = dst->src[0];
 
@@ -2348,12 +2319,13 @@ size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst) {
 
     const int M = dst->ne[1];
     const int K = src0->ne[0];
+    const int64_t n_batch = dst->ne[2] * dst->ne[3];
 
     size_t desired_wsize = 0;
 
     GGML_DISPATCH_QTYPES(TYPE, [&] {
         const size_t row_size_A = K / blck_size * sizeof(vec_dot_type);
-        desired_wsize = M * row_size_A;
+        desired_wsize = n_batch * M * row_size_A;
     });
 
     return desired_wsize;
@@ -2365,7 +2337,7 @@ size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst) {
 // src1: input  in shape of {M, K}, float32
 // dst:  output in shape of {M, N}, float32
 //
-// the function performs: dst = src1 @ src0.T
+// the function performs: dst = src1 @ src0.T for each batch
 //
 void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_tensor * dst) {
     struct ggml_tensor * src0 = dst->src[0];
@@ -2382,17 +2354,26 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
     const int K = src0->ne[0];
     const int ldc = dst->nb[1] / dst->nb[0];
 
+    const int64_t ne2 = dst->ne[2];
+    const int64_t n_batch = ne2 * dst->ne[3];
+
     if (is_floating_type) {
         constexpr int BLOCK_M = 4;
         constexpr int BLOCK_N = 6;
         const int MB = div_up(M, BLOCK_M);
         const int NB = div_up(N, BLOCK_N);
 
-        parallel_for_ggml(params, MB * NB, [&](int begin, int end) {
+        parallel_for_ggml(params, n_batch * MB * NB, [&](int begin, int end) {
             GGML_DISPATCH_FLOATING_TYPES(TYPE, [&] {
                 for (int i = begin; i < end; ++i) {
-                    int mb = i / NB;
-                    int nb = i % NB;
+                    int batch_idx = i / (MB * NB);
+                    int remaining = i % (MB * NB);
+                    int mb = remaining / NB;
+                    int nb = remaining % NB;
+
+                    int64_t src0_offset = ggml_batch_offset(src0, batch_idx, ne2);
+                    int64_t src1_offset = ggml_batch_offset(src1, batch_idx, ne2);
+                    int64_t dst_offset  = ggml_batch_offset(dst,  batch_idx, ne2);
 
                     int mb_start = mb * BLOCK_M;
                     int mb_size = std::min(BLOCK_M, M - mb_start);
@@ -2424,10 +2405,10 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
     void * wdata = params->wdata;
 
     //TODO: performance improvement: merge quant A
   if (params->ith == 0) {
// if (params->ith == 0) {
         GGML_DISPATCH_QTYPES(TYPE, [&] {
             const size_t row_size_A = K / blck_size * sizeof(vec_dot_type);
-            const size_t desired_wsize = M * row_size_A;
+            const size_t desired_wsize = n_batch * M * row_size_A;
             if (params->wsize < desired_wsize) {
                 GGML_ABORT("insufficient work space size");
             }
@@ -2436,12 +2417,19 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
             // Q4_K, Q5_K, Q6_K, IQ4_XS handles 8 TILE_K per blck_size
             GGML_ASSERT(TILE_K == blck_size || TILE_K * 8 == blck_size);
 
-            const float * A_data = static_cast<const float *>(src1->data);
-            for (int m = 0; m < M; ++m) {
-                from_float<vec_dot_type>(A_data + m * K, (char *)wdata + m * row_size_A, K);
-            }
+            parallel_for_ggml(params, n_batch, [&](int begin, int end) {
+                for (int batch_idx = begin; batch_idx < end; ++batch_idx) {
+                    int64_t src1_offset = ggml_batch_offset(src1, batch_idx, ne2);
+                    const float * A_data = (const float *)((const char *)src1->data + src1_offset);
+                    char * wdata_batch = (char *)wdata + batch_idx * M * row_size_A;
+
+                    for (int m = 0; m < M; ++m) {
+                        from_float<vec_dot_type>(A_data + m * K, wdata_batch + m * row_size_A, K);
+                    }
+                }
+            });
         });
   }
// }
 
     ggml_barrier(params->threadpool);
 
@@ -2451,13 +2439,19 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
         constexpr int BLOCK_N = TILE_N * kTilesN;
         const int NB = div_up(N, BLOCK_N);
 
-        parallel_for_ggml(params, NB, [&](int begin, int end) {
+        parallel_for_ggml(params, n_batch * NB, [&](int begin, int end) {
             GGML_DISPATCH_QTYPES(TYPE, [&] {
                 const int KB = K / blck_size;
                 const int TILE_SIZE = get_tile_size<type>();
                 const int row_size_A = KB * sizeof(vec_dot_type);
                 for (int i = begin; i < end; ++i) {
-                    int nb = i;
+                    int batch_idx = i / NB;
+                    int nb = i % NB;
+
+                    int64_t src0_offset = ggml_batch_offset(src0, batch_idx, ne2);
+                    int64_t dst_offset  = ggml_batch_offset(dst,  batch_idx, ne2);
+                    const char * wdata_batch = (const char *)wdata + batch_idx * row_size_A;
+
                     int nb_start = nb * BLOCK_N;
                     int nb_size = std::min(BLOCK_N, N - nb_start); // 32, 64, 96
 
@@ -2481,7 +2475,7 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
     const int MB = div_up(M, BLOCK_M);
     const int NB = div_up(N, BLOCK_N);
 
-    parallel_for_ggml(params, MB * NB, [&](int begin, int end) {
+    parallel_for_ggml(params, n_batch * MB * NB, [&](int begin, int end) {
         // init tile config for each thread
         ggml_tile_config_init();
 
@@ -2491,8 +2485,14 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
             const int row_size_A = KB * sizeof(vec_dot_type);
 
             for (int i = begin; i < end; ++i) {
-                int mb = i / NB;
-                int nb = i % NB;
+                int batch_idx = i / (MB * NB);
+                int remaining = i % (MB * NB);
+                int mb = remaining / NB;
+                int nb = remaining % NB;
+
+                int64_t src0_offset = ggml_batch_offset(src0, batch_idx, ne2);
+                int64_t dst_offset  = ggml_batch_offset(dst,  batch_idx, ne2);
+                const char * wdata_batch = (const char *)wdata + batch_idx * M * row_size_A;
 
                 int mb_start = mb * BLOCK_M;
                 int mb_size = std::min(BLOCK_M, M - mb_start);
@@ -2501,9 +2501,9 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
 
                 tinygemm_kernel_amx<vec_dot_type, type, float, blck_size>(
                     mb_size, nb_size, KB,
-                    (const char *)wdata + mb_start * row_size_A,
-                    (const char *)src0->data + PACKED_INDEX(nb * 2, 0, KB, TILE_SIZE),
-                    (float *) dst->data + mb_start * N + nb_start, ldc);
+                    wdata_batch + mb_start * row_size_A,
+                    (const char *)src0->data + src0_offset + PACKED_INDEX(nb * 2, 0, KB, TILE_SIZE),
+                    (float *) dst->data + dst_offset + mb_start * N + nb_start, ldc);
             }
         });
     });