From: Georgi Gerganov Date: Sun, 23 Apr 2023 13:38:00 +0000 (+0300) Subject: ggml : sync llama.cpp (AVX improvements) X-Git-Tag: upstream/0.0.1642~1525 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=48913c8deb71417bc9e17aefe78ec89fe063f718;p=pkg%2Fggml%2Fsources%2Fggml ggml : sync llama.cpp (AVX improvements) --- diff --git a/src/ggml.c b/src/ggml.c index 281b2028..3ee2d081 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -509,14 +509,25 @@ static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { const __m256i ax = _mm256_sign_epi8(x, x); // Sign the values of the y vectors const __m256i sy = _mm256_sign_epi8(y, x); +#if __AVXVNNI__ + const __m256i zero = _mm256_setzero_si256(); + const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy); + return _mm256_cvtepi32_ps(summed_pairs); +#else // Perform multiplication and create 16-bit values const __m256i dot = _mm256_maddubs_epi16(ax, sy); return sum_i16_pairs_float(dot); +#endif } static inline __m128i packNibbles( __m256i bytes ) { // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh +#if __AVX512F__ + const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4); // 0000_0000_abcd_0000 + bytes = _mm256_or_si256(bytes, bytes_srli_4); // 0000_abcd_abcd_efgh + return _mm256_cvtepi16_epi8(bytes); // abcd_efgh +#else const __m256i lowByte = _mm256_set1_epi16( 0xFF ); __m256i high = _mm256_andnot_si256( lowByte, bytes ); __m256i low = _mm256_and_si256( lowByte, bytes ); @@ -527,6 +538,7 @@ static inline __m128i packNibbles( __m256i bytes ) __m128i r0 = _mm256_castsi256_si128( bytes ); __m128i r1 = _mm256_extracti128_si256( bytes, 1 ); return _mm_packus_epi16( r0, r1 ); +#endif } #else static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 ) @@ -2935,6 +2947,7 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * #elif defined(__AVX2__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); + float summs = 0.0f; // Main loop for (int i = 0; i < nb; i++) { @@ -2942,9 +2955,8 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d)); const __m256 dx = _mm256_set_m128(d1, d0); - const __m128 m0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].m)); - const __m128 m1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].m)); - const __m256 mx = _mm256_set_m128(m1, m0); + summs += GGML_FP16_TO_FP32(x[2*i + 0].m) * y[i].s0 + + GGML_FP16_TO_FP32(x[2*i + 1].m) * y[i].s1; const __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs); const __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs); @@ -2953,16 +2965,12 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * const __m256 dy = _mm256_broadcast_ss(&y[i].d); const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); - const __m256i syi = _mm256_maddubs_epi16(_mm256_set1_epi8(1), by); - const __m256 syf = sum_i16_pairs_float(syi); - const __m256 q = mul_sum_i8_pairs_float(bx, by); - const __m256 sxy = _mm256_fmadd_ps(q, dx, _mm256_mul_ps(mx, syf)); - acc = _mm256_fmadd_ps(sxy, dy, acc); + acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc); } - *s = hsum_float_8(acc); + *s = hsum_float_8(acc) + summs; #else // scalar float sumf = 0.0;