]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : sync llama.cpp - new quantization formats Q4 + Q8
authorGeorgi Gerganov <redacted>
Sat, 20 May 2023 11:56:14 +0000 (14:56 +0300)
committerGeorgi Gerganov <redacted>
Sat, 20 May 2023 11:56:14 +0000 (14:56 +0300)
include/ggml/ggml.h
src/ggml.c

index 3269218bd261b6f606c23a4fcfffae13ce64b807..51a616c501bb3eb46bf2b20c727b0c0a0b7a16dc 100644 (file)
 #define GGML_FILE_MAGIC   0x67676d6c // "ggml"
 #define GGML_FILE_VERSION 1
 
-#define GGML_QNT_VERSION        1    // bump this on quantization format changes
+#define GGML_QNT_VERSION        2    // bump this on quantization format changes
 #define GGML_QNT_VERSION_FACTOR 1000 // do not change this
 
 #define GGML_MAX_DIMS          4
index 5e6a9a0d462995a193006b4fad17474018717170..6b48852d8093ac419301a5173ba6d5962d4ee675 100644 (file)
@@ -512,7 +512,7 @@ static inline int hsum_i32_4(const __m128i a) {
     return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
 }
 
-#if __AVX2__ || __AVX512F__
+#if defined(__AVX2__) || defined(__AVX512F__)
 // spread 32 bits to 32 bytes { 0x00, 0xFF }
 static inline __m256i bytes_from_bits_32(const uint8_t * x) {
     uint32_t x32;
@@ -543,12 +543,7 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) {
     return _mm256_cvtepi32_ps(summed_pairs);
 }
 
-// multiply int8_t, add results pairwise twice and return as float vector
-static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
-    // Get absolute values of x vectors
-    const __m256i ax = _mm256_sign_epi8(x, x);
-    // Sign the values of the y vectors
-    const __m256i sy = _mm256_sign_epi8(y, x);
+static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
 #if __AVXVNNI__
     const __m256i zero = _mm256_setzero_si256();
     const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
@@ -560,6 +555,21 @@ static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
 #endif
 }
 
+// multiply int8_t, add results pairwise twice and return as float vector
+static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
+#if __AVXVNNIINT8__
+    const __m256i zero = _mm256_setzero_si256();
+    const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y);
+    return _mm256_cvtepi32_ps(summed_pairs);
+#else
+    // Get absolute values of x vectors
+    const __m256i ax = _mm256_sign_epi8(x, x);
+    // Sign the values of the y vectors
+    const __m256i sy = _mm256_sign_epi8(y, x);
+    return mul_sum_us8_pairs_float(ax, sy);
+#endif
+}
+
 static inline __m128i packNibbles( __m256i bytes )
 {
     // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
@@ -619,6 +629,17 @@ static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
     return _mm256_cvtepi32_ps(summed_pairs);
 }
 
+static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
+    const __m128i axl = _mm256_castsi256_si128(ax);
+    const __m128i axh = _mm256_extractf128_si256(ax, 1);
+    const __m128i syl = _mm256_castsi256_si128(sy);
+    const __m128i syh = _mm256_extractf128_si256(sy, 1);
+    // Perform multiplication and create 16-bit values
+    const __m128i dotl = _mm_maddubs_epi16(axl, syl);
+    const __m128i doth = _mm_maddubs_epi16(axh, syh);
+    return sum_i16_pairs_float(doth, dotl);
+}
+
 // multiply int8_t, add results pairwise twice and return as float vector
 static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
     const __m128i xl = _mm256_castsi256_si128(x);
@@ -667,7 +688,7 @@ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128
 #endif // __AVX__ || __AVX2__ || __AVX512F__
 #endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
 
-#if __ARM_NEON
+#if defined(__ARM_NEON)
 
 #if !defined(__aarch64__)
 
@@ -748,18 +769,18 @@ int32x4_t vcvtnq_s32_f32(float32x4_t v) {
 
 #define QK4_0 32
 typedef struct {
-    float   d;          // delta
+    ggml_fp16_t d;          // delta
     uint8_t qs[QK4_0 / 2];  // nibbles / quants
 } block_q4_0;
-static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding");
+static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
 
 #define QK4_1 32
 typedef struct {
-    float   d;          // delta
-    float   m;          // min
+    ggml_fp16_t d;          // delta
+    ggml_fp16_t m;          // min
     uint8_t qs[QK4_1 / 2];  // nibbles / quants
 } block_q4_1;
-static_assert(sizeof(block_q4_1) == 2 * sizeof(float) + QK4_1 / 2, "wrong q4_1 block size/padding");
+static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_fp16_t) + QK4_1 / 2, "wrong q4_1 block size/padding");
 
 #define QK5_0 32
 typedef struct {
@@ -780,16 +801,16 @@ static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) +
 
 #define QK8_0 32
 typedef struct {
-    float   d;          // delta
-    int8_t  qs[QK8_0];  // quants
+    ggml_fp16_t d;         // delta
+    int8_t  qs[QK8_0];     // quants
 } block_q8_0;
-static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
+static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
 
 #define QK8_1 32
 typedef struct {
-    float   d;         // delta
-    float   s;         // d * sum(qs[i])
-    int8_t  qs[QK8_1]; // quants
+    float d;               // delta
+    float s;               // d * sum(qs[i])
+    int8_t  qs[QK8_1];     // quants
 } block_q8_1;
 static_assert(sizeof(block_q8_1) == 2*sizeof(float) + QK8_1, "wrong q8_1 block size/padding");
 
@@ -816,7 +837,7 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
         const float d  = max / -8;
         const float id = d ? 1.0f/d : 0.0f;
 
-        y[i].d = d;
+        y[i].d = GGML_FP32_TO_FP16(d);
 
         for (int j = 0; j < qk/2; ++j) {
             const float x0 = x[i*qk + 0    + j]*id;
@@ -856,8 +877,8 @@ static void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * r
         const float d  = (max - min) / ((1 << 4) - 1);
         const float id = d ? 1.0f/d : 0.0f;
 
-        y[i].d = d;
-        y[i].m = min;
+        y[i].d = GGML_FP32_TO_FP16(d);
+        y[i].m = GGML_FP32_TO_FP16(min);
 
         for (int j = 0; j < qk/2; ++j) {
             const float x0 = (x[i*qk + 0    + j] - min)*id;
@@ -988,7 +1009,7 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
         const float d = amax / ((1 << 7) - 1);
         const float id = d ? 1.0f/d : 0.0f;
 
-        y[i].d = d;
+        y[i].d = GGML_FP32_TO_FP16(d);
 
         for (int j = 0; j < QK8_0; ++j) {
             const float x0 = x[i*QK8_0 + j]*id;
@@ -1023,7 +1044,7 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
         const float d = amax / ((1 << 7) - 1);
         const float id = d ? 1.0f/d : 0.0f;
 
-        y[i].d = d;
+        y[i].d = GGML_FP32_TO_FP16(d);
 
         for (int j = 0; j < 8; j++) {
             const float32x4_t v  = vmulq_n_f32(srcv[j], id);
@@ -1058,7 +1079,7 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
 
         // Quantize these floats
         const float d = maxScalar / 127.f;
-        y[i].d = d;
+        y[i].d = GGML_FP32_TO_FP16(d);
         const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
         const __m256 mul = _mm256_set1_ps( id );
 
@@ -1157,7 +1178,7 @@ static void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * r
             sum += y[i].qs[QK8_1/2 + j];
         }
 
-        y[i].s = d * sum;
+        y[i].s = sum*d;
     }
 }
 
@@ -1309,7 +1330,7 @@ static void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict
     const int nb = k / qk;
 
     for (int i = 0; i < nb; i++) {
-        const float d = x[i].d;
+        const float d = GGML_FP16_TO_FP32(x[i].d);
 
         for (int j = 0; j < qk/2; ++j) {
             const int x0 = (x[i].qs[j] & 0x0F) - 8;
@@ -1329,8 +1350,8 @@ static void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict
     const int nb = k / qk;
 
     for (int i = 0; i < nb; i++) {
-        const float d = x[i].d;
-        const float m = x[i].m;
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+        const float m = GGML_FP16_TO_FP32(x[i].m);
 
         for (int j = 0; j < qk/2; ++j) {
             const int x0 = (x[i].qs[j] & 0x0F);
@@ -1405,7 +1426,7 @@ static void dequantize_row_q8_0(const void * restrict vx, float * restrict y, in
     const block_q8_0 * restrict x = vx;
 
     for (int i = 0; i < nb; i++) {
-        const float d = x[i].d;
+        const float d = GGML_FP16_TO_FP32(x[i].d);
 
         for (int j = 0; j < qk; ++j) {
             y[i*qk + j] = x[i].qs[j]*d;
@@ -1669,8 +1690,9 @@ quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
 static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
     float tmp[8];
 
-    for (int i = 0; i < 8; i++)
+    for (int i = 0; i < 8; i++) {
         tmp[i] = GGML_FP16_TO_FP32(x[i]);
+    }
 
     return _mm256_loadu_ps(tmp);
 }
@@ -2090,8 +2112,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
         const block_q8_0 * restrict y0 = &y[i + 0];
         const block_q8_0 * restrict y1 = &y[i + 1];
 
-        const uint8x16_t m4b   = vdupq_n_u8(0x0F);
-        const int8x16_t  s8b   = vdupq_n_s8(0x8);
+        const uint8x16_t m4b = vdupq_n_u8(0x0F);
+        const int8x16_t  s8b = vdupq_n_s8(0x8);
 
         const uint8x16_t v0_0 = vld1q_u8(x0->qs);
         const uint8x16_t v0_1 = vld1q_u8(x1->qs);
@@ -2119,8 +2141,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
         const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
         const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
 
-        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
 #else
         const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l));
         const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
@@ -2137,8 +2159,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
         const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
         const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
 
-        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d);
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d);
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
 #endif
     }
 
@@ -2150,7 +2172,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
     // Main loop
     for (int i = 0; i < nb; ++i) {
         /* Compute combined scale for the block */
-        const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
+        const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
 
         __m256i bx = bytes_from_nibbles_32(x[i].qs);
 
@@ -2174,7 +2196,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
     // Main loop
     for (int i = 0; i < nb; ++i) {
         // Compute combined scale for the block
-        const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
+        const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
 
         const __m128i lowMask = _mm_set1_epi8(0xF);
         const __m128i off = _mm_set1_epi8(8);
@@ -2216,7 +2238,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
         _mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0);
 
         // Compute combined scale for the block 0 and 1
-        const __m128 d_0_1 = _mm_mul_ps( _mm_set1_ps( x[0].d ), _mm_set1_ps( y[0].d ) );
+        const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[0].d) * GGML_FP16_TO_FP32(y[0].d) );
 
         const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[0].qs);
 
@@ -2234,7 +2256,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
         _mm_prefetch(&y[1] + sizeof(block_q8_0), _MM_HINT_T0);
 
         // Compute combined scale for the block 2 and 3
-        const __m128 d_2_3 = _mm_mul_ps( _mm_set1_ps( x[1].d ), _mm_set1_ps( y[1].d ) );
+        const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[1].d) * GGML_FP16_TO_FP32(y[1].d) );
 
         const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[1].qs);
 
@@ -2267,7 +2289,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
         _mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0);
 
         // Compute combined scale for the block 0 and 1
-        const __m128 d_0_1 = _mm_mul_ps( _mm_set1_ps( x[i].d ), _mm_set1_ps( y[i].d ) );
+        const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
 
         const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[i].qs);
 
@@ -2285,7 +2307,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
         _mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
 
         // Compute combined scale for the block 2 and 3
-        const __m128 d_2_3 = _mm_mul_ps( _mm_set1_ps( x[i + 1].d ), _mm_set1_ps( y[i + 1].d ) );
+        const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i + 1].d) * GGML_FP16_TO_FP32(y[i + 1].d) );
 
         const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[i + 1].qs);
 
@@ -2333,7 +2355,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
             sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
         }
 
-        sumf += (x[i].d*y[i].d)*sumi;
+        sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d);
     }
 
     *s = sumf;
@@ -2363,7 +2385,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
         const block_q8_1 * restrict y0 = &y[i + 0];
         const block_q8_1 * restrict y1 = &y[i + 1];
 
-        summs += x0->m * y0->s + x1->m * y1->s;
+        summs += GGML_FP16_TO_FP32(x0->m) * y0->s + GGML_FP16_TO_FP32(x1->m) * y1->s;
 
         const uint8x16_t m4b = vdupq_n_u8(0x0F);
 
@@ -2387,8 +2409,8 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
         const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
         const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
 
-        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*y0->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*y1->d);
 #else
         const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0l));
         const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0l));
@@ -2405,8 +2427,8 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
         const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
         const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
 
-        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d);
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d);
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d);
 #endif
     }
 
@@ -2419,13 +2441,13 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
 
     // Main loop
     for (int i = 0; i < nb; ++i) {
-        const float * d0 = &x[i].d;
-        const float * d1 = &y[i].d;
+        const float d0 = GGML_FP16_TO_FP32(x[i].d);
+        const float d1 = y[i].d;
 
-        summs += x[i].m * y[i].s;
+        summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
 
-        const __m256 d0v = _mm256_broadcast_ss( d0 );
-        const __m256 d1v = _mm256_broadcast_ss( d1 );
+        const __m256 d0v = _mm256_set1_ps( d0 );
+        const __m256 d1v = _mm256_set1_ps( d1 );
 
         // Compute combined scales
         const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
@@ -2434,7 +2456,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
         const __m256i bx = bytes_from_nibbles_32(x[i].qs);
         const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
 
-        const __m256 xy = mul_sum_i8_pairs_float(bx, by);
+        const __m256 xy = mul_sum_us8_pairs_float(bx, by);
 
         // Accumulate d0*d1*x*y
 #if defined(__AVX2__)
@@ -2459,7 +2481,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
             sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
         }
 
-        sumf += (x[i].d*y[i].d)*sumi + x[i].m*y[i].s;
+        sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
     }
 
     *s = sumf;
@@ -2535,16 +2557,13 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
         const int8x16_t v1_1l = vld1q_s8(y1->qs);
         const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
 
-        const float x0d = GGML_FP16_TO_FP32(x0->d);
-        const float x1d = GGML_FP16_TO_FP32(x1->d);
-
 #if defined(__ARM_FEATURE_DOTPROD)
         sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
                         vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
-                        vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), x0d*y0->d);
+                        vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
         sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
                         vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
-                        vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), x1d*y1->d);
+                        vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
 #else
         const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
         const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
@@ -2561,8 +2580,8 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
         const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
         const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
 
-        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d);
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1d*y1->d);
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
 #endif
     }
 
@@ -2637,7 +2656,7 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
     // Main loop
     for (int i = 0; i < nb; i++) {
         /* Compute combined scale for the block */
-        const __m256 d = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)), _mm256_broadcast_ss(&y[i].d));
+        const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
 
         __m256i bx = bytes_from_nibbles_32(x[i].qs);
         __m256i bxhi = bytes_from_bits_32(x[i].qh);
@@ -2661,7 +2680,7 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
     // Main loop
     for (int i = 0; i < nb; i++) {
         /* Compute combined scale for the block */
-        const __m256 d = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)), _mm256_broadcast_ss(&y[i].d));
+        const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
 
         __m256i bx = bytes_from_nibbles_32(x[i].qs);
         const __m256i bxhi = bytes_from_bits_32(x[i].qh);
@@ -2704,7 +2723,7 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
             sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]);
         }
 
-        sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi;
+        sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi;
     }
 
     *s = sumf;
@@ -2786,16 +2805,13 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
         const int8x16_t v1_1l = vld1q_s8(y1->qs);
         const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
 
-        const float x0d = GGML_FP16_TO_FP32(x0->d);
-        const float x1d = GGML_FP16_TO_FP32(x1->d);
-
 #if defined(__ARM_FEATURE_DOTPROD)
         sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
                         vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
-                        vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), x0d*y0->d);
+                        vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*y0->d);
         sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
                         vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
-                        vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), x1d*y1->d);
+                        vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*y1->d);
 #else
         const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
         const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
@@ -2812,8 +2828,8 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
         const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
         const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
 
-        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d);
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1d*y1->d);
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d);
 #endif
     }
 
@@ -2873,15 +2889,14 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
         const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
         const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
 
-        const float x0d = GGML_FP16_TO_FP32(x0->d);
-
         // dot product
-        sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(
-                        wasm_i32x4_add(
-                            wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
-                                           wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
-                            wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
-                                           wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), wasm_f32x4_splat(x0d*y0->d)));
+        sumv = wasm_f32x4_add(sumv,
+                wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add(
+                wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
+                               wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
+                wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
+                               wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
+                    wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * y0->d));
     }
 
     *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
@@ -2903,10 +2918,10 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
         bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10));
         bx = _mm256_or_si256(bx, bxhi);
 
-        const __m256 dy = _mm256_broadcast_ss(&y[i].d);
+        const __m256 dy = _mm256_set1_ps(y[i].d);
         const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
 
-        const __m256 q = mul_sum_i8_pairs_float(bx, by);
+        const __m256 q = mul_sum_us8_pairs_float(bx, by);
 
         acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
     }
@@ -2937,10 +2952,10 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
         bxh = _mm_or_si128(bxh, bxhih);
         bx = _mm256_set_m128i(bxh, bxl);
 
-        const __m256 dy = _mm256_broadcast_ss(&y[i].d);
+        const __m256 dy = _mm256_set1_ps(y[i].d);
         const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
 
-        const __m256 q = mul_sum_i8_pairs_float(bx, by);
+        const __m256 q = mul_sum_us8_pairs_float(bx, by);
 
         acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);
     }
@@ -3007,11 +3022,11 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
 #if defined(__ARM_FEATURE_DOTPROD)
         sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
                         vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
-                        vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), x0->d*y0->d);
+                        vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
 
         sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
                         vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
-                        vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), x1->d*y1->d);
+                        vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
 
 #else
         const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0));
@@ -3029,8 +3044,8 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
         const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
         const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
 
-        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), x0->d*y0->d);
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), x1->d*y1->d);
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
 #endif
     }
 
@@ -3042,7 +3057,7 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
     // Main loop
     for (int i = 0; i < nb; ++i) {
         // Compute combined scale for the block
-        const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
+        const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
         __m256i bx = _mm256_loadu_si256((const __m256i *)x[i].qs);
         __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
 
@@ -3068,7 +3083,7 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
             sumi += x[i].qs[j]*y[i].qs[j];
         }
 
-        sumf += (x[i].d*y[i].d)*sumi;
+        sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d));
     }
 
     *s = sumf;
@@ -6214,7 +6229,6 @@ struct ggml_tensor * ggml_alibi(
     ((int32_t *) b->data)[1] = n_head;
     GGML_ASSERT(sizeof(float) == sizeof(int32_t));
     (((float *) b->data)[2]) = bias_max;
-    
 
     ggml_scratch_load(ctx);
 
@@ -6246,7 +6260,6 @@ struct ggml_tensor * ggml_clamp(
     struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
     ((float *) b->data)[0] = min;
     ((float *) b->data)[1] = max;
-    
 
     result->op   = GGML_OP_CLAMP;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -10538,34 +10551,29 @@ static void ggml_compute_forward_diag_mask_f32(
     assert(src1->type == GGML_TYPE_I32);
     assert(ggml_nelements(src1) == 2);
 
+    const int ith = params->ith;
+    const int nth = params->nth;
+
     const int  n_past  =       ((int32_t *) src1->data)[0];
     const bool inplace = (bool)((int32_t *) src1->data)[1];
 
-    if (params->type == GGML_TASK_INIT) {
-        // TODO: this hack is not good, need a better way to handle this
-        if (!inplace) {
-            // use the init task to copy src -> dst
-            struct ggml_compute_params params_cpy = *params;
-
-            params_cpy.ith  = 0;
-            params_cpy.nth  = 1;
-            params_cpy.type = GGML_TASK_COMPUTE;
-
-            ggml_compute_forward_dup_same_cont(&params_cpy, src0, dst);
-        }
+    assert(n_past >= 0);
 
-        return;
+    if (!inplace && (params->type == GGML_TASK_INIT)) {
+        // memcpy needs to be synchronized across threads to avoid race conditions.
+        // => do it in INIT phase
+        GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
+        GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
+        memcpy(
+            ((char *)  dst->data),
+            ((char *) src0->data),
+            ggml_nbytes(dst));
     }
 
-    if (params->type == GGML_TASK_FINALIZE) {
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
     }
 
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    assert(n_past >= 0);
-
     // TODO: handle transposed/permuted matrices
 
     const int n  = ggml_nrows(src0);
@@ -10725,9 +10733,9 @@ static void ggml_compute_forward_alibi_f32(
         return;
     }
 
-    const int n_past = ((int32_t *) src1->data)[0];
-    const int n_head = ((int32_t *) src1->data)[1];
-    const float max_bias = ((float *) src1->data)[2];
+    const int   n_past   = ((int32_t *) src1->data)[0];
+    const int   n_head   = ((int32_t *) src1->data)[1];
+    const float max_bias = ((float *)   src1->data)[2];
 
     assert(n_past >= 0);
 
@@ -10790,9 +10798,9 @@ static void ggml_compute_forward_alibi_f16(
         return;
     }
 
-    const int n_past = ((int32_t *) src1->data)[0];
-    const int n_head = ((int32_t *) src1->data)[1];
-    const float max_bias = ((float *) src1->data)[2];
+    const int   n_past   = ((int32_t *) src1->data)[0];
+    const int   n_head   = ((int32_t *) src1->data)[1];
+    const float max_bias = ((float *)   src1->data)[2];
 
     assert(n_past >= 0);