]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
~7% faster Q5_1 AVX2 code (#1477)
authorIlya Kurdyukov <redacted>
Tue, 16 May 2023 18:36:47 +0000 (01:36 +0700)
committerGitHub <redacted>
Tue, 16 May 2023 18:36:47 +0000 (18:36 +0000)
ggml.c

diff --git a/ggml.c b/ggml.c
index 4311ce7cf9dbe67cb4a8f54c276b231ad7354e6b..dbef99312a56d6ec4cf5f1c7324b6cd1bee6f78e 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -543,12 +543,7 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) {
     return _mm256_cvtepi32_ps(summed_pairs);
 }
 
-// multiply int8_t, add results pairwise twice and return as float vector
-static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
-    // Get absolute values of x vectors
-    const __m256i ax = _mm256_sign_epi8(x, x);
-    // Sign the values of the y vectors
-    const __m256i sy = _mm256_sign_epi8(y, x);
+static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
 #if __AVXVNNI__
     const __m256i zero = _mm256_setzero_si256();
     const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
@@ -560,6 +555,21 @@ static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
 #endif
 }
 
+// multiply int8_t, add results pairwise twice and return as float vector
+static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
+#if __AVXVNNIINT8__
+    const __m256i zero = _mm256_setzero_si256();
+    const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y);
+    return _mm256_cvtepi32_ps(summed_pairs);
+#else
+    // Get absolute values of x vectors
+    const __m256i ax = _mm256_sign_epi8(x, x);
+    // Sign the values of the y vectors
+    const __m256i sy = _mm256_sign_epi8(y, x);
+    return mul_sum_us8_pairs_float(ax, sy);
+#endif
+}
+
 static inline __m128i packNibbles( __m256i bytes )
 {
     // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
@@ -619,6 +629,17 @@ static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
     return _mm256_cvtepi32_ps(summed_pairs);
 }
 
+static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
+    const __m128i axl = _mm256_castsi256_si128(ax);
+    const __m128i axh = _mm256_extractf128_si256(ax, 1);
+    const __m128i syl = _mm256_castsi256_si128(sy);
+    const __m128i syh = _mm256_extractf128_si256(sy, 1);
+    // Perform multiplication and create 16-bit values
+    const __m128i dotl = _mm_maddubs_epi16(axl, syl);
+    const __m128i doth = _mm_maddubs_epi16(axh, syh);
+    return sum_i16_pairs_float(doth, dotl);
+}
+
 // multiply int8_t, add results pairwise twice and return as float vector
 static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
     const __m128i xl = _mm256_castsi256_si128(x);
@@ -2434,7 +2455,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
         const __m256i bx = bytes_from_nibbles_32(x[i].qs);
         const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
 
-        const __m256 xy = mul_sum_i8_pairs_float(bx, by);
+        const __m256 xy = mul_sum_us8_pairs_float(bx, by);
 
         // Accumulate d0*d1*x*y
 #if defined(__AVX2__)
@@ -2906,7 +2927,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
         const __m256 dy = _mm256_broadcast_ss(&y[i].d);
         const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
 
-        const __m256 q = mul_sum_i8_pairs_float(bx, by);
+        const __m256 q = mul_sum_us8_pairs_float(bx, by);
 
         acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
     }
@@ -2940,7 +2961,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
         const __m256 dy = _mm256_broadcast_ss(&y[i].d);
         const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
 
-        const __m256 q = mul_sum_i8_pairs_float(bx, by);
+        const __m256 q = mul_sum_us8_pairs_float(bx, by);
 
         acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);
     }