]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : fix MUL_MAT_ID repack with Q8_K (#12544)
authorGeorgi Gerganov <redacted>
Wed, 26 Mar 2025 11:02:00 +0000 (13:02 +0200)
committerGitHub <redacted>
Wed, 26 Mar 2025 11:02:00 +0000 (13:02 +0200)
* ggml : fix MUL_MAT_ID repack with Q8_K

ggml-ci

* ggml : improve repack templates

ggml-ci

ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp

index e852c8253b44e9bf130e5d6cd1bcf98e653520e1..74a31abb2d6fc5236875770b11cfb1adf8747ccf 100644 (file)
@@ -250,7 +250,7 @@ static inline __m256i mul_sum_i8_pairs_int32x8(const __m256i x, const __m256i y)
 
 static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
 
-static void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
+static void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
     assert(QK8_0 == 32);
     assert(k % QK8_0 == 0);
     const int nb = k / QK8_0;
@@ -344,7 +344,7 @@ static void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRIC
 #endif
 }
 
-static void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
+static void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
     assert(QK8_0 == 32);
     assert(k % QK8_0 == 0);
     const int nb = k / QK8_0;
@@ -559,7 +559,7 @@ static void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRIC
 #endif
 }
 
-static void quantize_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
+static void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
     assert(QK_K == 256);
     assert(k % QK_K == 0);
     const int nb = k / QK_K;
@@ -811,7 +811,7 @@ static void quantize_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRIC
         // i.e first four bsums from the first super block, followed by first four bsums from second super block and so on
         for (int j = 0; j < QK_K * 4; j++) {
             int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
-            int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
+            int src_id     = (j % (4 * blck_size_interleave)) / blck_size_interleave;
             src_offset += (j % blck_size_interleave);
             int index = (((j & 31) >> 3) << 2) + ((j >> 8) << 4) + ((j >> 6) & 3);
 
@@ -823,26 +823,25 @@ static void quantize_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRIC
 #endif
 }
 
-static void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) {
+template <int64_t INTER_SIZE, ggml_type PARAM_TYPE>
+void ggml_quantize_mat_t(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row);
+
+template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
     assert(nrow == 4);
     UNUSED(nrow);
-    if (blck_size_interleave == 4) {
-        quantize_q8_0_4x4(x, vy, n_per_row);
-    } else if (blck_size_interleave == 8) {
-        quantize_q8_0_4x8(x, vy, n_per_row);
-    } else {
-        assert(false);
-    }
+    ggml_quantize_mat_q8_0_4x4(x, vy, n_per_row);
 }
 
-static void quantize_mat_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) {
+template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
     assert(nrow == 4);
     UNUSED(nrow);
-    if (blck_size_interleave == 8) {
-        quantize_q8_K_4x8(x, vy, n_per_row);
-    } else {
-        assert(false);
-    }
+    ggml_quantize_mat_q8_0_4x8(x, vy, n_per_row);
+}
+
+template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
+    assert(nrow == 4);
+    UNUSED(nrow);
+    ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row);
 }
 
 static void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@@ -5276,52 +5275,50 @@ template <> int repack<block_iq4_nl, 4, 4>(struct ggml_tensor * t, const void *
 //}
 
 // gemv
-template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
+template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE>
 void gemv(int, float *, size_t, const void *, const void *, int, int);
 
-template <> void gemv<block_q4_0, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+template <> void gemv<block_q4_0, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
     ggml_gemv_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
 }
 
-template <> void gemv<block_q4_0, 8, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+template <> void gemv<block_q4_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
     ggml_gemv_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
 }
 
-template <> void gemv<block_q4_0, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+template <> void gemv<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
     ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
 }
 
-template <> void gemv<block_q4_K, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+template <> void gemv<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
     ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
 }
 
-template <>
-void gemv<block_iq4_nl, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+template <> void gemv<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
     ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
 }
 
 // gemm
-template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
+template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE>
 void gemm(int, float *, size_t, const void *, const void *, int, int);
 
-template <> void gemm<block_q4_0, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+template <> void gemm<block_q4_0, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
     ggml_gemm_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
 }
 
-template <> void gemm<block_q4_0, 8, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+template <> void gemm<block_q4_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
     ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
 }
 
-template <> void gemm<block_q4_0, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+template <> void gemm<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
     ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
 }
 
-template <> void gemm<block_q4_K, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+template <> void gemm<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
     ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
 }
 
-template <>
-void gemm<block_iq4_nl, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+template <> void gemm<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
     ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
 }
 
@@ -5335,32 +5332,32 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
     bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
         // not realy a GGML_TYPE_Q8_0 but same size.
         switch (op->op) {
-        case GGML_OP_MUL_MAT:
-            size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
-            return true;
-        case GGML_OP_MUL_MAT_ID:
-            size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
-            size = GGML_PAD(size, sizeof(int64_t));  // + padding for next bloc.
-            size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2];
-            return true;
-        default:
-            // GGML_ABORT("fatal error");
-            break;
+            case GGML_OP_MUL_MAT:
+                size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
+                return true;
+            case GGML_OP_MUL_MAT_ID:
+                size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
+                size = GGML_PAD(size, sizeof(int64_t));  // + padding for next bloc.
+                size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2];
+                return true;
+            default:
+                // GGML_ABORT("fatal error");
+                break;
         }
         return false;
     }
 
     bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override {
         switch (op->op) {
-        case GGML_OP_MUL_MAT:
-            forward_mul_mat(params, op);
-            return true;
-        case GGML_OP_MUL_MAT_ID:
-            forward_mul_mat_id(params, op);
-            return true;
-        default:
-            // GGML_ABORT("fatal error");
-            break;
+            case GGML_OP_MUL_MAT:
+                forward_mul_mat(params, op);
+                return true;
+            case GGML_OP_MUL_MAT_ID:
+                forward_mul_mat_id(params, op);
+                return true;
+            default:
+                // GGML_ABORT("fatal error");
+                break;
         }
         return false;
     }
@@ -5399,17 +5396,10 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
         const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float;
 
         int64_t i11_processed = 0;
-        if(PARAM_TYPE == GGML_TYPE_Q8_K) {
-            for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
-                quantize_mat_q8_K((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10,
-                              INTER_SIZE);
-            }
-        } else {
-            for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
-                quantize_mat_q8_0((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10,
-                                INTER_SIZE);
-            }
+        for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
+            ggml_quantize_mat_t<INTER_SIZE, PARAM_TYPE>((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10);
         }
+
         i11_processed = ne11 - ne11 % 4;
         for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
             from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
@@ -5422,22 +5412,24 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
         int64_t      src0_start      = (ith * ne01) / nth;
         int64_t      src0_end        = ((ith + 1) * ne01) / nth;
         src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
-        src0_end   = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
+        src0_end   = (src0_end   % NB_COLS) ? src0_end   + NB_COLS - (src0_end   % NB_COLS) : src0_end;
         if (src0_start >= src0_end) {
             return;
         }
 
         // If there are more than three rows in src1, use gemm; otherwise, use gemv.
         if (ne11 > 3) {
-            gemm<BLOC_TYPE, INTER_SIZE, NB_COLS>(ne00, (float *) ((char *) dst->data) + src0_start, ne01,
-                                                 (const char *) src0->data + src0_start * nb01,
-                                                 (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
+            gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
+                    (float *) ((char *) dst->data) + src0_start, ne01,
+                    (const char *) src0->data + src0_start * nb01,
+                    (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
         }
         for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
-            gemv<BLOC_TYPE, INTER_SIZE, NB_COLS>(ne00, (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
-                                                 (const char *) src0->data + src0_start * nb01,
-                                                 (const char *) src1_wdata + (src1_col_stride * iter), 1,
-                                                 src0_end - src0_start);
+            gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
+                    (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
+                    (const char *) src0->data + src0_start * nb01,
+                    (const char *) src1_wdata + (src1_col_stride * iter), 1,
+                    src0_end - src0_start);
         }
     }
 
@@ -5452,7 +5444,7 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
         const int ith = params->ith;
         const int nth = params->nth;
 
-        const ggml_from_float_t from_float = ggml_get_type_traits_cpu(GGML_TYPE_Q8_0)->from_float;
+        const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float;
 
         // we don't support permuted src0 or src1
         GGML_ASSERT(nb00 == ggml_type_size(src0->type));
@@ -5474,7 +5466,7 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
         const int n_ids = ids->ne[0]; // n_expert_used
         const int n_as  = ne02;       // n_expert
 
-        const size_t nbw1 = ggml_row_size(GGML_TYPE_Q8_0, ne10);
+        const size_t nbw1 = ggml_row_size(PARAM_TYPE, ne10);
         const size_t nbw2 = nbw1*ne11;
         const size_t nbw3 = nbw2*ne12;
 
@@ -5486,12 +5478,13 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
         GGML_ASSERT(params->wsize >= (GGML_PAD(nbw3, sizeof(int64_t)) + n_as * sizeof(int64_t) +
                                       n_as * ne12 * sizeof(mmid_row_mapping)));
 
-        auto                      wdata             = (char *) params->wdata;
-        auto                      wdata_src1_end    = (char *) wdata + GGML_PAD(nbw3, sizeof(int64_t));
-        int64_t *                 matrix_row_counts = (int64_t *) (wdata_src1_end);                      // [n_as]
+        auto * wdata             = (char *)     params->wdata;
+        auto * wdata_src1_end    = (char *)     wdata + GGML_PAD(nbw3, sizeof(int64_t));
+        auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
+
         struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as);  // [n_as][ne12]
 
-        // src1: float32 => block_q8_0
+        // src1: float32 => param type
         for (int64_t i12 = 0; i12 < ne12; ++i12) {
             for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
                 from_float((float *)((char *) src1->data + i12 * nb12 + i11 * nb11),
@@ -5530,34 +5523,37 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
                 continue;
             }
 
-            auto src0_cur = (const char *) src0->data + cur_a*nb02;
+            const auto * src0_cur = (const char *) src0->data + cur_a*nb02;
 
             //const int64_t nr0 = ne01; // src0 rows
             const int64_t nr1 = cne1; // src1 rows
 
             int64_t src0_cur_start = (ith * ne01) / nth;
             int64_t src0_cur_end   = ((ith + 1) * ne01) / nth;
-            src0_cur_start =
-                (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
-            src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
 
-            if (src0_cur_start >= src0_cur_end) return;
+            src0_cur_start = (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
+            src0_cur_end   = (src0_cur_end   % NB_COLS) ? src0_cur_end   + NB_COLS - (src0_cur_end   % NB_COLS) : src0_cur_end;
+
+            if (src0_cur_start >= src0_cur_end) {
+                return;
+            }
 
             for (int ir1 = 0; ir1 < nr1; ir1++) {
                 struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
-                const int id       = row_mapping.i1; // selected expert index
 
-                const int64_t  i11 = id % ne11;
-                const int64_t  i12 = row_mapping.i2; // row index in src1
+                const int id = row_mapping.i1; // selected expert index
+
+                const int64_t i11 = id % ne11;
+                const int64_t i12 = row_mapping.i2; // row index in src1
 
-                const int64_t  i1 = id;  // selected expert index
-                const int64_t  i2 = i12; // row
+                const int64_t i1 = id;  // selected expert index
+                const int64_t i2 = i12; // row
 
-                auto src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2);
+                const auto * src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2);
 
-                gemv<BLOC_TYPE, INTER_SIZE, NB_COLS>(
-                        ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start,
-                        ne01,                    src0_cur + src0_cur_start * nb01,
+                gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
+                        (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01,
+                        src0_cur + src0_cur_start * nb01,
                         src1_col, 1, src0_cur_end - src0_cur_start);
             }
         }
@@ -5578,7 +5574,7 @@ static const tensor_traits<block_q4_0, 8, 8, GGML_TYPE_Q8_0> q4_0_8x8_q8_0;
 static const tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
 
 // instance for IQ4
-static const tensor_traits<block_iq4_nl, 4, 4, GGML_TYPE_IQ4_NL> iq4_nl_4x4_q8_0;
+static const tensor_traits<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0> iq4_nl_4x4_q8_0;
 
 }  // namespace ggml::cpu::aarch64