]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
A better `packNibbles` and `mul_sum_i8_pairs_float` implementation using AVX512 ...
authorYishuo Wang <redacted>
Sun, 23 Apr 2023 07:57:05 +0000 (15:57 +0800)
committerGitHub <redacted>
Sun, 23 Apr 2023 07:57:05 +0000 (07:57 +0000)
ggml.c

diff --git a/ggml.c b/ggml.c
index 281b20283c16f858f53bf506fccae122ea0a2f20..3c45c5e9d74b2f0c92965f796d14440c251caa0a 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -509,14 +509,25 @@ static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
     const __m256i ax = _mm256_sign_epi8(x, x);
     // Sign the values of the y vectors
     const __m256i sy = _mm256_sign_epi8(y, x);
+#if __AVXVNNI__
+    const __m256i zero = _mm256_setzero_si256();
+    const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
+    return _mm256_cvtepi32_ps(summed_pairs);
+#else
     // Perform multiplication and create 16-bit values
     const __m256i dot = _mm256_maddubs_epi16(ax, sy);
     return sum_i16_pairs_float(dot);
+#endif
 }
 
 static inline __m128i packNibbles( __m256i bytes )
 {
     // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
+#if __AVX512F__
+    const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4);   // 0000_0000_abcd_0000
+    bytes = _mm256_or_si256(bytes, bytes_srli_4);               // 0000_abcd_abcd_efgh
+    return _mm256_cvtepi16_epi8(bytes);                         // abcd_efgh
+#else
     const __m256i lowByte = _mm256_set1_epi16( 0xFF );
     __m256i high = _mm256_andnot_si256( lowByte, bytes );
     __m256i low = _mm256_and_si256( lowByte, bytes );
@@ -527,6 +538,7 @@ static inline __m128i packNibbles( __m256i bytes )
     __m128i r0 = _mm256_castsi256_si128( bytes );
     __m128i r1 = _mm256_extracti128_si256( bytes, 1 );
     return _mm_packus_epi16( r0, r1 );
+#endif
 }
 #else
 static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )