]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
sgemm : AVX Q4_0 and Q8_0 (#6891)
authorEve <redacted>
Wed, 8 May 2024 14:29:23 +0000 (14:29 +0000)
committerGitHub <redacted>
Wed, 8 May 2024 14:29:23 +0000 (17:29 +0300)
* basic avx implementation

* style

* combine denibble with load

* reduce 256 to 128 (and back!) conversions

* sse load

* Update sgemm.cpp

* oops

oops

sgemm.cpp

index 4e0159804e8166b118d51d200dfa727a148f9b42..40ba9d7e9a7b728ac92e4ce49c9e8cfad6bbd45e 100644 (file)
--- a/sgemm.cpp
+++ b/sgemm.cpp
@@ -1,6 +1,3 @@
-// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
-// vi: set et ft=c++ ts=4 sts=4 sw=4 fenc=utf-8 :vi
-//
 // Copyright 2024 Mozilla Foundation
 //
 // Permission is hereby granted, free of charge, to any person obtaining
@@ -585,15 +582,15 @@ class tinyBLAS_Q0_ARM {
 };
 #endif // __ARM_FEATURE_DOTPROD
 
-#if defined(__AVX2__) || defined(__AVX512F__)
+#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
 template <typename TA, typename TB, typename TC>
-class tinyBLAS_Q0_AVX2 {
+class tinyBLAS_Q0_AVX {
   public:
-    tinyBLAS_Q0_AVX2(int64_t k,
-                     const TA *A, int64_t lda,
-                     const TB *B, int64_t ldb,
-                     TC *C, int64_t ldc,
-                     int ith, int nth)
+    tinyBLAS_Q0_AVX(int64_t k,
+                    const TA *A, int64_t lda,
+                    const TB *B, int64_t ldb,
+                    TC *C, int64_t ldc,
+                    int ith, int nth)
         : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
     }
 
@@ -728,14 +725,34 @@ class tinyBLAS_Q0_AVX2 {
             __m256 Cv[RN][RM] = {};
             for (int64_t l = 0; l < k; ++l)
                 for (int64_t j = 0; j < RN; ++j)
-                    for (int64_t i = 0; i < RM; ++i)
+                    for (int64_t i = 0; i < RM; ++i) {
+#if defined(__AVX2__)
+                        __m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
+                                                              load(A + lda * (ii + i) + l)),
+                                             _mm256_sign_epi8(load(B + ldb * (jj + j) + l),
+                                                              load(A + lda * (ii + i) + l)));
+#else
+                        __m128i ali0 = load0(A + lda * (ii + i) + l);
+                        __m128i ali1 = load1(A + lda * (ii + i) + l);
+                        __m128i blj0 = load0(B + ldb * (jj + j) + l);
+                        __m128i blj1 = load1(B + ldb * (jj + j) + l);
+
+                        __m128i sepAA0 = _mm_sign_epi8(ali0, ali0);
+                        __m128i sepAA1 = _mm_sign_epi8(ali1, ali1);
+                        __m128i sepBA0 = _mm_sign_epi8(blj0, ali0);
+                        __m128i sepBA1 = _mm_sign_epi8(blj1, ali1);
+
+                        // updot
+                        const __m128i oneFill = _mm_set1_epi16(1);
+                        __m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0);
+                        __m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1);
+                        __m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0)));
+#endif
                         Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
                                                        unhalf(B[ldb * (jj + j) + l].d)),
-                                        updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
-                                                               load(A + lda * (ii + i) + l)),
-                                              _mm256_sign_epi8(load(B + ldb * (jj + j) + l),
-                                                               load(A + lda * (ii + i) + l))),
-                                        Cv[j][i]);
+                                                       udTmp,
+                                                       Cv[j][i]);
+                    }
             for (int64_t j = 0; j < RN; ++j)
                 for (int64_t i = 0; i < RM; ++i)
                     C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
@@ -746,10 +763,28 @@ class tinyBLAS_Q0_AVX2 {
         return _mm256_loadu_si256((const __m256i *)b->qs);
     }
 
+    inline __m128i load0(const block_q8_0 *b) {
+        return _mm_loadu_si128((const __m128i *)b->qs);
+    }
+
+    inline __m128i load1(const block_q8_0 *b) {
+        return _mm_loadu_si128(((const __m128i *)b->qs) + 1);
+    }
+
     inline __m256i load(const block_q4_0 *b) {
         return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8));
     }
 
+    inline __m128i load0(const block_q4_0 *b) {
+        const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
+        return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8));
+    }
+
+    inline __m128i load1(const block_q4_0 *b) {
+        const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
+        return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
+    }
+
     inline __m256 updot(__m256i u, __m256i s) {
         __m256i res;
 #if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
@@ -777,7 +812,7 @@ class tinyBLAS_Q0_AVX2 {
     const int ith;
     const int nth;
 };
-#endif // __AVX2__
+#endif // __AVX__
 
 } // namespace
 
@@ -928,8 +963,8 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
     case GGML_TYPE_Q8_0: {
         if (Btype != GGML_TYPE_Q8_0)
            return false;
-#if defined(__AVX2__) || defined(__AVX512F__)
-        tinyBLAS_Q0_AVX2<block_q8_0, block_q8_0, float> tb{
+#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
+        tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float> tb{
             k, (const block_q8_0 *)A, lda,
             (const block_q8_0 *)B, ldb,
             (float *)C, ldc,
@@ -952,8 +987,8 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
     case GGML_TYPE_Q4_0: {
         if (Btype != GGML_TYPE_Q8_0)
             return false;
-#if defined(__AVX2__) || defined(__AVX512F__)
-        tinyBLAS_Q0_AVX2<block_q4_0, block_q8_0, float> tb{
+#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
+        tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float> tb{
             k, (const block_q4_0 *)A, lda,
             (const block_q8_0 *)B, ldb,
             (float *)C, ldc,