]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Adding SSE instructions to ggml_vec_dot_q4_0_q8_0 (#1413)
author3ooabkhxtn <redacted>
Sat, 13 May 2023 08:43:33 +0000 (10:43 +0200)
committerGitHub <redacted>
Sat, 13 May 2023 08:43:33 +0000 (08:43 +0000)
ggml.c

diff --git a/ggml.c b/ggml.c
index 096ccacfb7e083533d36f10607533839e61b52d5..4eccd417a15b72d83f5439f4e1b38ad89efcc77c 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -472,7 +472,7 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
 // quantization
 //
 
-#if __AVX__ || __AVX2__ || __AVX512F__
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
 // multiply int8_t, add results pairwise twice
 static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
     // Get absolute values of x vectors
@@ -485,6 +485,7 @@ static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
     return _mm_madd_epi16(ones, dot);
 }
 
+#if __AVX__ || __AVX2__ || __AVX512F__
 // horizontally add 8 floats
 static inline float hsum_float_8(const __m256 x) {
     __m128 res = _mm256_extractf128_ps(x, 1);
@@ -596,7 +597,19 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
     return _mm_packus_epi16( bytes1, bytes2);
 }
 #endif
+#elif defined(__SSSE3__)
+// horizontally add 4x4 floats
+static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) {
+    __m128 res_0 =_mm_hadd_ps(a, b);
+    __m128 res_1 =_mm_hadd_ps(c, d);
+    __m128 res =_mm_hadd_ps(res_0, res_1);
+    res =_mm_hadd_ps(res, res);
+    res =_mm_hadd_ps(res, res);
+
+    return _mm_cvtss_f32(res);
+}
 #endif // __AVX__ || __AVX2__ || __AVX512F__
+#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
 
 #if __ARM_NEON
 
@@ -2129,6 +2142,126 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
     }
 
     *s = hsum_float_8(acc);
+#elif defined(__SSSE3__)
+    // set constants
+    const __m128i lowMask = _mm_set1_epi8(0xF);
+    const __m128i off = _mm_set1_epi8(8);
+
+    // Initialize accumulator with zeros
+    __m128 acc_0 = _mm_setzero_ps();
+    __m128 acc_1 = _mm_setzero_ps();
+    __m128 acc_2 = _mm_setzero_ps();
+    __m128 acc_3 = _mm_setzero_ps();
+
+    // First round without accumulation
+    {
+        _mm_prefetch(&x[0] + sizeof(block_q4_0), _MM_HINT_T0);
+        _mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0);
+
+        // Compute combined scale for the block 0 and 1
+        const __m128 d_0_1 = _mm_mul_ps( _mm_set1_ps( x[0].d ), _mm_set1_ps( y[0].d ) );
+
+        const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[0].qs);
+
+        __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
+        __m128i by_0 = _mm_loadu_si128((const __m128i *)y[0].qs);
+        bx_0 = _mm_sub_epi8(bx_0, off);
+        const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
+
+        __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
+        __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[0].qs + 16));
+        bx_1 = _mm_sub_epi8(bx_1, off);
+        const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
+
+        _mm_prefetch(&x[1] + sizeof(block_q4_0), _MM_HINT_T0);
+        _mm_prefetch(&y[1] + sizeof(block_q8_0), _MM_HINT_T0);
+
+        // Compute combined scale for the block 2 and 3
+        const __m128 d_2_3 = _mm_mul_ps( _mm_set1_ps( x[1].d ), _mm_set1_ps( y[1].d ) );
+
+        const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[1].qs);
+
+        __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
+        __m128i by_2 = _mm_loadu_si128((const __m128i *)y[1].qs);
+        bx_2 = _mm_sub_epi8(bx_2, off);
+        const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
+
+        __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
+        __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[1].qs + 16));
+        bx_3 = _mm_sub_epi8(bx_3, off);
+        const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
+
+        // Convert int32_t to float
+        __m128 p0 = _mm_cvtepi32_ps(i32_0);
+        __m128 p1 = _mm_cvtepi32_ps(i32_1);
+        __m128 p2 = _mm_cvtepi32_ps(i32_2);
+        __m128 p3 = _mm_cvtepi32_ps(i32_3);
+
+        // Apply the scale
+        acc_0 = _mm_mul_ps( d_0_1, p0 );
+        acc_1 = _mm_mul_ps( d_0_1, p1 );
+        acc_2 = _mm_mul_ps( d_2_3, p2 );
+        acc_3 = _mm_mul_ps( d_2_3, p3 );
+    }
+
+    // Main loop
+    for (int i = 2; i < nb; i+=2) {
+        _mm_prefetch(&x[i] + sizeof(block_q4_0), _MM_HINT_T0);
+        _mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0);
+
+        // Compute combined scale for the block 0 and 1
+        const __m128 d_0_1 = _mm_mul_ps( _mm_set1_ps( x[i].d ), _mm_set1_ps( y[i].d ) );
+
+        const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[i].qs);
+
+        __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
+        __m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs);
+        bx_0 = _mm_sub_epi8(bx_0, off);
+        const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
+
+        __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
+        __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
+        bx_1 = _mm_sub_epi8(bx_1, off);
+        const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
+
+        _mm_prefetch(&x[i] + 2 * sizeof(block_q4_0), _MM_HINT_T0);
+        _mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
+
+        // Compute combined scale for the block 2 and 3
+        const __m128 d_2_3 = _mm_mul_ps( _mm_set1_ps( x[i + 1].d ), _mm_set1_ps( y[i + 1].d ) );
+
+        const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[i + 1].qs);
+
+        __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
+        __m128i by_2 = _mm_loadu_si128((const __m128i *)y[i + 1].qs);
+        bx_2 = _mm_sub_epi8(bx_2, off);
+        const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
+
+        __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
+        __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[i + 1].qs + 16));
+        bx_3 = _mm_sub_epi8(bx_3, off);
+        const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
+
+        // Convert int32_t to float
+        __m128 p0 = _mm_cvtepi32_ps(i32_0);
+        __m128 p1 = _mm_cvtepi32_ps(i32_1);
+        __m128 p2 = _mm_cvtepi32_ps(i32_2);
+        __m128 p3 = _mm_cvtepi32_ps(i32_3);
+
+        // Apply the scale
+        __m128 p0_d = _mm_mul_ps( d_0_1, p0 );
+        __m128 p1_d = _mm_mul_ps( d_0_1, p1 );
+        __m128 p2_d = _mm_mul_ps( d_2_3, p2 );
+        __m128 p3_d = _mm_mul_ps( d_2_3, p3 );
+
+        // Acummulate
+        acc_0 = _mm_add_ps(p0_d, acc_0);
+        acc_1 = _mm_add_ps(p1_d, acc_1);
+        acc_2 = _mm_add_ps(p2_d, acc_2);
+        acc_3 = _mm_add_ps(p3_d, acc_3);
+    }
+
+    *s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
 #else
     // scalar
     float sumf = 0.0;