]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
add avx2 for dot_q8_0_q8_0, 2x faster than scalar (#1211)
authorYann Follet <redacted>
Fri, 28 Apr 2023 11:59:48 +0000 (19:59 +0800)
committerGitHub <redacted>
Fri, 28 Apr 2023 11:59:48 +0000 (11:59 +0000)
ggml.c

diff --git a/ggml.c b/ggml.c
index 3422a94481eca2f2494c21054b356fa2cc7ebc0e..1fbf2955d67308b8c2f25ab72aef6cf57b8c637a 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -3626,6 +3626,24 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
     }
 
     *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
+#elif defined(__AVX2__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+
+    // Main loop
+    for (int i = 0; i < nb; ++i) {
+        // Compute combined scale for the block
+        const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
+        __m256i bx = _mm256_loadu_si256((const __m256i *)x[i].qs);
+        __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
+
+        const __m256 q = mul_sum_i8_pairs_float(bx, by);
+
+        // Multiply q with scale and accumulate
+        acc = _mm256_fmadd_ps( d, q, acc );
+    }
+
+    *s = hsum_float_8(acc);
 #else
     // scalar
     float sumf = 0.0;