]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
AVX2 optimization for vec_dot_q4_2_q8_0 (#1068)
authorStephan Walter <redacted>
Thu, 20 Apr 2023 06:45:41 +0000 (06:45 +0000)
committerGitHub <redacted>
Thu, 20 Apr 2023 06:45:41 +0000 (08:45 +0200)
ggml.c

diff --git a/ggml.c b/ggml.c
index 9a3430859f7e1b3f90df0c751da52afbd0cd14e0..35b15cc2ee7805bed7875bfdaf90b0ad2f1fc86f 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -467,12 +467,30 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
 // quantization
 //
 
-// AVX routines provided by GH user Const-me
-// ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
+#if __AVX__ || __AVX2__ || __AVX512F__
+// Unpack 16 4-bit fields into 16 bytes
+// The output vector contains 16 bytes, each one in [ 0 .. 15 ] interval
+static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
+{
+    // Load 8 bytes from memory
+    __m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );
+
+    // Expand bytes into uint16_t values
+    __m128i bytes = _mm_cvtepu8_epi16( tmp );
+
+    // Unpack values into individual bytes
+    const __m128i lowMask = _mm_set1_epi8( 0xF );
+    __m128i high = _mm_andnot_si128( lowMask, bytes );
+    __m128i low = _mm_and_si128( lowMask, bytes );
+    high = _mm_slli_epi16( high, 4 );
+    bytes = _mm_or_si128( low, high );
+    return bytes;
+}
+
 #if __AVX2__ || __AVX512F__
 // Unpack 32 4-bit fields into 32 bytes
 // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
-static inline __m256i bytesFromNibbles( const uint8_t* rsi )
+static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
 {
     // Load 16 bytes from memory
     __m128i tmp = _mm_loadu_si128( ( const __m128i* )rsi );
@@ -503,24 +521,7 @@ static inline __m128i packNibbles( __m256i bytes )
     __m128i r1 = _mm256_extracti128_si256( bytes, 1 );
     return _mm_packus_epi16( r0, r1 );
 }
-#elif __AVX__
-static inline __m128i bytesFromNibbles( const uint8_t* rsi )
-{
-    // Load 8 bytes from memory
-    __m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );
-
-    // Expand bytes into uint16_t values
-    __m128i bytes = _mm_cvtepu8_epi16( tmp );
-
-    // Unpack values into individual bytes
-    const __m128i lowMask = _mm_set1_epi8( 0xF );
-    __m128i high = _mm_andnot_si128( lowMask, bytes );
-    __m128i low = _mm_and_si128( lowMask, bytes );
-    high = _mm_slli_epi16( high, 4 );
-    bytes = _mm_or_si128( low, high );
-    return bytes;
-}
-
+#else
 static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
 {
     // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
@@ -537,6 +538,7 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
     return _mm_packus_epi16( bytes1, bytes2);
 }
 #endif
+#endif // __AVX__ || __AVX2__ || __AVX512F__
 
 #if __ARM_NEON
 
@@ -1395,7 +1397,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
 
         for (int l = 0; l < QK4_0; l += 32) {
             // Load 32x4-bit integers into 32x8-bit integers
-            __m256i vx8 = bytesFromNibbles(pp+l/2);
+            __m256i vx8 = bytes_from_nibbles_32(pp+l/2);
 
             // Subtract 8 from the integers
             vx8 = _mm256_sub_epi8(vx8, _mm256_set1_epi8(8));
@@ -1513,7 +1515,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
 
         for (int l = 0; l < QK4_1; l += 32) {
             // Load 32x4-bit integers into 32x8-bit integers
-            __m256i vx8 = bytesFromNibbles(pp+l/2);
+            __m256i vx8 = bytes_from_nibbles_32(pp+l/2);
 
             // Convert to 16-bit int
             const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0));
@@ -2356,7 +2358,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
         /* Compute combined scale for the block */
         const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
 
-        __m256i bx = bytesFromNibbles(x[i].qs);
+        __m256i bx = bytes_from_nibbles_32(x[i].qs);
 
         // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
         const __m256i off = _mm256_set1_epi8( 8 );
@@ -2402,7 +2404,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
         __m128i i32[2];
         for (int j = 0; j < 2; ++j) {
             // Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
-            __m128i bx = bytesFromNibbles( x[i].qs + 8*j );
+            __m128i bx = bytes_from_nibbles_16(x[i].qs + 8*j);
             __m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));
 
             // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
@@ -2567,7 +2569,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
         const __m256 d1m0 = _mm256_mul_ps( d1v, m0v );
 
         // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
-        const __m256i bx = bytesFromNibbles( x[i].qs );
+        const __m256i bx = bytes_from_nibbles_32(x[i].qs);
         const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
 
         // Get absolute values of x vectors
@@ -2721,6 +2723,51 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
     }
 
     sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
+#elif defined(__AVX2__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+
+    // Main loop
+    for (int i = 0; i < nb; i++) {
+        /* Compute combined scale for the block */
+        const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
+        const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
+        const __m256 d = _mm256_mul_ps(_mm256_set_m128(d1, d0), _mm256_broadcast_ss(&y[i].d));
+
+        __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
+        __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
+        __m256i bx = _mm256_set_m128i(bx1, bx0);
+
+        // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
+        const __m256i off = _mm256_set1_epi8(8);
+        bx = _mm256_sub_epi8(bx, off);
+
+        __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
+
+        // Get absolute values of x vectors
+        const __m256i ax = _mm256_sign_epi8(bx, bx);
+        // Sign the values of the y vectors
+        const __m256i sy = _mm256_sign_epi8(by, bx);
+        // Perform multiplication and create 16-bit values
+        const __m256i dot = _mm256_maddubs_epi16(ax, sy);
+
+        const __m256i ones = _mm256_set1_epi16(1);
+        __m256i xy_q = _mm256_madd_epi16(ones, dot);
+
+        /* Convert to vectore of 8 int32_t to 8 floats */
+        __m256 q = _mm256_cvtepi32_ps(xy_q);
+
+        /* Multiply q with scale and accumulate */
+        acc = _mm256_fmadd_ps(d, q, acc);
+    }
+
+    // Return horizontal sum of the acc vector
+    __m128 res = _mm256_extractf128_ps(acc, 1);
+    res = _mm_add_ps(res, _mm256_castps256_ps128(acc));
+    res = _mm_add_ps(res, _mm_movehl_ps(res, res));
+    res = _mm_add_ss(res, _mm_movehdup_ps(res));
+
+    sumf = _mm_cvtss_f32(res);
 #else
     // scalar
     for (int i = 0; i < nb; i++) {