]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
sgemm : improved Q4_0 and Q8_0 performance via 4xN and Mx4 gemm (#8908)
authorSrihari-mcw <redacted>
Sat, 31 Aug 2024 08:20:35 +0000 (13:50 +0530)
committerGitHub <redacted>
Sat, 31 Aug 2024 08:20:35 +0000 (11:20 +0300)
ggml/src/llamafile/sgemm.cpp

index 6626ceb26213f230558cfaae9d26656c6e00b5dc..f0988ba7cd24c5907b8d131fe2dcccee8555a048 100644 (file)
@@ -606,17 +606,29 @@ class tinyBLAS_Q0_AVX {
         case 0x44:
             mc = 4;
             nc = 4;
+#if defined(__AVX2__) && defined(__F16C__)
+            gemm4xN<4>(m0, m, n0, n);
+#else
             gemm<4, 4>(m0, m, n0, n);
+#endif
             break;
         case 0x43:
             mc = 4;
             nc = 3;
+#if defined(__AVX2__) && defined(__F16C__)
+            gemm4xN<3>(m0, m, n0, n);
+#else
             gemm<4, 3>(m0, m, n0, n);
+#endif
             break;
         case 0x34:
             mc = 3;
             nc = 4;
+#if defined(__AVX2__) && defined(__F16C__)
+            gemmMx4<3>(m0, m, n0, n);
+#else
             gemm<3, 4>(m0, m, n0, n);
+#endif
             break;
         case 0x33:
             mc = 3;
@@ -626,12 +638,20 @@ class tinyBLAS_Q0_AVX {
         case 0x42:
             mc = 4;
             nc = 2;
+#if defined(__AVX2__) && defined(__F16C__)
+            gemm4xN<2>(m0, m, n0, n);
+#else
             gemm<4, 2>(m0, m, n0, n);
+#endif
             break;
         case 0x24:
             mc = 2;
             nc = 4;
+#if defined(__AVX2__) && defined(__F16C__)
+            gemmMx4<2>(m0, m, n0, n);
+#else
             gemm<2, 4>(m0, m, n0, n);
+#endif
             break;
 #else
         case 0x44:
@@ -639,13 +659,21 @@ class tinyBLAS_Q0_AVX {
         case 0x42:
             mc = 4;
             nc = 2;
+#if defined(__AVX2__) && defined(__F16C__)
+            gemm4xN<2>(m0, m, n0, n);
+#else
             gemm<4, 2>(m0, m, n0, n);
+#endif
             break;
         case 0x34:
         case 0x24:
             mc = 2;
             nc = 4;
+#if defined(__AVX2__) && defined(__F16C__)
+            gemmMx4<2>(m0, m, n0, n);
+#else
             gemm<2, 4>(m0, m, n0, n);
+#endif
             break;
         case 0x33:
 #endif
@@ -662,7 +690,11 @@ class tinyBLAS_Q0_AVX {
         case 0x41:
             mc = 4;
             nc = 1;
+#if defined(__AVX2__) && defined(__F16C__)
+            gemm4xN<1>(m0, m, n0, n);
+#else
             gemm<4, 1>(m0, m, n0, n);
+#endif
             break;
         case 0x22:
             mc = 2;
@@ -672,7 +704,11 @@ class tinyBLAS_Q0_AVX {
         case 0x14:
             mc = 1;
             nc = 4;
+#if defined(__AVX2__) && defined(__F16C__)
+            gemmMx4<1>(m0, m, n0, n);
+#else
             gemm<1, 4>(m0, m, n0, n);
+#endif
             break;
         case 0x31:
             mc = 3;
@@ -708,6 +744,119 @@ class tinyBLAS_Q0_AVX {
         mnpack(m0, m, np, n);
     }
 
+#if defined(__AVX2__) && defined(__F16C__)
+// Templated functions for gemm of dimensions 4xN
+    template <int RN>
+    NOINLINE void gemm4xN(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+        int64_t ytiles = (m - m0) / 4;
+        int64_t xtiles = (n - n0) / RN;
+        int64_t tiles = xtiles * ytiles;
+        int64_t duty = (tiles + nth - 1) / nth;
+        int64_t start = duty * ith;
+        int64_t end = start + duty;
+        if (end > tiles)
+            end = tiles;
+        for (int64_t job = start; job < end; ++job) {
+            int64_t ii = m0 + job / xtiles * 4;
+            int64_t jj = n0 + job % xtiles * RN;
+            __m256 Cv[RN][4] = {};
+            for (int64_t l = 0; l < k; ++l) {
+                uint64_t a_delta = ((uint64_t)A[lda * (ii + 3) + l].d << 48) | ((uint64_t)A[lda * (ii + 2) + l].d << 32) | ((uint64_t)A[lda * (ii + 1) + l].d << 16) | (A[lda * (ii + 0) + l].d);
+                // Convert delta values for four blocks to float values
+                __m128 da = _mm_cvtph_ps(_mm_set_epi64x(0, a_delta));
+                __m256i avec0 = load(A + lda * (ii + 0) + l);
+                __m256i avec1 = load(A + lda * (ii + 1) + l);
+                __m256i avec2 = load(A + lda * (ii + 2) + l);
+                __m256i avec3 = load(A + lda * (ii + 3) + l);
+                for (int64_t j = 0; j < RN; ++j) {
+                        __m128 db = _mm_set1_ps(unhalf(B[ldb * (jj + j) + l].d));
+                        // Computation of product of delta values for four blocks and replicate it across 256 bit lane
+                        __m256 dvec =  _mm256_castps128_ps256(_mm_mul_ps(da, db));
+                        dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);
+                        // Computation of dot product and multiplication with appropriate delta value products
+                        Cv[j][0] = madd(_mm256_shuffle_ps(dvec, dvec, 0),
+                                    updot(_mm256_sign_epi8(avec0, avec0),
+                                          _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec0)),
+                                    Cv[j][0]);
+                        Cv[j][1] = madd(_mm256_shuffle_ps(dvec, dvec, 85),
+                                    updot(_mm256_sign_epi8(avec1, avec1),
+                                            _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec1)),
+                                    Cv[j][1]);
+                        Cv[j][2] = madd(_mm256_shuffle_ps(dvec, dvec, 170),
+                                    updot(_mm256_sign_epi8(avec2, avec2),
+                                            _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec2)),
+                                    Cv[j][2]);
+                        Cv[j][3] = madd(_mm256_shuffle_ps(dvec, dvec, 255),
+                                    updot(_mm256_sign_epi8(avec3, avec3),
+                                            _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec3)),
+                                    Cv[j][3]);
+                }
+            }
+
+            for (int64_t j = 0; j < RN; ++j)
+                for (int64_t i = 0; i < 4; ++i)
+                    C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
+        }
+    }
+
+    // Templated functions for gemm of dimensions Mx4
+    template <int RM>
+    NOINLINE void gemmMx4(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+        int64_t ytiles = (m - m0) / RM;
+        int64_t xtiles = (n - n0) / 4;
+        int64_t tiles = xtiles * ytiles;
+        int64_t duty = (tiles + nth - 1) / nth;
+        int64_t start = duty * ith;
+        int64_t end = start + duty;
+        if (end > tiles)
+            end = tiles;
+        for (int64_t job = start; job < end; ++job) {
+            int64_t ii = m0 + job / xtiles * RM;
+            int64_t jj = n0 + job % xtiles * 4;
+            __m256 Cv[4][RM] = {};
+            for (int64_t l = 0; l < k; ++l) {
+                uint64_t b_delta = ((uint64_t)B[ldb * (jj + 3) + l].d << 48) | ((uint64_t)B[ldb * (jj + 2) + l].d << 32) | ((uint64_t)B[ldb * (jj + 1) + l].d << 16) | (B[ldb * (jj + 0) + l].d);
+                // Convert delta values for four blocks to float values
+                __m128 db = _mm_cvtph_ps(_mm_set_epi64x(0, b_delta));
+                __m256i bvec0 = load(B + ldb * (jj + 0) + l);
+                __m256i bvec1 = load(B + ldb * (jj + 1) + l);
+                __m256i bvec2 = load(B + ldb * (jj + 2) + l);
+                __m256i bvec3 = load(B + ldb * (jj + 3) + l);
+                for (int64_t i = 0; i < RM; ++i) {
+                    __m128 da = _mm_set1_ps(unhalf((A[lda * (ii + i) + l].d)));
+                    // Computation of product of delta values for four blocks and replicate it across 256 bit lane
+                    __m256 dvec =  _mm256_castps128_ps256(_mm_mul_ps(da, db));
+                    dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);
+                    // Computation of dot product and multiplication with appropriate delta value products
+                    Cv[0][i] = madd(_mm256_shuffle_ps(dvec, dvec, 0),
+                                    updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
+                                                            load(A + lda * (ii + i) + l)),
+                                            _mm256_sign_epi8(bvec0, load(A + lda * (ii + i) + l))),
+                                    Cv[0][i]);
+                    Cv[1][i] = madd(_mm256_shuffle_ps(dvec, dvec, 85),
+                                    updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
+                                                            load(A + lda * (ii + i) + l)),
+                                            _mm256_sign_epi8(bvec1, load(A + lda * (ii + i) + l))),
+                                    Cv[1][i]);
+                    Cv[2][i] = madd(_mm256_shuffle_ps(dvec, dvec, 170),
+                                    updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
+                                                            load(A + lda * (ii + i) + l)),
+                                            _mm256_sign_epi8(bvec2, load(A + lda * (ii + i) + l))),
+                                    Cv[2][i]);
+                    Cv[3][i] = madd(_mm256_shuffle_ps(dvec, dvec, 255),
+                                    updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
+                                                            load(A + lda * (ii + i) + l)),
+                                            _mm256_sign_epi8(bvec3, load(A + lda * (ii + i) + l))),
+                                    Cv[3][i]);
+                }
+            }
+            for (int64_t j = 0; j < 4; ++j)
+                for (int64_t i = 0; i < RM; ++i)
+                    C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
+        }
+    }
+#endif
+
     template <int RM, int RN>
     NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
         int64_t ytiles = (m - m0) / RM;