]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml : add mmla kernels for quantized GEMM (llama/4966)
authorsnadampal <redacted>
Sun, 11 Feb 2024 13:22:33 +0000 (07:22 -0600)
committerGeorgi Gerganov <redacted>
Mon, 12 Feb 2024 07:31:11 +0000 (09:31 +0200)
* ggml: aarch64: implement smmla kernel for q8_0_q8_0 quantized gemm

armv8.2-a and above supports MMLA instructions that have higher
throughput than DOT. this commit adds mmla kernel for
q8_0_q8_0 gemm. The feature is enabled if the platform supports
"__ARM_FEATURE_MATMUL_INT8"

On AWS Graviton3 processors this kernel resulted up to 1.5x
improvement for prompt evaluation throughput compared to the
default sdot kernel.

* ggml: aarch64: implement smmla kernel for q4_0_q8_0 quantized gemm

armv8.2-a and above supports MMLA instructions that have higher
throughput than DOT. this commit adds mmla kernel for
q4_0_q8_0 gemm. The feature is enabled if the platform supports
"__ARM_FEATURE_MATMUL_INT8"

On AWS Graviton3 processors this kernel resulted up to 1.5x
improvement for prompt evaluation throughput compared to the
default sdot kernel.

* ggml: aarch64: implement smmla kernel for q4_1_q8_1 quantized gemm

armv8.2-a and above supports MMLA instructions that have higher
throughput than DOT. this commit adds mmla kernel for
q4_1_q8_1 gemm. The feature is enabled if the platform supports
"__ARM_FEATURE_MATMUL_INT8"

On AWS Graviton3 processors this kernel resulted up to 1.5x
improvement for prompt evaluation throughput compared to the
default sdot kernel.

* ggml: update unit tests for the new vec_dot interface

* llama.cpp: add MATMUL_INT8 capability to system_info

ggml-quants.c
ggml-quants.h
ggml.c
ggml.h

index 1031e3761c3c387d2685a9af6c8c2b0005e45a5e..6c122dd2ad413975ce11aa4c05f0d0f8532cd914 100644 (file)
@@ -49,6 +49,8 @@
 #define MIN(a, b) ((a) < (b) ? (a) : (b))
 #define MAX(a, b) ((a) > (b) ? (a) : (b))
 
+#define UNUSED GGML_UNUSED
+
 #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
 
 #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
@@ -3677,15 +3679,88 @@ static inline __m128i get_scale_shuffle(int i) {
 }
 #endif
 
-void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
     const int qk = QK8_0;
     const int nb = n / qk;
 
     assert(n % qk == 0);
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+    assert((nrc == 2) || (nrc == 1));
+#else
+    assert(nrc == 1);
+#endif
 
     const block_q4_0 * restrict x = vx;
     const block_q8_0 * restrict y = vy;
 
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+    if (nrc == 2) {
+        const block_q4_0 * restrict vx0 = vx;
+        const block_q4_0 * restrict vx1 = vx + bx;
+
+        const block_q8_0 * restrict vy0 = vy;
+        const block_q8_0 * restrict vy1 = vy + by;
+
+        float32x4_t sumv0 = vdupq_n_f32(0.0f);
+
+        for (int i = 0; i < nb; i++) {
+            const block_q4_0 * restrict b_x0 = &vx0[i];
+            const block_q4_0 * restrict b_x1 = &vx1[i];
+            const block_q8_0 * restrict b_y0 = &vy0[i];
+            const block_q8_0 * restrict b_y1 = &vy1[i];
+
+            const uint8x16_t m4b = vdupq_n_u8(0x0F);
+            const int8x16_t  s8b = vdupq_n_s8(0x8);
+
+            const uint8x16_t v0_0 = vld1q_u8(b_x0->qs);
+            const uint8x16_t v0_1 = vld1q_u8(b_x1->qs);
+
+            // 4-bit -> 8-bit
+            const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
+            const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+            const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
+            const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+            // sub 8
+            const int8x16_t x0_l = vsubq_s8(v0_0l, s8b);
+            const int8x16_t x0_h = vsubq_s8(v0_0h, s8b);
+            const int8x16_t x1_l = vsubq_s8(v0_1l, s8b);
+            const int8x16_t x1_h = vsubq_s8(v0_1h, s8b);
+
+            // load y
+            const int8x16_t y0_l = vld1q_s8(b_y0->qs);
+            const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
+            const int8x16_t y1_l = vld1q_s8(b_y1->qs);
+            const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
+
+            float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
+                                 GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
+                                 GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
+                                 GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
+
+            int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
+            int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
+
+            int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
+            int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
+
+            int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
+            int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
+
+            int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
+            int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
+
+            sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
+                                                                                l1, r1)), l2, r2)), l3, r3))), scale);
+        }
+        float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
+        float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
+
+        vst1_f32(s, vget_low_f32(sumv2));
+        vst1_f32(s + bs, vget_high_f32(sumv2));
+        return;
+    }
+#endif
 #if defined(__ARM_NEON)
     float32x4_t sumv0 = vdupq_n_f32(0.0f);
     float32x4_t sumv1 = vdupq_n_f32(0.0f);
@@ -3967,15 +4042,89 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx,
 #endif
 }
 
-void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
     const int qk = QK8_1;
     const int nb = n / qk;
 
     assert(n % qk == 0);
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+    assert((nrc == 2) || (nrc == 1));
+#else
+    assert(nrc == 1);
+#endif
 
     const block_q4_1 * restrict x = vx;
     const block_q8_1 * restrict y = vy;
 
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+    if (nrc == 2) {
+        const block_q4_1 * restrict vx0 = vx;
+        const block_q4_1 * restrict vx1 = vx + bx;
+        const block_q8_1 * restrict vy0 = vy;
+        const block_q8_1 * restrict vy1 = vy + by;
+
+        float32x4_t sumv0 = vdupq_n_f32(0.0f);
+        float32x4_t summs0 = vdupq_n_f32(0.0f);
+
+        for (int i = 0; i < nb; i++) {
+            const block_q4_1 * restrict b_x0 = &vx0[i];
+            const block_q4_1 * restrict b_x1 = &vx1[i];
+            const block_q8_1 * restrict b_y0 = &vy0[i];
+            const block_q8_1 * restrict b_y1 = &vy1[i];
+
+            float32x4_t summs_t = {GGML_FP16_TO_FP32(b_x0->m) * b_y0->s,
+                                   GGML_FP16_TO_FP32(b_x1->m) * b_y0->s,
+                                   GGML_FP16_TO_FP32(b_x0->m) * b_y1->s,
+                                   GGML_FP16_TO_FP32(b_x1->m) * b_y1->s};
+            summs0 += summs_t;
+
+            const uint8x16_t m4b = vdupq_n_u8(0x0F);
+
+            const uint8x16_t v0_0 = vld1q_u8(b_x0->qs);
+            const uint8x16_t v0_1 = vld1q_u8(b_x1->qs);
+
+            // 4-bit -> 8-bit
+            const int8x16_t x0_l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
+            const int8x16_t x0_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+            const int8x16_t x1_l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
+            const int8x16_t x1_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+            // load y
+            const int8x16_t y0_l = vld1q_s8(b_y0->qs);
+            const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
+            const int8x16_t y1_l = vld1q_s8(b_y1->qs);
+            const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
+
+            // mmla into int32x4_t
+            float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
+                                 GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
+                                 GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
+                                 GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
+
+            int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
+            int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
+
+            int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
+            int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
+
+            int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
+            int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
+
+            int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
+            int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
+            sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
+                                                                                l1, r1)), l2, r2)), l3, r3))), scale);
+        }
+
+        float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
+        float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
+        sumv2 = sumv2 + summs0;
+
+        vst1_f32(s, vget_low_f32(sumv2));
+        vst1_f32(s + bs, vget_high_f32(sumv2));
+        return;
+    }
+#endif
     // TODO: add WASM SIMD
 #if defined(__ARM_NEON)
     float32x4_t sumv0 = vdupq_n_f32(0.0f);
@@ -4107,12 +4256,17 @@ void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restri
 #endif
 }
 
-void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
     const int qk = QK8_0;
     const int nb = n / qk;
 
     assert(n % qk == 0);
     assert(qk == QK5_0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
 
     const block_q5_0 * restrict x = vx;
     const block_q8_0 * restrict y = vy;
@@ -4393,12 +4547,17 @@ void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restri
 #endif
 }
 
-void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
     const int qk = QK8_1;
     const int nb = n / qk;
 
     assert(n % qk == 0);
     assert(qk == QK5_1);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
 
     const block_q5_1 * restrict x = vx;
     const block_q8_1 * restrict y = vy;
@@ -4692,15 +4851,75 @@ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restri
 #endif
 }
 
-void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
     const int qk = QK8_0;
     const int nb = n / qk;
 
     assert(n % qk == 0);
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+    assert((nrc == 2) || (nrc == 1));
+#else
+    assert(nrc == 1);
+#endif
 
     const block_q8_0 * restrict x = vx;
     const block_q8_0 * restrict y = vy;
 
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+    if (nrc == 2) {
+        const block_q8_0 * restrict vx0 = vx;
+        const block_q8_0 * restrict vx1 = vx + bx;
+        const block_q8_0 * restrict vy0 = vy;
+        const block_q8_0 * restrict vy1 = vy + by;
+
+        float32x4_t sumv0 = vdupq_n_f32(0.0f);
+
+        for (int i = 0; i < nb; i++) {
+            const block_q8_0 * restrict b_x0 = &vx0[i];
+            const block_q8_0 * restrict b_y0 = &vy0[i];
+
+            const block_q8_0 * restrict b_x1 = &vx1[i];
+            const block_q8_0 * restrict b_y1 = &vy1[i];
+
+            const int8x16_t x0_l = vld1q_s8(b_x0->qs);
+            const int8x16_t x0_h = vld1q_s8(b_x0->qs + 16);
+            const int8x16_t x1_l = vld1q_s8(b_x1->qs);
+            const int8x16_t x1_h = vld1q_s8(b_x1->qs + 16);
+
+            // load y
+            const int8x16_t y0_l = vld1q_s8(b_y0->qs);
+            const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
+            const int8x16_t y1_l = vld1q_s8(b_y1->qs);
+            const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
+
+            float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
+                             GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
+                             GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
+                             GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
+
+            int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
+            int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
+
+            int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
+            int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
+
+            int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
+            int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
+
+            int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
+            int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
+
+            sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
+                                                                                       l1, r1)), l2, r2)), l3, r3))), scale);
+        }
+        float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
+        float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
+
+        vst1_f32(s, vget_low_f32(sumv2));
+        vst1_f32(s + bs, vget_high_f32(sumv2));
+        return;
+    }
+#endif
 #if defined(__ARM_NEON)
     float32x4_t sumv0 = vdupq_n_f32(0.0f);
     float32x4_t sumv1 = vdupq_n_f32(0.0f);
@@ -4795,7 +5014,12 @@ void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restri
 }
 
 #if QK_K == 256
-void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
 
     const block_q2_K * restrict x = vx;
     const block_q8_K * restrict y = vy;
@@ -5171,7 +5395,12 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
 
 #else
 
-void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
 
     const block_q2_K * restrict x = vx;
     const block_q8_K * restrict y = vy;
@@ -5429,8 +5658,13 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
 #endif
 
 #if QK_K == 256
-void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
     assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
 
     const uint32_t kmask1 = 0x03030303;
     const uint32_t kmask2 = 0x0f0f0f0f;
@@ -5949,8 +6183,13 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
 
 #else
 
-void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
     assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
 
     const block_q3_K * restrict x = vx;
     const block_q8_K * restrict y = vy;
@@ -6292,8 +6531,13 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
 #endif
 
 #if QK_K == 256
-void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
     assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
 
     const block_q4_K * restrict x = vx;
     const block_q8_K * restrict y = vy;
@@ -6648,8 +6892,13 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
 #endif
 }
 #else
-void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
     assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
 
     const block_q4_K * restrict x = vx;
     const block_q8_K * restrict y = vy;
@@ -6891,8 +7140,13 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
 #endif
 
 #if QK_K == 256
-void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy,  size_t by, int nrc) {
     assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
 
     const block_q5_K * restrict x = vx;
     const block_q8_K * restrict y = vy;
@@ -7311,8 +7565,13 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
 
 #else
 
-void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
     assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
 
     const block_q5_K * restrict x = vx;
     const block_q8_K * restrict y = vy;
@@ -7577,8 +7836,13 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
 
 
 #if QK_K == 256
-void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
     assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
 
     const block_q6_K * restrict x = vx;
     const block_q8_K * restrict y = vy;
@@ -8009,8 +8273,13 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
 
 #else
 
-void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
     assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
 
     const block_q6_K * restrict x = vx;
     const block_q8_K * restrict y = vy;
@@ -8339,8 +8608,13 @@ static const int8_t keven_signs_q2xs[1024] = {
      1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1,  1,  1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1,
 };
 
-void ggml_vec_dot_iq2_xxs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
     assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
 
     const block_iq2_xxs * restrict x = vx;
     const block_q8_K    * restrict y = vy;
@@ -8462,8 +8736,13 @@ void ggml_vec_dot_iq2_xxs_q8_K(const int n, float * restrict s, const void * res
 #endif
 }
 
-void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
     assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
 
     const block_iq2_xs * restrict x = vx;
     const block_q8_K   * restrict y = vy;
@@ -8682,8 +8961,13 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
 }
 
 // TODO
-void ggml_vec_dot_iq3_xxs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
     assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
 
     const block_iq3_xxs * restrict x = vx;
     const block_q8_K    * restrict y = vy;
index bfdf3c99718ee33301db9ba5441b1e9d647821df..68f09b1e12f256a07a3aa47c8732b5dd0b7c0171 100644 (file)
@@ -245,20 +245,20 @@ void dequantize_row_iq2_xs (const block_iq2_xs  * GGML_RESTRICT x, float * GGML_
 void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
 
 // Dot product
-void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy);
-void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy);
-void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy);
-void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy);
-void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy);
-
-void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy);
-void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy);
-void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy);
-void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy);
-void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy);
-void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy);
-void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy);
-void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy);
+void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
+void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 
 //
 // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
diff --git a/ggml.c b/ggml.c
index 9960573386448ed91a4f11ebd6935a8aea2e0641..d921d82fed7d3358ef08d85042fd08a60770f915 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -428,8 +428,8 @@ int64_t ggml_cycles_per_ms(void) {
 
 static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
 
-static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y);
-static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y);
+static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc);
+static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc);
 
 static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
     [GGML_TYPE_I8] = {
@@ -457,6 +457,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
         .is_quantized             = false,
         .vec_dot                  = (ggml_vec_dot_t) ggml_vec_dot_f32,
         .vec_dot_type             = GGML_TYPE_F32,
+        .nrows                    = 1,
     },
     [GGML_TYPE_F16] = {
         .type_name                = "f16",
@@ -468,6 +469,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
         .from_float_reference     = (ggml_from_float_t) ggml_fp32_to_fp16_row,
         .vec_dot                  = (ggml_vec_dot_t) ggml_vec_dot_f16,
         .vec_dot_type             = GGML_TYPE_F16,
+        .nrows                    = 1,
     },
     [GGML_TYPE_Q4_0] = {
         .type_name                = "q4_0",
@@ -479,6 +481,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
         .from_float_reference     = (ggml_from_float_t) quantize_row_q4_0_reference,
         .vec_dot                  = ggml_vec_dot_q4_0_q8_0,
         .vec_dot_type             = GGML_TYPE_Q8_0,
+#if defined (__ARM_FEATURE_MATMUL_INT8)
+        .nrows                    = 2,
+#else
+        .nrows                    = 1,
+#endif
     },
     [GGML_TYPE_Q4_1] = {
         .type_name                = "q4_1",
@@ -490,6 +497,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
         .from_float_reference     = (ggml_from_float_t) quantize_row_q4_1_reference,
         .vec_dot                  = ggml_vec_dot_q4_1_q8_1,
         .vec_dot_type             = GGML_TYPE_Q8_1,
+#if defined (__ARM_FEATURE_MATMUL_INT8)
+        .nrows                    = 2,
+#else
+        .nrows                    = 1,
+#endif
     },
     [4] = { // GGML_TYPE_Q4_2
         .type_name                = "DEPRECATED",
@@ -501,6 +513,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
         .from_float_reference     = NULL,
         .vec_dot                  = NULL,
         .vec_dot_type             = GGML_TYPE_COUNT,
+        .nrows                    = 1,
     },
     [5] = { // GGML_TYPE_Q4_3
         .type_name                = "DEPRECATED",
@@ -512,6 +525,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
         .from_float_reference     = NULL,
         .vec_dot                  = NULL,
         .vec_dot_type             = GGML_TYPE_COUNT,
+        .nrows                    = 1,
     },
     [GGML_TYPE_Q5_0] = {
         .type_name                = "q5_0",
@@ -523,6 +537,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
         .from_float_reference     = (ggml_from_float_t) quantize_row_q5_0_reference,
         .vec_dot                  = ggml_vec_dot_q5_0_q8_0,
         .vec_dot_type             = GGML_TYPE_Q8_0,
+        .nrows                    = 1,
     },
     [GGML_TYPE_Q5_1] = {
         .type_name                = "q5_1",
@@ -534,6 +549,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
         .from_float_reference     = (ggml_from_float_t) quantize_row_q5_1_reference,
         .vec_dot                  = ggml_vec_dot_q5_1_q8_1,
         .vec_dot_type             = GGML_TYPE_Q8_1,
+        .nrows                    = 1,
     },
     [GGML_TYPE_Q8_0] = {
         .type_name                = "q8_0",
@@ -545,6 +561,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
         .from_float_reference     = (ggml_from_float_t) quantize_row_q8_0_reference,
         .vec_dot                  = ggml_vec_dot_q8_0_q8_0,
         .vec_dot_type             = GGML_TYPE_Q8_0,
+#if defined (__ARM_FEATURE_MATMUL_INT8)
+        .nrows                    = 2,
+#else
+        .nrows                    = 1,
+#endif
     },
     [GGML_TYPE_Q8_1] = {
         .type_name                = "q8_1",
@@ -554,6 +575,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
         .from_float               = quantize_row_q8_1,
         .from_float_reference     = (ggml_from_float_t) quantize_row_q8_1_reference,
         .vec_dot_type             = GGML_TYPE_Q8_1,
+        .nrows                    = 1,
     },
     [GGML_TYPE_Q2_K] = {
         .type_name                = "q2_K",
@@ -565,6 +587,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
         .from_float_reference     = (ggml_from_float_t) quantize_row_q2_K_reference,
         .vec_dot                  = ggml_vec_dot_q2_K_q8_K,
         .vec_dot_type             = GGML_TYPE_Q8_K,
+        .nrows                    = 1,
     },
     [GGML_TYPE_Q3_K] = {
         .type_name                = "q3_K",
@@ -576,6 +599,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
         .from_float_reference     = (ggml_from_float_t) quantize_row_q3_K_reference,
         .vec_dot                  = ggml_vec_dot_q3_K_q8_K,
         .vec_dot_type             = GGML_TYPE_Q8_K,
+        .nrows                    = 1,
     },
     [GGML_TYPE_Q4_K] = {
         .type_name                = "q4_K",
@@ -587,6 +611,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
         .from_float_reference     = (ggml_from_float_t) quantize_row_q4_K_reference,
         .vec_dot                  = ggml_vec_dot_q4_K_q8_K,
         .vec_dot_type             = GGML_TYPE_Q8_K,
+        .nrows                    = 1,
     },
     [GGML_TYPE_Q5_K] = {
         .type_name                = "q5_K",
@@ -598,6 +623,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
         .from_float_reference     = (ggml_from_float_t) quantize_row_q5_K_reference,
         .vec_dot                  = ggml_vec_dot_q5_K_q8_K,
         .vec_dot_type             = GGML_TYPE_Q8_K,
+        .nrows                    = 1,
     },
     [GGML_TYPE_Q6_K] = {
         .type_name                = "q6_K",
@@ -609,6 +635,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
         .from_float_reference     = (ggml_from_float_t) quantize_row_q6_K_reference,
         .vec_dot                  = ggml_vec_dot_q6_K_q8_K,
         .vec_dot_type             = GGML_TYPE_Q8_K,
+        .nrows                    = 1,
     },
     [GGML_TYPE_IQ2_XXS] = {
         .type_name                = "iq2_xxs",
@@ -620,6 +647,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
         .from_float_reference     = NULL,
         .vec_dot                  = ggml_vec_dot_iq2_xxs_q8_K,
         .vec_dot_type             = GGML_TYPE_Q8_K,
+        .nrows                    = 1,
     },
     [GGML_TYPE_IQ2_XS] = {
         .type_name                = "iq2_xs",
@@ -631,6 +659,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
         .from_float_reference     = NULL,
         .vec_dot                  = ggml_vec_dot_iq2_xs_q8_K,
         .vec_dot_type             = GGML_TYPE_Q8_K,
+        .nrows                    = 1,
     },
     [GGML_TYPE_IQ3_XXS] = {
         .type_name                = "iq3_xxs",
@@ -642,6 +671,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
         .from_float_reference     = (ggml_from_float_t)quantize_row_iq3_xxs_reference,
         .vec_dot                  = ggml_vec_dot_iq3_xxs_q8_K,
         .vec_dot_type             = GGML_TYPE_Q8_K,
+        .nrows                    = 1,
     },
     [GGML_TYPE_Q8_K] = {
         .type_name                = "q8_K",
@@ -1212,7 +1242,13 @@ inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x)
 inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i]*y[i];   }
 inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i]/y[i];   }
 
-static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) {
+static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) {
+   assert(nrc == 1);
+   UNUSED(nrc);
+   UNUSED(bx);
+   UNUSED(by);
+   UNUSED(bs);
+
 #ifdef GGML_SIMD
     float sumf = 0.0f;
     const int np = (n & ~(GGML_F32_STEP - 1));
@@ -1249,7 +1285,13 @@ static void ggml_vec_dot_f32(const int n, float * restrict s, const float * rest
     *s = sumf;
 }
 
-static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
+static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc) {
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
     ggml_float sumf = 0.0;
 
 #if defined(GGML_SIMD)
@@ -1455,7 +1497,7 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) {
 #endif
 }
 
-inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrtf(*s);   }
+inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s);   }
 inline static void ggml_vec_sqr_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i];   }
 inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
 inline static void ggml_vec_log_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]);   }
@@ -9992,6 +10034,7 @@ static void ggml_compute_forward_mul_mat(
     ggml_vec_dot_t    const vec_dot               = type_traits[type].vec_dot;
     enum ggml_type    const vec_dot_type          = type_traits[type].vec_dot_type;
     ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
+    int64_t           const vec_dot_num_rows      = type_traits[type].nrows;
 
     GGML_ASSERT(ne0 == ne01);
     GGML_ASSERT(ne1 == ne11);
@@ -10159,12 +10202,23 @@ static void ggml_compute_forward_mul_mat(
     const int64_t blck_0 = 16;
     const int64_t blck_1 = 16;
 
+    // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
+    int64_t nrc = vec_dot_num_rows;
+    // TODO: currently the mmla kernels support only even numbered rows/cols.
+    // this check can be removed once they are extended to support odd numbered rows/cols too
+    if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) {
+        nrc = 1;
+    }
+
+    const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;
+
     // attempt to reduce false-sharing (does not seem to make a difference)
-    float tmp[16];
+    // 16 * 2, accounting for mmla kernels
+    float tmp[32];
 
     for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
         for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
-            for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
+            for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ir1 += nrc) {
                 const int64_t i13 = (ir1/(ne12*ne1));
                 const int64_t i12 = (ir1 - i13*ne12*ne1)/ne1;
                 const int64_t i11 = (ir1 - i13*ne12*ne1 - i12*ne1);
@@ -10187,17 +10241,19 @@ static void ggml_compute_forward_mul_mat(
                     (src1_cont || src1->type != vec_dot_type
                      ? (i11      + i12*ne11 + i13*ne12*ne11)*row_size
                      : (i11*nb11 + i12*nb12 + i13*nb13));
-
                 float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
 
                 //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
                 //    vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
                 //}
 
-                for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
-                    vec_dot(ne00, &tmp[ir0 - iir0], src0_row + ir0*nb01, src1_col);
+                for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ir0 += nrc) {
+                    vec_dot(ne00, &tmp[ir0 - iir0], (nrc>1 ? 16 : 0), src0_row + ir0*nb01, (nrc>1 ? nb01 : 0), src1_col, (nrc>1 ? src1_col_stride : 0), nrc);
+                }
+
+                for (int cn = 0; cn < nrc; ++cn) {
+                    memcpy(&dst_col[iir0 + cn*nb1/nb0], tmp + (cn*16), (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
                 }
-                memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
             }
         }
     }
@@ -10386,7 +10442,7 @@ static void ggml_compute_forward_mul_mat_id(
                     //}
 
                     for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
-                        vec_dot(ne00, &tmp[ir0 - iir0], src0_row + ir0*nb01, src1_col);
+                        vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_row + ir0*nb01, 0, src1_col, 0, 1);
                     }
                     memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
                 }
@@ -11568,7 +11624,7 @@ static void ggml_compute_forward_soft_max_back_f32(
 
         // linear runtime, no additional memory
         float dot_y_dy = 0;
-        ggml_vec_dot_f32 (nc, &dot_y_dy, y, dy);
+        ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
         ggml_vec_cpy_f32 (nc, dx, dy);
         ggml_vec_acc1_f32(nc, dx, -dot_y_dy);
         ggml_vec_mul_f32 (nc, dx, dx, y);
@@ -12369,9 +12425,9 @@ static void ggml_compute_forward_conv_transpose_1d_f16_f32(
             const int i1n = i10*ne11;
             for (int i00 = 0; i00 < ne00; i00++) {
                 float v = 0;
-                ggml_vec_dot_f16(ne02, &v,
-                        (ggml_fp16_t *)    wdata_src + i1n,
-                        (ggml_fp16_t *) wdata_kernel + i00*ne02);
+                ggml_vec_dot_f16(ne02, &v, 0,
+                        (ggml_fp16_t *)    wdata_src + i1n, 0,
+                        (ggml_fp16_t *) wdata_kernel + i00*ne02, 0, 1);
                 dst_data[i10*s0 + i00] += v;
             }
         }
@@ -12466,9 +12522,9 @@ static void ggml_compute_forward_conv_transpose_1d_f32(
             const int i1n = i10*ne11;
             for (int i00 = 0; i00 < ne00; i00++) {
                 float v = 0;
-                ggml_vec_dot_f32(ne02, &v,
-                        wdata_src + i1n,
-                        wdata_kernel + i00*ne02);
+                ggml_vec_dot_f32(ne02, &v, 0,
+                        wdata_src + i1n, 0,
+                        wdata_kernel + i00*ne02, 0, 1);
                 dst_data[i10*s0 + i00] += v;
             }
         }
@@ -12783,9 +12839,9 @@ static void ggml_compute_forward_conv_transpose_2d(
                 for (int i01 = 0; i01 < ne01; i01++) {
                     for (int i00 = 0; i00 < ne00; i00++) {
                         float v = 0;
-                        ggml_vec_dot_f16(ne03, &v,
-                                wdata_src + i1n,
-                                wdata_kernel + i01*ne00*ne03 + i00*ne03);
+                        ggml_vec_dot_f16(ne03, &v, 0,
+                                wdata_src + i1n, 0,
+                                wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1);
                         dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v;
                     }
                 }
@@ -13214,9 +13270,9 @@ static void ggml_compute_forward_flash_attn_f32(
             const int i1 = ik1;
 
             ggml_vec_dot_f32(neq0,
-                    S + i1,
-                    (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
-                    (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
+                    S + i1, 0,
+                    (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
+                    (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
         }
 
         // scale
@@ -13299,9 +13355,9 @@ static void ggml_compute_forward_flash_attn_f32(
             const int iv3 = iq3;
 
             ggml_vec_dot_f32(masked_begin,
-                    (float *) ((char *) dst->data + (ic*nb0 + i1*nb1  + i2*nb2   + i3*nb3)),
-                    (float *) ((char *) v->data   + (         ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
-                    S);
+                    (float *) ((char *) dst->data + (ic*nb0 + i1*nb1  + i2*nb2   + i3*nb3)), 0,
+                    (float *) ((char *) v->data   + (         ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0,
+                    S, 0, 1);
         }
     }
 }
@@ -13404,9 +13460,9 @@ static void ggml_compute_forward_flash_attn_f16(
                 const int i1 = ik1;
 
                 ggml_vec_dot_f16(neq0,
-                        S + i1,
-                        (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
-                        (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
+                        S + i1, 0,
+                        (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
+                        (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
             }
         } else {
             for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
@@ -13508,9 +13564,9 @@ static void ggml_compute_forward_flash_attn_f16(
                 const int iv3 = iq3;
 
                 ggml_vec_dot_f16(nev0,
-                        (float *)       ((char *) dst->data + (ic*nb0 + i1*nb1  + i2*nb2   + i3*nb3)),
-                        (ggml_fp16_t *) ((char *) v->data   + (         ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
-                        S16);
+                        (float *)       ((char *) dst->data + (ic*nb0 + i1*nb1  + i2*nb2   + i3*nb3)), 0,
+                        (ggml_fp16_t *) ((char *) v->data   + (         ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0,
+                        S16, 0, 1);
             }
         } else {
             for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
@@ -13652,9 +13708,9 @@ static void ggml_compute_forward_flash_ff_f16(
             const int i1 = ib01;
 
             ggml_vec_dot_f16(nea0,
-                    S + i1,
-                    (ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)),
-                    (ggml_fp16_t *) ((char *)  a->data + ( ia1*nba1  +  ia2*nba2  +  ia3*nba3)));
+                    S + i1, 0,
+                    (ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)), 0,
+                    (ggml_fp16_t *) ((char *)  a->data + ( ia1*nba1  +  ia2*nba2  +  ia3*nba3)), 0, 1);
         }
 
         ggml_vec_add_f32(neb01, S, S, (float *) b1->data);
@@ -13677,9 +13733,9 @@ static void ggml_compute_forward_flash_ff_f16(
             for (int64_t ic = 0; ic < nec01; ++ic) {
 
                 ggml_vec_dot_f16(neb01,
-                        (float *)       ((char *) dst->data + (ic*nb0 + i1*nb1   + i2*nb2   + i3*nb3)),
-                        (ggml_fp16_t *) ((char *) c0->data  + (         ic*nbc01 + i2*nbc02 + i3*nbc03)),
-                        S16);
+                        (float *)       ((char *) dst->data + (ic*nb0 + i1*nb1   + i2*nb2   + i3*nb3)), 0,
+                        (ggml_fp16_t *) ((char *) c0->data  + (         ic*nbc01 + i2*nbc02 + i3*nbc03)), 0,
+                        S16, 0, 1);
             }
 
             ggml_vec_add_f32(nec01,
@@ -13866,9 +13922,9 @@ static void ggml_compute_forward_flash_attn_back_f32(
                     const int i1 = ik1;
 
                     ggml_vec_dot_f32(neq0,
-                            S + i1,
-                            (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
-                            (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
+                            S + i1, 0,
+                            (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
+                            (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
                 }
 
                 // scale
@@ -14013,7 +14069,7 @@ static void ggml_compute_forward_flash_attn_back_f32(
 
                 // S = SM * (S - dot(SM, S))
                 float dot_SM_gradSM = 0;
-                ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, SM, S);
+                ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, 0, SM, 0, S, 0, 1);
                 ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
                 ggml_vec_mul_f32 (masked_begin, S, S, SM);
 
@@ -18382,7 +18438,7 @@ static enum ggml_opt_result linesearch_backtracking(
     }
 
     // compute the initial gradient in the search direction
-    ggml_vec_dot_f32(nx, &dginit, g, d);
+    ggml_vec_dot_f32(nx, &dginit, 0, g, 0, d, 0, 1);
 
     // make sure that d points to a descent direction
     if (0 < dginit) {
@@ -18432,7 +18488,7 @@ static enum ggml_opt_result linesearch_backtracking(
                 return count;
             }
 
-            ggml_vec_dot_f32(nx, &dg, g, d);
+            ggml_vec_dot_f32(nx, &dg, 0, g, 0, d, 0, 1);
 
             // check the Wolfe condition
             if (dg < params->lbfgs.wolfe * dginit) {
@@ -18693,8 +18749,8 @@ static enum ggml_opt_result ggml_opt_lbfgs(
         //     ys = y^t \cdot s    -> 1 / \rho.
         //     yy = y^t \cdot y.
         //
-        ggml_vec_dot_f32(nx, &ys, &lm_y[end[0]*nx], &lm_s[end[0]*nx]);
-        ggml_vec_dot_f32(nx, &yy, &lm_y[end[0]*nx], &lm_y[end[0]*nx]);
+        ggml_vec_dot_f32(nx, &ys, 0, &lm_y[end[0]*nx], 0, &lm_s[end[0]*nx], 0, 1);
+        ggml_vec_dot_f32(nx, &yy, 0, &lm_y[end[0]*nx], 0, &lm_y[end[0]*nx], 0, 1);
 
         lm_ys[end[0]] = ys;
 
@@ -18713,7 +18769,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
         for (int i = 0; i < bound; ++i) {
             j[0] = (j[0] + m - 1) % m;
             // \alpha_{j} = \rho_{j} s^{t}_{j} \cdot q_{k+1}
-            ggml_vec_dot_f32(nx, &lm_alpha[j[0]], &lm_s[j[0]*nx], d);
+            ggml_vec_dot_f32(nx, &lm_alpha[j[0]], 0, &lm_s[j[0]*nx], 0, d, 0, 1);
             lm_alpha[j[0]] /= lm_ys[j[0]];
             // q_{i} = q_{i+1} - \alpha_{i} y_{i}
             ggml_vec_mad_f32(nx, d, &lm_y[j[0]*nx], -lm_alpha[j[0]]);
@@ -18723,7 +18779,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
 
         for (int i = 0; i < bound; ++i) {
             // \beta_{j} = \rho_{j} y^t_{j} \cdot \gamma_{i}
-            ggml_vec_dot_f32(nx, &beta, &lm_y[j[0]*nx], d);
+            ggml_vec_dot_f32(nx, &beta, 0, &lm_y[j[0]*nx], 0, d, 0, 1);
             beta /= lm_ys[j[0]];
             // \gamma_{i+1} = \gamma_{i} + (\alpha_{j} - \beta_{j}) s_{j}
             ggml_vec_mad_f32(nx, d, &lm_s[j[0]*nx], lm_alpha[j[0]] - beta);
@@ -20621,4 +20677,12 @@ int ggml_cpu_has_vsx(void) {
 #endif
 }
 
+int ggml_cpu_has_matmul_int8(void) {
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
 ////////////////////////////////////////////////////////////////////////////////
diff --git a/ggml.h b/ggml.h
index 51309947110127b04e4d954e794601549f732694..01cecc1e1845ffceafdfed44a2b5de78d78ca41b 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -2290,6 +2290,7 @@ extern "C" {
     GGML_API int ggml_cpu_has_ssse3      (void);
     GGML_API int ggml_cpu_has_sycl       (void);
     GGML_API int ggml_cpu_has_vsx        (void);
+    GGML_API int ggml_cpu_has_matmul_int8(void);
 
     //
     // Internal types and functions exposed for tests and benchmarks
@@ -2303,7 +2304,8 @@ extern "C" {
 #endif
     typedef void (*ggml_to_float_t)  (const void  * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
     typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void  * GGML_RESTRICT y, int k);
-    typedef void (*ggml_vec_dot_t)   (const int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT x, const void * GGML_RESTRICT y);
+    typedef void (*ggml_vec_dot_t)   (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, size_t bx,
+                                      const void * GGML_RESTRICT y, size_t by, int nrc);
 
     typedef struct {
         const char      * type_name;
@@ -2315,6 +2317,7 @@ extern "C" {
         ggml_from_float_t from_float_reference;
         ggml_vec_dot_t    vec_dot;
         enum ggml_type    vec_dot_type;
+        int64_t           nrows; // number of rows to process simultaneously;
     } ggml_type_traits_t;
 
     GGML_API ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type);