#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define MAX(a, b) ((a) > (b) ? (a) : (b))
+#define UNUSED GGML_UNUSED
+
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
}
#endif
-void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
const int qk = QK8_0;
const int nb = n / qk;
assert(n % qk == 0);
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+ assert((nrc == 2) || (nrc == 1));
+#else
+ assert(nrc == 1);
+#endif
const block_q4_0 * restrict x = vx;
const block_q8_0 * restrict y = vy;
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+ if (nrc == 2) {
+ const block_q4_0 * restrict vx0 = vx;
+ const block_q4_0 * restrict vx1 = vx + bx;
+
+ const block_q8_0 * restrict vy0 = vy;
+ const block_q8_0 * restrict vy1 = vy + by;
+
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
+
+ for (int i = 0; i < nb; i++) {
+ const block_q4_0 * restrict b_x0 = &vx0[i];
+ const block_q4_0 * restrict b_x1 = &vx1[i];
+ const block_q8_0 * restrict b_y0 = &vy0[i];
+ const block_q8_0 * restrict b_y1 = &vy1[i];
+
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
+ const int8x16_t s8b = vdupq_n_s8(0x8);
+
+ const uint8x16_t v0_0 = vld1q_u8(b_x0->qs);
+ const uint8x16_t v0_1 = vld1q_u8(b_x1->qs);
+
+ // 4-bit -> 8-bit
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
+ const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+ // sub 8
+ const int8x16_t x0_l = vsubq_s8(v0_0l, s8b);
+ const int8x16_t x0_h = vsubq_s8(v0_0h, s8b);
+ const int8x16_t x1_l = vsubq_s8(v0_1l, s8b);
+ const int8x16_t x1_h = vsubq_s8(v0_1h, s8b);
+
+ // load y
+ const int8x16_t y0_l = vld1q_s8(b_y0->qs);
+ const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
+ const int8x16_t y1_l = vld1q_s8(b_y1->qs);
+ const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
+
+ float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
+ GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
+
+ int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
+ int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
+
+ int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
+ int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
+
+ int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
+ int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
+
+ int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
+ int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
+
+ sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
+ l1, r1)), l2, r2)), l3, r3))), scale);
+ }
+ float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
+ float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
+
+ vst1_f32(s, vget_low_f32(sumv2));
+ vst1_f32(s + bs, vget_high_f32(sumv2));
+ return;
+ }
+#endif
#if defined(__ARM_NEON)
float32x4_t sumv0 = vdupq_n_f32(0.0f);
float32x4_t sumv1 = vdupq_n_f32(0.0f);
#endif
}
-void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
const int qk = QK8_1;
const int nb = n / qk;
assert(n % qk == 0);
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+ assert((nrc == 2) || (nrc == 1));
+#else
+ assert(nrc == 1);
+#endif
const block_q4_1 * restrict x = vx;
const block_q8_1 * restrict y = vy;
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+ if (nrc == 2) {
+ const block_q4_1 * restrict vx0 = vx;
+ const block_q4_1 * restrict vx1 = vx + bx;
+ const block_q8_1 * restrict vy0 = vy;
+ const block_q8_1 * restrict vy1 = vy + by;
+
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
+ float32x4_t summs0 = vdupq_n_f32(0.0f);
+
+ for (int i = 0; i < nb; i++) {
+ const block_q4_1 * restrict b_x0 = &vx0[i];
+ const block_q4_1 * restrict b_x1 = &vx1[i];
+ const block_q8_1 * restrict b_y0 = &vy0[i];
+ const block_q8_1 * restrict b_y1 = &vy1[i];
+
+ float32x4_t summs_t = {GGML_FP16_TO_FP32(b_x0->m) * b_y0->s,
+ GGML_FP16_TO_FP32(b_x1->m) * b_y0->s,
+ GGML_FP16_TO_FP32(b_x0->m) * b_y1->s,
+ GGML_FP16_TO_FP32(b_x1->m) * b_y1->s};
+ summs0 += summs_t;
+
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
+
+ const uint8x16_t v0_0 = vld1q_u8(b_x0->qs);
+ const uint8x16_t v0_1 = vld1q_u8(b_x1->qs);
+
+ // 4-bit -> 8-bit
+ const int8x16_t x0_l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
+ const int8x16_t x0_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+ const int8x16_t x1_l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
+ const int8x16_t x1_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+ // load y
+ const int8x16_t y0_l = vld1q_s8(b_y0->qs);
+ const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
+ const int8x16_t y1_l = vld1q_s8(b_y1->qs);
+ const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
+
+ // mmla into int32x4_t
+ float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
+ GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
+
+ int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
+ int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
+
+ int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
+ int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
+
+ int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
+ int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
+
+ int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
+ int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
+ sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
+ l1, r1)), l2, r2)), l3, r3))), scale);
+ }
+
+ float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
+ float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
+ sumv2 = sumv2 + summs0;
+
+ vst1_f32(s, vget_low_f32(sumv2));
+ vst1_f32(s + bs, vget_high_f32(sumv2));
+ return;
+ }
+#endif
// TODO: add WASM SIMD
#if defined(__ARM_NEON)
float32x4_t sumv0 = vdupq_n_f32(0.0f);
#endif
}
-void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
const int qk = QK8_0;
const int nb = n / qk;
assert(n % qk == 0);
assert(qk == QK5_0);
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
const block_q5_0 * restrict x = vx;
const block_q8_0 * restrict y = vy;
#endif
}
-void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
const int qk = QK8_1;
const int nb = n / qk;
assert(n % qk == 0);
assert(qk == QK5_1);
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
const block_q5_1 * restrict x = vx;
const block_q8_1 * restrict y = vy;
#endif
}
-void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
const int qk = QK8_0;
const int nb = n / qk;
assert(n % qk == 0);
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+ assert((nrc == 2) || (nrc == 1));
+#else
+ assert(nrc == 1);
+#endif
const block_q8_0 * restrict x = vx;
const block_q8_0 * restrict y = vy;
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+ if (nrc == 2) {
+ const block_q8_0 * restrict vx0 = vx;
+ const block_q8_0 * restrict vx1 = vx + bx;
+ const block_q8_0 * restrict vy0 = vy;
+ const block_q8_0 * restrict vy1 = vy + by;
+
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
+
+ for (int i = 0; i < nb; i++) {
+ const block_q8_0 * restrict b_x0 = &vx0[i];
+ const block_q8_0 * restrict b_y0 = &vy0[i];
+
+ const block_q8_0 * restrict b_x1 = &vx1[i];
+ const block_q8_0 * restrict b_y1 = &vy1[i];
+
+ const int8x16_t x0_l = vld1q_s8(b_x0->qs);
+ const int8x16_t x0_h = vld1q_s8(b_x0->qs + 16);
+ const int8x16_t x1_l = vld1q_s8(b_x1->qs);
+ const int8x16_t x1_h = vld1q_s8(b_x1->qs + 16);
+
+ // load y
+ const int8x16_t y0_l = vld1q_s8(b_y0->qs);
+ const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
+ const int8x16_t y1_l = vld1q_s8(b_y1->qs);
+ const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
+
+ float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
+ GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
+
+ int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
+ int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
+
+ int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
+ int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
+
+ int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
+ int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
+
+ int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
+ int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
+
+ sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
+ l1, r1)), l2, r2)), l3, r3))), scale);
+ }
+ float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
+ float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
+
+ vst1_f32(s, vget_low_f32(sumv2));
+ vst1_f32(s + bs, vget_high_f32(sumv2));
+ return;
+ }
+#endif
#if defined(__ARM_NEON)
float32x4_t sumv0 = vdupq_n_f32(0.0f);
float32x4_t sumv1 = vdupq_n_f32(0.0f);
}
#if QK_K == 256
-void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
const block_q2_K * restrict x = vx;
const block_q8_K * restrict y = vy;
#else
-void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
const block_q2_K * restrict x = vx;
const block_q8_K * restrict y = vy;
#endif
#if QK_K == 256
-void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(n % QK_K == 0);
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
const uint32_t kmask1 = 0x03030303;
const uint32_t kmask2 = 0x0f0f0f0f;
#else
-void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(n % QK_K == 0);
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
const block_q3_K * restrict x = vx;
const block_q8_K * restrict y = vy;
#endif
#if QK_K == 256
-void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(n % QK_K == 0);
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
const block_q4_K * restrict x = vx;
const block_q8_K * restrict y = vy;
#endif
}
#else
-void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(n % QK_K == 0);
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
const block_q4_K * restrict x = vx;
const block_q8_K * restrict y = vy;
#endif
#if QK_K == 256
-void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(n % QK_K == 0);
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
const block_q5_K * restrict x = vx;
const block_q8_K * restrict y = vy;
#else
-void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(n % QK_K == 0);
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
const block_q5_K * restrict x = vx;
const block_q8_K * restrict y = vy;
#if QK_K == 256
-void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(n % QK_K == 0);
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
const block_q6_K * restrict x = vx;
const block_q8_K * restrict y = vy;
#else
-void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(n % QK_K == 0);
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
const block_q6_K * restrict x = vx;
const block_q8_K * restrict y = vy;
1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1,
};
-void ggml_vec_dot_iq2_xxs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(n % QK_K == 0);
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
const block_iq2_xxs * restrict x = vx;
const block_q8_K * restrict y = vy;
#endif
}
-void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(n % QK_K == 0);
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
const block_iq2_xs * restrict x = vx;
const block_q8_K * restrict y = vy;
}
// TODO
-void ggml_vec_dot_iq3_xxs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(n % QK_K == 0);
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
const block_iq3_xxs * restrict x = vx;
const block_q8_K * restrict y = vy;
void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
// Dot product
-void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy);
-void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy);
-void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy);
-void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy);
-void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy);
-
-void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy);
-void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy);
-void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy);
-void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy);
-void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy);
-void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy);
-void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy);
-void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy);
+void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
+void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
//
// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
-static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y);
-static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y);
+static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc);
+static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc);
static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
[GGML_TYPE_I8] = {
.is_quantized = false,
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32,
.vec_dot_type = GGML_TYPE_F32,
+ .nrows = 1,
},
[GGML_TYPE_F16] = {
.type_name = "f16",
.from_float_reference = (ggml_from_float_t) ggml_fp32_to_fp16_row,
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f16,
.vec_dot_type = GGML_TYPE_F16,
+ .nrows = 1,
},
[GGML_TYPE_Q4_0] = {
.type_name = "q4_0",
.from_float_reference = (ggml_from_float_t) quantize_row_q4_0_reference,
.vec_dot = ggml_vec_dot_q4_0_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0,
+#if defined (__ARM_FEATURE_MATMUL_INT8)
+ .nrows = 2,
+#else
+ .nrows = 1,
+#endif
},
[GGML_TYPE_Q4_1] = {
.type_name = "q4_1",
.from_float_reference = (ggml_from_float_t) quantize_row_q4_1_reference,
.vec_dot = ggml_vec_dot_q4_1_q8_1,
.vec_dot_type = GGML_TYPE_Q8_1,
+#if defined (__ARM_FEATURE_MATMUL_INT8)
+ .nrows = 2,
+#else
+ .nrows = 1,
+#endif
},
[4] = { // GGML_TYPE_Q4_2
.type_name = "DEPRECATED",
.from_float_reference = NULL,
.vec_dot = NULL,
.vec_dot_type = GGML_TYPE_COUNT,
+ .nrows = 1,
},
[5] = { // GGML_TYPE_Q4_3
.type_name = "DEPRECATED",
.from_float_reference = NULL,
.vec_dot = NULL,
.vec_dot_type = GGML_TYPE_COUNT,
+ .nrows = 1,
},
[GGML_TYPE_Q5_0] = {
.type_name = "q5_0",
.from_float_reference = (ggml_from_float_t) quantize_row_q5_0_reference,
.vec_dot = ggml_vec_dot_q5_0_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0,
+ .nrows = 1,
},
[GGML_TYPE_Q5_1] = {
.type_name = "q5_1",
.from_float_reference = (ggml_from_float_t) quantize_row_q5_1_reference,
.vec_dot = ggml_vec_dot_q5_1_q8_1,
.vec_dot_type = GGML_TYPE_Q8_1,
+ .nrows = 1,
},
[GGML_TYPE_Q8_0] = {
.type_name = "q8_0",
.from_float_reference = (ggml_from_float_t) quantize_row_q8_0_reference,
.vec_dot = ggml_vec_dot_q8_0_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0,
+#if defined (__ARM_FEATURE_MATMUL_INT8)
+ .nrows = 2,
+#else
+ .nrows = 1,
+#endif
},
[GGML_TYPE_Q8_1] = {
.type_name = "q8_1",
.from_float = quantize_row_q8_1,
.from_float_reference = (ggml_from_float_t) quantize_row_q8_1_reference,
.vec_dot_type = GGML_TYPE_Q8_1,
+ .nrows = 1,
},
[GGML_TYPE_Q2_K] = {
.type_name = "q2_K",
.from_float_reference = (ggml_from_float_t) quantize_row_q2_K_reference,
.vec_dot = ggml_vec_dot_q2_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K,
+ .nrows = 1,
},
[GGML_TYPE_Q3_K] = {
.type_name = "q3_K",
.from_float_reference = (ggml_from_float_t) quantize_row_q3_K_reference,
.vec_dot = ggml_vec_dot_q3_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K,
+ .nrows = 1,
},
[GGML_TYPE_Q4_K] = {
.type_name = "q4_K",
.from_float_reference = (ggml_from_float_t) quantize_row_q4_K_reference,
.vec_dot = ggml_vec_dot_q4_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K,
+ .nrows = 1,
},
[GGML_TYPE_Q5_K] = {
.type_name = "q5_K",
.from_float_reference = (ggml_from_float_t) quantize_row_q5_K_reference,
.vec_dot = ggml_vec_dot_q5_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K,
+ .nrows = 1,
},
[GGML_TYPE_Q6_K] = {
.type_name = "q6_K",
.from_float_reference = (ggml_from_float_t) quantize_row_q6_K_reference,
.vec_dot = ggml_vec_dot_q6_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K,
+ .nrows = 1,
},
[GGML_TYPE_IQ2_XXS] = {
.type_name = "iq2_xxs",
.from_float_reference = NULL,
.vec_dot = ggml_vec_dot_iq2_xxs_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K,
+ .nrows = 1,
},
[GGML_TYPE_IQ2_XS] = {
.type_name = "iq2_xs",
.from_float_reference = NULL,
.vec_dot = ggml_vec_dot_iq2_xs_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K,
+ .nrows = 1,
},
[GGML_TYPE_IQ3_XXS] = {
.type_name = "iq3_xxs",
.from_float_reference = (ggml_from_float_t)quantize_row_iq3_xxs_reference,
.vec_dot = ggml_vec_dot_iq3_xxs_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K,
+ .nrows = 1,
},
[GGML_TYPE_Q8_K] = {
.type_name = "q8_K",
inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; }
inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; }
-static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) {
+static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) {
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
+
#ifdef GGML_SIMD
float sumf = 0.0f;
const int np = (n & ~(GGML_F32_STEP - 1));
*s = sumf;
}
-static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
+static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc) {
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
+
ggml_float sumf = 0.0;
#if defined(GGML_SIMD)
#endif
}
-inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrtf(*s); }
+inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); }
inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
inline static void ggml_vec_log_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]); }
ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
+ int64_t const vec_dot_num_rows = type_traits[type].nrows;
GGML_ASSERT(ne0 == ne01);
GGML_ASSERT(ne1 == ne11);
const int64_t blck_0 = 16;
const int64_t blck_1 = 16;
+ // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
+ int64_t nrc = vec_dot_num_rows;
+ // TODO: currently the mmla kernels support only even numbered rows/cols.
+ // this check can be removed once they are extended to support odd numbered rows/cols too
+ if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) {
+ nrc = 1;
+ }
+
+ const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;
+
// attempt to reduce false-sharing (does not seem to make a difference)
- float tmp[16];
+ // 16 * 2, accounting for mmla kernels
+ float tmp[32];
for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
- for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
+ for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ir1 += nrc) {
const int64_t i13 = (ir1/(ne12*ne1));
const int64_t i12 = (ir1 - i13*ne12*ne1)/ne1;
const int64_t i11 = (ir1 - i13*ne12*ne1 - i12*ne1);
(src1_cont || src1->type != vec_dot_type
? (i11 + i12*ne11 + i13*ne12*ne11)*row_size
: (i11*nb11 + i12*nb12 + i13*nb13));
-
float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
//for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
// vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
//}
- for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
- vec_dot(ne00, &tmp[ir0 - iir0], src0_row + ir0*nb01, src1_col);
+ for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ir0 += nrc) {
+ vec_dot(ne00, &tmp[ir0 - iir0], (nrc>1 ? 16 : 0), src0_row + ir0*nb01, (nrc>1 ? nb01 : 0), src1_col, (nrc>1 ? src1_col_stride : 0), nrc);
+ }
+
+ for (int cn = 0; cn < nrc; ++cn) {
+ memcpy(&dst_col[iir0 + cn*nb1/nb0], tmp + (cn*16), (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
}
- memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
}
}
}
//}
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
- vec_dot(ne00, &tmp[ir0 - iir0], src0_row + ir0*nb01, src1_col);
+ vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_row + ir0*nb01, 0, src1_col, 0, 1);
}
memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
}
// linear runtime, no additional memory
float dot_y_dy = 0;
- ggml_vec_dot_f32 (nc, &dot_y_dy, y, dy);
+ ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
ggml_vec_cpy_f32 (nc, dx, dy);
ggml_vec_acc1_f32(nc, dx, -dot_y_dy);
ggml_vec_mul_f32 (nc, dx, dx, y);
const int i1n = i10*ne11;
for (int i00 = 0; i00 < ne00; i00++) {
float v = 0;
- ggml_vec_dot_f16(ne02, &v,
- (ggml_fp16_t *) wdata_src + i1n,
- (ggml_fp16_t *) wdata_kernel + i00*ne02);
+ ggml_vec_dot_f16(ne02, &v, 0,
+ (ggml_fp16_t *) wdata_src + i1n, 0,
+ (ggml_fp16_t *) wdata_kernel + i00*ne02, 0, 1);
dst_data[i10*s0 + i00] += v;
}
}
const int i1n = i10*ne11;
for (int i00 = 0; i00 < ne00; i00++) {
float v = 0;
- ggml_vec_dot_f32(ne02, &v,
- wdata_src + i1n,
- wdata_kernel + i00*ne02);
+ ggml_vec_dot_f32(ne02, &v, 0,
+ wdata_src + i1n, 0,
+ wdata_kernel + i00*ne02, 0, 1);
dst_data[i10*s0 + i00] += v;
}
}
for (int i01 = 0; i01 < ne01; i01++) {
for (int i00 = 0; i00 < ne00; i00++) {
float v = 0;
- ggml_vec_dot_f16(ne03, &v,
- wdata_src + i1n,
- wdata_kernel + i01*ne00*ne03 + i00*ne03);
+ ggml_vec_dot_f16(ne03, &v, 0,
+ wdata_src + i1n, 0,
+ wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1);
dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v;
}
}
const int i1 = ik1;
ggml_vec_dot_f32(neq0,
- S + i1,
- (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
- (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
+ S + i1, 0,
+ (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
+ (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
}
// scale
const int iv3 = iq3;
ggml_vec_dot_f32(masked_begin,
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
- (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
- S);
+ (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
+ (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0,
+ S, 0, 1);
}
}
}
const int i1 = ik1;
ggml_vec_dot_f16(neq0,
- S + i1,
- (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
- (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
+ S + i1, 0,
+ (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
+ (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
}
} else {
for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
const int iv3 = iq3;
ggml_vec_dot_f16(nev0,
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
- (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
- S16);
+ (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
+ (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0,
+ S16, 0, 1);
}
} else {
for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
const int i1 = ib01;
ggml_vec_dot_f16(nea0,
- S + i1,
- (ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)),
- (ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3)));
+ S + i1, 0,
+ (ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)), 0,
+ (ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3)), 0, 1);
}
ggml_vec_add_f32(neb01, S, S, (float *) b1->data);
for (int64_t ic = 0; ic < nec01; ++ic) {
ggml_vec_dot_f16(neb01,
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
- (ggml_fp16_t *) ((char *) c0->data + ( ic*nbc01 + i2*nbc02 + i3*nbc03)),
- S16);
+ (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
+ (ggml_fp16_t *) ((char *) c0->data + ( ic*nbc01 + i2*nbc02 + i3*nbc03)), 0,
+ S16, 0, 1);
}
ggml_vec_add_f32(nec01,
const int i1 = ik1;
ggml_vec_dot_f32(neq0,
- S + i1,
- (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
- (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
+ S + i1, 0,
+ (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
+ (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
}
// scale
// S = SM * (S - dot(SM, S))
float dot_SM_gradSM = 0;
- ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, SM, S);
+ ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, 0, SM, 0, S, 0, 1);
ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
ggml_vec_mul_f32 (masked_begin, S, S, SM);
}
// compute the initial gradient in the search direction
- ggml_vec_dot_f32(nx, &dginit, g, d);
+ ggml_vec_dot_f32(nx, &dginit, 0, g, 0, d, 0, 1);
// make sure that d points to a descent direction
if (0 < dginit) {
return count;
}
- ggml_vec_dot_f32(nx, &dg, g, d);
+ ggml_vec_dot_f32(nx, &dg, 0, g, 0, d, 0, 1);
// check the Wolfe condition
if (dg < params->lbfgs.wolfe * dginit) {
// ys = y^t \cdot s -> 1 / \rho.
// yy = y^t \cdot y.
//
- ggml_vec_dot_f32(nx, &ys, &lm_y[end[0]*nx], &lm_s[end[0]*nx]);
- ggml_vec_dot_f32(nx, &yy, &lm_y[end[0]*nx], &lm_y[end[0]*nx]);
+ ggml_vec_dot_f32(nx, &ys, 0, &lm_y[end[0]*nx], 0, &lm_s[end[0]*nx], 0, 1);
+ ggml_vec_dot_f32(nx, &yy, 0, &lm_y[end[0]*nx], 0, &lm_y[end[0]*nx], 0, 1);
lm_ys[end[0]] = ys;
for (int i = 0; i < bound; ++i) {
j[0] = (j[0] + m - 1) % m;
// \alpha_{j} = \rho_{j} s^{t}_{j} \cdot q_{k+1}
- ggml_vec_dot_f32(nx, &lm_alpha[j[0]], &lm_s[j[0]*nx], d);
+ ggml_vec_dot_f32(nx, &lm_alpha[j[0]], 0, &lm_s[j[0]*nx], 0, d, 0, 1);
lm_alpha[j[0]] /= lm_ys[j[0]];
// q_{i} = q_{i+1} - \alpha_{i} y_{i}
ggml_vec_mad_f32(nx, d, &lm_y[j[0]*nx], -lm_alpha[j[0]]);
for (int i = 0; i < bound; ++i) {
// \beta_{j} = \rho_{j} y^t_{j} \cdot \gamma_{i}
- ggml_vec_dot_f32(nx, &beta, &lm_y[j[0]*nx], d);
+ ggml_vec_dot_f32(nx, &beta, 0, &lm_y[j[0]*nx], 0, d, 0, 1);
beta /= lm_ys[j[0]];
// \gamma_{i+1} = \gamma_{i} + (\alpha_{j} - \beta_{j}) s_{j}
ggml_vec_mad_f32(nx, d, &lm_s[j[0]*nx], lm_alpha[j[0]] - beta);
#endif
}
+int ggml_cpu_has_matmul_int8(void) {
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+ return 1;
+#else
+ return 0;
+#endif
+}
+
////////////////////////////////////////////////////////////////////////////////
GGML_API int ggml_cpu_has_ssse3 (void);
GGML_API int ggml_cpu_has_sycl (void);
GGML_API int ggml_cpu_has_vsx (void);
+ GGML_API int ggml_cpu_has_matmul_int8(void);
//
// Internal types and functions exposed for tests and benchmarks
#endif
typedef void (*ggml_to_float_t) (const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
- typedef void (*ggml_vec_dot_t) (const int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT x, const void * GGML_RESTRICT y);
+ typedef void (*ggml_vec_dot_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, size_t bx,
+ const void * GGML_RESTRICT y, size_t by, int nrc);
typedef struct {
const char * type_name;
ggml_from_float_t from_float_reference;
ggml_vec_dot_t vec_dot;
enum ggml_type vec_dot_type;
+ int64_t nrows; // number of rows to process simultaneously;
} ggml_type_traits_t;
GGML_API ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type);