]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Add initial AVX512 support for dot product on Linux (#320)
authorCasey Primozic <redacted>
Tue, 21 Mar 2023 14:35:42 +0000 (07:35 -0700)
committerGitHub <redacted>
Tue, 21 Mar 2023 14:35:42 +0000 (15:35 +0100)
 * Update Makefile to detect AVX512 support and add compiler flags if it's available
 * Based on existing AVX2 implementation, dot product on one 32-value block of 4-bit quantized ints at a time
 * Perform 8 bit -> 16 bit sign extension and multiply+add on 32 values at time instead of 16
 * Use built-in AVX512 horizontal reduce add to get sum at the end
 * Manual unrolling on inner dot product loop to reduce loop counter overhead

Makefile
ggml.c

index 44fb298403eca6b1893eebbb436b541d42524cda..ec2eb75696d9546354aba902fc4a3288ef4b29f2 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -95,6 +95,38 @@ ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686))
                ifneq (,$(findstring sse3,$(SSE3_M)))
                        CFLAGS += -msse3
                endif
+               AVX512F_M := $(shell grep "avx512f " /proc/cpuinfo)
+               ifneq (,$(findstring avx512f,$(AVX512F_M)))
+                       CFLAGS += -mavx512f
+               endif
+               AVX512BW_M := $(shell grep "avx512bw " /proc/cpuinfo)
+               ifneq (,$(findstring avx512bw,$(AVX512BW_M)))
+                       CFLAGS += -mavx512bw
+               endif
+               AVX512DQ_M := $(shell grep "avx512dq " /proc/cpuinfo)
+               ifneq (,$(findstring avx512dq,$(AVX512DQ_M)))
+                       CFLAGS += -mavx512dq
+               endif
+               AVX512VL_M := $(shell grep "avx512vl " /proc/cpuinfo)
+               ifneq (,$(findstring avx512vl,$(AVX512VL_M)))
+                       CFLAGS += -mavx512vl
+               endif
+               AVX512CD_M := $(shell grep "avx512cd " /proc/cpuinfo)
+               ifneq (,$(findstring avx512cd,$(AVX512CD_M)))
+                       CFLAGS += -mavx512cd
+               endif
+               AVX512ER_M := $(shell grep "avx512er " /proc/cpuinfo)
+               ifneq (,$(findstring avx512er,$(AVX512ER_M)))
+                       CFLAGS += -mavx512er
+               endif
+               AVX512IFMA_M := $(shell grep "avx512ifma " /proc/cpuinfo)
+               ifneq (,$(findstring avx512ifma,$(AVX512IFMA_M)))
+                       CFLAGS += -mavx512ifma
+               endif
+               AVX512PF_M := $(shell grep "avx512pf " /proc/cpuinfo)
+               ifneq (,$(findstring avx512pf,$(AVX512PF_M)))
+                       CFLAGS += -mavx512pf
+               endif
        else ifeq ($(UNAME_S),Haiku)
                AVX1_M := $(shell sysinfo -cpu | grep "AVX ")
                ifneq (,$(findstring avx,$(AVX1_M)))
diff --git a/ggml.c b/ggml.c
index 4813f74c895c950cd5c75309418dbcaf537893fe..f85138f3853c8c7f300ff2009a18000202523c7e 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -361,7 +361,7 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
 
 // AVX routines provided by GH user Const-me
 // ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
-#if __AVX2__
+#if __AVX2__ || __AVX512F__
 // Unpack 32 4-bit fields into 32 bytes
 // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
 static inline __m256i bytesFromNibbles( const uint8_t* rsi )
@@ -397,7 +397,6 @@ static inline __m128i packNibbles( __m256i bytes )
 }
 #endif
 
-
 // method 5
 // blocks of QK elements
 // represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
@@ -1262,6 +1261,47 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
     *s = sumf;
 }
 
+#if __AVX512F__ && QK == 32
+static inline __m512 dot_q4_0_oneblock_avx512(
+    __m512 acc,
+    const uint8_t * pd0,
+    const uint8_t * pd1,
+    const uint8_t * pb0,
+    const uint8_t * pb1,
+    size_t bs,
+    int i
+) {
+    const float * d0_0 = (const float *) (pd0 + i*bs);
+    const float * d1_0 = (const float *) (pd1 + i*bs);
+
+    const uint8_t * restrict p0 = pb0 + (i+0)*bs;
+    const uint8_t * restrict p1 = pb1 + (i+0)*bs;
+
+    // Compute combined scale for the block
+    float scaleScalar = d0_0[0] * d1_0[0];
+    __m512 scale = _mm512_set1_ps( scaleScalar );
+
+    __m256i bx = bytesFromNibbles( p0 );
+    __m256i by = bytesFromNibbles( p1 );
+
+    // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
+    const __m256i off = _mm256_set1_epi8( 8 );
+    bx = _mm256_sub_epi8( bx, off );
+    by = _mm256_sub_epi8( by, off );
+
+    // Sign-extend 16 signed bytes into int16_t
+    __m512i x32 = _mm512_cvtepi8_epi16( bx );
+    __m512i y32 = _mm512_cvtepi8_epi16( by );
+    // Compute products of int16_t integers, add pairwise
+    __m512i i64 = _mm512_madd_epi16( x32, y32 );
+
+    // Convert int32_t to float
+    __m512 p = _mm512_cvtepi32_ps( i64 );
+    // Apply the scale, and accumulate
+    return _mm512_fmadd_ps( scale, p, acc );
+}
+#endif
+
 inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
     ggml_float sumf = 0.0;
 
@@ -1417,6 +1457,40 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
 #else
 #error "not implemented for QK"
 #endif
+#elif defined(__AVX512F__)
+
+#if QK == 32
+    // Initialize accumulator with zeros
+    __m512 acc0 = _mm512_setzero_ps();
+    __m512 acc1 = _mm512_setzero_ps();
+
+    const int superblock_size = 8;
+    const int superblock_count = nb / superblock_size;
+    const int remainder = nb % superblock_size;
+
+    for (int superblock_ix = 0; superblock_ix < superblock_count; superblock_ix += 1) {
+        int i = superblock_ix * superblock_size;
+
+        acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+0 );
+        acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+1 );
+        acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+2 );
+        acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+3 );
+        acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+4 );
+        acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+5 );
+        acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+6 );
+        acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+7 );
+    }
+
+    // Remainders
+    for (int i = superblock_count * superblock_size; i < nb; ++i) {
+        acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i );
+    }
+
+    // Horizontal sum of all lanes of the accumulator
+    sumf = _mm512_reduce_add_ps( acc0 ) + _mm512_reduce_add_ps( acc1 );
+#else
+#error "not implemented for QK"
+#endif
 #elif defined(__AVX2__)
 #if QK == 32
     const size_t countBlocks = nb;
@@ -1928,7 +2002,7 @@ inline static void ggml_vec_mad_q4_1(const int n, float * restrict y, void * res
     const size_t bs = 2*sizeof(float) + QK/2;
 
     const uint8_t * restrict pd = ((const uint8_t *)x + 0*bs);
-    const uint8_t * restrict pm = ((const uint8_t *)x + 0*bs +   sizeof(float)); 
+    const uint8_t * restrict pm = ((const uint8_t *)x + 0*bs +   sizeof(float));
     const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + 2*sizeof(float));
 
     for (int i = 0; i < nb; i++) {