]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llamafile : extend sgemm.cpp support for Q5_0 models (#10010)
authorSrihari-mcw <redacted>
Fri, 25 Oct 2024 07:27:41 +0000 (12:57 +0530)
committerGitHub <redacted>
Fri, 25 Oct 2024 07:27:41 +0000 (10:27 +0300)
ggml/src/llamafile/sgemm.cpp

index 0193a463aefec94f0425e85ee3bc6d79c0e10d88..9eead3f61e090dd2f1c0e4840f94d0279165d68e 100644 (file)
@@ -942,6 +942,36 @@ class tinyBLAS_Q0_AVX {
         return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
     }
 
+    inline __m256i load(const block_q5_0 *b) {
+        return _mm256_or_si256(denibble(b->qs), bittobyte(b->qh));
+    }
+
+    inline __m128i load0(const block_q5_0* b) {
+        const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
+        uint32_t x32;
+        memcpy(&x32, b->qh, sizeof(uint32_t));
+        __m128i qxl = _mm_and_si128(_mm_set1_epi8(15), x);
+        __m128i bytesl = _mm_cmpeq_epi8(_mm_set1_epi64x(-1),
+                                        _mm_or_si128(_mm_set1_epi64x(0x7fbfdfeff7fbfdfe),
+                                                     _mm_shuffle_epi8(_mm_set1_epi32(x32),
+                                                                      _mm_set_epi64x(0x0101010101010101, 0x0000000000000000))));
+        bytesl = _mm_andnot_si128(bytesl, _mm_set1_epi8((char)0xF0));
+        return _mm_or_si128(qxl, bytesl);
+    }
+
+    inline __m128i load1(const block_q5_0* b) {
+        const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
+        uint32_t x32;
+        memcpy(&x32, b->qh, sizeof(uint32_t));
+        __m128i qxh = _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4));
+        __m128i bytesh = _mm_cmpeq_epi8(_mm_set1_epi64x(-1),
+                                        _mm_or_si128(_mm_set1_epi64x(0x7fbfdfeff7fbfdfe),
+                                                     _mm_shuffle_epi8(_mm_set1_epi32(x32),
+                                                                      _mm_set_epi64x(0x0303030303030303, 0x0202020202020202))));
+        bytesh = _mm_andnot_si128(bytesh, _mm_set1_epi8((char)0xF0));
+        return _mm_or_si128(qxh, bytesh);
+    }
+
     inline __m256i load(const block_iq4_nl *b) {
         return MM256_SET_M128I(load1(b), load0(b));
     }
@@ -973,6 +1003,17 @@ class tinyBLAS_Q0_AVX {
                                                         _mm_srli_epi16(x, 4), 1));
     }
 
+    static inline __m256i bittobyte(const uint8_t *p) {
+        uint32_t x32;
+        memcpy(&x32, p, sizeof(uint32_t));
+        __m256i bytes = _mm256_cmpeq_epi8(_mm256_set1_epi64x(-1),
+                                          _mm256_or_si256(_mm256_set1_epi64x(0x7fbfdfeff7fbfdfe),
+                                                          _mm256_shuffle_epi8(_mm256_set1_epi32(x32),
+                                                                              _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202,
+                                                                                                0x0101010101010101, 0x0000000000000000))));
+        return _mm256_andnot_si256(bytes, _mm256_set1_epi8((char)0xF0));
+    }
+
     const TA *const A;
     const TB *const B;
     TC *const C;
@@ -1182,6 +1223,22 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
 #endif
     }
 
+    case GGML_TYPE_Q5_0: {
+        if (Btype != GGML_TYPE_Q8_0)
+            return false;
+#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
+        tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float> tb{
+            k, (const block_q5_0 *)A, lda,
+            (const block_q8_0 *)B, ldb,
+            (float *)C, ldc,
+            ith, nth};
+        tb.matmul(m, n);
+        return true;
+#else
+        return false;
+#endif
+    }
+
     case GGML_TYPE_IQ4_NL: {
         if (Btype != GGML_TYPE_Q8_0)
             return false;