]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml-cpu: support IQ4_NL_4_4 by runtime repack (llama/10541)
authorShupei Fan <redacted>
Thu, 28 Nov 2024 12:52:03 +0000 (20:52 +0800)
committerGeorgi Gerganov <redacted>
Sun, 8 Dec 2024 18:14:35 +0000 (20:14 +0200)
* ggml-cpu: support IQ4_NL_4_4 by runtime repack

* ggml-cpu: add __ARM_FEATURE_DOTPROD guard

ggml/include/ggml-cpu.h
ggml/include/ggml.h
ggml/src/ggml-common.h
ggml/src/ggml-cpu/ggml-cpu-aarch64.c
ggml/src/ggml-cpu/ggml-cpu-aarch64.h
ggml/src/ggml-cpu/ggml-cpu.c
ggml/src/ggml-cpu/ggml-cpu.cpp
ggml/src/ggml.c

index a5358d047a08e6e4682b536129d1464ad5205bcc..e14ea9ea5301f8ef3371dfe0ae389253943880bf 100644 (file)
@@ -91,6 +91,7 @@ extern "C" {
     GGML_BACKEND_API int ggml_cpu_has_neon       (void);
     GGML_BACKEND_API int ggml_cpu_has_arm_fma    (void);
     GGML_BACKEND_API int ggml_cpu_has_fp16_va    (void);
+    GGML_BACKEND_API int ggml_cpu_has_dotprod    (void);
     GGML_BACKEND_API int ggml_cpu_has_matmul_int8(void);
     GGML_BACKEND_API int ggml_cpu_has_sve        (void);
     GGML_BACKEND_API int ggml_cpu_get_sve_cnt    (void);  // sve vector length in bytes
index 9843b09fbe83ed9d49373ce1630e6502aa0ca18d..65cb92c444bb775756eb323c72b14124ddf71278 100644 (file)
@@ -389,6 +389,9 @@ extern "C" {
         GGML_TYPE_Q4_0_8_8 = 33,
         GGML_TYPE_TQ1_0   = 34,
         GGML_TYPE_TQ2_0   = 35,
+        GGML_TYPE_IQ4_NL_4_4 = 36,
+        // GGML_TYPE_IQ4_NL_4_8 = 37,
+        // GGML_TYPE_IQ4_NL_8_8 = 38,
         GGML_TYPE_COUNT,
     };
 
index 050161393456e91d8fddbe35e935bca9deaad980..27253a6c2b3cafaf7d8bba6670f2dee8cb140b67 100644 (file)
@@ -418,6 +418,12 @@ typedef struct {
 } block_iq4_xs;
 static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
 
+typedef struct {
+    ggml_half d[4];        // deltas for 4 iq4_nl blocks
+    uint8_t qs[QK4_NL * 2];// nibbles / quants for 4 iq4_nl blocks
+} block_iq4_nlx4;
+static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(ggml_half) + QK4_NL * 2, "wrong iq4_nlx4 block size/padding");
+
 #endif // GGML_COMMON_DECL
 #endif // GGML_COMMON_DECL
 
index 96a16dfba1f65ffdb3b78ec98fafc35a08d98ddb..ced37887906711784d987c7ccc5e5b10c2f3470e 100644 (file)
@@ -187,6 +187,8 @@ static inline __m256i mul_sum_i8_pairs_int32x8(const __m256i x, const __m256i y)
 }
 #endif
 
+static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
+
 static void quantize_q8_0_4x4(const float * restrict x, void * restrict vy, int64_t k) {
     assert(QK8_0 == 32);
     assert(k % QK8_0 == 0);
@@ -528,7 +530,7 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
     UNUSED(blocklen);
 
 #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
-    if (ggml_cpu_has_neon()) {
+    if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
         const void * b_ptr = vx;
         const void * a_ptr = vy;
         float * res_ptr = s;
@@ -996,6 +998,102 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
     }
 }
 
+void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen = 4;
+
+    assert (n % qk == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
+        const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl);
+        const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+        float * res_ptr = s;
+
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
+
+            float32x4_t sumf = vdupq_n_f32(0);
+            for (int l = 0; l < nb; l++) {
+                uint8x16_t b_0 = vld1q_u8(b_ptr[l].qs + 0);
+                uint8x16_t b_1 = vld1q_u8(b_ptr[l].qs + 16);
+                uint8x16_t b_2 = vld1q_u8(b_ptr[l].qs + 32);
+                uint8x16_t b_3 = vld1q_u8(b_ptr[l].qs + 48);
+
+                int8x16_t b_0_hi = vqtbl1q_s8(kvalues, b_0 >> 4);
+                int8x16_t b_0_lo = vqtbl1q_s8(kvalues, b_0 & 0x0F);
+                int8x16_t b_1_hi = vqtbl1q_s8(kvalues, b_1 >> 4);
+                int8x16_t b_1_lo = vqtbl1q_s8(kvalues, b_1 & 0x0F);
+                int8x16_t b_2_hi = vqtbl1q_s8(kvalues, b_2 >> 4);
+                int8x16_t b_2_lo = vqtbl1q_s8(kvalues, b_2 & 0x0F);
+                int8x16_t b_3_hi = vqtbl1q_s8(kvalues, b_3 >> 4);
+                int8x16_t b_3_lo = vqtbl1q_s8(kvalues, b_3 & 0x0F);
+
+                int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 0);
+                int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16);
+
+                int32x4_t sumi = vdupq_n_s32(0);
+                sumi = vdotq_laneq_s32(sumi, b_0_lo, a_0, 0);
+                sumi = vdotq_laneq_s32(sumi, b_0_hi, a_1, 0);
+                sumi = vdotq_laneq_s32(sumi, b_1_lo, a_0, 1);
+                sumi = vdotq_laneq_s32(sumi, b_1_hi, a_1, 1);
+                sumi = vdotq_laneq_s32(sumi, b_2_lo, a_0, 2);
+                sumi = vdotq_laneq_s32(sumi, b_2_hi, a_1, 2);
+                sumi = vdotq_laneq_s32(sumi, b_3_lo, a_0, 3);
+                sumi = vdotq_laneq_s32(sumi, b_3_hi, a_1, 3);
+
+                float32x4_t a_d = vcvt_f32_f16(vld1_dup_f16((const float16_t *)&a_ptr[l].d));
+                float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
+                float32x4_t d = a_d * b_d;
+
+                sumf = vmlaq_f32(sumf, d, vcvtq_f32_s32(sumi));
+            }
+
+            vst1q_f32(res_ptr + x * 4, sumf);
+        }
+        return;
+    }
+#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
+    {
+        float sumf[4];
+        int sumi;
+
+        const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
+
+            for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
+            for (int l = 0; l < nb; l++) {
+                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                    for (int j = 0; j < ncols_interleaved; j++) {
+                        sumi = 0;
+                        for (int i = 0; i < blocklen; ++i) {
+                            const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
+                            const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
+                            sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2]));
+                        }
+                        sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d);
+                    }
+                }
+            }
+            for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
+        }
+    }
+}
+
 void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
     const int qk = QK8_0;
     const int nb = n / qk;
@@ -1017,7 +1115,7 @@ void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
     UNUSED(blocklen);
 
 #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
-    if (ggml_cpu_has_neon()) {
+    if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
         const void * b_ptr = vx;
         const void * a_ptr = vy;
         float * res_ptr = s;
@@ -3386,6 +3484,117 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
     }
 }
 
+void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen = 4;
+
+    assert (n % qk == 0);
+    assert (nr % 4 == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
+        const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl);
+
+        for (int y = 0; y < nr / 4; y++) {
+            const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+            for (int x = 0; x < nc / ncols_interleaved; x++) {
+                const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
+
+                float32x4_t sumf[4];
+                for (int m = 0; m < 4; m++) {
+                    sumf[m] = vdupq_n_f32(0);
+                }
+
+                for (int l = 0; l < nb; l++) {
+                    float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d));
+                    float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
+
+                    int32x4_t sumi_0 = vdupq_n_s32(0);
+                    int32x4_t sumi_1 = vdupq_n_s32(0);
+                    int32x4_t sumi_2 = vdupq_n_s32(0);
+                    int32x4_t sumi_3 = vdupq_n_s32(0);
+
+                    for (int k = 0; k < 4; k++) {
+                        int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0);
+                        int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64);
+
+                        uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k);
+                        int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4);
+                        int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF);
+
+                        sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0);
+                        sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1);
+                        sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2);
+                        sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3);
+                        sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0);
+                        sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1);
+                        sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2);
+                        sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3);
+                    }
+
+                    sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0));
+                    sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1));
+                    sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2));
+                    sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3));
+                }
+
+                for (int m = 0; m < 4; m++) {
+                    vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);
+                }
+            }
+        }
+        return;
+    }
+#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
+    {
+        float sumf[4][4];
+        int sumi;
+
+        for (int y = 0; y < nr / 4; y++) {
+            const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+            for (int x = 0; x < nc / ncols_interleaved; x++) {
+                const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
+                for (int m = 0; m < 4; m++) {
+                    for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
+                }
+                for (int l = 0; l < nb; l++) {
+                    for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                        for (int m = 0; m < 4; m++) {
+                            for (int j = 0; j < ncols_interleaved; j++) {
+                                sumi = 0;
+                                for (int i = 0; i < blocklen; ++i) {
+                                    const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
+                                    const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
+                                    sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
+                                            (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4]));
+                                }
+                                sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]);
+                            }
+                        }
+                    }
+                }
+                for (int m = 0; m < 4; m++) {
+                    for (int j = 0; j < ncols_interleaved; j++)
+                        s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
+                }
+            }
+        }
+    }
+}
+
 // FIXME: this code is duplicated from ggml-aarch64.c
 static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) {
     block_q4_0x4 out;
@@ -3518,6 +3727,70 @@ static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor *t, int interleave_block,
     GGML_UNUSED(data_size);
 }
 
+static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_size_interleave) {
+    block_iq4_nlx4 out;
+
+    for (int i = 0; i < 4; i++) {
+        out.d[i] = in[i].d;
+    }
+
+    const int end = QK4_NL * 2 / blck_size_interleave;
+
+    if (blck_size_interleave == 8) {
+        for (int i = 0; i < end; ++i) {
+            int src_id = i % 4;
+            int src_offset = (i / 4) * blck_size_interleave;
+            int dst_offset = i * blck_size_interleave;
+
+            // Using memcpy to avoid unaligned memory accesses
+            memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t));
+        }
+    } else if (blck_size_interleave == 4) {
+        for (int i = 0; i < end; ++i) {
+            int src_id = i % 4;
+            int src_offset = (i / 4) * blck_size_interleave;
+            int dst_offset = i * blck_size_interleave;
+
+            memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint32_t));
+        }
+    } else {
+        GGML_ASSERT(false);
+    }
+
+    return out;
+}
+
+static int repack_iq4_nl_to_iq4_nl_4_bl(struct ggml_tensor * t, int interleave_block, const void * restrict data, size_t data_size) {
+    GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL);
+    GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
+
+    block_iq4_nlx4 * dst = (block_iq4_nlx4 *)t->data;
+    const block_iq4_nl * src = (const block_iq4_nl *)data;
+    block_iq4_nl dst_tmp[4];
+    int nrow = t->ne[1]; // Number of rows
+    int nrows_interleaved = 4;
+    int nblocks = t->ne[0] / QK4_0;
+
+    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl));
+
+    if (nrow % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
+        return -1;
+    }
+
+    for (int b = 0; b < nrow; b += nrows_interleaved) {
+        for (int64_t x = 0; x < nblocks; x++) {
+            for (int i = 0; i < nrows_interleaved; i++) {
+                dst_tmp[i] = src[x + i * nblocks];
+            }
+            *dst++ = make_block_iq4_nlx4(dst_tmp, interleave_block);
+        }
+        src += nrows_interleaved * nblocks;
+    }
+    return 0;
+
+    GGML_UNUSED(data_size);
+}
+
 // Prepare for optimized kernels if applicable
 void ggml_aarch64_repack_tensor(struct ggml_tensor * cur, enum ggml_type repack_type, const void * restrict data, size_t data_size) {
     if (cur->type == repack_type) {
@@ -3525,20 +3798,30 @@ void ggml_aarch64_repack_tensor(struct ggml_tensor * cur, enum ggml_type repack_
         return;
     }
 
-    GGML_ASSERT(cur->type == GGML_TYPE_Q4_0);
-
-    switch (repack_type) {
-        case GGML_TYPE_Q4_0_8_8:
-            repack_q4_0_to_q4_0_8_bl(cur, 8, data, data_size);
-            break;
-        case GGML_TYPE_Q4_0_4_8:
-            repack_q4_0_to_q4_0_4_bl(cur, 8, data, data_size);
-            break;
-        case GGML_TYPE_Q4_0_4_4:
-            repack_q4_0_to_q4_0_4_bl(cur, 4, data, data_size);
-            break;
-        default:
-            GGML_ABORT("Unsupported type");
+    if (cur->type == GGML_TYPE_Q4_0) {
+        switch (repack_type) {
+            case GGML_TYPE_Q4_0_8_8:
+                repack_q4_0_to_q4_0_8_bl(cur, 8, data, data_size);
+                break;
+            case GGML_TYPE_Q4_0_4_8:
+                repack_q4_0_to_q4_0_4_bl(cur, 8, data, data_size);
+                break;
+            case GGML_TYPE_Q4_0_4_4:
+                repack_q4_0_to_q4_0_4_bl(cur, 4, data, data_size);
+                break;
+            default:
+                GGML_ABORT("Unsupported type");
+        }
+    } else if (cur->type == GGML_TYPE_IQ4_NL) {
+        switch (repack_type) {
+            case GGML_TYPE_IQ4_NL_4_4:
+                repack_iq4_nl_to_iq4_nl_4_bl(cur, 4, data, data_size);
+                break;
+            default:
+                GGML_ABORT("Unsupported type");
+        }
+    } else {
+        GGML_ABORT("Unsupported type");
     }
 }
 
@@ -3551,9 +3834,13 @@ enum ggml_type ggml_aarch64_get_optimal_repack_type(const struct ggml_tensor * c
         if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
             return GGML_TYPE_Q4_0_4_8;
         }
-        if (ggml_cpu_has_neon()) {
+        if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
             return GGML_TYPE_Q4_0_4_4;
         }
+    } else if (cur->type == GGML_TYPE_IQ4_NL) {
+        if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
+            return GGML_TYPE_IQ4_NL_4_4;
+        }
     }
 
     return cur->type;
index 53b30c1dd2dfea0b2f34a2039d025f0fb608d3c0..3d9db6a19eb87cb2ebf9e7a6bd426fcc178fa52e 100644 (file)
@@ -15,11 +15,13 @@ void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in
 void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 
 // GEMM
 void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 
 void           ggml_aarch64_repack_tensor(struct ggml_tensor * cur, enum ggml_type repack_type, const void * data, size_t data_size);
 enum ggml_type ggml_aarch64_get_optimal_repack_type(const struct ggml_tensor * cur);
index c6ede19d9d1c0aae72784100093145b4ec66944a..fea867440424e869c55af8d98188f75f4a741ef6 100644 (file)
@@ -109,10 +109,11 @@ static ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16];
 #if defined(__ARM_ARCH)
 struct ggml_arm_arch_features_type {
     int has_neon;
+    int has_dotprod;
     int has_i8mm;
     int has_sve;
     int sve_cnt;
-} ggml_arm_arch_features = {-1, -1, -1, 0};
+} ggml_arm_arch_features = {-1, -1, -1, -1, 0};
 #endif
 
 
@@ -446,6 +447,15 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
         .vec_dot_type             = GGML_TYPE_Q8_K,
         .nrows                    = 1,
     },
+    [GGML_TYPE_IQ4_NL_4_4] = {
+        .from_float               = NULL,
+        .vec_dot                  = NULL,
+        .vec_dot_type             = GGML_TYPE_Q8_0,
+        .nrows                    = 1,
+        .ncols                    = 4,
+        .gemv                     = ggml_gemv_iq4_nl_4x4_q8_0,
+        .gemm                     = ggml_gemm_iq4_nl_4x4_q8_0,
+    },
 };
 
 const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type) {
@@ -2439,6 +2449,7 @@ static void ggml_init_arm_arch_features(void) {
     uint32_t hwcap2 = getauxval(AT_HWCAP2);
 
     ggml_arm_arch_features.has_neon = !!(hwcap & HWCAP_ASIMD);
+    ggml_arm_arch_features.has_dotprod = !!(hwcap && HWCAP_ASIMDDP);
     ggml_arm_arch_features.has_i8mm = !!(hwcap2 & HWCAP2_I8MM);
     ggml_arm_arch_features.has_sve  = !!(hwcap & HWCAP_SVE);
 
@@ -2453,6 +2464,11 @@ static void ggml_init_arm_arch_features(void) {
     }
     ggml_arm_arch_features.has_neon = oldp;
 
+    if (sysctlbyname("hw.optional.arm.FEAT_DotProd", &oldp, &size, NULL, 0) != 0) {
+        oldp = 0;
+    }
+    ggml_arm_arch_features.has_dotprod = oldp;
+
     if (sysctlbyname("hw.optional.arm.FEAT_I8MM", &oldp, &size, NULL, 0) != 0) {
         oldp = 0;
     }
@@ -9133,6 +9149,7 @@ static void ggml_compute_forward_clamp(
         case GGML_TYPE_Q4_0_4_4:
         case GGML_TYPE_Q4_0_4_8:
         case GGML_TYPE_Q4_0_8_8:
+        case GGML_TYPE_IQ4_NL_4_4:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
@@ -13880,6 +13897,14 @@ int ggml_cpu_has_neon(void) {
 #endif
 }
 
+int ggml_cpu_has_dotprod(void) {
+#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_DOTPROD)
+    return ggml_arm_arch_features.has_dotprod;
+#else
+    return 0;
+#endif
+}
+
 int ggml_cpu_has_sve(void) {
 #if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SVE)
     return ggml_arm_arch_features.has_sve;
index febed433ada2b293af5617e5d7afc2b2c4188164..44d99089a490cb30c47b87ee3c4446fa2f9078c9 100644 (file)
@@ -457,7 +457,7 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
     const struct ggml_tensor * src1 = op->src[1];
 
     if (src0 && src0->buffer && ggml_backend_cpu_buft_is_aarch64(src0->buffer->buft)) {
-        if (op->op != GGML_OP_MUL_MAT || src0->type != GGML_TYPE_Q4_0 || ggml_aarch64_get_optimal_repack_type(src0) == GGML_TYPE_Q4_0) {
+        if (op->op != GGML_OP_MUL_MAT || src0->type == ggml_aarch64_get_optimal_repack_type(src0)) {
             return false;
         }
     }
index 1a2318cb188c423c6347e0592a5465c760325715..1a9a7efaf7f39db056e96a05a002d7f0d2eda77e 100644 (file)
@@ -831,6 +831,15 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
         .to_float                 = (ggml_to_float_t) dequantize_row_tq2_0,
         .from_float_ref           = (ggml_from_float_t) quantize_row_tq2_0_ref,
     },
+    [GGML_TYPE_IQ4_NL_4_4] = {
+        .type_name                = "iq4_nl_4x4",
+        .blck_size                = QK4_NL,
+        .blck_size_interleave     = 4,
+        .type_size                = sizeof(block_iq4_nl),
+        .is_quantized             = true,
+        .to_float                 = NULL,
+        .from_float_ref           = NULL,
+    },
 };
 
 const struct ggml_type_traits * ggml_get_type_traits(enum ggml_type type) {