-// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
-// vi: set et ft=c++ ts=4 sts=4 sw=4 fenc=utf-8 :vi
-//
// Copyright 2024 Mozilla Foundation
//
// Permission is hereby granted, free of charge, to any person obtaining
};
#endif // __ARM_FEATURE_DOTPROD
-#if defined(__AVX2__) || defined(__AVX512F__)
+#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
template <typename TA, typename TB, typename TC>
-class tinyBLAS_Q0_AVX2 {
+class tinyBLAS_Q0_AVX {
public:
- tinyBLAS_Q0_AVX2(int64_t k,
- const TA *A, int64_t lda,
- const TB *B, int64_t ldb,
- TC *C, int64_t ldc,
- int ith, int nth)
+ tinyBLAS_Q0_AVX(int64_t k,
+ const TA *A, int64_t lda,
+ const TB *B, int64_t ldb,
+ TC *C, int64_t ldc,
+ int ith, int nth)
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
}
__m256 Cv[RN][RM] = {};
for (int64_t l = 0; l < k; ++l)
for (int64_t j = 0; j < RN; ++j)
- for (int64_t i = 0; i < RM; ++i)
+ for (int64_t i = 0; i < RM; ++i) {
+#if defined(__AVX2__)
+ __m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
+ load(A + lda * (ii + i) + l)),
+ _mm256_sign_epi8(load(B + ldb * (jj + j) + l),
+ load(A + lda * (ii + i) + l)));
+#else
+ __m128i ali0 = load0(A + lda * (ii + i) + l);
+ __m128i ali1 = load1(A + lda * (ii + i) + l);
+ __m128i blj0 = load0(B + ldb * (jj + j) + l);
+ __m128i blj1 = load1(B + ldb * (jj + j) + l);
+
+ __m128i sepAA0 = _mm_sign_epi8(ali0, ali0);
+ __m128i sepAA1 = _mm_sign_epi8(ali1, ali1);
+ __m128i sepBA0 = _mm_sign_epi8(blj0, ali0);
+ __m128i sepBA1 = _mm_sign_epi8(blj1, ali1);
+
+ // updot
+ const __m128i oneFill = _mm_set1_epi16(1);
+ __m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0);
+ __m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1);
+ __m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0)));
+#endif
Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
unhalf(B[ldb * (jj + j) + l].d)),
- updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
- load(A + lda * (ii + i) + l)),
- _mm256_sign_epi8(load(B + ldb * (jj + j) + l),
- load(A + lda * (ii + i) + l))),
- Cv[j][i]);
+ udTmp,
+ Cv[j][i]);
+ }
for (int64_t j = 0; j < RN; ++j)
for (int64_t i = 0; i < RM; ++i)
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
return _mm256_loadu_si256((const __m256i *)b->qs);
}
+ inline __m128i load0(const block_q8_0 *b) {
+ return _mm_loadu_si128((const __m128i *)b->qs);
+ }
+
+ inline __m128i load1(const block_q8_0 *b) {
+ return _mm_loadu_si128(((const __m128i *)b->qs) + 1);
+ }
+
inline __m256i load(const block_q4_0 *b) {
return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8));
}
+ inline __m128i load0(const block_q4_0 *b) {
+ const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
+ return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8));
+ }
+
+ inline __m128i load1(const block_q4_0 *b) {
+ const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
+ return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
+ }
+
inline __m256 updot(__m256i u, __m256i s) {
__m256i res;
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
const int ith;
const int nth;
};
-#endif // __AVX2__
+#endif // __AVX__
} // namespace
case GGML_TYPE_Q8_0: {
if (Btype != GGML_TYPE_Q8_0)
return false;
-#if defined(__AVX2__) || defined(__AVX512F__)
- tinyBLAS_Q0_AVX2<block_q8_0, block_q8_0, float> tb{
+#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
+ tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float> tb{
k, (const block_q8_0 *)A, lda,
(const block_q8_0 *)B, ldb,
(float *)C, ldc,
case GGML_TYPE_Q4_0: {
if (Btype != GGML_TYPE_Q8_0)
return false;
-#if defined(__AVX2__) || defined(__AVX512F__)
- tinyBLAS_Q0_AVX2<block_q4_0, block_q8_0, float> tb{
+#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
+ tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float> tb{
k, (const block_q4_0 *)A, lda,
(const block_q8_0 *)B, ldb,
(float *)C, ldc,