]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : sync llama.cpp (Q4_3 + CUDA)
authorGeorgi Gerganov <redacted>
Sat, 22 Apr 2023 09:36:42 +0000 (12:36 +0300)
committerGeorgi Gerganov <redacted>
Sat, 22 Apr 2023 09:52:43 +0000 (12:52 +0300)
scripts/sync-llama.sh [new file with mode: 0755]
src/ggml-cuda.cu [new file with mode: 0644]
src/ggml-cuda.h [new file with mode: 0644]
src/ggml.c

diff --git a/scripts/sync-llama.sh b/scripts/sync-llama.sh
new file mode 100755 (executable)
index 0000000..e5d8c4d
--- /dev/null
@@ -0,0 +1,6 @@
+#!/bin/bash
+
+cp -rpv ../llama.cpp/ggml.c       src/ggml.c
+cp -rpv ../llama.cpp/ggml-cuda.cu src/ggml-cuda.cu
+cp -rpv ../llama.cpp/ggml-cuda.h  src/ggml-cuda.h
+cp -rpv ../llama.cpp/ggml.h       include/ggml/ggml.h
diff --git a/src/ggml-cuda.cu b/src/ggml-cuda.cu
new file mode 100644 (file)
index 0000000..fa511c1
--- /dev/null
@@ -0,0 +1,228 @@
+#include <stdint.h>
+#include <stdio.h>
+#include <cuda_fp16.h>
+#include <atomic>
+#include "ggml-cuda.h"
+
+typedef uint16_t ggml_fp16_t;
+static_assert(sizeof(__half) == sizeof(ggml_fp16_t), "wrong fp16 size");
+
+#define QK4_0 32
+typedef struct {
+    float   d;              // delta
+    uint8_t qs[QK4_0 / 2];  // nibbles / quants
+} block_q4_0;
+static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding");
+
+#define QK4_1 32
+typedef struct {
+    float   d;              // delta
+    float   m;              // min
+    uint8_t qs[QK4_1 / 2];  // nibbles / quants
+} block_q4_1;
+static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
+
+#define QK4_2 16
+typedef struct {
+    __half  d;              // delta
+    uint8_t qs[QK4_2 / 2];  // nibbles / quants
+} block_q4_2;
+static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding");
+
+#define QK4_3 16
+typedef struct {
+    __half  d;              // delta
+    __half  m;              // min
+    uint8_t qs[QK4_3 / 2];  // nibbles / quants
+} block_q4_3;
+static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding");
+
+static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
+    const block_q4_0 * x = (const block_q4_0 *) vx;
+
+    const int i = blockIdx.x;
+
+    const float d = x[i].d;
+
+    const uint8_t * pp = x[i].qs;
+
+    for (int l = 0; l < QK4_0; l += 2) {
+        const uint8_t vi = pp[l/2];
+
+        const int8_t vi0 = vi & 0xf;
+        const int8_t vi1 = vi >> 4;
+
+        const float v0 = (vi0 - 8)*d;
+        const float v1 = (vi1 - 8)*d;
+
+        y[i*QK4_0 + l + 0] = v0;
+        y[i*QK4_0 + l + 1] = v1;
+    }
+}
+
+static __global__ void dequantize_block_q4_1(const void * vx, float * y) {
+    const block_q4_1 * x = (const block_q4_1 *) vx;
+
+    const int i = blockIdx.x;
+
+    const float d = x[i].d;
+    const float m = x[i].m;
+
+    const uint8_t * pp = x[i].qs;
+
+    for (int l = 0; l < QK4_1; l += 2) {
+        const uint8_t vi = pp[l/2];
+
+        const int8_t vi0 = vi & 0xf;
+        const int8_t vi1 = vi >> 4;
+
+        const float v0 = vi0*d + m;
+        const float v1 = vi1*d + m;
+
+        y[i*QK4_1 + l + 0] = v0;
+        y[i*QK4_1 + l + 1] = v1;
+    }
+}
+
+static __global__ void dequantize_block_q4_2(const void * vx, float * y) {
+    const block_q4_2 * x = (const block_q4_2 *) vx;
+
+    const int i = blockIdx.x;
+
+    const float d = x[i].d;
+
+    const uint8_t * pp = x[i].qs;
+
+    for (int l = 0; l < QK4_2; l += 2) {
+        const uint8_t vi = pp[l/2];
+
+        const int8_t vi0 = vi & 0xf;
+        const int8_t vi1 = vi >> 4;
+
+        const float v0 = (vi0 - 8)*d;
+        const float v1 = (vi1 - 8)*d;
+
+        y[i*QK4_2 + l + 0] = v0;
+        y[i*QK4_2 + l + 1] = v1;
+    }
+}
+
+static __global__ void dequantize_block_q4_3(const void * vx, float * y) {
+    const block_q4_3 * x = (const block_q4_3 *) vx;
+
+    const int i = blockIdx.x;
+
+    const float d = x[i].d;
+    const float m = x[i].m;
+
+    const uint8_t * pp = x[i].qs;
+
+    for (int l = 0; l < QK4_3; l += 2) {
+        const uint8_t vi = pp[l/2];
+
+        const int8_t vi0 = vi & 0xf;
+        const int8_t vi1 = vi >> 4;
+
+        const float v0 = vi0*d + m;
+        const float v1 = vi1*d + m;
+
+        y[i*QK4_3 + l + 0] = v0;
+        y[i*QK4_3 + l + 1] = v1;
+    }
+}
+
+void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
+    const int nb = k / QK4_0;
+    dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
+}
+
+void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
+    const int nb = k / QK4_1;
+    dequantize_block_q4_1<<<nb, 1, 0, stream>>>(vx, y);
+}
+
+void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
+    const int nb = k / QK4_2;
+    dequantize_block_q4_2<<<nb, 1, 0, stream>>>(vx, y);
+}
+
+void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
+    const int nb = k / QK4_3;
+    dequantize_block_q4_3<<<nb, 1, 0, stream>>>(vx, y);
+}
+
+// buffer pool for cuda
+#define MAX_CUDA_BUFFERS 16
+
+struct scoped_spin_lock {
+    std::atomic_flag& lock;
+    scoped_spin_lock(std::atomic_flag& lock) : lock(lock) {
+        while (lock.test_and_set(std::memory_order_acquire)) {
+            ; // spin
+        }
+    }
+    ~scoped_spin_lock() {
+        lock.clear(std::memory_order_release);
+    }
+    scoped_spin_lock(const scoped_spin_lock&) = delete;
+    scoped_spin_lock& operator=(const scoped_spin_lock&) = delete;
+};
+
+struct cuda_buffer {
+    void * ptr = nullptr;
+    size_t size = 0;
+};
+
+static cuda_buffer g_cuda_buffer_pool[MAX_CUDA_BUFFERS];
+static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
+
+void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
+    scoped_spin_lock lock(g_cuda_pool_lock);
+
+    for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
+        cuda_buffer& b = g_cuda_buffer_pool[i];
+        if (b.size >= size && b.ptr != nullptr) {
+            void * ptr = b.ptr;
+            *actual_size = b.size;
+            b.ptr = nullptr;
+            b.size = 0;
+            return ptr;
+        }
+    }
+    void * ptr;
+    CUDA_CHECK(cudaMalloc((void **) &ptr, size));
+    *actual_size = size;
+    return ptr;
+}
+
+void ggml_cuda_pool_free(void * ptr, size_t size) {
+    scoped_spin_lock lock(g_cuda_pool_lock);
+
+    for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
+        cuda_buffer& b = g_cuda_buffer_pool[i];
+        if (b.ptr == nullptr) {
+            b.ptr = ptr;
+            b.size = size;
+            return;
+        }
+    }
+    fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n");
+    CUDA_CHECK(cudaFree(ptr));
+}
+
+cublasHandle_t g_cublasH = NULL;
+cudaStream_t g_cudaStream = NULL;
+
+void ggml_init_cublas(void) {
+    if (g_cublasH == NULL) {
+        // create cublas handle, bind a stream
+        CUBLAS_CHECK(cublasCreate(&g_cublasH));
+
+        CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream, cudaStreamNonBlocking));
+
+        CUBLAS_CHECK(cublasSetStream(g_cublasH, g_cudaStream));
+
+        // configure logging to stdout
+        // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
+    }
+}
diff --git a/src/ggml-cuda.h b/src/ggml-cuda.h
new file mode 100644 (file)
index 0000000..370bbc7
--- /dev/null
@@ -0,0 +1,41 @@
+#include <cublas_v2.h>
+#include <cuda_runtime.h>
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+#define CUDA_CHECK(err)                                                                 \
+    do {                                                                                \
+        cudaError_t err_ = (err);                                                       \
+        if (err_ != cudaSuccess) {                                                      \
+            fprintf(stderr, "CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__,   \
+                cudaGetErrorString(err_));                                              \
+            exit(1);                                                                    \
+        }                                                                               \
+    } while (0)
+
+#define CUBLAS_CHECK(err)                                                               \
+    do {                                                                                \
+        cublasStatus_t err_ = (err);                                                    \
+        if (err_ != CUBLAS_STATUS_SUCCESS) {                                            \
+            fprintf(stderr, "cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__);    \
+            exit(1);                                                                    \
+        }                                                                               \
+    } while (0)
+
+extern cublasHandle_t g_cublasH;
+extern cudaStream_t   g_cudaStream;
+
+void   ggml_init_cublas(void);
+void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size);
+void   ggml_cuda_pool_free(void * ptr, size_t size);
+
+void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
+void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream);
+void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream);
+void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t stream);
+
+#ifdef  __cplusplus
+}
+#endif
index 998602150fe559e7c72c2feb81aa8ca96b109e62..4277683e5b9a3d41661409d1ece37955ebfcff7e 100644 (file)
@@ -148,44 +148,7 @@ inline static void* ggml_aligned_malloc(size_t size) {
 #elif defined(GGML_USE_OPENBLAS)
 #include <cblas.h>
 #elif defined(GGML_USE_CUBLAS)
-#include <cublas_v2.h>
-#include <cuda_runtime.h>
 #include "ggml-cuda.h"
-
-#define CUDA_CHECK(err)                                                        \
-    do {                                                                       \
-        cudaError_t err_ = (err);                                              \
-        if (err_ != cudaSuccess) {                                             \
-            printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__,   \
-                cudaGetErrorString(err_));                                     \
-            exit(1);                                                           \
-        }                                                                      \
-    } while (0)
-
-#define CUBLAS_CHECK(err)                                                      \
-    do {                                                                       \
-        cublasStatus_t err_ = (err);                                           \
-        if (err_ != CUBLAS_STATUS_SUCCESS) {                                   \
-            printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__);    \
-            exit(1);                                                           \
-        }                                                                      \
-    } while (0)
-
-static cublasHandle_t cublasH = NULL;
-static cudaStream_t cudaStream = NULL;
-static void init_cublas(void) {
-    if (cublasH == NULL) {
-        // create cublas handle, bind a stream
-        CUBLAS_CHECK(cublasCreate(&cublasH));
-
-        CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking));
-
-        CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream));
-
-        // configure logging to stdout
-        // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
-    }
-}
 #endif
 
 #undef MIN
@@ -487,6 +450,32 @@ static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
     return bytes;
 }
 
+// horizontally add 8 floats
+static inline float hsum_float_8(const __m256 x) {
+    __m128 res = _mm256_extractf128_ps(x, 1);
+    res = _mm_add_ps(res, _mm256_castps256_ps128(x));
+    res = _mm_add_ps(res, _mm_movehl_ps(res, res));
+    res = _mm_add_ss(res, _mm_movehdup_ps(res));
+    return _mm_cvtss_f32(res);
+}
+
+// horizontally add 8 int32_t
+static inline int hsum_i32_8(const __m256i a) {
+    const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
+    const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
+    const __m128i sum64 = _mm_add_epi32(hi64, sum128);
+    const __m128i hi32  = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
+    return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
+}
+
+// horizontally add 4 int32_t
+static inline int hsum_i32_4(const __m128i a) {
+    const __m128i hi64 = _mm_unpackhi_epi64(a, a);
+    const __m128i sum64 = _mm_add_epi32(hi64, a);
+    const __m128i hi32  = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
+    return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
+}
+
 #if __AVX2__ || __AVX512F__
 // Unpack 32 4-bit fields into 32 bytes
 // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
@@ -507,6 +496,24 @@ static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
     return bytes;
 }
 
+// add int16_t pairwise and return as float vector
+static inline __m256 sum_i16_pairs_float(const __m256i x) {
+    const __m256i ones = _mm256_set1_epi16(1);
+    const __m256i summed_pairs = _mm256_madd_epi16(ones, x);
+    return _mm256_cvtepi32_ps(summed_pairs);
+}
+
+// multiply int8_t, add results pairwise twice and return as float vector
+static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
+    // Get absolute values of x vectors
+    const __m256i ax = _mm256_sign_epi8(x, x);
+    // Sign the values of the y vectors
+    const __m256i sy = _mm256_sign_epi8(y, x);
+    // Perform multiplication and create 16-bit values
+    const __m256i dot = _mm256_maddubs_epi16(ax, sy);
+    return sum_i16_pairs_float(dot);
+}
+
 static inline __m128i packNibbles( __m256i bytes )
 {
     // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
@@ -657,9 +664,11 @@ static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong
 #define QK8_0 32
 typedef struct {
     float   d;          // delta
+    float   s0;         // d * sum(qs[i]) low
+    float   s1;         // d * sum(qs[i]) high
     int8_t  qs[QK8_0];  // quants
 } block_q8_0;
-static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
+static_assert(sizeof(block_q8_0) == 3*sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
 
 
 // reference implementation for deterministic creation of model files
@@ -1233,9 +1242,9 @@ static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int
 
     block_q4_2 * restrict y = vy;
 
-    //quantize_row_q4_2_reference(x, y, k);
+    quantize_row_q4_2_reference(x, y, k);
     // This produces the exact same format, just better match to the input floats ("better" as measured by RMSE)
-    quantize_row_q4_2_rmse(x, y, k);
+    //quantize_row_q4_2_rmse(x, y, k);
 }
 
 static void quantize_row_q4_3_reference(const float * restrict x, block_q4_3 * restrict y, int k) {
@@ -1299,10 +1308,22 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
 
         y[i].d = d;
 
-        for (int l = 0; l < QK8_0; ++l) {
-            const float v = x[i*QK8_0 + l]*id;
-            y[i].qs[l] = roundf(v);
+        int sum0 = 0;
+        int sum1 = 0;
+
+        for (int l = 0; l < QK8_0/2; ++l) {
+            const float v0 = x[i*QK8_0           + l]*id;
+            const float v1 = x[i*QK8_0 + QK8_0/2 + l]*id;
+
+            y[i].qs[          l] = roundf(v0);
+            y[i].qs[QK8_0/2 + l] = roundf(v1);
+
+            sum0 += y[i].qs[          l];
+            sum1 += y[i].qs[QK8_0/2 + l];
         }
+
+        y[i].s0 = d * sum0;
+        y[i].s1 = d * sum1;
     }
 }
 
@@ -1332,7 +1353,11 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
 
         y[i].d = d;
 
-        for (int l = 0; l < 8; l++) {
+        int32x4_t accv0 = vdupq_n_s32(0);
+        int32x4_t accv1 = vdupq_n_s32(0);
+
+        // low half
+        for (int l = 0; l < 4; l++) {
             const float32x4_t v  = vmulq_n_f32(srcv[l], id);
             const int32x4_t   vi = vcvtnq_s32_f32(v);
 
@@ -1340,7 +1365,28 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
             y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
             y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
             y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
+
+            accv0 = vaddq_s32(accv0, vi);
         }
+
+        // high half
+        for (int l = 4; l < 8; l++) {
+            const float32x4_t v  = vmulq_n_f32(srcv[l], id);
+            const int32x4_t   vi = vcvtnq_s32_f32(v);
+
+            y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0);
+            y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
+            y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
+            y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
+
+            accv1 = vaddq_s32(accv1, vi);
+        }
+
+        const int32_t sum0 = vaddvq_s32(accv0);
+        const int32_t sum1 = vaddvq_s32(accv1);
+
+        y[i].s0 = d * sum0;
+        y[i].s1 = d * sum1;
     }
 #elif defined(__AVX2__) || defined(__AVX__)
     for (int i = 0; i < nb; i++) {
@@ -1388,6 +1434,11 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
         __m256i i3 = _mm256_cvtps_epi32( v3 );
 
 #if defined(__AVX2__)
+        // Compute the sum of the quants and set y[i].s
+        //y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
+        y[i].s0 = d * hsum_i32_8(_mm256_add_epi32(i0, i1));
+        y[i].s1 = d * hsum_i32_8(_mm256_add_epi32(i2, i3));
+
         // Convert int32 to int16
         i0 = _mm256_packs_epi32( i0, i1 );     // 0, 1, 2, 3,  8, 9, 10, 11,  4, 5, 6, 7, 12, 13, 14, 15
         i2 = _mm256_packs_epi32( i2, i3 );     // 16, 17, 18, 19,  24, 25, 26, 27,  20, 21, 22, 23, 28, 29, 30, 31
@@ -1413,6 +1464,12 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
         __m128i ni6 = _mm256_castsi256_si128( i3 );
         __m128i ni7 = _mm256_extractf128_si256( i3, 1);
 
+        // Compute the sum of the quants and set y[i].s
+        const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3));
+        const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7));
+        y[i].s0 = d * hsum_i32_4(s0);
+        y[i].s1 = d * hsum_i32_4(s1);
+
         // Convert int32 to int16
         ni0 = _mm_packs_epi32( ni0, ni1 );
         ni2 = _mm_packs_epi32( ni2, ni3 );
@@ -2366,20 +2423,21 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
     const block_q4_0 * restrict x = vx;
     const block_q8_0 * restrict y = vy;
 
-    float sumf = 0.0;
-
 #if defined(__ARM_NEON)
     float32x4_t sumv0 = vdupq_n_f32(0.0f);
     float32x4_t sumv1 = vdupq_n_f32(0.0f);
 
+    float sum8 = 0;
+
     for (int i = 0; i < nb; i += 2) {
         const block_q4_0 * restrict x0 = &x[i + 0];
         const block_q4_0 * restrict x1 = &x[i + 1];
         const block_q8_0 * restrict y0 = &y[i + 0];
         const block_q8_0 * restrict y1 = &y[i + 1];
 
+        sum8 += x0->d * (y0->s0 + y0->s1) + x1->d * (y1->s0 + y1->s1);
+
         const uint8x16_t m4b   = vdupq_n_u8(0xf);
-        const int8x16_t  s8b   = vdupq_n_s8(0x8);
 
         const uint8x16_t v0_0 = vld1q_u8(x0->qs);
         const uint8x16_t v0_1 = vld1q_u8(x1->qs);
@@ -2390,12 +2448,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
         const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
         const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
 
-        // sub 8
-        const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
-        const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
-        const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
-        const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
-
         // load y
         const int8x16_t v1_0l = vld1q_s8(y0->qs);
         const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
@@ -2410,21 +2462,21 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
 
 #if defined(__ARM_FEATURE_DOTPROD)
         // dot product into int32x4_t
-        const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs);
-        const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs);
+        const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs);
+        const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs);
 
         sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
         sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
 #else
-        const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
-        const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
-        const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
-        const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
+        const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0ls));
+        const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0ls));
+        const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0hs));
+        const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0hs));
 
-        const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
-        const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
-        const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
-        const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
+        const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1ls));
+        const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1ls));
+        const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1hs));
+        const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1hs));
 
         const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
         const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
@@ -2436,7 +2488,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
 #endif
     }
 
-    sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
+    *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) - 8 * sum8;
 #elif defined(__AVX2__)
     // Initialize accumulator with zeros
     __m256 acc = _mm256_setzero_ps();
@@ -2454,32 +2506,13 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
 
         __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
 
-        // Get absolute values of x vectors
-        const __m256i ax = _mm256_sign_epi8(bx, bx);
-
-        // Sign the values of the y vectors
-        const __m256i sy = _mm256_sign_epi8(by, bx);
-
-        // Perform multiplication and create 16-bit values
-        const __m256i dot = _mm256_maddubs_epi16(ax, sy);
-
-        const __m256i ones = _mm256_set1_epi16(1);
-        __m256i xy_q = _mm256_madd_epi16(ones, dot);
-
-        /* Convert to vectore of 8 int32_t to 8 floats */
-        __m256 q = _mm256_cvtepi32_ps( xy_q );
+        const __m256 q = mul_sum_i8_pairs_float(bx, by);
 
         /* Multiply q with scale and accumulate */
         acc = _mm256_fmadd_ps( d, q, acc );
     }
 
-    // Return horizontal sum of the acc vector
-    __m128 res = _mm256_extractf128_ps( acc, 1 );
-    res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
-    res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
-    res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
-
-    sumf = _mm_cvtss_f32( res );
+    *s = hsum_float_8(acc);
 #elif defined(__AVX__)
     // Initialize accumulator with zeros
     __m256 acc = _mm256_setzero_ps();
@@ -2518,15 +2551,10 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
         acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
     }
 
-    // Return horizontal sum of the acc vector
-    __m128 res = _mm256_extractf128_ps( acc, 1 );
-    res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
-    res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
-    res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
-
-    sumf = _mm_cvtss_f32( res );
+    *s = hsum_float_8(acc);
 #else
     // scalar
+    float sumf = 0.0;
     for (int i = 0; i < nb; i++) {
         const float d0 = x[i].d;
         const float d1 = y[i].d;
@@ -2548,9 +2576,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
         }
         sumf += d0*d1*sumi;
     }
-#endif
-
     *s = sumf;
+#endif
 }
 
 static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
@@ -2562,19 +2589,21 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
     const block_q4_1 * restrict x = vx;
     const block_q8_0 * restrict y = vy;
 
-    float sumf = 0.0;
-
     // TODO: add AVX / WASM SIMD / etc
 #if defined(__ARM_NEON)
     float32x4_t sumv0 = vdupq_n_f32(0.0f);
     float32x4_t sumv1 = vdupq_n_f32(0.0f);
 
+    float summs = 0;
+
     for (int i = 0; i < nb; i += 2) {
         const block_q4_1 * restrict x0 = &x[i + 0];
         const block_q4_1 * restrict x1 = &x[i + 1];
         const block_q8_0 * restrict y0 = &y[i + 0];
         const block_q8_0 * restrict y1 = &y[i + 1];
 
+        summs += x0->m * (y0->s0 + y0->s1) + x1->m * (y1->s0 + y1->s1);
+
         const uint8x16_t m4b = vdupq_n_u8(0xf);
 
         const uint8x16_t v0_0 = vld1q_u8(x0->qs);
@@ -2586,33 +2615,22 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
         const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
         const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
 
+        // interleave
+        const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
+        const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
+        const int8x16_t v0_1lz = vzip1q_s8(v0_1l, v0_1h);
+        const int8x16_t v0_1hz = vzip2q_s8(v0_1l, v0_1h);
+
         // load y
         const int8x16_t v1_0l = vld1q_s8(y0->qs);
         const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
         const int8x16_t v1_1l = vld1q_s8(y1->qs);
         const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
 
-        // interleave
-        const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
-        const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
-        const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
-        const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
-
-        const int16x8_t s0i = vaddq_s16(
-                        vaddq_s16(vmovl_s8(vget_low_s8(v1_0ls)), vmovl_s8(vget_high_s8(v1_0ls))),
-                        vaddq_s16(vmovl_s8(vget_low_s8(v1_0hs)), vmovl_s8(vget_high_s8(v1_0hs))));
-
-        const int16x8_t s1i = vaddq_s16(
-                        vaddq_s16(vmovl_s8(vget_low_s8(v1_1ls)), vmovl_s8(vget_high_s8(v1_1ls))),
-                        vaddq_s16(vmovl_s8(vget_low_s8(v1_1hs)), vmovl_s8(vget_high_s8(v1_1hs))));
-
-        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s0i), vget_high_s16(s0i))), x0->m*y0->d);
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s1i), vget_high_s16(s1i))), x1->m*y1->d);
-
 #if defined(__ARM_FEATURE_DOTPROD)
         // dot product into int32x4_t
-        const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs);
-        const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs);
+        const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l), v0_0hz, v1_0h);
+        const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l), v0_1hz, v1_1h);
 
         sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
         sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
@@ -2637,65 +2655,40 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
 #endif
     }
 
-    sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
+    *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
 #elif defined(__AVX2__)
     // Initialize accumulator with zeros
     __m256 acc = _mm256_setzero_ps();
 
+    float summs = 0;
+
     // Main loop
     for (int i = 0; i < nb; ++i) {
         const float * d0 = &x[i].d;
         const float * d1 = &y[i].d;
-        const float * m0 = &x[i].m;
+
+        summs += x[i].m * (y[i].s0 + y[i].s1);
 
         const __m256 d0v = _mm256_broadcast_ss( d0 );
         const __m256 d1v = _mm256_broadcast_ss( d1 );
-        const __m256 m0v = _mm256_broadcast_ss( m0 );
 
         // Compute combined scales
         const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
-        const __m256 d1m0 = _mm256_mul_ps( d1v, m0v );
 
         // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
         const __m256i bx = bytes_from_nibbles_32(x[i].qs);
         const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
 
-        // Get absolute values of x vectors
-        const __m256i ax = _mm256_sign_epi8( bx, bx );
-
-        // Sign the values of the y vectors
-        const __m256i sy = _mm256_sign_epi8( by, bx );
-
-        // Perform multiplication and create 16-bit values
-        const __m256i dot = _mm256_maddubs_epi16( ax, sy );
-        const __m256i ones = _mm256_set1_epi16( 1 );
-        const __m256i xy_q = _mm256_madd_epi16( ones, dot );
-
-        // Convert to vector of 8 int32_t to 8 floats
-        const __m256 xy = _mm256_cvtepi32_ps( xy_q );
+        const __m256 xy = mul_sum_i8_pairs_float(bx, by);
 
         // Accumulate d0*d1*x*y
         acc = _mm256_fmadd_ps( d0d1, xy, acc );
-
-        // Compute sum of y values
-        const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
-        const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
-        const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones );
-        const __m256 ysum = _mm256_cvtepi32_ps( ysumi );
-
-        // Accumulate d1*m0*y
-        acc = _mm256_fmadd_ps( d1m0, ysum, acc );
     }
 
-    // Return horizontal sum of the acc vector
-    __m128 res = _mm256_extractf128_ps( acc, 1 );
-    res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
-    res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
-    res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
-
-    sumf = _mm_cvtss_f32( res );
+    *s = hsum_float_8(acc) + summs;
 #else
     // scalar
+    float sumf = 0.0;
     for (int i = 0; i < nb; i++) {
         const float d0 = x[i].d;
         const float m0 = x[i].m;
@@ -2717,9 +2710,8 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
             sumf += f0*f2 + f1*f3;
         }
     }
-#endif
-
     *s = sumf;
+#endif
 }
 
 static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
@@ -2732,8 +2724,6 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
     const block_q4_2 * restrict x = vx;
     const block_q8_0 * restrict y = vy;
 
-    float sumf = 0.0;
-
 #if defined(__ARM_NEON)
     float32x4_t sumv0 = vdupq_n_f32(0.0f);
     float32x4_t sumv1 = vdupq_n_f32(0.0f);
@@ -2811,7 +2801,7 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
 #endif
     }
 
-    sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
+    *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
 #elif defined(__AVX2__)
     // Initialize accumulator with zeros
     __m256 acc = _mm256_setzero_ps();
@@ -2833,32 +2823,16 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
 
         __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
 
-        // Get absolute values of x vectors
-        const __m256i ax = _mm256_sign_epi8(bx, bx);
-        // Sign the values of the y vectors
-        const __m256i sy = _mm256_sign_epi8(by, bx);
-        // Perform multiplication and create 16-bit values
-        const __m256i dot = _mm256_maddubs_epi16(ax, sy);
-
-        const __m256i ones = _mm256_set1_epi16(1);
-        __m256i xy_q = _mm256_madd_epi16(ones, dot);
-
-        /* Convert to vectore of 8 int32_t to 8 floats */
-        __m256 q = _mm256_cvtepi32_ps(xy_q);
+        const __m256 q = mul_sum_i8_pairs_float(bx, by);
 
         /* Multiply q with scale and accumulate */
         acc = _mm256_fmadd_ps(d, q, acc);
     }
 
-    // Return horizontal sum of the acc vector
-    __m128 res = _mm256_extractf128_ps(acc, 1);
-    res = _mm_add_ps(res, _mm256_castps256_ps128(acc));
-    res = _mm_add_ps(res, _mm_movehl_ps(res, res));
-    res = _mm_add_ss(res, _mm_movehdup_ps(res));
-
-    sumf = _mm_cvtss_f32(res);
+    *s = hsum_float_8(acc);
 #else
     // scalar
+    float sumf = 0.0;
     for (int i = 0; i < nb; i++) {
         const uint8_t * restrict x0 = x[2*i + 0].qs;
         const uint8_t * restrict x1 = x[2*i + 1].qs;
@@ -2893,9 +2867,8 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
         sumf += (d0 * y[i].d) * sumi_0;
         sumf += (d1 * y[i].d) * sumi_1;
     }
-#endif
-
     *s = sumf;
+#endif
 }
 
 static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
@@ -2908,96 +2881,91 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
     const block_q4_3 * restrict x = vx;
     const block_q8_0 * restrict y = vy;
 
-    float sumf = 0.0;
-
 #if defined(__ARM_NEON)
     float32x4_t sumv0 = vdupq_n_f32(0.0f);
     float32x4_t sumv1 = vdupq_n_f32(0.0f);
 
-    for (int i = 0; i < nb; i += 2) {
+    float summs0 = 0.0f;
+    float summs1 = 0.0f;
+
+    for (int i = 0; i < nb; ++i) {
         const block_q4_3 * restrict x0_0 = &x[2*(i + 0) + 0];
         const block_q4_3 * restrict x0_1 = &x[2*(i + 0) + 1];
-        const block_q4_3 * restrict x1_0 = &x[2*(i + 1) + 0];
-        const block_q4_3 * restrict x1_1 = &x[2*(i + 1) + 1];
 
         const block_q8_0 * restrict y0 = &y[i + 0];
-        const block_q8_0 * restrict y1 = &y[i + 1];
-
-        const uint8x16_t m4b = vdupq_n_u8(0xf);
 
-        const float x0_0d = GGML_FP16_TO_FP32(x0_0->d);
-        const float x0_1d = GGML_FP16_TO_FP32(x0_1->d);
-        const float x1_0d = GGML_FP16_TO_FP32(x1_0->d);
-        const float x1_1d = GGML_FP16_TO_FP32(x1_1->d);
-
-        const float x0_0m = GGML_FP16_TO_FP32(x0_0->m);
-        const float x0_1m = GGML_FP16_TO_FP32(x0_1->m);
-        const float x1_0m = GGML_FP16_TO_FP32(x1_0->m);
-        const float x1_1m = GGML_FP16_TO_FP32(x1_1->m);
+        summs0 += GGML_FP16_TO_FP32(x0_0->m) * y0->s0;
+        summs1 += GGML_FP16_TO_FP32(x0_1->m) * y0->s1;
 
         const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
-        const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs));
 
         // 4-bit -> 8-bit
-        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
+        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, vdupq_n_u8(0xf)));
         const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
-        const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
-        const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
 
         // interleave
         const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
         const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
-        const int8x16_t v0_1lz = vzip1q_s8(v0_1l, v0_1h);
-        const int8x16_t v0_1hz = vzip2q_s8(v0_1l, v0_1h);
 
         // load y
         const int8x16_t v1_0l = vld1q_s8(y0->qs);
         const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
-        const int8x16_t v1_1l = vld1q_s8(y1->qs);
-        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
-
-        const int16x8_t sy0_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0l)), vmovl_s8(vget_high_s8(v1_0l)));
-        const int16x8_t sy0_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0h)), vmovl_s8(vget_high_s8(v1_0h)));
-
-        const int16x8_t sy1_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1l)), vmovl_s8(vget_high_s8(v1_1l)));
-        const int16x8_t sy1_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1h)), vmovl_s8(vget_high_s8(v1_1h)));
 
-        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_0), vget_high_s16(sy0_0))), x0_0m*y0->d);
-        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_1), vget_high_s16(sy0_1))), x0_1m*y0->d);
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_0), vget_high_s16(sy1_0))), x1_0m*y1->d);
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_1), vget_high_s16(sy1_1))), x1_1m*y1->d);
+        const float x0_0d = GGML_FP16_TO_FP32(x0_0->d);
+        const float x0_1d = GGML_FP16_TO_FP32(x0_1->d);
 
 #if defined(__ARM_FEATURE_DOTPROD)
         sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d);
-        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d);
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), x1_0d*y1->d);
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), x1_1d*y1->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d);
 #else
         const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
         const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
         const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
         const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
 
-        const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
-        const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
-        const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
-        const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
-
         const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
         const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
-        const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
-        const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
 
         sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(pl0), x0_0d*y0->d);
-        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(ph0), x0_1d*y0->d);
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(pl1), x1_0d*y1->d);
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph1), x1_1d*y1->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph0), x0_1d*y0->d);
 #endif
     }
 
-    sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
+    *s = vaddvq_f32(vaddq_f32(sumv0, sumv1)) + summs0 + summs1;
+#elif defined(__AVX2__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+
+    // Main loop
+    for (int i = 0; i < nb; i++) {
+        const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
+        const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
+        const __m256 dx = _mm256_set_m128(d1, d0);
+
+        const __m128 m0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].m));
+        const __m128 m1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].m));
+        const __m256 mx = _mm256_set_m128(m1, m0);
+
+        const __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
+        const __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
+        const __m256i bx = _mm256_set_m128i(bx1, bx0);
+
+        const __m256 dy = _mm256_broadcast_ss(&y[i].d);
+        const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
+
+        const __m256i syi = _mm256_maddubs_epi16(_mm256_set1_epi8(1), by);
+        const __m256 syf = sum_i16_pairs_float(syi);
+
+        const __m256 q = mul_sum_i8_pairs_float(bx, by);
+
+        const __m256 sxy = _mm256_fmadd_ps(q, dx, _mm256_mul_ps(mx, syf));
+        acc = _mm256_fmadd_ps(sxy, dy, acc);
+    }
+
+    *s = hsum_float_8(acc);
 #else
     // scalar
+    float sumf = 0.0;
     for (int i = 0; i < nb; i++) {
         const uint8_t * restrict x0 = x[2*i + 0].qs;
         const uint8_t * restrict x1 = x[2*i + 1].qs;
@@ -3008,9 +2976,6 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
         const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
         const float m1 = GGML_FP16_TO_FP32(x[2*i + 1].m);
 
-        int sy_0 = 0;
-        int sy_1 = 0;
-
         int sxy_0 = 0;
         int sxy_1 = 0;
 
@@ -3030,19 +2995,14 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
             const int y0_1 = y0[2*(j + QK8_0/4) + 0];
             const int y1_1 = y0[2*(j + QK8_0/4) + 1];
 
-            sy_0 += y0_0 + y1_0;
-            sy_1 += y0_1 + y1_1;
-
             sxy_0 += x0_0*y0_0 + x1_0*y1_0;
             sxy_1 += x0_1*y0_1 + x1_1*y1_1;
         }
 
-        sumf += (d0*sxy_0 + m0*sy_0)*y[i].d;
-        sumf += (d1*sxy_1 + m1*sy_1)*y[i].d;
+        sumf += (d0*sxy_0 + d1*sxy_1)*y[i].d + m0*y[i].s0 + m1*y[i].s1;
     }
-#endif
-
     *s = sumf;
+#endif
 }
 
 
@@ -3720,7 +3680,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
 
         // initialize cuBLAS
         #if defined(GGML_USE_CUBLAS)
-        init_cublas();
+        ggml_init_cublas();
         #endif
 
         is_first_call = false;
@@ -7566,18 +7526,16 @@ static void ggml_compute_forward_mul_mat_f32(
         }
 
 #if defined(GGML_USE_CUBLAS)
-        float *d_X = NULL;
-        float *d_Y = NULL;
-        float *d_D = NULL;
         const float alpha = 1.0f;
         const float beta = 0.0f;
         const int x_ne = ne01 * ne10;
         const int y_ne = ne11 * ne10;
         const int d_ne = ne11 * ne01;
 
-        CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
-        CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
-        CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
+        size_t x_size, y_size, d_size;
+        float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
+        float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
+        float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
 #endif
 
         for (int64_t i03 = 0; i03 < ne03; i03++) {
@@ -7589,19 +7547,19 @@ static void ggml_compute_forward_mul_mat_f32(
 
 #if defined(GGML_USE_CUBLAS)
                 // copy data to device
-                CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream));
-                CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
+                CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, g_cudaStream));
+                CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
 
                 // compute
                 CUBLAS_CHECK(
-                    cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
+                    cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
                             ne01, ne11, ne10,
                             &alpha, d_X, ne00,
                                     d_Y, ne10,
                             &beta,  d_D, ne01));
 
                 // copy data to host
-                CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
+                CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
 #else
                 // zT = y * xT
                 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
@@ -7613,10 +7571,10 @@ static void ggml_compute_forward_mul_mat_f32(
             }
         }
 #if defined(GGML_USE_CUBLAS)
-        CUDA_CHECK(cudaStreamSynchronize(cudaStream));
-        CUDA_CHECK(cudaFree(d_X));
-        CUDA_CHECK(cudaFree(d_Y));
-        CUDA_CHECK(cudaFree(d_D));
+        CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
+        ggml_cuda_pool_free(d_X, x_size);
+        ggml_cuda_pool_free(d_Y, y_size);
+        ggml_cuda_pool_free(d_D, d_size);
 #endif
         //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
 
@@ -7766,18 +7724,16 @@ static void ggml_compute_forward_mul_mat_f16_f32(
 #if defined(GGML_USE_CUBLAS)
         ggml_fp16_t * const wdata = params->wdata;
 
-        float *d_X = NULL;
-        float *d_Y = NULL;
-        float *d_D = NULL;
         const float alpha = 1.0f;
         const float beta = 0.0f;
         const int x_ne = ne01 * ne10;
         const int y_ne = ne11 * ne10;
         const int d_ne = ne11 * ne01;
 
-        CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(ggml_fp16_t) * x_ne));
-        CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
-        CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
+        size_t x_size, y_size, d_size;
+        float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
+        float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
+        float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
 #else
         float * const wdata = params->wdata;
 #endif
@@ -7811,12 +7767,12 @@ static void ggml_compute_forward_mul_mat_f16_f32(
                 float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
 
                 // copy data to device
-                CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, cudaStream));
-                CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, cudaStream));
+                CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, g_cudaStream));
+                CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
 
                 // compute
                 CUBLAS_CHECK(
-                    cublasGemmEx(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
+                    cublasGemmEx(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
                             ne01, ne11, ne10,
                             &alpha, d_X, CUDA_R_16F, ne00,
                                     d_Y, CUDA_R_16F, ne10,
@@ -7825,7 +7781,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
                             CUBLAS_GEMM_DEFAULT));
 
                 // copy data to host
-                CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
+                CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
 #else
                 const float * x = wdata;
                 const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
@@ -7843,10 +7799,10 @@ static void ggml_compute_forward_mul_mat_f16_f32(
         }
 
 #if defined(GGML_USE_CUBLAS)
-        CUDA_CHECK(cudaStreamSynchronize(cudaStream));
-        CUDA_CHECK(cudaFree(d_X));
-        CUDA_CHECK(cudaFree(d_Y));
-        CUDA_CHECK(cudaFree(d_D));
+        CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
+        ggml_cuda_pool_free(d_X, x_size);
+        ggml_cuda_pool_free(d_Y, y_size);
+        ggml_cuda_pool_free(d_D, d_size);
 #endif
         /*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
 
@@ -8014,20 +7970,17 @@ static void ggml_compute_forward_mul_mat_q_f32(
         }
 
 #if defined(GGML_USE_CUBLAS)
-        float *d_X = NULL;
-        float *d_Y = NULL;
-        float *d_D = NULL;
-        float *d_Q = NULL;
         const float alpha = 1.0f;
         const float beta = 0.0f;
         const int x_ne = ne01 * ne10;
         const int y_ne = ne11 * ne10;
         const int d_ne = ne11 * ne01;
 
-        CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
-        CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
-        CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
-        CUDA_CHECK(cudaMalloc((void **)(&d_Q), GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type]));
+        size_t x_size, y_size, d_size, q_size;
+        float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
+        float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
+        float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
+        float *d_Q = ggml_cuda_pool_malloc(GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], &q_size);
 
         void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream)  = NULL;
         if (type == GGML_TYPE_Q4_0) {
@@ -8057,9 +8010,9 @@ static void ggml_compute_forward_mul_mat_q_f32(
                 // copy and dequantize on device
                 CUDA_CHECK(
                     cudaMemcpyAsync(d_Q, (char *) src0->data + i03*nb03 + i02*nb02,
-                        GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, cudaStream));
+                        GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, g_cudaStream));
 
-                dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, cudaStream);
+                dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, g_cudaStream);
                 CUDA_CHECK(cudaGetLastError());
 #else
                 {
@@ -8075,18 +8028,18 @@ static void ggml_compute_forward_mul_mat_q_f32(
 
 #if defined(GGML_USE_CUBLAS)
                 // copy data to device
-                CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
+                CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
 
                 // compute
                 CUBLAS_CHECK(
-                    cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
+                    cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
                             ne01, ne11, ne10,
                             &alpha, d_X, ne00,
                                     d_Y, ne10,
                             &beta,  d_D, ne01));
 
                 // copy data to host
-                CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
+                CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
 #else
                 // zT = y * xT
                 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
@@ -8099,11 +8052,11 @@ static void ggml_compute_forward_mul_mat_q_f32(
         }
 
 #if defined(GGML_USE_CUBLAS)
-        CUDA_CHECK(cudaStreamSynchronize(cudaStream));
-        CUDA_CHECK(cudaFree(d_X));
-        CUDA_CHECK(cudaFree(d_Y));
-        CUDA_CHECK(cudaFree(d_D));
-        CUDA_CHECK(cudaFree(d_Q));
+        CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
+        ggml_cuda_pool_free(d_X, x_size);
+        ggml_cuda_pool_free(d_Y, y_size);
+        ggml_cuda_pool_free(d_D, d_size);
+        ggml_cuda_pool_free(d_Q, q_size);
 #endif
         //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);