]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml-cpu: FA add GEMM microkernel (#19422)
authorAman Gupta <redacted>
Sun, 15 Feb 2026 05:39:24 +0000 (11:09 +0530)
committerGitHub <redacted>
Sun, 15 Feb 2026 05:39:24 +0000 (11:09 +0530)
* ggml-cpu: FA add GEMM microkernel

* add guard for sizeless vector types

* fix case where DV % GGML_F32_EPR !=0

* move memset out of the loop

* move another memset out of the loop

* use RM=4 for arm

* simd_gemm: convert everything to int

* convert everything to size_t to avoid warnings

* fixup

* add pragma for ignoring aggressive loop optimizations

ggml/src/ggml-cpu/common.h
ggml/src/ggml-cpu/ggml-cpu.c
ggml/src/ggml-cpu/ops.cpp
ggml/src/ggml-cpu/simd-gemm.h [new file with mode: 0644]

index 1057b5bb152dfd51a0f143e180cf2f3c7037fbf6..abbadc359c5addac941f8add92674061c8b790bf 100644 (file)
@@ -6,8 +6,8 @@
 #include "ggml-impl.h"
 #include "simd-mappings.h"
 
-#define GGML_FA_TILE_Q  32
-#define GGML_FA_TILE_KV 16
+#define GGML_FA_TILE_Q  64
+#define GGML_FA_TILE_KV 64
 
 #ifdef __cplusplus
 
index e048d5e5e77d2117df40ebf2413bb0f04e5c6fb0..64eb01a4e18c90f884430830fb7c0f741ef9bbcd 100644 (file)
@@ -2874,8 +2874,8 @@ struct ggml_cplan ggml_graph_plan(
                         const int64_t DV = node->src[2]->ne[0];
 
                         // Tiled flash attention scratch (tile sizes defined in common.h)
-                        // Per-thread: Q_q + KQ + mask + VKQ32 + V32 + padding
-                        size_t prefill  = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV)*n_tasks;
+                        // Per-thread: Q_q + KQ + mask + VKQ32 + V32 + K_f32 + padding
+                        size_t prefill  = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV + GGML_FA_TILE_KV*DK)*n_tasks;
 
                         // Decode path: n_kv_chunks = n_tasks (one chunk per thread)
                         // Per-thread: VKQ accmulator (DV), partial M, partial S + intra-thread scratch for V, Q and VKQ
index 4352e132807b5d88d1d6d550f08e50ffe5198388..b7a70e06f1d032ae2c31ae13635e4dc04e4f5583 100644 (file)
@@ -3,6 +3,7 @@
 #include "ggml-cpu.h"
 #include "ggml-impl.h"
 #include "binary-ops.h"
+#include "simd-gemm.h"
 #include "ggml.h"
 #include "unary-ops.h"
 #include "vec.h"
@@ -8389,10 +8390,6 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
     GGML_ASSERT(k->type == v->type);
     const ggml_type kv_type = k->type;
 
-    const auto * kv_type_traits_cpu = ggml_get_type_traits_cpu(kv_type);
-    const ggml_from_float_t kv_from_float = kv_type_traits_cpu->from_float;
-    const ggml_vec_dot_t    kv_vec_dot    = kv_type_traits_cpu->vec_dot;
-    const size_t kv_type_size = ggml_type_size(kv_type);
 
     // broadcast factors
     const int64_t rk2 = neq2/nek2;
@@ -8424,8 +8421,6 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
     static constexpr int Q_TILE_SZ  = ggml_fa_tile_config::Q;
     static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV;
 
-    GGML_ASSERT(nek1 % KV_TILE_SZ == 0 && "KV sequence length must be divisible by KV_TILE_SZ");
-
     int ir = ir0;
     while (ir < ir1) {
         // q indices for the start of this tile
@@ -8452,18 +8447,20 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
         }
 
         // Per-thread scratch layout:
-        // Q_q:    Q_TILE_SZ * DK (converted Q tile in KV type)
+        // Q_q:    Q_TILE_SZ * DK (converted Q tile — F32 for GEMM, KV type for scalar)
         // KQ:     Q_TILE_SZ * KV_TILE_SZ (attention scores in float)
         // mask:   Q_TILE_SZ * KV_TILE_SZ (mask in float)
         // VKQ32:  Q_TILE_SZ * DV (FP32 output accumulator)
-        // V32:    KV_TILE_SZ * DV (F32 buffer for V tile - used for f166 conversion)
-        float * base  = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + CACHE_LINE_SIZE_F32);
+        // V32:    KV_TILE_SZ * DV (F32 buffer for V tile)
+        // K_f32:  KV_TILE_SZ * DK (F32 buffer for K tile — GEMM path)
+        float * base  = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + KV_TILE_SZ*DK + CACHE_LINE_SIZE_F32);
 
         void  * Q_q    = base;
         float * KQ     = (float *)((char *)base + Q_TILE_SZ * DK * sizeof(float));
         float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ;
         float * VKQ32  = mask32 + Q_TILE_SZ * KV_TILE_SZ;
-        float * V32    = VKQ32 + Q_TILE_SZ * DV;  // F32 buffer for V tile
+        float * V32    = VKQ32 + Q_TILE_SZ * DV;
+        float * K_f32  = V32 + KV_TILE_SZ * DV;
 
         memset(VKQ32, 0, Q_TILE_SZ * DV * sizeof(float));
         memset(mask32, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
@@ -8476,28 +8473,38 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
         const int iv3 = iq3 / rv3;
         const int iv2 = iq2 / rv2;
 
-        for (int tq = 0; tq < tile_rows; tq++) {
-            const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
-            kv_from_float(pq, (char *)Q_q + tq * DK * kv_type_size, DK);
-        }
-        // Zero-pad remaining rows
-        for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
-            memset((char *)Q_q + tq * DK * kv_type_size, 0, DK * kv_type_size);
+        {
+            float * Q_f32 = (float *)Q_q;
+            for (int tq = 0; tq < tile_rows; tq++) {
+                const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
+                memcpy(Q_f32 + tq * DK, pq, DK * sizeof(float));
+            }
+            for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
+                memset(Q_f32 + tq * DK, 0, DK * sizeof(float));
+            }
         }
 
+        memset(K_f32, 0, DK * KV_TILE_SZ * sizeof(float));
+        memset(V32,   0, KV_TILE_SZ * DV * sizeof(float));
+
         for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) {
+            const int kv_tile = (int)std::min((int64_t)KV_TILE_SZ, nek1 - ic);
 
             // skip the tile entirely if all the masks are -inf
             if (mask) {
                 bool can_skip = true;
                 for (int tq = 0; tq < tile_rows; tq++) {
                     const ggml_fp16_t * mp_row = (const ggml_fp16_t *)((const char *) mask->data + (iq1 + tq)*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]);
-                    for (int tk = 0; tk < KV_TILE_SZ; tk++) {
+                    for (int tk = 0; tk < kv_tile; tk++) {
                         mask32[tq * KV_TILE_SZ + tk] = slope * GGML_CPU_FP16_TO_FP32(mp_row[ic + tk]);
                         if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) {
                             can_skip = false;
                         }
                     }
+                    // Pad remaining mask entries with -inf
+                    for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) {
+                        mask32[tq * KV_TILE_SZ + tk] = -INFINITY;
+                    }
                 }
 
                 if (can_skip) {
@@ -8505,13 +8512,32 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
                 }
             }
 
-            for (int tq = 0; tq < Q_TILE_SZ; tq++) {
-                const void * q_row = (const char *)Q_q + tq * DK * kv_type_size;
-                for (int tk = 0; tk < KV_TILE_SZ; tk++) {
-                    const void * k_row = (const char *) k->data + ((ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3);
-                    float s;
-                    kv_vec_dot(DK, &s, 0, k_row, 0, q_row, 0, 1);
-                    KQ[tq * KV_TILE_SZ + tk] = s * scale;
+            // Pack K tile transposed: K_f32[dk][kv] so KV_TILE is contiguous (SIMD dim)
+            // Zero-pad the last tile so the GEMM always operates on KV_TILE_SZ columns
+            for (int tk = 0; tk < kv_tile; tk++) {
+                const char * k_data = (const char *)k->data + (ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3;
+                if (kv_type == GGML_TYPE_F16) {
+                    const ggml_fp16_t * k_f16 = (const ggml_fp16_t *)k_data;
+                    for (int64_t dk = 0; dk < DK; dk++) {
+                        K_f32[dk * KV_TILE_SZ + tk] = GGML_CPU_FP16_TO_FP32(k_f16[dk]);
+                    }
+                } else {
+                    const float * k_f32_src = (const float *)k_data;
+                    for (int64_t dk = 0; dk < DK; dk++) {
+                        K_f32[dk * KV_TILE_SZ + tk] = k_f32_src[dk];
+                    }
+                }
+            }
+            memset(KQ, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
+            simd_gemm(KQ, (const float *)Q_q, K_f32, Q_TILE_SZ, DK, KV_TILE_SZ);
+            ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, scale);
+
+            // Set padded KQ entries to -inf so softmax gives them zero weight
+            if (kv_tile < KV_TILE_SZ) {
+                for (int tq = 0; tq < Q_TILE_SZ; tq++) {
+                    for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) {
+                        KQ[tq * KV_TILE_SZ + tk] = -INFINITY;
+                    }
                 }
             }
 
@@ -8551,33 +8577,22 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
                 S[tq] += ggml_vec_soft_max_f32(KV_TILE_SZ, kq_row, kq_row, Mnew);
             }
 
-            // Convert V tile to F32 first (if F16), then do MAD
-            // On x86, ggml_vec_mad_f16 internall converts F16<->F32 on every load/store, so pre-converting is faster.
-            // TODO: on ARM, native f16 should be faster
-            if (kv_type == GGML_TYPE_F16) {
-                for (int tk = 0; tk < KV_TILE_SZ; tk++) {
-                    const ggml_fp16_t * v_row = (const ggml_fp16_t *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
-                    ggml_fp16_to_fp32_row(v_row, V32 + tk * DV, DV);
-                }
-                for (int tq = 0; tq < Q_TILE_SZ; tq++) {
-                    if (skip[tq]) continue;
-                    float * vkq_row = VKQ32 + tq * DV;
-                    for (int tk = 0; tk < KV_TILE_SZ; tk++) {
-                        const float p = KQ[tq * KV_TILE_SZ + tk];
-                        ggml_vec_mad_f32(DV, vkq_row, V32 + tk * DV, p);
-                    }
+            // V accumulation: VKQ32 += softmax(KQ) * V
+            // Pack V tile to contiguous F32, zero-padded
+            for (int tk = 0; tk < kv_tile; tk++) {
+                const char * v_data = (const char *)v->data + (ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3;
+                if (kv_type == GGML_TYPE_F16) {
+                    ggml_fp16_to_fp32_row((const ggml_fp16_t *)v_data, V32 + tk * DV, DV);
+                } else {
+                    memcpy(V32 + tk * DV, v_data, DV * sizeof(float));
                 }
-            } else {
-                for (int tq = 0; tq < Q_TILE_SZ; tq++) {
-                    if (skip[tq]) continue;
-                    float * vkq_row = VKQ32 + tq * DV;
-                    for (int tk = 0; tk < KV_TILE_SZ; tk++) {
-                        const float p = KQ[tq * KV_TILE_SZ + tk];
-                        const float * v_row = (const float *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
-                        ggml_vec_mad_f32(DV, vkq_row, v_row, p);
-                    }
+            }
+            for (int tq = 0; tq < Q_TILE_SZ; tq++) {
+                if (skip[tq]) {
+                    memset(KQ + tq * KV_TILE_SZ, 0, KV_TILE_SZ * sizeof(float));
                 }
             }
+            simd_gemm(VKQ32, KQ, V32, Q_TILE_SZ, KV_TILE_SZ, DV);
         }
 
         // sinks (apply only to valid rows in the tile)
@@ -8794,15 +8809,15 @@ static void ggml_compute_forward_flash_attn_ext_f16(
 
         const int64_t dr = (nr + nchunk - 1) / nchunk;
 
-        static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV;
         static constexpr int64_t Q_TILE_SZ  = ggml_fa_tile_config::Q;
-        const bool use_tiled = !use_ref &&
+        bool use_tiled = !use_ref &&
                                (q->type == GGML_TYPE_F32 &&
                                 kv_is_f32_or_f16 &&
                                 k->type == v->type &&
-                                nek1 % KV_TILE_SZ == 0 &&
                                 neq1 >= Q_TILE_SZ);
-
+#ifdef GGML_SIMD
+        use_tiled &= (DV % GGML_F32_EPR == 0);
+#endif
         int current_chunk = ith;
 
         while (current_chunk < nchunk) {
diff --git a/ggml/src/ggml-cpu/simd-gemm.h b/ggml/src/ggml-cpu/simd-gemm.h
new file mode 100644 (file)
index 0000000..cd98a1b
--- /dev/null
@@ -0,0 +1,139 @@
+#pragma once
+
+// Computes C[M x N] += A[M x K] * B[K x N]
+
+#include "ggml-cpu-impl.h"
+#include "vec.h"
+#include "common.h"
+#include "simd-mappings.h"
+
+// TODO: add support for sizeless vector types
+#if defined(GGML_SIMD) && !defined(__ARM_FEATURE_SVE) && !defined(__riscv_v_intrinsic)
+
+// TODO: untested on avx512
+// These are in units of GGML_F32_EPR
+#if defined(__AVX512F__) || defined (__ARM_NEON__)
+    static constexpr int GEMM_RM = 4;
+    static constexpr int GEMM_RN = 4; // 16+4+1 = 25/32
+#elif defined(__AVX2__) || defined(__AVX__)
+    static constexpr int GEMM_RM = 6;
+    static constexpr int GEMM_RN = 2; // 12+2+1 = 15/16
+#else
+    static constexpr int GEMM_RM = 2;
+    static constexpr int GEMM_RN = 2;
+#endif
+
+#if defined(__GNUC__) && !defined(__clang__)
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Waggressive-loop-optimizations"
+#endif
+
+template <int RM, int RN>
+static inline void simd_gemm_ukernel(
+    float       * GGML_RESTRICT C,
+    const float * GGML_RESTRICT A,
+    const float * GGML_RESTRICT B,
+    int64_t K, int64_t N,
+    int64_t ii, int64_t jj)
+{
+    static constexpr int KN = GGML_F32_EPR;
+
+    GGML_F32_VEC acc[RM][RN];
+    for (int i = 0; i < RM; i++) {
+        for (int r = 0; r < RN; r++) {
+            acc[i][r] = GGML_F32_VEC_LOAD(C + (ii + i) * N + jj + r * KN);
+        }
+    }
+
+    for (int64_t kk = 0; kk < K; kk++) {
+        GGML_F32_VEC Bv[RN];
+        for (int r = 0; r < RN; r++) {
+            Bv[r] = GGML_F32_VEC_LOAD(B + kk * N + jj + r * KN);
+        }
+        for (int i = 0; i < RM; i++) {
+            GGML_F32_VEC p = GGML_F32_VEC_SET1(A[(ii + i) * K + kk]);
+            for (int r = 0; r < RN; r++) {
+                acc[i][r] = GGML_F32_VEC_FMA(acc[i][r], Bv[r], p);
+            }
+        }
+    }
+
+    for (int i = 0; i < RM; i++) {
+        for (int r = 0; r < RN; r++) {
+            GGML_F32_VEC_STORE(C + (ii + i) * N + jj + r * KN, acc[i][r]);
+        }
+    }
+}
+
+// C[M x N] += A[M x K] * B[K x N]
+static void simd_gemm(
+    float       * GGML_RESTRICT C,
+    const float * GGML_RESTRICT A,
+    const float * GGML_RESTRICT B,
+    int64_t M, int64_t K, int64_t N)
+{
+    static constexpr int KN = GGML_F32_EPR;
+
+    int64_t ii = 0;
+    for (; ii + GEMM_RM <= M; ii += GEMM_RM) {
+        int64_t jj = 0;
+        for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) {
+            simd_gemm_ukernel<GEMM_RM, GEMM_RN>(C, A, B, K, N, ii, jj);
+        }
+        for (; jj + KN <= N; jj += KN) {
+            simd_gemm_ukernel<GEMM_RM, 1>(C, A, B, K, N, ii, jj);
+        }
+        for (; jj < N; jj++) {
+            for (int i = 0; i < GEMM_RM; i++) {
+                float a = C[(ii + i) * N + jj];
+                for (int64_t kk = 0; kk < K; kk++) {
+                    a += A[(ii + i) * K + kk] * B[kk * N + jj];
+                }
+                C[(ii + i) * N + jj] = a;
+            }
+        }
+    }
+
+    // Tail rows: one at a time
+    for (; ii < M; ii++) {
+        int64_t jj = 0;
+        for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) {
+            simd_gemm_ukernel<1, GEMM_RN>(C, A, B, K, N, ii, jj);
+        }
+        for (; jj + KN <= N; jj += KN) {
+            simd_gemm_ukernel<1, 1>(C, A, B, K, N, ii, jj);
+        }
+        for (; jj < N; jj++) {
+            float a = C[ii * N + jj];
+            for (int64_t kk = 0; kk < K; kk++) {
+                a += A[ii * K + kk] * B[kk * N + jj];
+            }
+            C[ii * N + jj] = a;
+        }
+    }
+}
+
+#if defined(__GNUC__) && !defined(__clang__)
+#pragma GCC diagnostic pop
+#endif
+
+#else // scalar path
+
+static void simd_gemm(
+    float       * GGML_RESTRICT C,
+    const float * GGML_RESTRICT A,
+    const float * GGML_RESTRICT B,
+    int64_t M, int64_t K, int64_t N)
+{
+    for (int64_t i = 0; i < M; i++) {
+        for (int64_t j = 0; j < N; j++) {
+            float sum = C[i * N + j];
+            for (int64_t kk = 0; kk < K; kk++) {
+                sum += A[i * K + kk] * B[kk * N + j];
+            }
+            C[i * N + j] = sum;
+        }
+    }
+}
+
+#endif // GGML_SIMD