const __m256i ax = _mm256_sign_epi8(x, x);
// Sign the values of the y vectors
const __m256i sy = _mm256_sign_epi8(y, x);
+#if __AVXVNNI__
+ const __m256i zero = _mm256_setzero_si256();
+ const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
+ return _mm256_cvtepi32_ps(summed_pairs);
+#else
// Perform multiplication and create 16-bit values
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
return sum_i16_pairs_float(dot);
+#endif
}
static inline __m128i packNibbles( __m256i bytes )
{
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
+#if __AVX512F__
+ const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4); // 0000_0000_abcd_0000
+ bytes = _mm256_or_si256(bytes, bytes_srli_4); // 0000_abcd_abcd_efgh
+ return _mm256_cvtepi16_epi8(bytes); // abcd_efgh
+#else
const __m256i lowByte = _mm256_set1_epi16( 0xFF );
__m256i high = _mm256_andnot_si256( lowByte, bytes );
__m256i low = _mm256_and_si256( lowByte, bytes );
__m128i r0 = _mm256_castsi256_si128( bytes );
__m128i r1 = _mm256_extracti128_si256( bytes, 1 );
return _mm_packus_epi16( r0, r1 );
+#endif
}
#else
static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
#elif defined(__AVX2__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
+ float summs = 0.0f;
// Main loop
for (int i = 0; i < nb; i++) {
const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
const __m256 dx = _mm256_set_m128(d1, d0);
- const __m128 m0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].m));
- const __m128 m1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].m));
- const __m256 mx = _mm256_set_m128(m1, m0);
+ summs += GGML_FP16_TO_FP32(x[2*i + 0].m) * y[i].s0
+ + GGML_FP16_TO_FP32(x[2*i + 1].m) * y[i].s1;
const __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
const __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
const __m256 dy = _mm256_broadcast_ss(&y[i].d);
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
- const __m256i syi = _mm256_maddubs_epi16(_mm256_set1_epi8(1), by);
- const __m256 syf = sum_i16_pairs_float(syi);
-
const __m256 q = mul_sum_i8_pairs_float(bx, by);
- const __m256 sxy = _mm256_fmadd_ps(q, dx, _mm256_mul_ps(mx, syf));
- acc = _mm256_fmadd_ps(sxy, dy, acc);
+ acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
}
- *s = hsum_float_8(acc);
+ *s = hsum_float_8(acc) + summs;
#else
// scalar
float sumf = 0.0;