]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml : sync latest ggml + llama.cpp updates (quantization)
authorGeorgi Gerganov <redacted>
Sat, 29 Apr 2023 09:31:52 +0000 (12:31 +0300)
committerGeorgi Gerganov <redacted>
Sat, 29 Apr 2023 09:32:28 +0000 (12:32 +0300)
ggml-cuda.cu [new file with mode: 0644]
ggml-cuda.h [new file with mode: 0644]
ggml.c
ggml.h

diff --git a/ggml-cuda.cu b/ggml-cuda.cu
new file mode 100644 (file)
index 0000000..5a2701c
--- /dev/null
@@ -0,0 +1,365 @@
+#include <stdint.h>
+#include <stdio.h>
+#include <cuda_fp16.h>
+#include <atomic>
+#include "ggml-cuda.h"
+
+typedef uint16_t ggml_fp16_t;
+static_assert(sizeof(__half) == sizeof(ggml_fp16_t), "wrong fp16 size");
+
+#define QK4_0 32
+typedef struct {
+    float   d;              // delta
+    uint8_t qs[QK4_0 / 2];  // nibbles / quants
+} block_q4_0;
+static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding");
+
+#define QK4_1 32
+typedef struct {
+    float   d;              // delta
+    float   m;              // min
+    uint8_t qs[QK4_1 / 2];  // nibbles / quants
+} block_q4_1;
+static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
+
+#define QK4_2 16
+typedef struct {
+    __half  d;              // delta
+    uint8_t qs[QK4_2 / 2];  // nibbles / quants
+} block_q4_2;
+static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding");
+
+#define QK5_0 32
+typedef struct {
+    __half d;               // delta
+    uint8_t qh[4];          // 5-th bit of quants
+    uint8_t qs[QK5_0 / 2];  // nibbles / quants
+} block_q5_0;
+static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
+
+#define QK5_1 32
+typedef struct {
+    __half d;               // delta
+    __half m;               // min
+    uint32_t qh;            // 5-th bit of quants
+    uint8_t qs[QK5_1 / 2];  // nibbles / quants
+} block_q5_1;
+static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
+
+#define QK8_0 32
+typedef struct {
+    float   d;              // delta
+    int8_t  qs[QK8_0];      // quants
+} block_q8_0;
+static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
+
+static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
+    const block_q4_0 * x = (const block_q4_0 *) vx;
+
+    const int i = blockIdx.x;
+
+    const float d = x[i].d;
+
+    const uint8_t * pp = x[i].qs;
+
+    for (int l = 0; l < QK4_0; l += 2) {
+        const uint8_t vi = pp[l/2];
+
+        const int8_t vi0 = vi & 0xf;
+        const int8_t vi1 = vi >> 4;
+
+        const float v0 = (vi0 - 8)*d;
+        const float v1 = (vi1 - 8)*d;
+
+        y[i*QK4_0 + l + 0] = v0;
+        y[i*QK4_0 + l + 1] = v1;
+    }
+}
+
+static __global__ void dequantize_block_q4_1(const void * vx, float * y) {
+    const block_q4_1 * x = (const block_q4_1 *) vx;
+
+    const int i = blockIdx.x;
+
+    const float d = x[i].d;
+    const float m = x[i].m;
+
+    const uint8_t * pp = x[i].qs;
+
+    for (int l = 0; l < QK4_1; l += 2) {
+        const uint8_t vi = pp[l/2];
+
+        const int8_t vi0 = vi & 0xf;
+        const int8_t vi1 = vi >> 4;
+
+        const float v0 = vi0*d + m;
+        const float v1 = vi1*d + m;
+
+        y[i*QK4_1 + l + 0] = v0;
+        y[i*QK4_1 + l + 1] = v1;
+    }
+}
+
+static __global__ void dequantize_block_q4_2(const void * vx, float * y) {
+    const block_q4_2 * x = (const block_q4_2 *) vx;
+
+    const int i = blockIdx.x;
+
+    const float d = x[i].d;
+
+    const uint8_t * pp = x[i].qs;
+
+    for (int l = 0; l < QK4_2; l += 2) {
+        const uint8_t vi = pp[l/2];
+
+        const int8_t vi0 = vi & 0xf;
+        const int8_t vi1 = vi >> 4;
+
+        const float v0 = (vi0 - 8)*d;
+        const float v1 = (vi1 - 8)*d;
+
+        y[i*QK4_2 + l + 0] = v0;
+        y[i*QK4_2 + l + 1] = v1;
+    }
+}
+
+static __global__ void dequantize_block_q5_0(const void * vx, float * y) {
+    const block_q5_0 * x = (const block_q5_0 *) vx;
+
+    const int i = blockIdx.x;
+
+    const float d = x[i].d;
+
+    const uint8_t * pp = x[i].qs;
+
+    uint32_t qh;
+    memcpy(&qh, x[i].qh, sizeof(qh));
+
+    for (int l = 0; l < QK5_0; l += 2) {
+        const uint8_t vi = pp[l/2];
+
+        const int8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
+        const int8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
+
+        const int8_t vi0 = ((vi & 0xf) | vh0);
+        const int8_t vi1 = ((vi >>  4) | vh1);
+
+        const float v0 = (vi0 - 16)*d;
+        const float v1 = (vi1 - 16)*d;
+
+        y[i*QK5_0 + l + 0] = v0;
+        y[i*QK5_0 + l + 1] = v1;
+    }
+}
+
+static __global__ void dequantize_block_q5_1(const void * vx, float * y) {
+    const block_q5_1 * x = (const block_q5_1 *) vx;
+
+    const int i = blockIdx.x;
+
+    const float d = x[i].d;
+    const float m = x[i].m;
+
+    const uint8_t * pp = x[i].qs;
+
+    const uint32_t qh = x[i].qh;
+
+    for (int l = 0; l < QK5_1; l += 2) {
+        const uint8_t vi = pp[l/2];
+
+        const int8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
+        const int8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
+
+        const int8_t vi0 = (vi & 0xf) | vh0;
+        const int8_t vi1 = (vi >>  4) | vh1;
+
+        const float v0 = vi0*d + m;
+        const float v1 = vi1*d + m;
+
+        y[i*QK5_1 + l + 0] = v0;
+        y[i*QK5_1 + l + 1] = v1;
+    }
+}
+
+static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
+    const block_q8_0 * x = (const block_q8_0 *) vx;
+
+    const int i = blockIdx.x;
+
+    const float d = x[i].d;
+
+    const int8_t * pp = x[i].qs;
+
+    for (int l = 0; l < QK8_0; l++) {
+        const int8_t vi = pp[l];
+
+        y[i*QK8_0 + l] = vi*d;
+    }
+}
+
+void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
+    const int nb = k / QK4_0;
+    dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
+}
+
+void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
+    const int nb = k / QK4_1;
+    dequantize_block_q4_1<<<nb, 1, 0, stream>>>(vx, y);
+}
+
+void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
+    const int nb = k / QK4_2;
+    dequantize_block_q4_2<<<nb, 1, 0, stream>>>(vx, y);
+}
+
+void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
+    const int nb = k / QK5_0;
+    dequantize_block_q5_0<<<nb, 1, 0, stream>>>(vx, y);
+}
+
+void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
+    const int nb = k / QK5_1;
+    dequantize_block_q5_1<<<nb, 1, 0, stream>>>(vx, y);
+}
+
+void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
+    const int nb = k / QK8_0;
+    dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y);
+}
+
+dequantize_row_q_cuda_t ggml_get_dequantize_row_q_cuda(ggml_type type) {
+    switch (type) {
+        case GGML_TYPE_Q4_0:
+            return dequantize_row_q4_0_cuda;
+        case GGML_TYPE_Q4_1:
+            return dequantize_row_q4_1_cuda;
+        case GGML_TYPE_Q4_2:
+            return dequantize_row_q4_2_cuda;
+        case GGML_TYPE_Q5_0:
+            return dequantize_row_q5_0_cuda;
+        case GGML_TYPE_Q5_1:
+            return dequantize_row_q5_1_cuda;
+        case GGML_TYPE_Q8_0:
+            return dequantize_row_q8_0_cuda;
+        default:
+            return nullptr;
+    }
+}
+
+// buffer pool for cuda
+#define MAX_CUDA_BUFFERS 16
+
+struct scoped_spin_lock {
+    std::atomic_flag& lock;
+    scoped_spin_lock(std::atomic_flag& lock) : lock(lock) {
+        while (lock.test_and_set(std::memory_order_acquire)) {
+            ; // spin
+        }
+    }
+    ~scoped_spin_lock() {
+        lock.clear(std::memory_order_release);
+    }
+    scoped_spin_lock(const scoped_spin_lock&) = delete;
+    scoped_spin_lock& operator=(const scoped_spin_lock&) = delete;
+};
+
+struct cuda_buffer {
+    void * ptr = nullptr;
+    size_t size = 0;
+};
+
+static cuda_buffer g_cuda_buffer_pool[MAX_CUDA_BUFFERS];
+static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
+
+void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
+    scoped_spin_lock lock(g_cuda_pool_lock);
+
+    for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
+        cuda_buffer& b = g_cuda_buffer_pool[i];
+        if (b.size >= size && b.ptr != nullptr) {
+            void * ptr = b.ptr;
+            *actual_size = b.size;
+            b.ptr = nullptr;
+            b.size = 0;
+            return ptr;
+        }
+    }
+    void * ptr;
+    CUDA_CHECK(cudaMalloc((void **) &ptr, size));
+    *actual_size = size;
+    return ptr;
+}
+
+void ggml_cuda_pool_free(void * ptr, size_t size) {
+    scoped_spin_lock lock(g_cuda_pool_lock);
+
+    for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
+        cuda_buffer& b = g_cuda_buffer_pool[i];
+        if (b.ptr == nullptr) {
+            b.ptr = ptr;
+            b.size = size;
+            return;
+        }
+    }
+    fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n");
+    CUDA_CHECK(cudaFree(ptr));
+}
+
+cublasHandle_t g_cublasH = nullptr;
+cudaStream_t g_cudaStream = nullptr;
+cudaStream_t g_cudaStream2 = nullptr;
+cudaEvent_t g_cudaEvent = nullptr;
+
+void ggml_init_cublas() {
+    if (g_cublasH == nullptr) {
+        // create cublas handle, bind a stream
+        CUBLAS_CHECK(cublasCreate(&g_cublasH));
+        CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream, cudaStreamNonBlocking));
+        CUBLAS_CHECK(cublasSetStream(g_cublasH, g_cudaStream));
+
+        // create additional stream and event for synchronization
+        CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream2, cudaStreamNonBlocking));
+        CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvent, cudaEventDisableTiming));
+
+        // configure logging to stdout
+        // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
+    }
+}
+
+cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream) {
+    const uint64_t ne0 = src->ne[0];
+    const uint64_t ne1 = src->ne[1];
+    const uint64_t nb0 = src->nb[0];
+    const uint64_t nb1 = src->nb[1];
+    const uint64_t nb2 = src->nb[2];
+    const uint64_t nb3 = src->nb[3];
+    const enum ggml_type type = src->type;
+    const size_t ts = ggml_type_size(type);
+    const size_t bs = ggml_blck_size(type);
+
+    const void * x = (const void *) ((const char *) src->data + i2*nb2 + i3*nb3);
+    if (nb0 == ts && nb1 == ts*ne0/bs) {
+        return cudaMemcpyAsync(dst, x, ne1*nb1, cudaMemcpyHostToDevice, stream);
+    } else if (nb0 == ts) {
+        return cudaMemcpy2DAsync(dst, ts*ne0/bs, x, nb1, ts*ne0/bs, ne1, cudaMemcpyHostToDevice, stream);
+    } else {
+        for (uint64_t i1 = 0; i1 < ne1; i1++) {
+            const void * rx = (const void *) ((const char *) x + i1*nb1);
+            void * rd = (void *) ((char *) dst + i1*ts*ne0/bs);
+            // pretend the row is a matrix with cols=1
+            cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, cudaMemcpyHostToDevice, stream);
+            if (r != cudaSuccess) return r;
+        }
+        return cudaSuccess;
+    }
+}
+
+void * ggml_cuda_host_malloc(size_t size) {
+    void * ptr;
+    CUDA_CHECK(cudaMallocHost((void **) &ptr, size));
+    return ptr;
+}
+
+void ggml_cuda_host_free(void * ptr) {
+    CUDA_CHECK(cudaFreeHost(ptr));
+}
diff --git a/ggml-cuda.h b/ggml-cuda.h
new file mode 100644 (file)
index 0000000..36782d9
--- /dev/null
@@ -0,0 +1,54 @@
+#include <cublas_v2.h>
+#include <cuda_runtime.h>
+#include "ggml.h"
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+#define CUDA_CHECK(err)                                                                 \
+    do {                                                                                \
+        cudaError_t err_ = (err);                                                       \
+        if (err_ != cudaSuccess) {                                                      \
+            fprintf(stderr, "CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__,   \
+                cudaGetErrorString(err_));                                              \
+            exit(1);                                                                    \
+        }                                                                               \
+    } while (0)
+
+#define CUBLAS_CHECK(err)                                                               \
+    do {                                                                                \
+        cublasStatus_t err_ = (err);                                                    \
+        if (err_ != CUBLAS_STATUS_SUCCESS) {                                            \
+            fprintf(stderr, "cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__);    \
+            exit(1);                                                                    \
+        }                                                                               \
+    } while (0)
+
+extern cublasHandle_t g_cublasH;
+extern cudaStream_t g_cudaStream;
+extern cudaStream_t g_cudaStream2;
+extern cudaEvent_t g_cudaEvent;
+
+void   ggml_init_cublas(void);
+void * ggml_cuda_host_malloc(size_t size);
+void   ggml_cuda_host_free(void * ptr);
+
+void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size);
+void   ggml_cuda_pool_free(void * ptr, size_t size);
+
+void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
+void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream);
+void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream);
+void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
+void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream);
+void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
+
+cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream);
+
+typedef void (*dequantize_row_q_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
+dequantize_row_q_cuda_t ggml_get_dequantize_row_q_cuda(enum ggml_type type);
+
+#ifdef  __cplusplus
+}
+#endif
diff --git a/ggml.c b/ggml.c
index ca3b7b95c46f5a623789fc4ff1972f5cfd0a0522..2ec0c0bf9d9f4408bb736e6b32c8a205bbe489c5 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -19,6 +19,7 @@
 #include <inttypes.h>
 #include <stdio.h>
 #include <float.h>
+#include <limits.h>
 
 // if C99 - static_assert is noop
 // ref: https://stackoverflow.com/a/53923785/4039976
@@ -142,10 +143,14 @@ inline static void* ggml_aligned_malloc(size_t size) {
         } \
     } while (0)
 
-#ifdef GGML_USE_ACCELERATE
+#if defined(GGML_USE_ACCELERATE)
 #include <Accelerate/Accelerate.h>
-#elif GGML_USE_OPENBLAS
+#elif defined(GGML_USE_OPENBLAS)
 #include <cblas.h>
+#elif defined(GGML_USE_CUBLAS)
+#include "ggml-cuda.h"
+#elif defined(GGML_USE_CLBLAST)
+#include "ggml-opencl.h"
 #endif
 
 #undef MIN
@@ -325,6 +330,20 @@ static ggml_fp16_t table_exp_f16[1 << 16];
 // precomputed f32 table for f16 (256 KB)
 static float table_f32_f16[1 << 16];
 
+#if defined(__ARM_NEON)
+#define B1(c,s,n)  0x ## n ## c ,  0x ## n ## s
+#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
+#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)
+#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)
+#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)
+#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)
+#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)
+#define B8(c,s  ) B7(c,s,     c), B7(c,s,     s)
+
+// precomputed tables for expanding 8bits to 8 bytes (shl 4)
+static const uint64_t table_b2b_u[1 << 8] = { B8(00, 10) };
+#endif
+
 // On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
 // so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
 // This is also true for POWER9.
@@ -427,14 +446,69 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
 // quantization
 //
 
-#define QK 32
+#if __AVX__ || __AVX2__ || __AVX512F__
+// Unpack 16 4-bit fields into 16 bytes
+// The output vector contains 16 bytes, each one in [ 0 .. 15 ] interval
+static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
+{
+    // Load 8 bytes from memory
+    __m128i tmp = _mm_loadl_epi64( ( const __m128i* )rsi );
+
+    // Expand bytes into uint16_t values
+    __m128i bytes = _mm_cvtepu8_epi16( tmp );
+
+    // Unpack values into individual bytes
+    const __m128i lowMask = _mm_set1_epi8( 0xF );
+    __m128i high = _mm_andnot_si128( lowMask, bytes );
+    __m128i low = _mm_and_si128( lowMask, bytes );
+    high = _mm_slli_epi16( high, 4 );
+    bytes = _mm_or_si128( low, high );
+    return bytes;
+}
+
+// horizontally add 8 floats
+static inline float hsum_float_8(const __m256 x) {
+    __m128 res = _mm256_extractf128_ps(x, 1);
+    res = _mm_add_ps(res, _mm256_castps256_ps128(x));
+    res = _mm_add_ps(res, _mm_movehl_ps(res, res));
+    res = _mm_add_ss(res, _mm_movehdup_ps(res));
+    return _mm_cvtss_f32(res);
+}
+
+// horizontally add 8 int32_t
+static inline int hsum_i32_8(const __m256i a) {
+    const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
+    const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
+    const __m128i sum64 = _mm_add_epi32(hi64, sum128);
+    const __m128i hi32  = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
+    return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
+}
+
+// horizontally add 4 int32_t
+static inline int hsum_i32_4(const __m128i a) {
+    const __m128i hi64 = _mm_unpackhi_epi64(a, a);
+    const __m128i sum64 = _mm_add_epi32(hi64, a);
+    const __m128i hi32  = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
+    return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
+}
 
-// AVX routines provided by GH user Const-me
-// ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
 #if __AVX2__ || __AVX512F__
+// spread 32 bits to 32 bytes { 0x00, 0xFF }
+static inline __m256i bytes_from_bits_32(const uint8_t * x) {
+    uint32_t x32;
+    memcpy(&x32, x, sizeof(uint32_t));
+    const __m256i shuf_mask = _mm256_set_epi64x(
+        0x0303030303030303, 0x0202020202020202,
+        0x0101010101010101, 0x0000000000000000);
+    __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask);
+    const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);
+    bytes = _mm256_or_si256(bytes, bit_mask);
+    return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1));
+}
+
 // Unpack 32 4-bit fields into 32 bytes
 // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
-static inline __m256i bytesFromNibbles( const uint8_t* rsi )
+static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
 {
     // Load 16 bytes from memory
     __m128i tmp = _mm_loadu_si128( ( const __m128i* )rsi );
@@ -451,9 +525,38 @@ static inline __m256i bytesFromNibbles( const uint8_t* rsi )
     return bytes;
 }
 
+// add int16_t pairwise and return as float vector
+static inline __m256 sum_i16_pairs_float(const __m256i x) {
+    const __m256i ones = _mm256_set1_epi16(1);
+    const __m256i summed_pairs = _mm256_madd_epi16(ones, x);
+    return _mm256_cvtepi32_ps(summed_pairs);
+}
+
+// multiply int8_t, add results pairwise twice and return as float vector
+static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
+    // Get absolute values of x vectors
+    const __m256i ax = _mm256_sign_epi8(x, x);
+    // Sign the values of the y vectors
+    const __m256i sy = _mm256_sign_epi8(y, x);
+#if __AVXVNNI__
+    const __m256i zero = _mm256_setzero_si256();
+    const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
+    return _mm256_cvtepi32_ps(summed_pairs);
+#else
+    // Perform multiplication and create 16-bit values
+    const __m256i dot = _mm256_maddubs_epi16(ax, sy);
+    return sum_i16_pairs_float(dot);
+#endif
+}
+
 static inline __m128i packNibbles( __m256i bytes )
 {
     // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
+#if __AVX512F__
+    const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4);   // 0000_0000_abcd_0000
+    bytes = _mm256_or_si256(bytes, bytes_srli_4);               // 0000_abcd_abcd_efgh
+    return _mm256_cvtepi16_epi8(bytes);                         // abcd_efgh
+#else
     const __m256i lowByte = _mm256_set1_epi16( 0xFF );
     __m256i high = _mm256_andnot_si256( lowByte, bytes );
     __m256i low = _mm256_and_si256( lowByte, bytes );
@@ -464,25 +567,9 @@ static inline __m128i packNibbles( __m256i bytes )
     __m128i r0 = _mm256_castsi256_si128( bytes );
     __m128i r1 = _mm256_extracti128_si256( bytes, 1 );
     return _mm_packus_epi16( r0, r1 );
+#endif
 }
-#elif __AVX__
-static inline __m128i bytesFromNibbles( const uint8_t* rsi )
-{
-    // Load 8 bytes from memory
-    __m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );
-
-    // Expand bytes into uint16_t values
-    __m128i bytes = _mm_cvtepu8_epi16( tmp );
-
-    // Unpack values into individual bytes
-    const __m128i lowMask = _mm_set1_epi8( 0xF );
-    __m128i high = _mm_andnot_si128( lowMask, bytes );
-    __m128i low = _mm_and_si128( lowMask, bytes );
-    high = _mm_slli_epi16( high, 4 );
-    bytes = _mm_or_si128( low, high );
-    return bytes;
-}
-
+#else
 static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
 {
     // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
@@ -499,6 +586,7 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
     return _mm_packus_epi16( bytes1, bytes2);
 }
 #endif
+#endif // __AVX__ || __AVX2__ || __AVX512F__
 
 #if __ARM_NEON
 
@@ -516,6 +604,18 @@ inline static uint16_t vaddvq_u8(uint8x16_t v) {
         (uint16_t)vgetq_lane_u8(v, 14) + (uint16_t)vgetq_lane_u8(v, 15);
 }
 
+inline static int16_t vaddvq_s8(int8x16_t v) {
+    return
+        (int16_t)vgetq_lane_s8(v, 0)  + (int16_t)vgetq_lane_s8(v, 1)  +
+        (int16_t)vgetq_lane_s8(v, 2)  + (int16_t)vgetq_lane_s8(v, 3)  +
+        (int16_t)vgetq_lane_s8(v, 4)  + (int16_t)vgetq_lane_s8(v, 5)  +
+        (int16_t)vgetq_lane_s8(v, 6)  + (int16_t)vgetq_lane_s8(v, 7)  +
+        (int16_t)vgetq_lane_s8(v, 8)  + (int16_t)vgetq_lane_s8(v, 9)  +
+        (int16_t)vgetq_lane_s8(v, 10) + (int16_t)vgetq_lane_s8(v, 11) +
+        (int16_t)vgetq_lane_s8(v, 12) + (int16_t)vgetq_lane_s8(v, 13) +
+        (int16_t)vgetq_lane_s8(v, 14) + (int16_t)vgetq_lane_s8(v, 15);
+}
+
 inline static int32_t vaddvq_s16(int16x8_t v) {
     return
         (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
@@ -571,51 +671,92 @@ uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
 #endif
 #endif
 
-// method 5
-// blocks of QK elements
-// represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
+
+#define QK4_0 32
 typedef struct {
-    float   d; // delta
-    uint8_t qs[QK / 2]; // nibbles / quants
+    float   d;          // delta
+    uint8_t qs[QK4_0 / 2];  // nibbles / quants
 } block_q4_0;
-static_assert(sizeof(block_q4_0) == sizeof(float) + QK / 2, "wrong q4_0 block size/padding");
+static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding");
 
-// method 4
-// blocks of QK elements
-// represented with 2 floats (delta + min) and QK/2 8-bit ints (i.e QK 4-bit unsigned integer factors)
+#define QK4_1 32
 typedef struct {
-    float   d;
-    float   m;
-    uint8_t qs[QK / 2]; // nibbles / quants
+    float   d;          // delta
+    float   m;          // min
+    uint8_t qs[QK4_1 / 2];  // nibbles / quants
 } block_q4_1;
-static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK / 2, "wrong q4_1 block size/padding");
+static_assert(sizeof(block_q4_1) == 2 * sizeof(float) + QK4_1 / 2, "wrong q4_1 block size/padding");
+
+#define QK4_2 16
+typedef struct {
+    ggml_fp16_t d;         // delta
+    uint8_t qs[QK4_2 / 2]; // nibbles / quants
+} block_q4_2;
+static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding");
+
+#define QK5_0 32
+typedef struct {
+    ggml_fp16_t d;         // delta
+    uint8_t qh[4];         // 5-th bit of quants
+    uint8_t qs[QK5_0 / 2]; // nibbles / quants
+} block_q5_0;
+static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
+
+#define QK5_1 32
+typedef struct {
+    ggml_fp16_t d;         // delta
+    ggml_fp16_t m;         // min
+    uint8_t qh[4];         // 5-th bit of quants
+    uint8_t qs[QK5_1 / 2]; // nibbles / quants
+} block_q5_1;
+static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
+
+#define QK8_0 32
+typedef struct {
+    float   d;          // delta
+    int8_t  qs[QK8_0];  // quants
+} block_q8_0;
+static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
+
+#define QK8_1 32
+typedef struct {
+    float   d;          // delta
+    float   s0;         // d * sum(qs[i]) low
+    float   s1;         // d * sum(qs[i]) high
+    int8_t  qs[QK8_1];  // quants
+} block_q8_1;
+static_assert(sizeof(block_q8_1) == 3*sizeof(float) + QK8_1, "wrong q8_1 block size/padding");
 
 // reference implementation for deterministic creation of model files
 static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
-    assert(k % QK == 0);
-    const int nb = k / QK;
+    assert(k % QK4_0 == 0);
+    const int nb = k / QK4_0;
 
-    uint8_t pp[QK/2];
+    uint8_t pp[QK4_0/2];
 
     for (int i = 0; i < nb; i++) {
         float amax = 0.0f; // absolute max
+        float max = 0.0f;
 
-        for (int l = 0; l < QK; l++) {
-            const float v = x[i*QK + l];
-            amax = MAX(amax, fabsf(v));
+        for (int l = 0; l < QK4_0; l++) {
+            const float v = x[i*QK4_0 + l];
+            if (amax < fabsf(v)) {
+                amax = fabsf(v);
+                max = v;
+            }
         }
 
-        const float d = amax / ((1 << 3) - 1);
+        const float d = max / -8;
         const float id = d ? 1.0f/d : 0.0f;
 
         y[i].d = d;
 
-        for (int l = 0; l < QK; l += 2) {
-            const float v0 = x[i*QK + l + 0]*id;
-            const float v1 = x[i*QK + l + 1]*id;
+        for (int l = 0; l < QK4_0; l += 2) {
+            const float v0 = x[i*QK4_0 + l + 0]*id;
+            const float v1 = x[i*QK4_0 + l + 1]*id;
 
-            const uint8_t vi0 = (int8_t)roundf(v0) + 8;
-            const uint8_t vi1 = (int8_t)roundf(v1) + 8;
+            const uint8_t vi0 = MIN(15, (int8_t)roundf(v0) + 8);
+            const uint8_t vi1 = MIN(15, (int8_t)roundf(v1) + 8);
 
             assert(vi0 < 16);
             assert(vi1 < 16);
@@ -628,35 +769,49 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
 }
 
 static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int k) {
-    assert(k % QK == 0);
-    const int nb = k / QK;
+    assert(k % QK4_0 == 0);
+    const int nb = k / QK4_0;
 
     block_q4_0 * restrict y = vy;
 
 #if defined(__POWER9_VECTOR__)
     const vector float v85 = vec_splats(8.5f);
+    const vector signed int v15 = vec_splats(15);
     for (int i = 0; i < nb; i++) {
-        float amax = 0.0f; // absolute max
+        float max = 0.0f;
+        float min = 0.0f;
 
         vector float srcv [8];
-        vector float asrcv[8];
-        vector float amaxv[8];
+        vector float maxv[8];
+        vector float minv[8];
 
         for (int l = 0; l < 8; l++) srcv[l]  = *(vector float *)(x + i*32 + 4*l);
-        for (int l = 0; l < 8; l++) asrcv[l] = vec_abs(srcv[l]);
-
-        for (int l = 0; l < 4; l++) amaxv[2*l] = vec_max(asrcv[2*l], asrcv[2*l+1]);
-        //for (int l = 0; l < 2; l++) amaxv[4*l] = vec_max(amaxv[4*l], amaxv[4*l+2]);
-        amaxv[0] = vec_max(amaxv[0], amaxv[2]);
-        amaxv[4] = vec_max(amaxv[4], amaxv[6]);
-        //for (int l = 0; l < 1; l++) amaxv[8*l] = vec_max(amaxv[8*l], amaxv[8*l+4]);
-        amaxv[0] = vec_max(amaxv[0], amaxv[4]);
-
-        amax = MAX(
-                MAX(vec_extract(amaxv[0], 0), vec_extract(amaxv[0], 1)),
-                MAX(vec_extract(amaxv[0], 2), vec_extract(amaxv[0], 3)));
-
-        const float d = amax / ((1 << 3) - 1);
+        //for (int l = 0; l < 8; l++) asrcv[l] = vec_abs(srcv[l]);
+
+        for (int l = 0; l < 4; l++) maxv[2*l] = vec_max(asrcv[2*l], asrcv[2*l+1]);
+        //for (int l = 0; l < 2; l++) maxv[4*l] = vec_max(maxv[4*l], maxv[4*l+2]);
+        maxv[0] = vec_max(maxv[0], maxv[2]);
+        maxv[4] = vec_max(maxv[4], maxv[6]);
+        //for (int l = 0; l < 1; l++) maxv[8*l] = vec_max(maxv[8*l], maxv[8*l+4]);
+        maxv[0] = vec_max(maxv[0], maxv[4]);
+
+        for (int l = 0; l < 4; l++) minv[2*l] = vec_min(asrcv[2*l], asrcv[2*l+1]);
+        //for (int l = 0; l < 2; l++) minv[4*l] = vec_min(minv[4*l], minv[4*l+2]);
+        minv[0] = vec_min(minv[0], minv[2]);
+        minv[4] = vec_min(minv[4], minv[6]);
+        //for (int l = 0; l < 1; l++) minv[8*l] = vec_min(minv[8*l], minv[8*l+4]);
+        minv[0] = vec_min(minv[0], minv[4]);
+
+
+        max = MAX(
+                MAX(vec_extract(maxv[0], 0), vec_extract(maxv[0], 1)),
+                MAX(vec_extract(maxv[0], 2), vec_extract(maxv[0], 3)));
+        min = MIN(
+                MIN(vec_extract(minv[0], 0), vec_extract(minv[0], 1)),
+                MIN(vec_extract(minv[0], 2), vec_extract(minv[0], 3)));
+
+        const float magnitude = max >= fabsf(min) ? max : min;
+        const float d = magnitude / -8;
         const float id = d ? 1.0/d : 0.0;
 
         y[i].d = d;
@@ -666,27 +821,33 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
         for (int l = 0; l < 8; l++) {
             const vector float vf  = vec_madd(srcv[l], vid, v85);
             const vector signed int vi = vec_signed(vf);
+            const vector signed int vc = vec_min(vi, v15);
 
-            pb[2*l + 0] = vec_extract(vi, 0) | (vec_extract(vi, 1) << 4);
-            pb[2*l + 1] = vec_extract(vi, 2) | (vec_extract(vi, 3) << 4);
+            pb[2*l + 0] = vec_extract(vc, 0) | (vec_extract(vc, 1) << 4);
+            pb[2*l + 1] = vec_extract(vc, 2) | (vec_extract(vc, 3) << 4);
         }
     }
 #elif __ARM_NEON
     for (int i = 0; i < nb; i++) {
         float32x4_t srcv [8];
-        float32x4_t asrcv[8];
-        float32x4_t amaxv[8];
+        float32x4_t maxv[8];
+        float32x4_t minv[8];
 
         for (int l = 0; l < 8; l++) srcv[l]  = vld1q_f32(x + i*32 + 4*l);
-        for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
 
-        for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]);
-        for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
-        for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
+        for (int l = 0; l < 4; l++) maxv[2*l] = vmaxq_f32(srcv[2*l], srcv[2*l+1]);
+        for (int l = 0; l < 2; l++) maxv[4*l] = vmaxq_f32(maxv[4*l], maxv[4*l+2]);
+        for (int l = 0; l < 1; l++) maxv[8*l] = vmaxq_f32(maxv[8*l], maxv[8*l+4]);
 
-        const float amax = vmaxvq_f32(amaxv[0]);
+        for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l+1]);
+        for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l+2]);
+        for (int l = 0; l < 1; l++) minv[8*l] = vminq_f32(minv[8*l], minv[8*l+4]);
+
+        const float max = vmaxvq_f32(maxv[0]);
+        const float min = vminvq_f32(minv[0]);
 
-        const float d = amax / ((1 << 3) - 1);
+        const float magnitude = max >= fabsf(min) ? max : min;
+        const float d = magnitude / -8;
         const float id = d ? 1.0f/d : 0.0f;
 
         y[i].d = d;
@@ -695,9 +856,10 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
             const float32x4_t v  = vmulq_n_f32(srcv[l], id);
             const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(8.5f));
             const int32x4_t   vi = vcvtq_s32_f32(vf);
+            const int32x4_t   vc = vminq_s32(vi, vdupq_n_s32(15));
 
-            y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
-            y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
+            y[i].qs[2*l + 0] = vgetq_lane_s32(vc, 0) | (vgetq_lane_s32(vc, 1) << 4);
+            y[i].qs[2*l + 1] = vgetq_lane_s32(vc, 2) | (vgetq_lane_s32(vc, 3) << 4);
         }
     }
 #elif defined(__AVX2__)
@@ -709,22 +871,31 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
         __m256 v3 = _mm256_loadu_ps( x + 24 );
         x += 32;
 
-        // Compute max(abs(e)) for the block
-        const __m256 signBit = _mm256_set1_ps( -0.0f );
-        __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
-        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
-        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
-        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
+        // Compute max for the block
+        __m256 max  = _mm256_max_ps( v0, v1 );
+        __m256 maxTmp = _mm256_max_ps( v2, v3 );
+        max = _mm256_max_ps( max, maxTmp );
 
-        __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
+        __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( max, 1 ), _mm256_castps256_ps128( max ) );
         max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
         max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
         const float maxScalar = _mm_cvtss_f32( max4 );
 
+        // Compute min for the block
+        __m256 min  = _mm256_min_ps( v0, v1 );
+        __m256 minTmp = _mm256_min_ps( v2, v3 );
+        min = _mm256_min_ps( min, minTmp );
+
+        __m128 min4 = _mm_min_ps( _mm256_extractf128_ps( min, 1 ), _mm256_castps256_ps128( min ) );
+        min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) );
+        min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) );
+        const float minScalar = _mm_cvtss_f32( min4 );
+
         // Quantize these floats
-        const float d = maxScalar / 7.0f;
+        const float magnitude = maxScalar >= fabsf(minScalar) ? maxScalar : minScalar;
+        const float d = magnitude / -8.0f;
         y[i].d = d;
-        const float id = ( maxScalar != 0.0f ) ? 7.0f / maxScalar : 0.0f;
+        const float id = ( magnitude != 0.0f ) ? -8.0f / magnitude : 0.0f;
         const __m256 mul = _mm256_set1_ps( id );
 
         // Apply the multiplier
@@ -757,9 +928,11 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
         const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
         i0 = _mm256_permutevar8x32_epi32( i0, perm );
 
-        // Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ]
+        // Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ]
         const __m256i off = _mm256_set1_epi8( 8 );
         i0 = _mm256_add_epi8( i0, off );
+        const __m256i maxNibble = _mm256_set1_epi8( 15 );
+        i0 = _mm256_min_epi8( i0, maxNibble );
 
         // Compress the vector into 4 bit/value, and store
         __m128i res = packNibbles( i0 );
@@ -774,22 +947,31 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
         __m256 v3 = _mm256_loadu_ps( x + 24 );
         x += 32;
 
-        // Compute max(abs(e)) for the block
-        const __m256 signBit = _mm256_set1_ps( -0.0f );
-        __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
-        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
-        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
-        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
+        // Compute max for the block
+        __m256 max  = _mm256_max_ps( v0, v1 );
+        __m256 maxTmp = _mm256_max_ps( v2, v3 );
+        max = _mm256_max_ps( max, maxTmp );
 
-        __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
+        __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( max, 1 ), _mm256_castps256_ps128( max ) );
         max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
         max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
         const float maxScalar = _mm_cvtss_f32( max4 );
 
+        // Compute min for the block
+        __m256 min  = _mm256_min_ps( v0, v1 );
+        __m256 minTmp = _mm256_min_ps( v2, v3 );
+        min = _mm256_min_ps( min, minTmp );
+
+        __m128 min4 = _mm_min_ps( _mm256_extractf128_ps( min, 1 ), _mm256_castps256_ps128( min ) );
+        min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) );
+        min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) );
+        const float minScalar = _mm_cvtss_f32( min4 );
+
         // Quantize these floats
-        const float d = maxScalar / 7.0f;
+        const float magnitude = maxScalar >= fabsf(minScalar) ? maxScalar : minScalar;
+        const float d = magnitude / -8.0f;
         y[i].d = d;
-        const float id = ( maxScalar != 0.0f ) ? 7.0f / maxScalar : 0.0f;
+        const float id = ( magnitude != 0.0f ) ? -8.0f / magnitude : 0.0f;
         const __m256 mul = _mm256_set1_ps( id );
 
         // Apply the multiplier
@@ -830,10 +1012,13 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
         ni0 = _mm_packs_epi16( ni0, ni2 );
         ni4 = _mm_packs_epi16( ni4, ni6 );
 
-        // Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ]
-        const __m128i off = _mm_set1_epi8( 8);
+        // Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ]
+        const __m128i off = _mm_set1_epi8( 8 );
         ni0 = _mm_add_epi8( ni0, off );
         ni4 = _mm_add_epi8( ni4, off );
+        const __m128i maxNibble = _mm_set1_epi8( 15 );
+        ni0 = _mm_min_epi8( ni0, maxNibble );
+        ni4 = _mm_min_epi8( ni4, maxNibble );
 
         // Compress the vector into 4 bit/value, and store
         __m128i res = packNibbles( ni0, ni4 );
@@ -841,24 +1026,32 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
     }
 #elif defined(__wasm_simd128__)
     for (int i = 0; i < nb; i++) {
-        float amax = 0.0f; // absolute max
+        float max = 0.0f;
+        float min = 0.0f;
 
         v128_t srcv [8];
-        v128_t asrcv[8];
-        v128_t amaxv[8];
+        v128_t maxv[8];
+        v128_t minv[8];
 
         for (int l = 0; l < 8; l++) srcv[l]  = wasm_v128_load(x + i*32 + 4*l);
-        for (int l = 0; l < 8; l++) asrcv[l] = wasm_f32x4_abs(srcv[l]);
 
-        for (int l = 0; l < 4; l++) amaxv[2*l] = wasm_f32x4_max(asrcv[2*l], asrcv[2*l+1]);
-        for (int l = 0; l < 2; l++) amaxv[4*l] = wasm_f32x4_max(amaxv[4*l], amaxv[4*l+2]);
-        for (int l = 0; l < 1; l++) amaxv[8*l] = wasm_f32x4_max(amaxv[8*l], amaxv[8*l+4]);
+        for (int l = 0; l < 4; l++) maxv[2*l] = wasm_f32x4_max(srcv[2*l], srcv[2*l+1]);
+        for (int l = 0; l < 2; l++) maxv[4*l] = wasm_f32x4_max(maxv[4*l], maxv[4*l+2]);
+        for (int l = 0; l < 1; l++) maxv[8*l] = wasm_f32x4_max(maxv[8*l], maxv[8*l+4]);
 
-        amax = MAX(
-                MAX(wasm_f32x4_extract_lane(amaxv[0], 0), wasm_f32x4_extract_lane(amaxv[0], 1)),
-                MAX(wasm_f32x4_extract_lane(amaxv[0], 2), wasm_f32x4_extract_lane(amaxv[0], 3)));
+        for (int l = 0; l < 4; l++) minv[2*l] = wasm_f32x4_min(srcv[2*l], srcv[2*l+1]);
+        for (int l = 0; l < 2; l++) minv[4*l] = wasm_f32x4_min(minv[4*l], minv[4*l+2]);
+        for (int l = 0; l < 1; l++) minv[8*l] = wasm_f32x4_min(minv[8*l], minv[8*l+4]);
 
-        const float d = amax / ((1 << 3) - 1);
+        max = MAX(
+                MAX(wasm_f32x4_extract_lane(maxv[0], 0), wasm_f32x4_extract_lane(maxv[0], 1)),
+                MAX(wasm_f32x4_extract_lane(maxv[0], 2), wasm_f32x4_extract_lane(maxv[0], 3)));
+        min = MIN(
+                MIN(wasm_f32x4_extract_lane(minv[0], 0), wasm_f32x4_extract_lane(minv[0], 1)),
+                MIN(wasm_f32x4_extract_lane(minv[0], 2), wasm_f32x4_extract_lane(minv[0], 3)));
+
+        const float magnitude = max >= fabsf(min) ? max : min;
+        const float d = magnitude / -8;
         const float id = d ? 1.0/d : 0.0;
 
         y[i].d = d;
@@ -867,9 +1060,10 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
             const v128_t v  = wasm_f32x4_mul(srcv[l], wasm_f32x4_splat(id));
             const v128_t vf = wasm_f32x4_add(v, wasm_f32x4_splat(8.5f));
             const v128_t vi = wasm_i32x4_trunc_sat_f32x4(vf);
+            const v128_t vc = wasm_i32x4_min_u(vi, wasm_i32x4_splat(15));
 
-            y[i].qs[2*l + 0] = wasm_i32x4_extract_lane(vi, 0) | (wasm_i32x4_extract_lane(vi, 1) << 4);
-            y[i].qs[2*l + 1] = wasm_i32x4_extract_lane(vi, 2) | (wasm_i32x4_extract_lane(vi, 3) << 4);
+            y[i].qs[2*l + 0] = wasm_i32x4_extract_lane(vc, 0) | (wasm_i32x4_extract_lane(vc, 1) << 4);
+            y[i].qs[2*l + 1] = wasm_i32x4_extract_lane(vc, 2) | (wasm_i32x4_extract_lane(vc, 3) << 4);
         }
     }
 #else
@@ -879,19 +1073,19 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
 }
 
 static void quantize_row_q4_1_reference(const float * restrict x, void * restrict vy, int k) {
-    assert(k % QK == 0);
-    const int nb = k / QK;
+    assert(k % QK4_1 == 0);
+    const int nb = k / QK4_1;
 
     block_q4_1 * restrict y = vy;
 
-    uint8_t pp[QK/2];
+    uint8_t pp[QK4_1/2];
 
     for (int i = 0; i < nb; i++) {
         float min = FLT_MAX;
         float max = -FLT_MAX;
 
-        for (int l = 0; l < QK; l++) {
-            const float v = x[i*QK + l];
+        for (int l = 0; l < QK4_1; l++) {
+            const float v = x[i*QK4_1 + l];
             if (v < min) min = v;
             if (v > max) max = v;
         }
@@ -902,9 +1096,9 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric
         y[i].d = d;
         y[i].m = min;
 
-        for (int l = 0; l < QK; l += 2) {
-            const float v0 = (x[i*QK + l + 0] - min)*id;
-            const float v1 = (x[i*QK + l + 1] - min)*id;
+        for (int l = 0; l < QK4_1; l += 2) {
+            const float v0 = (x[i*QK4_1 + l + 0] - min)*id;
+            const float v1 = (x[i*QK4_1 + l + 1] - min)*id;
 
             const uint8_t vi0 = roundf(v0);
             const uint8_t vi1 = roundf(v1);
@@ -920,9 +1114,9 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric
 }
 
 static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int k) {
-    assert(k % QK == 0);
+    assert(k % QK4_1 == 0);
 
-    const int nb = k / QK;
+    const int nb = k / QK4_1;
 
     block_q4_1 * restrict y = vy;
 
@@ -1006,7 +1200,7 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
         float32x4_t minv[8];
         float32x4_t maxv[8];
 
-        for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*QK + 4*l);
+        for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*QK4_1 + 4*l);
 
         for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l + 1]);
         for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l + 2]);
@@ -1042,178 +1236,555 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
 #endif
 }
 
-static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) {
-    assert(k % QK == 0);
-    const int nb = k / QK;
+// reference implementation for deterministic creation of model files
+static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * restrict y, int k) {
+    assert(k % QK4_2 == 0);
 
-    const block_q4_0 * restrict x = vx;
+    const int nb = k / QK4_2;
 
-#if defined(__AVX2__)
     for (int i = 0; i < nb; i++) {
-        // scale factor
-        const __m256 d_v = _mm256_broadcast_ss(&x[i].d);
+        float amax = 0.0f; // absolute max
+        float max = 0.0f;
 
-        const uint8_t * restrict pp = x[i].qs;
+        for (int l = 0; l < QK4_2; l++) {
+            const float v = x[i*QK4_2 + l];
+            if (amax < fabsf(v)) {
+                amax = fabsf(v);
+                max = v;
+            }
+        }
 
-        for (int l = 0; l < QK; l += 32) {
-            // Load 32x4-bit integers into 32x8-bit integers
-            __m256i vx8 = bytesFromNibbles(pp+l/2);
+        const float d = max / -8;
 
-            // Subtract 8 from the integers
-            vx8 = _mm256_sub_epi8(vx8, _mm256_set1_epi8(8));
+        const float id = d ? 1.0f/d : 0.0f;
 
-            // Convert to 16-bit int
-            const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0));
-            const __m256i vx16_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 1));
+        y[i].d = GGML_FP32_TO_FP16(d);
 
-            // Convert to 32-bit int -> float 32
-            const __m256 vf[4] = {
-                _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 0))),
-                _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 1))),
-                _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 0))),
-                _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 1)))
-            };
+        for (int l = 0; l < QK4_2; l += 2) {
+            const float v0 = x[i*QK4_2 + l + 0]*id;
+            const float v1 = x[i*QK4_2 + l + 1]*id;
 
-            // Scale and store
-            for (int j = 0; j < 4; j++) {
-                const __m256 result = _mm256_mul_ps(vf[j], d_v);
-                _mm256_storeu_ps(y + i * QK + l + j*8, result);
-            }
+            const uint8_t vi0 = MIN(15, (uint8_t)(v0 + 8.5f));
+            const uint8_t vi1 = MIN(15, (uint8_t)(v1 + 8.5f));
+
+            assert(vi0 < 16);
+            assert(vi1 < 16);
+
+            y[i].qs[l/2] = vi0 | (vi1 << 4);
         }
     }
-#elif defined(__ARM_NEON)
-    for (int i = 0; i < nb; i++) {
-        const float32x4_t vd = vdupq_n_f32(x[i].d);
+}
 
-        const uint8_t * restrict pp = x[i].qs;
+static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int k) {
+    assert(k % QK4_2 == 0);
 
-        for (int l = 0; l < QK; l += 16) {
-            // Load 16x4-bit integers into 8x8-bit integers
-            const uint8x8_t v8 = vld1_u8(pp + l/2);
+    block_q4_2 * restrict y = vy;
 
-            // Expand 4-bit qs to 8-bit bytes
-            const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f));
-            const uint8x8_t v1 = vshr_n_u8(v8, 4);
+    quantize_row_q4_2_reference(x, y, k);
+}
 
-            // Convert to signed 8-bit integers
-            const int8x8_t vs_0 = vreinterpret_s8_u8(v0);
-            const int8x8_t vs_1 = vreinterpret_s8_u8(v1);
+static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) {
+    assert(k % QK5_0 == 0);
+    const int nb = k / QK5_0;
 
-            // Subtract 8 from each byte
-            const int8x8_t vb_0 = vsub_s8(vs_0, vdup_n_s8(8));
-            const int8x8_t vb_1 = vsub_s8(vs_1, vdup_n_s8(8));
+    for (int i = 0; i < nb; i++) {
+        float amax = 0.0f; // absolute max
+        float max = 0.0f;
 
-            // Interleave and combine
-            const int8x8_t vx_0 = vzip1_s8(vb_0, vb_1);
-            const int8x8_t vx_1 = vzip2_s8(vb_0, vb_1);
+        for (int l = 0; l < QK5_0; l++) {
+            const float v = x[i*QK5_0 + l];
+            if (amax < fabsf(v)) {
+                amax = fabsf(v);
+                max = v;
+            }
+        }
 
-            const int8x16_t vq = vcombine_s8(vx_0, vx_1);
+        const float d = max / -16;
+        const float id = d ? 1.0f/d : 0.0f;
 
-            // convert to 2x int16x8_t
-            const int16x8_t vi_0 = vmovl_s8(vget_low_s8 (vq));
-            const int16x8_t vi_1 = vmovl_s8(vget_high_s8(vq));
+        y[i].d = GGML_FP32_TO_FP16(d);
 
-            // convert to 4x float32x4_t
-            const float32x4_t vf_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_0)));
-            const float32x4_t vf_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_0)));
-            const float32x4_t vf_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_1)));
-            const float32x4_t vf_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_1)));
+        uint32_t qh = 0;
 
-            // Multiply by d
-            const float32x4_t r0 = vmulq_f32(vf_0, vd);
-            const float32x4_t r1 = vmulq_f32(vf_1, vd);
-            const float32x4_t r2 = vmulq_f32(vf_2, vd);
-            const float32x4_t r3 = vmulq_f32(vf_3, vd);
+        for (int l = 0; l < QK5_0; l += 2) {
+            const float v0 = x[i*QK5_0 + l + 0]*id;
+            const float v1 = x[i*QK5_0 + l + 1]*id;
 
-            // Store
-            vst1q_f32(y + i*QK + l +  0, r0);
-            vst1q_f32(y + i*QK + l +  4, r1);
-            vst1q_f32(y + i*QK + l +  8, r2);
-            vst1q_f32(y + i*QK + l + 12, r3);
+            const uint32_t vi0 = MIN(31, (int) (v0 + 16.5f));
+            const uint32_t vi1 = MIN(31, (int) (v1 + 16.5f));
+
+            y[i].qs[l/2] = (vi0 & 0x0F) | ((vi1 & 0x0F) << 4);
+
+            // get the 5-th bit and store it in qh at the right position
+            qh |= ((vi0 & 0x10) >> 4) << (l + 0);
+            qh |= ((vi1 & 0x10) >> 4) << (l + 1);
         }
+
+        memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
     }
-#else
-    // scalar
+}
+
+static void quantize_row_q5_0(const float * restrict x, void * restrict vy, int k) {
+    assert(k % QK5_0 == 0);
+
+    block_q5_0 * restrict y = vy;
+
+    quantize_row_q5_0_reference(x, y, k);
+}
+
+static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) {
+    assert(k % QK5_1 == 0);
+    const int nb = k / QK5_1;
+
     for (int i = 0; i < nb; i++) {
-        const float d = x[i].d;
+        float min = FLT_MAX;
+        float max = -FLT_MAX;
 
-        const uint8_t * restrict pp = x[i].qs;
+        for (int l = 0; l < QK5_1; l++) {
+            const float v = x[i*QK5_1 + l];
+            if (v < min) min = v;
+            if (v > max) max = v;
+        }
 
-        for (int l = 0; l < QK; l += 2) {
-            const uint8_t vi = pp[l/2];
+        const float d = (max - min) / ((1 << 5) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
 
-            const int8_t vi0 = vi & 0xf;
-            const int8_t vi1 = vi >> 4;
+        y[i].d = GGML_FP32_TO_FP16(d);
+        y[i].m = GGML_FP32_TO_FP16(min);
 
-            const float v0 = (vi0 - 8)*d;
-            const float v1 = (vi1 - 8)*d;
+        uint32_t qh = 0;
 
-            //printf("d = %f, vi = %d, vi0 = %d, vi1 = %d, v0 = %f, v1 = %f\n", d, vi, vi0, vi1, v0, v1);
+        for (int l = 0; l < QK5_1; l += 2) {
+            const float v0 = (x[i*QK5_1 + l + 0] - min)*id;
+            const float v1 = (x[i*QK5_1 + l + 1] - min)*id;
+
+            const uint32_t vi0 = (int) (v0 + 0.5f);
+            const uint32_t vi1 = (int) (v1 + 0.5f);
 
-            y[i*QK + l + 0] = v0;
-            y[i*QK + l + 1] = v1;
+            y[i].qs[l/2] = (vi0 & 0x0F) | ((vi1 & 0x0F) << 4);
 
-            assert(!isnan(y[i*QK + l + 0]));
-            assert(!isnan(y[i*QK + l + 1]));
+            // get the 5-th bit and store it in qh at the right position
+            qh |= ((vi0 & 0x10) >> 4) << (l + 0);
+            qh |= ((vi1 & 0x10) >> 4) << (l + 1);
         }
+
+        memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
     }
-#endif
 }
 
-static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, int k) {
-    assert(k % QK == 0);
-    const int nb = k / QK;
+static void quantize_row_q5_1(const float * restrict x, void * restrict vy, int k) {
+    assert(k % QK5_1 == 0);
 
-    const block_q4_1 * restrict x = vx;
+    block_q5_1 * restrict y = vy;
+
+    quantize_row_q5_1_reference(x, y, k);
+}
+
+// reference implementation for deterministic creation of model files
+static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) {
+    assert(k % QK8_0 == 0);
+    const int nb = k / QK8_0;
 
-#if defined(__AVX2__)
     for (int i = 0; i < nb; i++) {
-        const __m256 d_v = _mm256_broadcast_ss(&x[i].d);
-        const __m256 d_m = _mm256_broadcast_ss(&x[i].m);
+        float amax = 0.0f; // absolute max
 
-        const uint8_t * restrict pp = x[i].qs;
+        for (int l = 0; l < QK8_0; l++) {
+            const float v = x[i*QK8_0 + l];
+            amax = MAX(amax, fabsf(v));
+        }
 
-        for (int l = 0; l < QK; l += 32) {
-            // Load 32x4-bit integers into 32x8-bit integers
-            __m256i vx8 = bytesFromNibbles(pp+l/2);
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
 
-            // Convert to 16-bit int
-            const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0));
-            const __m256i vx16_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 1));
+        y[i].d = d;
 
-            // Convert to 32-bit int -> float 32
-            const __m256 vf[4] = {
-                _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 0))),
-                _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 1))),
-                _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 0))),
-                _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 1)))
-            };
+        for (int l = 0; l < QK8_0; ++l) {
+            const float v0 = x[i*QK8_0 + l]*id;
 
-            // Scale, add m and store
-            for (int j = 0; j < 4; j++) {
-                const __m256 result = _mm256_add_ps(_mm256_mul_ps(vf[j], d_v), d_m);
-                _mm256_storeu_ps(y + i * QK + l + j*8, result);
-            }
+            y[i].qs[l] = roundf(v0);
         }
     }
-#elif defined(__ARM_NEON)
-    for (int i = 0; i < nb; i++) {
-        const float32x4_t vd = vdupq_n_f32(x[i].d);
-        const float32x4_t vm = vdupq_n_f32(x[i].m);
+}
 
-        const uint8_t * restrict pp = x[i].qs;
+static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
+    assert(k % QK8_0 == 0);
 
-        for (int l = 0; l < QK; l += 16) {
-            // Load 16x4-bit integers into 8x8-bit integers
-            const uint8x8_t v8 = vld1_u8(pp + l/2);
+    block_q8_0 * restrict y = vy;
 
-            // Expand 4-bit qs to 8-bit bytes
-            const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f));
-            const uint8x8_t v1 = vshr_n_u8(v8, 4);
+    quantize_row_q8_0_reference(x, y, k);
+}
 
-            // Interleave and combine
-            const uint8x8_t vx_0 = vzip1_u8(v0, v1);
-            const uint8x8_t vx_1 = vzip2_u8(v0, v1);
+// reference implementation for deterministic creation of model files
+static void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k) {
+    assert(k % QK8_1 == 0);
+    const int nb = k / QK8_1;
+
+    for (int i = 0; i < nb; i++) {
+        float amax = 0.0f; // absolute max
+
+        for (int l = 0; l < QK8_1; l++) {
+            const float v = x[i*QK8_1 + l];
+            amax = MAX(amax, fabsf(v));
+        }
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = d;
+
+        int sum0 = 0;
+        int sum1 = 0;
+
+        for (int l = 0; l < QK8_1/2; ++l) {
+            const float v0 = x[i*QK8_1           + l]*id;
+            const float v1 = x[i*QK8_1 + QK8_1/2 + l]*id;
+
+            y[i].qs[          l] = roundf(v0);
+            y[i].qs[QK8_1/2 + l] = roundf(v1);
+
+            sum0 += y[i].qs[          l];
+            sum1 += y[i].qs[QK8_1/2 + l];
+        }
+
+        y[i].s0 = d * sum0;
+        y[i].s1 = d * sum1;
+    }
+}
+
+static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) {
+    assert(k % QK8_1 == 0);
+    const int nb = k / QK8_1;
+
+    block_q8_1 * restrict y = vy;
+
+#if defined(__ARM_NEON)
+    for (int i = 0; i < nb; i++) {
+        float32x4_t srcv [8];
+        float32x4_t asrcv[8];
+        float32x4_t amaxv[8];
+
+        for (int l = 0; l < 8; l++) srcv[l]  = vld1q_f32(x + i*32 + 4*l);
+        for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
+
+        for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]);
+        for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
+        for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
+
+        const float amax = vmaxvq_f32(amaxv[0]);
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = d;
+
+        int32x4_t accv0 = vdupq_n_s32(0);
+        int32x4_t accv1 = vdupq_n_s32(0);
+
+        // low half
+        for (int l = 0; l < 4; l++) {
+            const float32x4_t v  = vmulq_n_f32(srcv[l], id);
+            const int32x4_t   vi = vcvtnq_s32_f32(v);
+
+            y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0);
+            y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
+            y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
+            y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
+
+            accv0 = vaddq_s32(accv0, vi);
+        }
+
+        // high half
+        for (int l = 4; l < 8; l++) {
+            const float32x4_t v  = vmulq_n_f32(srcv[l], id);
+            const int32x4_t   vi = vcvtnq_s32_f32(v);
+
+            y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0);
+            y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
+            y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
+            y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
+
+            accv1 = vaddq_s32(accv1, vi);
+        }
+
+        const int32_t sum0 = vaddvq_s32(accv0);
+        const int32_t sum1 = vaddvq_s32(accv1);
+
+        y[i].s0 = d * sum0;
+        y[i].s1 = d * sum1;
+    }
+#elif defined(__AVX2__) || defined(__AVX__)
+    for (int i = 0; i < nb; i++) {
+        // Load elements into 4 AVX vectors
+        __m256 v0 = _mm256_loadu_ps( x );
+        __m256 v1 = _mm256_loadu_ps( x + 8 );
+        __m256 v2 = _mm256_loadu_ps( x + 16 );
+        __m256 v3 = _mm256_loadu_ps( x + 24 );
+        x += 32;
+
+        // Compute max(abs(e)) for the block
+        const __m256 signBit = _mm256_set1_ps( -0.0f );
+        __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
+
+        __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
+        max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
+        max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
+        const float maxScalar = _mm_cvtss_f32( max4 );
+
+        // Quantize these floats
+        const float d = maxScalar / 127.f;
+        y[i].d = d;
+        const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
+        const __m256 mul = _mm256_set1_ps( id );
+
+        // Apply the multiplier
+        v0 = _mm256_mul_ps( v0, mul );
+        v1 = _mm256_mul_ps( v1, mul );
+        v2 = _mm256_mul_ps( v2, mul );
+        v3 = _mm256_mul_ps( v3, mul );
+
+        // Round to nearest integer
+        v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
+        v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
+        v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
+        v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
+
+        // Convert floats to integers
+        __m256i i0 = _mm256_cvtps_epi32( v0 );
+        __m256i i1 = _mm256_cvtps_epi32( v1 );
+        __m256i i2 = _mm256_cvtps_epi32( v2 );
+        __m256i i3 = _mm256_cvtps_epi32( v3 );
+
+#if defined(__AVX2__)
+        // Compute the sum of the quants and set y[i].s
+        //y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
+        y[i].s0 = d * hsum_i32_8(_mm256_add_epi32(i0, i1));
+        y[i].s1 = d * hsum_i32_8(_mm256_add_epi32(i2, i3));
+
+        // Convert int32 to int16
+        i0 = _mm256_packs_epi32( i0, i1 );     // 0, 1, 2, 3,  8, 9, 10, 11,  4, 5, 6, 7, 12, 13, 14, 15
+        i2 = _mm256_packs_epi32( i2, i3 );     // 16, 17, 18, 19,  24, 25, 26, 27,  20, 21, 22, 23, 28, 29, 30, 31
+                                            // Convert int16 to int8
+        i0 = _mm256_packs_epi16( i0, i2 );     // 0, 1, 2, 3,  8, 9, 10, 11,  16, 17, 18, 19,  24, 25, 26, 27,  4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
+
+        // We got our precious signed bytes, but the order is now wrong
+        // These AVX2 pack instructions process 16-byte pieces independently
+        // The following instruction is fixing the order
+        const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
+        i0 = _mm256_permutevar8x32_epi32( i0, perm );
+
+        _mm256_storeu_si256((__m256i *)y[i].qs, i0);
+#else
+        // Since we don't have in AVX some necessary functions,
+        // we split the registers in half and call AVX2 analogs from SSE
+        __m128i ni0 = _mm256_castsi256_si128( i0 );
+        __m128i ni1 = _mm256_extractf128_si256( i0, 1);
+        __m128i ni2 = _mm256_castsi256_si128( i1 );
+        __m128i ni3 = _mm256_extractf128_si256( i1, 1);
+        __m128i ni4 = _mm256_castsi256_si128( i2 );
+        __m128i ni5 = _mm256_extractf128_si256( i2, 1);
+        __m128i ni6 = _mm256_castsi256_si128( i3 );
+        __m128i ni7 = _mm256_extractf128_si256( i3, 1);
+
+        // Compute the sum of the quants and set y[i].s
+        const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3));
+        const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7));
+        y[i].s0 = d * hsum_i32_4(s0);
+        y[i].s1 = d * hsum_i32_4(s1);
+
+        // Convert int32 to int16
+        ni0 = _mm_packs_epi32( ni0, ni1 );
+        ni2 = _mm_packs_epi32( ni2, ni3 );
+        ni4 = _mm_packs_epi32( ni4, ni5 );
+        ni6 = _mm_packs_epi32( ni6, ni7 );
+        // Convert int16 to int8
+        ni0 = _mm_packs_epi16( ni0, ni2 );
+        ni4 = _mm_packs_epi16( ni4, ni6 );
+
+        _mm_storeu_si128((__m128i *)(y[i].qs +  0), ni0);
+        _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
+#endif
+    }
+#else
+    // scalar
+    quantize_row_q8_1_reference(x, y, k);
+#endif
+}
+
+static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) {
+    assert(k % QK4_0 == 0);
+    const int nb = k / QK4_0;
+
+    const block_q4_0 * restrict x = vx;
+
+#if defined(__AVX2__)
+    for (int i = 0; i < nb; i++) {
+        // scale factor
+        const __m256 d_v = _mm256_broadcast_ss(&x[i].d);
+
+        const uint8_t * restrict pp = x[i].qs;
+
+        for (int l = 0; l < QK4_0; l += 32) {
+            // Load 32x4-bit integers into 32x8-bit integers
+            __m256i vx8 = bytes_from_nibbles_32(pp+l/2);
+
+            // Subtract 8 from the integers
+            vx8 = _mm256_sub_epi8(vx8, _mm256_set1_epi8(8));
+
+            // Convert to 16-bit int
+            const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0));
+            const __m256i vx16_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 1));
+
+            // Convert to 32-bit int -> float 32
+            const __m256 vf[4] = {
+                _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 0))),
+                _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 1))),
+                _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 0))),
+                _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 1)))
+            };
+
+            // Scale and store
+            for (int j = 0; j < 4; j++) {
+                const __m256 result = _mm256_mul_ps(vf[j], d_v);
+                _mm256_storeu_ps(y + i * QK4_0 + l + j*8, result);
+            }
+        }
+    }
+#elif defined(__ARM_NEON)
+    for (int i = 0; i < nb; i++) {
+        const float32x4_t vd = vdupq_n_f32(x[i].d);
+
+        const uint8_t * restrict pp = x[i].qs;
+
+        for (int l = 0; l < QK4_0; l += 16) {
+            // Load 16x4-bit integers into 8x8-bit integers
+            const uint8x8_t v8 = vld1_u8(pp + l/2);
+
+            // Expand 4-bit qs to 8-bit bytes
+            const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0F));
+            const uint8x8_t v1 = vshr_n_u8(v8, 4);
+
+            // Convert to signed 8-bit integers
+            const int8x8_t vs_0 = vreinterpret_s8_u8(v0);
+            const int8x8_t vs_1 = vreinterpret_s8_u8(v1);
+
+            // Subtract 8 from each byte
+            const int8x8_t vb_0 = vsub_s8(vs_0, vdup_n_s8(8));
+            const int8x8_t vb_1 = vsub_s8(vs_1, vdup_n_s8(8));
+
+            // Interleave and combine
+            const int8x8_t vx_0 = vzip1_s8(vb_0, vb_1);
+            const int8x8_t vx_1 = vzip2_s8(vb_0, vb_1);
+
+            const int8x16_t vq = vcombine_s8(vx_0, vx_1);
+
+            // convert to 2x int16x8_t
+            const int16x8_t vi_0 = vmovl_s8(vget_low_s8 (vq));
+            const int16x8_t vi_1 = vmovl_s8(vget_high_s8(vq));
+
+            // convert to 4x float32x4_t
+            const float32x4_t vf_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_0)));
+            const float32x4_t vf_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_0)));
+            const float32x4_t vf_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_1)));
+            const float32x4_t vf_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_1)));
+
+            // Multiply by d
+            const float32x4_t r0 = vmulq_f32(vf_0, vd);
+            const float32x4_t r1 = vmulq_f32(vf_1, vd);
+            const float32x4_t r2 = vmulq_f32(vf_2, vd);
+            const float32x4_t r3 = vmulq_f32(vf_3, vd);
+
+            // Store
+            vst1q_f32(y + i*QK4_0 + l +  0, r0);
+            vst1q_f32(y + i*QK4_0 + l +  4, r1);
+            vst1q_f32(y + i*QK4_0 + l +  8, r2);
+            vst1q_f32(y + i*QK4_0 + l + 12, r3);
+        }
+    }
+#else
+    // scalar
+    for (int i = 0; i < nb; i++) {
+        const float d = x[i].d;
+
+        const uint8_t * restrict pp = x[i].qs;
+
+        for (int l = 0; l < QK4_0; l += 2) {
+            const uint8_t vi = pp[l/2];
+
+            const int8_t vi0 = vi & 0x0F;
+            const int8_t vi1 = vi >> 4;
+
+            const float v0 = (vi0 - 8)*d;
+            const float v1 = (vi1 - 8)*d;
+
+            //printf("d = %f, vi = %d, vi0 = %d, vi1 = %d, v0 = %f, v1 = %f\n", d, vi, vi0, vi1, v0, v1);
+
+            y[i*QK4_0 + l + 0] = v0;
+            y[i*QK4_0 + l + 1] = v1;
+
+            assert(!isnan(y[i*QK4_0 + l + 0]));
+            assert(!isnan(y[i*QK4_0 + l + 1]));
+        }
+    }
+#endif
+}
+
+static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, int k) {
+    assert(k % QK4_1 == 0);
+    const int nb = k / QK4_1;
+
+    const block_q4_1 * restrict x = vx;
+
+#if defined(__AVX2__)
+    for (int i = 0; i < nb; i++) {
+        const __m256 d_v = _mm256_broadcast_ss(&x[i].d);
+        const __m256 d_m = _mm256_broadcast_ss(&x[i].m);
+
+        const uint8_t * restrict pp = x[i].qs;
+
+        for (int l = 0; l < QK4_1; l += 32) {
+            // Load 32x4-bit integers into 32x8-bit integers
+            __m256i vx8 = bytes_from_nibbles_32(pp+l/2);
+
+            // Convert to 16-bit int
+            const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0));
+            const __m256i vx16_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 1));
+
+            // Convert to 32-bit int -> float 32
+            const __m256 vf[4] = {
+                _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 0))),
+                _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 1))),
+                _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 0))),
+                _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 1)))
+            };
+
+            // Scale, add m and store
+            for (int j = 0; j < 4; j++) {
+                const __m256 result = _mm256_add_ps(_mm256_mul_ps(vf[j], d_v), d_m);
+                _mm256_storeu_ps(y + i * QK4_1 + l + j*8, result);
+            }
+        }
+    }
+#elif defined(__ARM_NEON)
+    for (int i = 0; i < nb; i++) {
+        const float32x4_t vd = vdupq_n_f32(x[i].d);
+        const float32x4_t vm = vdupq_n_f32(x[i].m);
+
+        const uint8_t * restrict pp = x[i].qs;
+
+        for (int l = 0; l < QK4_1; l += 16) {
+            // Load 16x4-bit integers into 8x8-bit integers
+            const uint8x8_t v8 = vld1_u8(pp + l/2);
+
+            // Expand 4-bit qs to 8-bit bytes
+            const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0F));
+            const uint8x8_t v1 = vshr_n_u8(v8, 4);
+
+            // Interleave and combine
+            const uint8x8_t vx_0 = vzip1_u8(v0, v1);
+            const uint8x8_t vx_1 = vzip2_u8(v0, v1);
 
             const uint8x16_t vq = vcombine_u8(vx_0, vx_1);
 
@@ -1233,39 +1804,231 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
             const float32x4_t r2 = vmlaq_f32(vm, vf_2, vd);
             const float32x4_t r3 = vmlaq_f32(vm, vf_3, vd);
 
-            // Store
-            vst1q_f32(y + i*QK + l +  0, r0);
-            vst1q_f32(y + i*QK + l +  4, r1);
-            vst1q_f32(y + i*QK + l +  8, r2);
-            vst1q_f32(y + i*QK + l + 12, r3);
+            // Store
+            vst1q_f32(y + i*QK4_1 + l +  0, r0);
+            vst1q_f32(y + i*QK4_1 + l +  4, r1);
+            vst1q_f32(y + i*QK4_1 + l +  8, r2);
+            vst1q_f32(y + i*QK4_1 + l + 12, r3);
+        }
+    }
+#else
+    for (int i = 0; i < nb; i++) {
+        const float d = x[i].d;
+        const float m = x[i].m;
+
+        const uint8_t * restrict pp = x[i].qs;
+
+        for (int l = 0; l < QK4_1; l += 2) {
+            const uint8_t vi = pp[l/2];
+
+            const int8_t vi0 = vi & 0x0F;
+            const int8_t vi1 = vi >> 4;
+
+            const float v0 = vi0*d + m;
+            const float v1 = vi1*d + m;
+
+            y[i*QK4_1 + l + 0] = v0;
+            y[i*QK4_1 + l + 1] = v1;
+
+            assert(!isnan(y[i*QK4_1 + l + 0]));
+            assert(!isnan(y[i*QK4_1 + l + 1]));
+        }
+    }
+#endif
+}
+
+static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, int k) {
+    assert(k % QK4_2 == 0);
+    const int nb = k / QK4_2;
+
+    const block_q4_2 * restrict x = vx;
+
+    for (int i = 0; i < nb; i++) {
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict pp = x[i].qs;
+
+        for (int l = 0; l < QK4_2; l += 2) {
+            const uint8_t vi = pp[l/2];
+
+            const int8_t vi0 = vi & 0x0F;
+            const int8_t vi1 = vi >> 4;
+
+            const float v0 = (vi0 - 8)*d;
+            const float v1 = (vi1 - 8)*d;
+
+            y[i*QK4_2 + l + 0] = v0;
+            y[i*QK4_2 + l + 1] = v1;
+
+            assert(!isnan(y[i*QK4_2 + l + 0]));
+            assert(!isnan(y[i*QK4_2 + l + 1]));
+        }
+    }
+}
+
+static void dequantize_row_q5_0(const void * restrict vx, float * restrict y, int k) {
+    assert(k % QK5_0 == 0);
+    const int nb = k / QK5_0;
+
+    const block_q5_0 * restrict x = vx;
+
+    for (int i = 0; i < nb; i++) {
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict pp = x[i].qs;
+
+        uint32_t qh;
+        memcpy(&qh, x[i].qh, sizeof(qh));
+
+        for (int l = 0; l < QK5_0; l += 2) {
+            const uint8_t vi = pp[l/2];
+
+            // extract the 5-th bit from qh
+            const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
+            const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
+
+            const int8_t vi0 = (vi & 0x0F) | vh0;
+            const int8_t vi1 = (vi >>   4) | vh1;
+
+            const float v0 = (vi0 - 16)*d;
+            const float v1 = (vi1 - 16)*d;
+
+            y[i*QK5_0 + l + 0] = v0;
+            y[i*QK5_0 + l + 1] = v1;
+
+            assert(!isnan(y[i*QK5_0 + l + 0]));
+            assert(!isnan(y[i*QK5_0 + l + 1]));
         }
     }
-#else
+}
+
+static void dequantize_row_q5_1(const void * restrict vx, float * restrict y, int k) {
+    assert(k % QK5_1 == 0);
+    const int nb = k / QK5_1;
+
+    const block_q5_1 * restrict x = vx;
+
     for (int i = 0; i < nb; i++) {
-        const float d = x[i].d;
-        const float m = x[i].m;
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+        const float m = GGML_FP16_TO_FP32(x[i].m);
 
         const uint8_t * restrict pp = x[i].qs;
 
-        for (int l = 0; l < QK; l += 2) {
+        uint32_t qh;
+        memcpy(&qh, x[i].qh, sizeof(qh));
+
+        for (int l = 0; l < QK5_1; l += 2) {
             const uint8_t vi = pp[l/2];
 
-            const int8_t vi0 = vi & 0xf;
-            const int8_t vi1 = vi >> 4;
+            // extract the 5-th bit from qh
+            const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
+            const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
+
+            const uint8_t vi0 = (vi & 0x0F) | vh0;
+            const uint8_t vi1 = (vi >>   4) | vh1;
 
             const float v0 = vi0*d + m;
             const float v1 = vi1*d + m;
 
-            y[i*QK + l + 0] = v0;
-            y[i*QK + l + 1] = v1;
+            y[i*QK5_1 + l + 0] = v0;
+            y[i*QK5_1 + l + 1] = v1;
 
-            assert(!isnan(y[i*QK + l + 0]));
-            assert(!isnan(y[i*QK + l + 1]));
+            assert(!isnan(y[i*QK5_1 + l + 0]));
+            assert(!isnan(y[i*QK5_1 + l + 1]));
         }
     }
-#endif
 }
 
+static void dequantize_row_q8_0(const void * restrict vx, float * restrict y, int k) {
+    assert(k % QK8_0 == 0);
+    const int nb = k / QK8_0;
+
+    const block_q8_0 * restrict x = vx;
+
+    for (int i = 0; i < nb; i++) {
+        const float d = x[i].d;
+
+        const int8_t * restrict pp = x[i].qs;
+
+        for (int l = 0; l < QK8_0; ++l) {
+            y[i*QK8_0 + l] = pp[l]*d;
+        }
+    }
+}
+
+static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+
+static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
+    [GGML_TYPE_Q4_0] = {
+        .dequantize_row_q         = dequantize_row_q4_0,
+        .quantize_row_q           = quantize_row_q4_0,
+        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
+        .quantize_row_q_dot       = quantize_row_q8_0,
+        .vec_dot_q                = ggml_vec_dot_q4_0_q8_0,
+        .vec_dot_type             = GGML_TYPE_Q8_0,
+    },
+    [GGML_TYPE_Q4_1] = {
+        .dequantize_row_q         = dequantize_row_q4_1,
+        .quantize_row_q           = quantize_row_q4_1,
+        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
+        .quantize_row_q_dot       = quantize_row_q8_1,
+        .vec_dot_q                = ggml_vec_dot_q4_1_q8_1,
+        .vec_dot_type             = GGML_TYPE_Q8_1,
+    },
+    [GGML_TYPE_Q4_2] = {
+        .dequantize_row_q         = dequantize_row_q4_2,
+        .quantize_row_q           = quantize_row_q4_2,
+        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_reference,
+        .quantize_row_q_dot       = quantize_row_q8_0,
+        .vec_dot_q                = ggml_vec_dot_q4_2_q8_0,
+        .vec_dot_type             = GGML_TYPE_Q8_0,
+    },
+    [GGML_TYPE_Q5_0] = {
+        .dequantize_row_q         = dequantize_row_q5_0,
+        .quantize_row_q           = quantize_row_q5_0,
+        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_0_reference,
+        .quantize_row_q_dot       = quantize_row_q8_0,
+        .vec_dot_q                = ggml_vec_dot_q5_0_q8_0,
+        .vec_dot_type             = GGML_TYPE_Q8_0,
+    },
+    [GGML_TYPE_Q5_1] = {
+        .dequantize_row_q         = dequantize_row_q5_1,
+        .quantize_row_q           = quantize_row_q5_1,
+        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_1_reference,
+        .quantize_row_q_dot       = quantize_row_q8_1,
+        .vec_dot_q                = ggml_vec_dot_q5_1_q8_1,
+        .vec_dot_type             = GGML_TYPE_Q8_1,
+    },
+    [GGML_TYPE_Q8_0] = {
+        .dequantize_row_q         = dequantize_row_q8_0,
+        .quantize_row_q           = quantize_row_q8_0,
+        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_0_reference,
+        .quantize_row_q_dot       = quantize_row_q8_0,
+        .vec_dot_q                = ggml_vec_dot_q8_0_q8_0,
+        .vec_dot_type             = GGML_TYPE_Q8_0,
+    },
+    [GGML_TYPE_Q8_1] = {
+        .dequantize_row_q         = NULL,   // TODO
+        .quantize_row_q           = quantize_row_q8_1,
+        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_1_reference,
+        .quantize_row_q_dot       = quantize_row_q8_1,
+        .vec_dot_q                = NULL,   // TODO
+        .vec_dot_type             = GGML_TYPE_Q8_1,
+    },
+};
+
+// For internal test use
+quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
+    GGML_ASSERT(i < GGML_TYPE_COUNT);
+    return quantize_fns[i];
+}
+
+
 //
 // simd mappings
 //
@@ -1822,37 +2585,6 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
     *s = sumf;
 }
 
-#if __AVX512F__ && QK == 32
-static inline __m512 dot_q4_0_oneblock_avx512(
-    __m512 acc,
-    const block_q4_0 * restrict x,
-    const block_q4_0 * restrict y,
-    int i
-) {
-    // Compute combined scale for the block
-    __m512 d = _mm512_set1_ps( x[i].d * y[i].d );
-
-    __m256i bx = bytesFromNibbles( x[i].qs );
-    __m256i by = bytesFromNibbles( y[i].qs );
-
-    // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
-    const __m256i off = _mm256_set1_epi8( 8 );
-    bx = _mm256_sub_epi8( bx, off );
-    by = _mm256_sub_epi8( by, off );
-
-    // Sign-extend 16 signed bytes into int16_t
-    __m512i x32 = _mm512_cvtepi8_epi16( bx );
-    __m512i y32 = _mm512_cvtepi8_epi16( by );
-    // Compute products of int16_t integers, add pairwise
-    __m512i i64 = _mm512_madd_epi16( x32, y32 );
-
-    // Convert int32_t to float
-    __m512 p = _mm512_cvtepi32_ps( i64 );
-    // Apply the scale, and accumulate
-    return _mm512_fmadd_ps( d, p, acc );
-}
-#endif
-
 inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
     ggml_float sumf = 0.0;
 
@@ -1869,552 +2601,853 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
             ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
             ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
 
-            sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]);
-        }
-    }
+            sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]);
+        }
+    }
+
+    // reduce sum0..sum3 to sum0
+    GGML_F16_VEC_REDUCE(sumf, sum);
+
+    // leftovers
+    for (int i = np; i < n; ++i) {
+        sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
+    }
+#else
+    for (int i = 0; i < n; ++i) {
+        sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
+    }
+#endif
+
+    *s = sumf;
+}
+
+static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    const int nb = n / QK8_0;
+
+    assert(n % QK8_0 == 0);
+    assert(nb % 2 == 0);
+
+    const block_q4_0 * restrict x = vx;
+    const block_q8_0 * restrict y = vy;
+
+#if defined(__ARM_NEON)
+    float32x4_t sumv0 = vdupq_n_f32(0.0f);
+    float32x4_t sumv1 = vdupq_n_f32(0.0f);
+
+    for (int i = 0; i < nb; i += 2) {
+        const block_q4_0 * restrict x0 = &x[i + 0];
+        const block_q4_0 * restrict x1 = &x[i + 1];
+        const block_q8_0 * restrict y0 = &y[i + 0];
+        const block_q8_0 * restrict y1 = &y[i + 1];
+
+        const uint8x16_t m4b   = vdupq_n_u8(0x0F);
+        const int8x16_t  s8b   = vdupq_n_s8(0x8);
+
+        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
+        const uint8x16_t v0_1 = vld1q_u8(x1->qs);
+
+        // 4-bit -> 8-bit
+        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
+        const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+        const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
+        const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+        // sub 8
+        const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
+        const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
+        const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
+        const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
+
+        // load y
+        const int8x16_t v1_0l = vld1q_s8(y0->qs);
+        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
+        const int8x16_t v1_1l = vld1q_s8(y1->qs);
+        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
+
+        // interleave
+        const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
+        const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
+        const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
+        const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+        // dot product into int32x4_t
+        const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs);
+        const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs);
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
+#else
+        const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
+        const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
+        const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
+        const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
+
+        const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
+        const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
+        const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
+        const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
+
+        const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
+        const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
+        const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
+        const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d);
+#endif
+    }
+
+    *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
+#elif defined(__AVX2__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+
+    // Main loop
+    for (int i = 0; i < nb; ++i) {
+        /* Compute combined scale for the block */
+        const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
+
+        __m256i bx = bytes_from_nibbles_32(x[i].qs);
+
+        // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
+        const __m256i off = _mm256_set1_epi8( 8 );
+        bx = _mm256_sub_epi8( bx, off );
+
+        __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
+
+        const __m256 q = mul_sum_i8_pairs_float(bx, by);
+
+        /* Multiply q with scale and accumulate */
+        acc = _mm256_fmadd_ps( d, q, acc );
+    }
+
+    *s = hsum_float_8(acc);
+#elif defined(__AVX__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+
+    // Main loop
+    for (int i = 0; i < nb; ++i) {
+        // Compute combined scale for the block
+        const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
+
+        __m128i i32[2];
+        for (int j = 0; j < 2; ++j) {
+            // Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
+            __m128i bx = bytes_from_nibbles_16(x[i].qs + 8*j);
+            __m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));
+
+            // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
+            const __m128i off = _mm_set1_epi8( 8 );
+            bx = _mm_sub_epi8( bx, off );
+
+            // Get absolute values of x vectors
+            const __m128i ax = _mm_sign_epi8(bx, bx);
+
+            // Sign the values of the y vectors
+            const __m128i sy = _mm_sign_epi8(by, bx);
+
+            // Perform multiplication and create 16-bit values
+            const __m128i dot = _mm_maddubs_epi16(ax, sy);
+
+            const __m128i ones = _mm_set1_epi16(1);
+            i32[j] = _mm_madd_epi16(ones, dot);
+        }
+
+        // Convert int32_t to float
+        __m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] ));
+        // Apply the scale, and accumulate
+        acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
+    }
+
+    *s = hsum_float_8(acc);
+#else
+    // scalar
+    float sumf = 0.0;
+    for (int i = 0; i < nb; i++) {
+        const float d0 = x[i].d;
+        const float d1 = y[i].d;
+
+        const uint8_t * restrict p0 = x[i].qs;
+        const  int8_t * restrict p1 = y[i].qs;
+
+        int sumi = 0;
+        for (int j = 0; j < QK8_0/2; j++) {
+            const uint8_t v0 = p0[j];
+
+            const int i0 = (int8_t) (v0 & 0x0F) - 8;
+            const int i1 = (int8_t) (v0 >>   4) - 8;
+
+            const int i2 = p1[2*j + 0];
+            const int i3 = p1[2*j + 1];
+
+            sumi += i0*i2 + i1*i3;
+        }
+        sumf += d0*d1*sumi;
+    }
+    *s = sumf;
+#endif
+}
+
+static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    const int nb = n / QK8_1;
+
+    assert(n % QK8_1 == 0);
+    assert(nb % 2 == 0);
+
+    const block_q4_1 * restrict x = vx;
+    const block_q8_1 * restrict y = vy;
+
+    // TODO: add AVX / WASM SIMD / etc
+#if defined(__ARM_NEON)
+    float32x4_t sumv0 = vdupq_n_f32(0.0f);
+    float32x4_t sumv1 = vdupq_n_f32(0.0f);
+
+    float summs = 0;
+
+    for (int i = 0; i < nb; i += 2) {
+        const block_q4_1 * restrict x0 = &x[i + 0];
+        const block_q4_1 * restrict x1 = &x[i + 1];
+        const block_q8_1 * restrict y0 = &y[i + 0];
+        const block_q8_1 * restrict y1 = &y[i + 1];
+
+        summs += x0->m * (y0->s0 + y0->s1) + x1->m * (y1->s0 + y1->s1);
+
+        const uint8x16_t m4b = vdupq_n_u8(0x0F);
+
+        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
+        const uint8x16_t v0_1 = vld1q_u8(x1->qs);
+
+        // 4-bit -> 8-bit
+        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
+        const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+        const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
+        const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+        // interleave
+        const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
+        const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
+        const int8x16_t v0_1lz = vzip1q_s8(v0_1l, v0_1h);
+        const int8x16_t v0_1hz = vzip2q_s8(v0_1l, v0_1h);
+
+        // load y
+        const int8x16_t v1_0l = vld1q_s8(y0->qs);
+        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
+        const int8x16_t v1_1l = vld1q_s8(y1->qs);
+        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+        // dot product into int32x4_t
+        const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l), v0_0hz, v1_0h);
+        const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l), v0_1hz, v1_1h);
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
+#else
+        const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
+        const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
+        const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
+        const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
+
+        const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
+        const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
+        const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
+        const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
+
+        const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
+        const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
+        const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
+        const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d);
+#endif
+    }
+
+    *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
+#elif defined(__AVX2__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+
+    float summs = 0;
+
+    // Main loop
+    for (int i = 0; i < nb; ++i) {
+        const float * d0 = &x[i].d;
+        const float * d1 = &y[i].d;
+
+        summs += x[i].m * (y[i].s0 + y[i].s1);
+
+        const __m256 d0v = _mm256_broadcast_ss( d0 );
+        const __m256 d1v = _mm256_broadcast_ss( d1 );
+
+        // Compute combined scales
+        const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
+
+        // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
+        const __m256i bx = bytes_from_nibbles_32(x[i].qs);
+        const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
+
+        const __m256 xy = mul_sum_i8_pairs_float(bx, by);
+
+        // Accumulate d0*d1*x*y
+        acc = _mm256_fmadd_ps( d0d1, xy, acc );
+    }
+
+    *s = hsum_float_8(acc) + summs;
+#else
+    // scalar
+    float sumf = 0.0;
+    for (int i = 0; i < nb; i++) {
+        const float d0 = x[i].d;
+        const float m0 = x[i].m;
+        const float d1 = y[i].d;
+
+        const uint8_t * restrict p0 = x[i].qs;
+        const  int8_t * restrict p1 = y[i].qs;
+
+        // TODO: this is very slow ..
+        for (int j = 0; j < QK8_1/2; j++) {
+            const uint8_t v0 = p0[j];
+
+            const float f0 = d0*(v0 & 0x0F) + m0;
+            const float f1 = d0*(v0 >>   4) + m0;
 
-    // reduce sum0..sum3 to sum0
-    GGML_F16_VEC_REDUCE(sumf, sum);
+            const float f2 = d1*p1[2*j + 0];
+            const float f3 = d1*p1[2*j + 1];
 
-    // leftovers
-    for (int i = np; i < n; ++i) {
-        sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
-    }
-#else
-    for (int i = 0; i < n; ++i) {
-        sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
+            sumf += f0*f2 + f1*f3;
+        }
     }
-#endif
-
     *s = sumf;
+#endif
 }
 
-static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
-    const int nb = n / QK;
+static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    const int nb = n / QK8_0;
 
-    assert(n % QK == 0);
+    assert(n % QK8_0 == 0);
     assert(nb % 2 == 0);
+    assert(QK8_0 == 2*QK4_2);
 
-    const block_q4_0 * restrict x = vx;
-    const block_q4_0 * restrict y = vy;
-
-    float sumf = 0.0;
+    const block_q4_2 * restrict x = vx;
+    const block_q8_0 * restrict y = vy;
 
 #if defined(__ARM_NEON)
-    float sum0 = 0.0f;
-    float sum1 = 0.0f;
+    float32x4_t sumv0 = vdupq_n_f32(0.0f);
+    float32x4_t sumv1 = vdupq_n_f32(0.0f);
 
     for (int i = 0; i < nb; i += 2) {
-        const block_q4_0 * restrict x0 = &x[i + 0];
-        const block_q4_0 * restrict y0 = &y[i + 0];
-        const block_q4_0 * restrict x1 = &x[i + 1];
-        const block_q4_0 * restrict y1 = &y[i + 1];
+        const block_q4_2 * restrict x0_0 = &x[2*(i + 0) + 0];
+        const block_q4_2 * restrict x0_1 = &x[2*(i + 0) + 1];
+        const block_q4_2 * restrict x1_0 = &x[2*(i + 1) + 0];
+        const block_q4_2 * restrict x1_1 = &x[2*(i + 1) + 1];
 
-        const uint8x16_t m4b = vdupq_n_u8(0xf);
-        const int8x16_t  s8b = vdupq_n_s8(0x8);
+        const block_q8_0 * restrict y0 = &y[i + 0];
+        const block_q8_0 * restrict y1 = &y[i + 1];
 
-        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
-        const uint8x16_t v1_0 = vld1q_u8(y0->qs);
-        const uint8x16_t v0_1 = vld1q_u8(x1->qs);
-        const uint8x16_t v1_1 = vld1q_u8(y1->qs);
+        const uint8x16_t m4b   = vdupq_n_u8(0x0F);
+        const int8x16_t  s8b   = vdupq_n_s8(0x8);
+
+        const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
+        const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs));
 
         // 4-bit -> 8-bit
-        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
-        const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b));
+        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
         const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
-        const int8x16_t v1_0h = vreinterpretq_s8_u8(vshrq_n_u8(v1_0, 4));
-
-        const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
-        const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b));
+        const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
         const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
-        const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4));
 
         // sub 8
         const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
-        const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
         const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
-        const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
-
         const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
-        const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
         const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
-        const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
-
-#if defined(__ARM_FEATURE_DOTPROD)
-        // dot product into int32x4_t
-        int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
-        int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
-
-        p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
-        p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
 
-        sum0 += x0->d*y0->d*vaddvq_s32(p_0);
-        sum1 += x1->d*y1->d*vaddvq_s32(p_1);
-#else
-        const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
-        const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
-        const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
-        const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
-
-        const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
-        const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
-        const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
-        const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
+        // interleave
+        const int8x16_t v0_0lz = vzip1q_s8(v0_0ls, v0_0hs);
+        const int8x16_t v0_0hz = vzip2q_s8(v0_0ls, v0_0hs);
+        const int8x16_t v0_1lz = vzip1q_s8(v0_1ls, v0_1hs);
+        const int8x16_t v0_1hz = vzip2q_s8(v0_1ls, v0_1hs);
 
-        const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h);
-        const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h);
+        // load y
+        const int8x16_t v1_0l = vld1q_s8(y0->qs);
+        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
+        const int8x16_t v1_1l = vld1q_s8(y1->qs);
+        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
 
-        const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h);
-        const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
-
-        const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
-        const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
+#if defined(__ARM_FEATURE_DOTPROD)
+        sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
+                vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), GGML_FP16_TO_FP32(x0_0->d)),
+                vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
 
-        sum0 += x0->d*y0->d*vaddvq_s16(p_0);
-        sum1 += x1->d*y1->d*vaddvq_s16(p_1);
+        sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
+                vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), GGML_FP16_TO_FP32(x1_0->d)),
+                vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
+#else
+        const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
+        const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
+        const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
+        const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
+
+        const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
+        const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
+        const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
+        const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
+
+        const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
+        const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
+        const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
+        const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
+
+        sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
+                vmulq_n_f32(vcvtq_f32_s32(pl0), GGML_FP16_TO_FP32(x0_0->d)),
+                vmulq_n_f32(vcvtq_f32_s32(ph0), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
+
+        sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
+                vmulq_n_f32(vcvtq_f32_s32(pl1), GGML_FP16_TO_FP32(x1_0->d)),
+                vmulq_n_f32(vcvtq_f32_s32(ph1), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
 #endif
     }
 
-    sumf = sum0 + sum1;
-#elif defined(__AVX512F__)
+    *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
+#elif defined(__AVX2__)
     // Initialize accumulator with zeros
-    __m512 acc0 = _mm512_setzero_ps();
-    __m512 acc1 = _mm512_setzero_ps();
+    __m256 acc = _mm256_setzero_ps();
 
-    const int superblock_size = 8;
-    const int superblock_count = nb / superblock_size;
+    // Main loop
+    for (int i = 0; i < nb; i++) {
+        /* Compute combined scale for the block */
+        const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
+        const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
+        const __m256 d = _mm256_mul_ps(_mm256_set_m128(d1, d0), _mm256_broadcast_ss(&y[i].d));
 
-    for (int superblock_ix = 0; superblock_ix < superblock_count; superblock_ix += 1) {
-        int i = superblock_ix * superblock_size;
+        __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
+        __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
+        __m256i bx = _mm256_set_m128i(bx1, bx0);
 
-        acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+0 );
-        acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+1 );
-        acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+2 );
-        acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+3 );
-        acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+4 );
-        acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+5 );
-        acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+6 );
-        acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+7 );
-    }
+        // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
+        const __m256i off = _mm256_set1_epi8(8);
+        bx = _mm256_sub_epi8(bx, off);
 
-    // Remainders
-    for (int i = superblock_count * superblock_size; i < nb; ++i) {
-        acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i );
-    }
+        __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
 
-    // Horizontal sum of all lanes of the accumulator
-    sumf = _mm512_reduce_add_ps( acc0 ) + _mm512_reduce_add_ps( acc1 );
-#elif defined(__AVX2__)
-    // Initialize accumulator with zeros
-    __m256 acc = _mm256_setzero_ps();
+        const __m256 q = mul_sum_i8_pairs_float(bx, by);
 
-    /* Prepare the constants we will need during execution */
-    const __m256i lowMask = _mm256_set1_epi8( 0xF );
-    const __m256i offset_8 = _mm256_set1_epi16( 8 );
+        /* Multiply q with scale and accumulate */
+        acc = _mm256_fmadd_ps(d, q, acc);
+    }
 
-#define UNROLL_COUNT 8
-    // make sure we only unroll multiples of the block count
-    assert(nb % UNROLL_COUNT == 0);
+    *s = hsum_float_8(acc);
+#else
+    // scalar
+    float sumf = 0.0;
+    for (int i = 0; i < nb; i++) {
+        const uint8_t * restrict x0 = x[2*i + 0].qs;
+        const uint8_t * restrict x1 = x[2*i + 1].qs;
+        const  int8_t * restrict y0 = y[i].qs;
 
-    // Main loop
-    for (int i = 0; i < nb; i+=UNROLL_COUNT) {
-        // This loop will be unrolled by the compiler
-        for (int u=0;u<UNROLL_COUNT;u++)  {
-            /* Compute combined scale for the block */
-            const __m256 scale = _mm256_mul_ps(
-                    _mm256_broadcast_ss( &x[i+u].d ),
-                    _mm256_broadcast_ss( &y[i+u].d ) );
-
-            /* get input from x
-               Input: 32 Nibbles (16 bytes) at *x[i+u]
-               Output: 2 vectors with 16 values of type int16_t (x_high_q, x_low_q) */
-
-            /* Load 16 bytes from memory */
-            const __m128i tmp_x = _mm_loadu_si128( ( const __m128i* ) x[i+u].qs);
-            /* Expand bytes into uint16_t values */
-            const __m256i bytes_x = _mm256_cvtepu8_epi16(tmp_x);
-            /* Unpack values into individual bytes */
-            __m256i x_low_q = _mm256_and_si256( lowMask, bytes_x );
-            const __m256i pre_shift_x_high_q = _mm256_andnot_si256( lowMask, bytes_x );
-            __m256i x_high_q = _mm256_srli_epi16( pre_shift_x_high_q, 4 );
-            /* Now we have two vectors with bytes in [ 0 .. 15 ] interval.  Offset them into [ -8 .. +7 ] interval.  */
-            x_high_q = _mm256_sub_epi16( x_high_q, offset_8 );
-            x_low_q = _mm256_sub_epi16( x_low_q, offset_8 );
-
-            /* get input from y
-               Input: 32 Nibbles (16 bytes) at *y[i+u]
-               Output: 2 vectors with 16 values of type int16_t (y_high_q, y_low_q) */
-
-            /* Load 16 bytes from memory */
-            const __m128i tmp_y = _mm_loadu_si128( (const __m128i* ) y[i+u].qs);
-            /* Expand bytes into uint16_t values */
-            const __m256i bytes_y = _mm256_cvtepu8_epi16(tmp_y);
-            /* Unpack values into individual bytes */
-            const __m256i pre_shift_y_high_q = _mm256_andnot_si256( lowMask, bytes_y );
-            __m256i y_high_q = _mm256_srli_epi16( pre_shift_y_high_q, 4 );
-            __m256i y_low_q = _mm256_and_si256( lowMask, bytes_y );
-            /* Now we have two vectors with bytes in [ 0 .. 15 ] interval.  Offset them into [ -8 .. +7 ] interval.  */
-            y_high_q = _mm256_sub_epi16( y_high_q, offset_8 );
-            y_low_q = _mm256_sub_epi16( y_low_q, offset_8 );
-
-            /* Compute products of int16_t integers, add pairwise, store as int32_t */
-            __m256i xy_high_q = _mm256_madd_epi16( x_high_q, y_high_q );
-            __m256i xy_low_q = _mm256_madd_epi16( x_low_q, y_low_q );
-
-            /* Accumulate the products of int32_t integers -> we now have a vector of 8 int_32t */
-            __m256i xy_q = _mm256_add_epi32( xy_high_q, xy_low_q );
-
-            /* Convert to vectore of 8 int32_t to 8 floats */
-            __m256 q = _mm256_cvtepi32_ps( xy_q );
-
-            /* Multiply q with scale and accumulate */
-            acc = _mm256_fmadd_ps( scale, q, acc );
-        }
-    }
-
-    // Return horizontal sum of the acc vector
-    __m128 res = _mm256_extractf128_ps( acc, 1 );
-    res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
-    res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
-    res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
-
-    sumf = _mm_cvtss_f32( res );
-#elif defined(__AVX__)
-    // Initialize accumulator with zeros
-    __m256 acc = _mm256_setzero_ps();
+        const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
+        const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
 
-    // Main loop
-    for (int i = 0; i < nb; ++i) {
-        // Compute combined scale for the block
-        const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
+        int sumi_0 = 0;
+        int sumi_1 = 0;
 
-        __m128i i32[2];
-        for (int j = 0; j < 2; ++j) {
-            // Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
-            __m128i bx = bytesFromNibbles( x[i].qs + 8*j );
-            __m128i by = bytesFromNibbles( y[i].qs + 8*j );
+        for (int j = 0; j < QK8_0/4; j++) {
+            const uint8_t v0 = x0[j];
+            const uint8_t v1 = x1[j];
 
-            // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
-            const __m128i off = _mm_set1_epi8( 8 );
-            bx = _mm_sub_epi8( bx, off );
-            by = _mm_sub_epi8( by, off );
+            const int i0_0 = (int8_t) (v0 & 0x0F) - 8;
+            const int i1_0 = (int8_t) (v0 >>   4) - 8;
 
-            // Get absolute values of x vectors
-            const __m128i ax = _mm_sign_epi8(bx, bx);
+            const int i0_1 = (int8_t) (v1 & 0x0F) - 8;
+            const int i1_1 = (int8_t) (v1 >>   4) - 8;
 
-            // Sign the values of the y vectors
-            const __m128i sy = _mm_sign_epi8(by, bx);
+            const int i2_0 = y0[2*j + 0];
+            const int i3_0 = y0[2*j + 1];
 
-            // Perform multiplication and create 16-bit values
-            const __m128i dot = _mm_maddubs_epi16(ax, sy);
+            const int i2_1 = y0[2*(j + QK8_0/4) + 0];
+            const int i3_1 = y0[2*(j + QK8_0/4) + 1];
 
-            const __m128i ones = _mm_set1_epi16(1);
-            i32[j] = _mm_madd_epi16(ones, dot);
+            sumi_0 += i0_0*i2_0 + i1_0*i3_0;
+            sumi_1 += i0_1*i2_1 + i1_1*i3_1;
         }
 
-        // Convert int32_t to float
-        __m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] ));
-        // Apply the scale, and accumulate
-        acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
+        sumf += (d0 * y[i].d) * sumi_0;
+        sumf += (d1 * y[i].d) * sumi_1;
     }
+    *s = sumf;
+#endif
+}
 
-    // Return horizontal sum of the acc vector
-    __m128 res = _mm256_extractf128_ps( acc, 1 );
-    res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
-    res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
-    res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
+static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    const int nb = n / QK8_0;
 
-    sumf = _mm_cvtss_f32( res );
-#elif defined(__wasm_simd128__)
-    // wasm simd
-    float sum0 = 0.0f;
-    float sum1 = 0.0f;
+    assert(n % QK8_0 == 0);
+    assert(nb % 2 == 0);
+    assert(QK8_0 == QK5_0);
 
-    for (int i = 0; i < nb; i += 2) {
-        const block_q4_0 * restrict x0 = &x[i + 0];
-        const block_q4_0 * restrict y0 = &y[i + 0];
-        const block_q4_0 * restrict x1 = &x[i + 1];
-        const block_q4_0 * restrict y1 = &y[i + 1];
+    const block_q5_0 * restrict x = vx;
+    const block_q8_0 * restrict y = vy;
 
-        const v128_t m4b = wasm_u8x16_splat(0xf);
-        const v128_t s8b = wasm_i8x16_splat(0x8);
+#if defined(__ARM_NEON)
+    float32x4_t sumv = vdupq_n_f32(0.0f);
 
-        const v128_t v0_0 = wasm_v128_load(x0->qs);
-        const v128_t v0_1 = wasm_v128_load(y0->qs);
-        const v128_t v1_0 = wasm_v128_load(x1->qs);
-        const v128_t v1_1 = wasm_v128_load(y1->qs);
+    uint64_t tmp[4];
 
-        // 4-bit -> 8-bit
-        const v128_t v0_0l = wasm_v128_and(v0_0, m4b);
-        const v128_t v1_0l = wasm_v128_and(v1_0, m4b);
+    for (int i = 0; i < nb; ++i) {
+        const block_q5_0 * restrict x0 = &x[i];
+        const block_q8_0 * restrict y0 = &y[i];
 
-        const v128_t v0_0h = wasm_u8x16_shr(v0_0, 4);
-        const v128_t v1_0h = wasm_u8x16_shr(v1_0, 4);
+        const uint8x16_t m4b  = vdupq_n_u8(0x0F);
+        const int8x16_t  s16b = vdupq_n_s8(0x10);
 
-        const v128_t v0_1l = wasm_v128_and(v0_1, m4b);
-        const v128_t v1_1l = wasm_v128_and(v1_1, m4b);
+        // extract the 5th bit
+        uint32_t qh;
+        memcpy(&qh, x0->qh, sizeof(qh));
 
-        const v128_t v0_1h = wasm_u8x16_shr(v0_1, 4);
-        const v128_t v1_1h = wasm_u8x16_shr(v1_1, 4);
+        tmp[0] = table_b2b_u[(qh >>  0) & 0xFF];
+        tmp[1] = table_b2b_u[(qh >>  8) & 0xFF];
+        tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
+        tmp[3] = table_b2b_u[(qh >> 24)       ];
 
-        // sub 8
-        const v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b);
-        const v128_t v1_0ls = wasm_i8x16_sub(v1_0l, s8b);
+        const int8x16_t qhl = vld1q_s8((const int8_t *)(tmp + 0));
+        const int8x16_t qhh = vld1q_s8((const int8_t *)(tmp + 2));
+
+        const uint8x16_t v0 = vld1q_u8(x0->qs);
+
+        // 4-bit -> 8-bit
+        const int8x16_t v0l = vreinterpretq_s8_u8(vandq_u8  (v0, m4b));
+        const int8x16_t v0h = vreinterpretq_s8_u8(vshrq_n_u8(v0, 4));
+
+        // interleave
+        const int8x16_t v0lz = vzip1q_s8(v0l, v0h);
+        const int8x16_t v0hz = vzip2q_s8(v0l, v0h);
 
-        const v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b);
-        const v128_t v1_0hs = wasm_i8x16_sub(v1_0h, s8b);
+        // add high bit and sub 16
+        const int8x16_t v0lf = vsubq_s8(vorrq_s8(v0lz, qhl), s16b);
+        const int8x16_t v0hf = vsubq_s8(vorrq_s8(v0hz, qhh), s16b);
 
-        const v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b);
-        const v128_t v1_1ls = wasm_i8x16_sub(v1_1l, s8b);
+        // load y
+        const int8x16_t v1l = vld1q_s8(y0->qs);
+        const int8x16_t v1h = vld1q_s8(y0->qs + 16);
 
-        const v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b);
-        const v128_t v1_1hs = wasm_i8x16_sub(v1_1h, s8b);
+        const float x0d = GGML_FP16_TO_FP32(x0->d);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+        sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(
+                        vdotq_s32(vdupq_n_s32(0), v0lf, v1l),
+                        vdotq_s32(vdupq_n_s32(0), v0hf, v1h))), x0d*y0->d);
+#else
+        const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0lf), vget_low_s8 (v1l));
+        const int16x8_t pl0h = vmull_s8(vget_high_s8(v0lf), vget_high_s8(v1l));
+        const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0hf), vget_low_s8 (v1h));
+        const int16x8_t ph0h = vmull_s8(vget_high_s8(v0hf), vget_high_s8(v1h));
 
-        // dot product into int16x8_t
-        const v128_t pl0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0ls), wasm_i16x8_extend_low_i8x16(v1_0ls));
-        const v128_t pl0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0ls), wasm_i16x8_extend_high_i8x16(v1_0ls));
+        const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
+        const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
 
-        const v128_t ph0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0hs), wasm_i16x8_extend_low_i8x16(v1_0hs));
-        const v128_t ph0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0hs), wasm_i16x8_extend_high_i8x16(v1_0hs));
+        sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d);
+#endif
+    }
 
-        const v128_t pl1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1ls), wasm_i16x8_extend_low_i8x16(v1_1ls));
-        const v128_t pl1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1ls), wasm_i16x8_extend_high_i8x16(v1_1ls));
+    *s = vaddvq_f32(sumv);
+#elif defined(__AVX2__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
 
-        const v128_t ph1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1hs), wasm_i16x8_extend_low_i8x16(v1_1hs));
-        const v128_t ph1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1hs), wasm_i16x8_extend_high_i8x16(v1_1hs));
+    // Main loop
+    for (int i = 0; i < nb; i++) {
+        /* Compute combined scale for the block */
+        const __m256 d = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)), _mm256_broadcast_ss(&y[i].d));
 
-        const v128_t pl_0 = wasm_i16x8_add(pl0l, pl0h);
-        const v128_t ph_0 = wasm_i16x8_add(ph0l, ph0h);
+        __m256i bx = bytes_from_nibbles_32(x[i].qs);
+        __m256i bxhi = bytes_from_bits_32(x[i].qh);
+        bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0));
+        bx = _mm256_or_si256(bx, bxhi);
 
-        const v128_t pl_1 = wasm_i16x8_add(pl1l, pl1h);
-        const v128_t ph_1 = wasm_i16x8_add(ph1l, ph1h);
+        __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
 
-        const v128_t p_0 = wasm_i16x8_add(pl_0, ph_0);
-        const v128_t p_1 = wasm_i16x8_add(pl_1, ph_1);
+        const __m256 q = mul_sum_i8_pairs_float(bx, by);
 
-        sum0 += x0->d * y0->d * (
-                wasm_i16x8_extract_lane(p_0, 0) + wasm_i16x8_extract_lane(p_0, 1) +
-                wasm_i16x8_extract_lane(p_0, 2) + wasm_i16x8_extract_lane(p_0, 3) +
-                wasm_i16x8_extract_lane(p_0, 4) + wasm_i16x8_extract_lane(p_0, 5) +
-                wasm_i16x8_extract_lane(p_0, 6) + wasm_i16x8_extract_lane(p_0, 7));
-        sum1 += x1->d * y1->d * (
-                wasm_i16x8_extract_lane(p_1, 0) + wasm_i16x8_extract_lane(p_1, 1) +
-                wasm_i16x8_extract_lane(p_1, 2) + wasm_i16x8_extract_lane(p_1, 3) +
-                wasm_i16x8_extract_lane(p_1, 4) + wasm_i16x8_extract_lane(p_1, 5) +
-                wasm_i16x8_extract_lane(p_1, 6) + wasm_i16x8_extract_lane(p_1, 7));
+        /* Multiply q with scale and accumulate */
+        acc = _mm256_fmadd_ps(d, q, acc);
     }
 
-    sumf = sum0 + sum1;
+    *s = hsum_float_8(acc);
 #else
     // scalar
+    float sumf = 0.0;
     for (int i = 0; i < nb; i++) {
-        const float d0 = x[i].d;
-        const float d1 = y[i].d;
+        const uint8_t * restrict x0 = x[i].qs;
+        const  int8_t * restrict y0 = y[i].qs;
 
-        const uint8_t * restrict p0 = x[i].qs;
-        const uint8_t * restrict p1 = y[i].qs;
+        uint32_t qh;
+        memcpy(&qh, x[i].qh, sizeof(qh));
 
-        int sumi = 0;
-        for (int j = 0; j < QK/2; j++) {
-            const uint8_t v0 = p0[j];
-            const uint8_t v1 = p1[j];
+        const float d = GGML_FP16_TO_FP32(x[i].d);
 
-            const int8_t i0 = (int8_t) (v0 & 0xf) - 8;
-            const int8_t i1 = (int8_t) (v0 >> 4)  - 8;
+        int sxy = 0;
 
-            const int8_t i2 = (int8_t) (v1 & 0xf) - 8;
-            const int8_t i3 = (int8_t) (v1 >> 4)  - 8;
+        for (int j = 0; j < QK8_0/2; j++) {
+            const uint8_t v0 = x0[j];
 
-            sumi += i0*i2 + i1*i3;
+            const int x0_0h = ((qh & (1 << (2*j + 0))) >> (2*j + 0)) << 4;
+            const int x1_0h = ((qh & (1 << (2*j + 1))) >> (2*j + 1)) << 4;
+
+            const int x0_0 = ((v0 & 0x0F) | x0_0h) - 16;
+            const int x1_0 = ((v0 >>   4) | x1_0h) - 16;
+
+            const int y0_0 = y0[2*j + 0];
+            const int y1_0 = y0[2*j + 1];
+
+            sxy += x0_0*y0_0 + x1_0*y1_0;
         }
-        sumf += d0 * d1 * sumi;
-    }
-#endif
 
+        sumf += (d*sxy)*y[i].d;
+    }
     *s = sumf;
+#endif
 }
 
-static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
-    const int nb = n / QK;
+static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    const int nb = n / QK8_1;
 
-    const block_q4_1 * restrict x = vx;
-    const block_q4_1 * restrict y = vy;
+    assert(n % QK8_1 == 0);
+    assert(nb % 2 == 0);
+    assert(QK8_1 == QK5_1);
 
-    float sumf = 0.0;
+    const block_q5_1 * restrict x = vx;
+    const block_q8_1 * restrict y = vy;
 
-#if defined(__AVX2__)
+#if defined(__ARM_NEON)
+    float32x4_t sumv = vdupq_n_f32(0.0f);
+
+    float summs = 0.0f;
+
+    uint64_t tmp[4];
+
+    for (int i = 0; i < nb; ++i) {
+        const block_q5_1 * restrict x0 = &x[i];
+        const block_q8_1 * restrict y0 = &y[i];
+
+        summs += GGML_FP16_TO_FP32(x0->m) * (y0->s0 + y0->s1);
+
+        // extract the 5th bit
+        uint32_t qh;
+        memcpy(&qh, x0->qh, sizeof(qh));
+
+        tmp[0] = table_b2b_u[(qh >>  0) & 0xFF];
+        tmp[1] = table_b2b_u[(qh >>  8) & 0xFF];
+        tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
+        tmp[3] = table_b2b_u[(qh >> 24)       ];
+
+        const int8x16_t qhl = vld1q_s8((const int8_t *)(tmp + 0));
+        const int8x16_t qhh = vld1q_s8((const int8_t *)(tmp + 2));
+
+        const uint8x16_t v0 = vld1q_u8(x0->qs);
+
+        // 4-bit -> 8-bit
+        const int8x16_t v0l = vreinterpretq_s8_u8(vandq_u8  (v0, vdupq_n_u8(0x0F)));
+        const int8x16_t v0h = vreinterpretq_s8_u8(vshrq_n_u8(v0, 4));
+
+        // interleave
+        const int8x16_t v0lz = vzip1q_s8(v0l, v0h);
+        const int8x16_t v0hz = vzip2q_s8(v0l, v0h);
+
+        // add
+        const int8x16_t v0lf = vorrq_s8(v0lz, qhl);
+        const int8x16_t v0hf = vorrq_s8(v0hz, qhh);
+
+        // load y
+        const int8x16_t v1l = vld1q_s8(y0->qs);
+        const int8x16_t v1h = vld1q_s8(y0->qs + 16);
+
+        const float x0d = GGML_FP16_TO_FP32(x0->d);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+        sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(
+                        vdotq_s32(vdupq_n_s32(0), v0lf, v1l),
+                        vdotq_s32(vdupq_n_s32(0), v0hf, v1h))), x0d*y0->d);
+#else
+        const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0lf), vget_low_s8 (v1l));
+        const int16x8_t pl0h = vmull_s8(vget_high_s8(v0lf), vget_high_s8(v1l));
+        const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0hf), vget_low_s8 (v1h));
+        const int16x8_t ph0h = vmull_s8(vget_high_s8(v0hf), vget_high_s8(v1h));
+
+        const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
+        const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
+
+        sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d);
+#endif
+    }
+
+    *s = vaddvq_f32(sumv) + summs;
+#elif defined(__AVX2__)
     // Initialize accumulator with zeros
     __m256 acc = _mm256_setzero_ps();
-    // Accumulator for constant offsets
-    float acc_offset = 0.0f;
+    float summs = 0.0f;
 
     // Main loop
-    for (int i = 0; i < nb; ++i) {
-        const float * d0 = &x[i].d;
-        const float * d1 = &y[i].d;
-
-        const float * m0 = &x[i].m;
-        const float * m1 = &y[i].m;
+    for (int i = 0; i < nb; i++) {
+        const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
 
-        const __m256 d0v = _mm256_broadcast_ss( d0 );
-        const __m256 d1v = _mm256_broadcast_ss( d1 );
-        const __m256 m0v = _mm256_broadcast_ss( m0 );
-        const __m256 m1v = _mm256_broadcast_ss( m1 );
+        summs += GGML_FP16_TO_FP32(x[i].m) * (y[i].s0 + y[i].s1);
 
-        // Compute combined scale for the block
-        const __m256 scale_01 = _mm256_mul_ps( d0v, d1v );
+        __m256i bx = bytes_from_nibbles_32(x[i].qs);
+        __m256i bxhi = bytes_from_bits_32(x[i].qh);
+        bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10));
+        bx = _mm256_or_si256(bx, bxhi);
 
-        // Compute cross scales for the block
-        const __m256 scale_0 = _mm256_mul_ps( d0v, m1v );
-        const __m256 scale_1 = _mm256_mul_ps( m0v, d1v );
-        const __m256 cross_scales = _mm256_blend_ps( scale_0, scale_1, 0xAA /* 0b10101010 */ );
+        const __m256 dy = _mm256_broadcast_ss(&y[i].d);
+        const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
 
-        // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
-        __m256i bx = bytesFromNibbles( x[i].qs );
-        __m256i by = bytesFromNibbles( y[i].qs );
-
-        // Now we have a vector with bytes in [ 0 .. 15 ] interval.
-
-        // Sign-extend first 16 signed bytes into int16_t
-        __m256i x16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( bx ) );
-        __m256i y16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
-        // Compute products of int16_t integers, add pairwise
-        __m256i i32 = _mm256_madd_epi16( x16, y16 );
-
-        // Sign-extend last 16 signed bytes into int16_t vectors
-        __m256i x16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( bx, 1 ) );
-        __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
-        // Accumulate products of int16_t integers
-        i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16_h, y16_h ) );
-
-        // compute sums of unsigned bytes in bx, by in blocks of 8.
-        // This results in a layout like X100 0000 X200 0000 X300 0000 X400 0000,
-        // which we then interleave as X100 Y100 X200 Y200 X300 Y300 X400 Y400.
-        // so if we then cast to 8 singles, we get 8 floats like [ x0_7, y0_7, x8_15, y8_15, x16_23, y16_23, x24_31, y24_31 ]
-        __m256i xsumi = _mm256_sad_epu8( bx, _mm256_setzero_si256() );
-        __m256i ysumi = _mm256_sad_epu8( by, _mm256_setzero_si256() );
-        __m256i sumsi = _mm256_or_si256( xsumi, _mm256_slli_si256( ysumi, 4 ) );
-        __m256  sums  = _mm256_cvtepi32_ps( sumsi );
+        const __m256 q = mul_sum_i8_pairs_float(bx, by);
 
-        // Convert int32_t to float
-        __m256 p = _mm256_cvtepi32_ps( i32 );
-        // Apply the scale, and accumulate
-        // acc += d0*d1*x*y + d0*m1*x + d1*m0*y
-        acc = _mm256_fmadd_ps( scale_01, p, acc );
-        acc = _mm256_fmadd_ps( cross_scales, sums, acc );
-        // acc_offset += m0*m1 (for each entry in the block)
-        acc_offset += (*m0)*(*m1);
+        acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
     }
 
-    // Return horizontal sum of the acc vector
-    __m128 res = _mm256_extractf128_ps( acc, 1 );
-    res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
-    res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
-    res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
+    *s = hsum_float_8(acc) + summs;
+#else
+    float sumf = 0.0;
 
-    sumf = _mm_cvtss_f32( res ) + acc_offset * QK;
-#elif defined(__ARM_NEON)
-    float sum00 = 0.0f;
-    float sum01 = 0.0f;
-    float sum10 = 0.0f;
-    float sum11 = 0.0f;
+    for (int i = 0; i < nb; i++) {
+        const uint8_t * restrict x0 = x[i].qs;
+        const  int8_t * restrict y0 = y[i].qs;
 
-    for (int i = 0; i < nb; i += 2) {
-        const block_q4_1 * restrict x0 = &x[i + 0];
-        const block_q4_1 * restrict y0 = &y[i + 0];
-        const block_q4_1 * restrict x1 = &x[i + 1];
-        const block_q4_1 * restrict y1 = &y[i + 1];
+        uint32_t qh;
+        memcpy(&qh, x[i].qh, sizeof(qh));
 
-        const uint8x16_t m4b = vdupq_n_u8(0xf);
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+        const float m = GGML_FP16_TO_FP32(x[i].m);
 
-        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
-        const uint8x16_t v1_0 = vld1q_u8(y0->qs);
-        const uint8x16_t v0_1 = vld1q_u8(x1->qs);
-        const uint8x16_t v1_1 = vld1q_u8(y1->qs);
+        int sxy = 0;
 
-        // 4-bit -> 8-bit
-        const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
-        const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
-        const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
-        const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
+        for (int j = 0; j < QK8_1/2; j++) {
+            const uint8_t v0 = x0[j];
 
-        const uint8x16_t v0_1l = vandq_u8(v0_1, m4b);
-        const uint8x16_t v1_1l = vandq_u8(v1_1, m4b);
-        const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4);
-        const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
+            const int x0_0h = ((qh & (1 << (2*j + 0))) >> (2*j + 0)) << 4;
+            const int x1_0h = ((qh & (1 << (2*j + 1))) >> (2*j + 1)) << 4;
 
-        sum00 += x0->m*y0->m;
-        sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h));
-        sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h));
+            const int x0_0 = (v0 & 0x0F) | x0_0h;
+            const int x1_0 = (v0 >>   4) | x1_0h;
 
-        sum00 += x1->m*y1->m;
-        sum01 += y1->m*x1->d*(vaddvq_u8(v0_1l) + vaddvq_u8(v0_1h));
-        sum10 += x1->m*y1->d*(vaddvq_u8(v1_1l) + vaddvq_u8(v1_1h));
+            const int y0_0 = y0[2*j + 0];
+            const int y1_0 = y0[2*j + 1];
 
-#if defined(__ARM_FEATURE_DOTPROD)
-        // dot product into int32x4_t
-        uint32x4_t p_0 = vdotq_u32(vdupq_n_u32(0), v0_0l, v1_0l);
-        uint32x4_t p_1 = vdotq_u32(vdupq_n_u32(0), v0_1l, v1_1l);
+            sxy += x0_0*y0_0 + x1_0*y1_0;
+        }
 
-        p_0 = vdotq_u32(p_0, v0_0h, v1_0h);
-        p_1 = vdotq_u32(p_1, v0_1h, v1_1h);
+        sumf += (d*sxy)*y[i].d + m*(y[i].s0 + y[i].s1);
+    }
 
-        sum11 += x0->d*y0->d*vaddvq_u32(p_0);
-        sum11 += x1->d*y1->d*vaddvq_u32(p_1);
-#else
-        const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
-        const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
-        const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
-        const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
+    *s = sumf;
+#endif
+}
 
-        const uint16x8_t pl1l = vmull_u8(vget_low_u8 (v0_1l), vget_low_u8 (v1_1l));
-        const uint16x8_t pl1h = vmull_u8(vget_high_u8(v0_1l), vget_high_u8(v1_1l));
-        const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h));
-        const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h));
+static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    const int nb = n / QK8_0;
 
-        const uint16x8_t pl_0 = vaddq_u16(pl0l, pl0h);
-        const uint16x8_t ph_0 = vaddq_u16(ph0l, ph0h);
+    assert(n % QK8_0 == 0);
+    assert(nb % 2 == 0);
+    assert(QK8_0 == QK8_0);
+
+    const block_q8_0 * restrict x = vx;
+    const block_q8_0 * restrict y = vy;
+
+#if defined(__ARM_NEON)
+    float32x4_t sumv0 = vdupq_n_f32(0.0f);
+    float32x4_t sumv1 = vdupq_n_f32(0.0f);
 
-        const uint16x8_t pl_1 = vaddq_u16(pl1l, pl1h);
-        const uint16x8_t ph_1 = vaddq_u16(ph1l, ph1h);
+    for (int i = 0; i < nb; i += 2) {
+        const block_q8_0 * restrict x0 = &x[i + 0];
+        const block_q8_0 * restrict x1 = &x[i + 1];
+        const block_q8_0 * restrict y0 = &y[i + 0];
+        const block_q8_0 * restrict y1 = &y[i + 1];
+
+        const int8x16_t x0_0 = vld1q_s8(x0->qs);
+        const int8x16_t x0_1 = vld1q_s8(x0->qs + 16);
+        const int8x16_t x1_0 = vld1q_s8(x1->qs);
+        const int8x16_t x1_1 = vld1q_s8(x1->qs + 16);
+
+        // load y
+        const int8x16_t y0_0 = vld1q_s8(y0->qs);
+        const int8x16_t y0_1 = vld1q_s8(y0->qs + 16);
+        const int8x16_t y1_0 = vld1q_s8(y1->qs);
+        const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
+                        vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
+                        vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), x0->d*y0->d);
 
-        const uint16x8_t p_0 = vaddq_u16(pl_0, ph_0);
-        const uint16x8_t p_1 = vaddq_u16(pl_1, ph_1);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
+                        vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
+                        vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), x1->d*y1->d);
 
-        sum11 += x0->d*y0->d*vaddvq_u16(p_0);
-        sum11 += x1->d*y1->d*vaddvq_u16(p_1);
+#else
+        const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0));
+        const int16x8_t p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
+        const int16x8_t p0_2 = vmull_s8(vget_low_s8 (x0_1), vget_low_s8 (y0_1));
+        const int16x8_t p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
+
+        const int16x8_t p1_0 = vmull_s8(vget_low_s8 (x1_0), vget_low_s8 (y1_0));
+        const int16x8_t p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
+        const int16x8_t p1_2 = vmull_s8(vget_low_s8 (x1_1), vget_low_s8 (y1_1));
+        const int16x8_t p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
+
+        const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
+        const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
+        const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
+        const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), x0->d*y0->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), x1->d*y1->d);
 #endif
     }
 
-    sumf = QK*sum00 + sum01 + sum10 + sum11;
-#else
-    // scalar
-    for (int i = 0; i < nb; i++) {
-        const float d0 = x[i].d;
-        const float d1 = y[i].d;
+    *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
+#elif defined(__AVX2__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
 
-        const float m0 = x[i].m;
-        const float m1 = y[i].m;
+    // Main loop
+    for (int i = 0; i < nb; ++i) {
+        // Compute combined scale for the block
+        const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
+        __m256i bx = _mm256_loadu_si256((const __m256i *)x[i].qs);
+        __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
 
-        const uint8_t * restrict p0 = x[i].qs;
-        const uint8_t * restrict p1 = y[i].qs;
+        const __m256 q = mul_sum_i8_pairs_float(bx, by);
 
-        for (int j = 0; j < QK/2; j++) {
-            const uint8_t v0 = p0[j];
-            const uint8_t v1 = p1[j];
+        // Multiply q with scale and accumulate
+        acc = _mm256_fmadd_ps( d, q, acc );
+    }
 
-            const float f0 = d0*(v0 & 0xf) + m0;
-            const float f1 = d0*(v0 >> 4)  + m0;
+    *s = hsum_float_8(acc);
+#else
+    // scalar
+    float sumf = 0.0;
+
+    for (int i = 0; i < nb; i++) {
+        const int8_t * restrict x0 = x[i].qs;
+        const int8_t * restrict y0 = y[i].qs;
+
+        int sumi = 0;
 
-            const float f2 = d1*(v1 & 0xf) + m1;
-            const float f3 = d1*(v1 >> 4)  + m1;
+        for (int j = 0; j < QK8_0; j++) {
+            const int v0 = x0[j];
+            const int v1 = y0[j];
 
-            sumf += f0*f2 + f1*f3;
+            sumi += v0*v1;
         }
+
+        sumf += (x[i].d*y[i].d)*sumi;
     }
-#endif
 
     *s = sumf;
+#endif
 }
 
 // compute GGML_VEC_DOT_UNROLL dot products at once
@@ -2613,6 +3646,14 @@ inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
 #endif
 }
 
+inline static void ggml_vec_sum_ggf(const int n, ggml_float * s, const float * x) {
+    ggml_float sum = 0.0;
+    for (int i = 0; i < n; ++i) {
+        sum += (ggml_float)x[i];
+    }
+    *s = sum;
+}
+
 inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
 #ifndef GGML_USE_ACCELERATE
     float max = -INFINITY;
@@ -2661,24 +3702,67 @@ inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x
 static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
     [GGML_TYPE_F32]  = 1,
     [GGML_TYPE_F16]  = 1,
-    [GGML_TYPE_Q4_0] = QK,
-    [GGML_TYPE_Q4_1] = QK,
+    [GGML_TYPE_Q4_0] = QK4_0,
+    [GGML_TYPE_Q4_1] = QK4_1,
+    [GGML_TYPE_Q4_2] = QK4_2,
+    [GGML_TYPE_Q5_0] = QK5_0,
+    [GGML_TYPE_Q5_1] = QK5_1,
+    [GGML_TYPE_Q8_0] = QK8_0,
+    [GGML_TYPE_Q8_1] = QK8_1,
     [GGML_TYPE_I8]   = 1,
     [GGML_TYPE_I16]  = 1,
     [GGML_TYPE_I32]  = 1,
 };
-static_assert(GGML_TYPE_COUNT == 7, "GGML_BLCK_SIZE is outdated");
+static_assert(GGML_TYPE_COUNT == 13, "GGML_BLCK_SIZE is outdated");
 
 static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
     [GGML_TYPE_F32]  = sizeof(float),
     [GGML_TYPE_F16]  = sizeof(ggml_fp16_t),
     [GGML_TYPE_Q4_0] = sizeof(block_q4_0),
     [GGML_TYPE_Q4_1] = sizeof(block_q4_1),
+    [GGML_TYPE_Q4_2] = sizeof(block_q4_2),
+    [GGML_TYPE_Q5_0] = sizeof(block_q5_0),
+    [GGML_TYPE_Q5_1] = sizeof(block_q5_1),
+    [GGML_TYPE_Q8_0] = sizeof(block_q8_0),
+    [GGML_TYPE_Q8_1] = sizeof(block_q8_1),
     [GGML_TYPE_I8]   = sizeof(int8_t),
     [GGML_TYPE_I16]  = sizeof(int16_t),
     [GGML_TYPE_I32]  = sizeof(int32_t),
 };
-static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_SIZE is outdated");
+static_assert(GGML_TYPE_COUNT == 13, "GGML_TYPE_SIZE is outdated");
+
+
+static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
+    [GGML_TYPE_F32]  = "f32",
+    [GGML_TYPE_F16]  = "f16",
+    [GGML_TYPE_Q4_0] = "q4_0",
+    [GGML_TYPE_Q4_1] = "q4_1",
+    [GGML_TYPE_Q4_2] = "q4_2",
+    [GGML_TYPE_Q5_0] = "q5_0",
+    [GGML_TYPE_Q5_1] = "q5_1",
+    [GGML_TYPE_Q8_0] = "q8_0",
+    [GGML_TYPE_Q8_1] = "q8_1",
+    [GGML_TYPE_I8]   = "i8",
+    [GGML_TYPE_I16]  = "i16",
+    [GGML_TYPE_I32]  = "i32",
+};
+static_assert(GGML_TYPE_COUNT == 13, "GGML_TYPE_NAME is outdated");
+
+static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
+    [GGML_TYPE_F32]  = false,
+    [GGML_TYPE_F16]  = false,
+    [GGML_TYPE_Q4_0] = true,
+    [GGML_TYPE_Q4_1] = true,
+    [GGML_TYPE_Q4_2] = true,
+    [GGML_TYPE_Q5_0] = true,
+    [GGML_TYPE_Q5_1] = true,
+    [GGML_TYPE_Q8_0] = true,
+    [GGML_TYPE_Q8_1] = true,
+    [GGML_TYPE_I8]   = false,
+    [GGML_TYPE_I16]  = false,
+    [GGML_TYPE_I32]  = false,
+};
+static_assert(GGML_TYPE_COUNT == 13, "GGML_IS_QUANTIZED is outdated");
 
 static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
     "NONE",
@@ -2726,7 +3810,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
     "MAP_BINARY",
 };
 
-static_assert(GGML_OP_COUNT == 38, "GGML_OP_COUNT != 38");
+static_assert(GGML_OP_COUNT == 39, "GGML_OP_COUNT != 39");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -2774,7 +3858,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "f(x,y)",
 };
 
-static_assert(GGML_OP_COUNT == 38, "GGML_OP_COUNT != 38");
+static_assert(GGML_OP_COUNT == 39, "GGML_OP_COUNT != 39");
 
 static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
 static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@@ -2904,6 +3988,11 @@ float ggml_type_sizef(enum ggml_type type) {
     return ((float)(GGML_TYPE_SIZE[type]))/GGML_BLCK_SIZE[type];
 }
 
+const char * ggml_type_name(enum ggml_type type) {
+    return GGML_TYPE_NAME[type];
+}
+
+
 size_t ggml_element_size(const struct ggml_tensor * tensor) {
     return GGML_TYPE_SIZE[tensor->type];
 }
@@ -2935,6 +4024,10 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
         (t0->ne[3] == t1->ne[3]);
 }
 
+bool ggml_is_quantized(enum ggml_type type) {
+    return GGML_IS_QUANTIZED[type];
+}
+
 static inline bool ggml_is_transposed(const struct ggml_tensor * tensor) {
     return tensor->nb[0] > tensor->nb[1];
 }
@@ -3045,6 +4138,13 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
             GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
         }
 
+        // initialize cuBLAS
+        #if defined(GGML_USE_CUBLAS)
+        ggml_init_cublas();
+        #elif defined(GGML_USE_CLBLAST)
+        ggml_cl_init();
+        #endif
+
         is_first_call = false;
     }
 
@@ -3346,14 +4446,6 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
     char * const data = tensor->data;
 
     switch (tensor->type) {
-        case GGML_TYPE_Q4_0:
-            {
-                GGML_ASSERT(false);
-            } break;
-        case GGML_TYPE_Q4_1:
-            {
-                GGML_ASSERT(false);
-            } break;
         case GGML_TYPE_I8:
             {
                 assert(tensor->nb[0] == sizeof(int8_t));
@@ -3389,7 +4481,7 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
                     ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
                 }
             } break;
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -3401,19 +4493,11 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
 struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
     const int n     = ggml_nrows(tensor);
     const int nc    = tensor->ne[0];
-    const size_t n1 = tensor->nb[1];
-
-    char * const data = tensor->data;
-
-    switch (tensor->type) {
-        case GGML_TYPE_Q4_0:
-            {
-                GGML_ASSERT(false);
-            } break;
-        case GGML_TYPE_Q4_1:
-            {
-                GGML_ASSERT(false);
-            } break;
+    const size_t n1 = tensor->nb[1];
+
+    char * const data = tensor->data;
+
+    switch (tensor->type) {
         case GGML_TYPE_I8:
             {
                 assert(tensor->nb[0] == sizeof(int8_t));
@@ -3449,7 +4533,7 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
                     ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
                 }
             } break;
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -3460,14 +4544,6 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
 
 int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
     switch (tensor->type) {
-        case GGML_TYPE_Q4_0:
-            {
-                GGML_ASSERT(false);
-            } break;
-        case GGML_TYPE_Q4_1:
-            {
-                GGML_ASSERT(false);
-            } break;
         case GGML_TYPE_I8:
             {
                 GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@@ -3493,7 +4569,7 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
                 GGML_ASSERT(tensor->nb[0] == sizeof(float));
                 return ((float *)(tensor->data))[i];
             } break;
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -3504,14 +4580,6 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
 
 void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
     switch (tensor->type) {
-        case GGML_TYPE_Q4_0:
-            {
-                GGML_ASSERT(false);
-            } break;
-        case GGML_TYPE_Q4_1:
-            {
-                GGML_ASSERT(false);
-            } break;
         case GGML_TYPE_I8:
             {
                 GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@@ -3537,7 +4605,7 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
                 GGML_ASSERT(tensor->nb[0] == sizeof(float));
                 ((float *)(tensor->data))[i] = value;
             } break;
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -3546,14 +4614,6 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
 
 float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
     switch (tensor->type) {
-        case GGML_TYPE_Q4_0:
-            {
-                GGML_ASSERT(false);
-            } break;
-        case GGML_TYPE_Q4_1:
-            {
-                GGML_ASSERT(false);
-            } break;
         case GGML_TYPE_I8:
             {
                 GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@@ -3579,7 +4639,7 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
                 GGML_ASSERT(tensor->nb[0] == sizeof(float));
                 return ((float *)(tensor->data))[i];
             } break;
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -3590,14 +4650,6 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
 
 void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
     switch (tensor->type) {
-        case GGML_TYPE_Q4_0:
-            {
-                GGML_ASSERT(false);
-            } break;
-        case GGML_TYPE_Q4_1:
-            {
-                GGML_ASSERT(false);
-            } break;
         case GGML_TYPE_I8:
             {
                 GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@@ -3623,7 +4675,7 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
                 GGML_ASSERT(tensor->nb[0] == sizeof(float));
                 ((float *)(tensor->data))[i] = value;
             } break;
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -4804,6 +5856,37 @@ struct ggml_tensor * ggml_rope(
     return result;
 }
 
+// ggml_alibi
+
+struct ggml_tensor * ggml_alibi(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   n_past,
+        int                   n_head) {
+    GGML_ASSERT(n_past >= 0);
+    bool is_node = false;
+
+    if (a->grad) {
+        GGML_ASSERT(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    // TODO: when implement backward, fix this:
+    //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+    struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+
+    struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
+    ((int32_t *) b->data)[0] = n_past;
+    ((int32_t *) b->data)[1] = n_head;
+
+    result->op   = GGML_OP_ALIBI;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+
+    return result;
+}
+
 // ggml_conv_1d_1s
 
 struct ggml_tensor * ggml_conv_1d_1s(
@@ -5023,7 +6106,6 @@ static void ggml_compute_forward_dup_f16(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
-    GGML_ASSERT(params->ith == 0);
     GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
@@ -5035,6 +6117,11 @@ static void ggml_compute_forward_dup_f16(
     const int64_t ne02 = src0->ne[2];
     const int64_t ne03 = src0->ne[3];
 
+    const int64_t ne0 = dst->ne[0];
+    const int64_t ne1 = dst->ne[1];
+    const int64_t ne2 = dst->ne[2];
+    const int64_t ne3 = dst->ne[3];
+
     const size_t nb00 = src0->nb[0];
     const size_t nb01 = src0->nb[1];
     const size_t nb02 = src0->nb[2];
@@ -5045,19 +6132,40 @@ static void ggml_compute_forward_dup_f16(
     const size_t nb2 = dst->nb[2];
     const size_t nb3 = dst->nb[3];
 
+    const int ith = params->ith; // thread index
+    const int nth = params->nth; // number of threads
+
     if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
-        memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
+        // parallelize by elements
+        const int ne = ggml_nelements(dst);
+        const int dr = (ne + nth - 1) / nth;
+        const int ie0 = dr * ith;
+        const int ie1 = MIN(ie0 + dr, ne);
+
+        memcpy(
+            ((char *)  dst->data + ie0*nb0),
+            ((char *) src0->data + ie0*nb00),
+            (ie1 - ie0) * GGML_TYPE_SIZE[src0->type]);
+
         return;
     }
 
+    // parallelize by rows
+    const int nr = ne01;
+    // number of rows per thread
+    const int dr = (nr + nth - 1) / nth;
+    // row range for this thread
+    const int ir0 = dr * ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
     if (src0->type == dst->type &&
-        src0->ne[0] == dst->ne[0] &&
-        src0->nb[0] == GGML_TYPE_SIZE[src0->type] && dst->nb[0] == GGML_TYPE_SIZE[dst->type]) {
+        ne00 == ne0 &&
+        nb00 == GGML_TYPE_SIZE[src0->type] && nb0 == GGML_TYPE_SIZE[dst->type]) {
         // copy by rows
         const size_t rs = ne00*nb00;
         for (int64_t i03 = 0; i03 < ne03; i03++) {
             for (int64_t i02 = 0; i02 < ne02; i02++) {
-                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
                     memcpy(
                         ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
                         ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
@@ -5071,21 +6179,21 @@ static void ggml_compute_forward_dup_f16(
     // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
 
     if (ggml_is_contiguous(dst)) {
-        if (src0->nb[0] == sizeof(ggml_fp16_t)) {
+        if (nb00 == sizeof(ggml_fp16_t)) {
             if (dst->type == GGML_TYPE_F16) {
                 size_t id = 0;
-                const size_t rs = ne00*nb00;
+                const size_t rs = ne00 * nb00;
+                char * dst_ptr = (char *) dst->data;
 
                 for (int i03 = 0; i03 < ne03; i03++) {
                     for (int i02 = 0; i02 < ne02; i02++) {
-                        for (int i01 = 0; i01 < ne01; i01++) {
+                        id += rs * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
                             const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
-                            char * dst_ptr = (char *) dst->data + id*rs;
-
-                            memcpy(dst_ptr, src0_ptr, rs);
-
-                            id++;
+                            memcpy(dst_ptr + id, src0_ptr, rs);
+                            id += rs;
                         }
+                        id += rs * (ne01 - ir1);
                     }
                 }
             } else if (dst->type == GGML_TYPE_F32) {
@@ -5094,14 +6202,39 @@ static void ggml_compute_forward_dup_f16(
 
                 for (int i03 = 0; i03 < ne03; i03++) {
                     for (int i02 = 0; i02 < ne02; i02++) {
-                        for (int i01 = 0; i01 < ne01; i01++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
                             for (int i00 = 0; i00 < ne00; i00++) {
-                                const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-
-                                dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
+                                dst_ptr[id] = GGML_FP16_TO_FP32(src0_ptr[i00]);
                                 id++;
                             }
                         }
+                        id += ne00 * (ne01 - ir1);
+                    }
+                }
+            } else if (ggml_is_quantized(dst->type)) {
+                quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
+                float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
+
+                size_t id = 0;
+                size_t rs = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
+                char * dst_ptr = (char *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += rs * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                src0_f32[i00] = GGML_FP16_TO_FP32(src0_ptr[i00]);
+                            }
+
+                            quantize_row_q(src0_f32, dst_ptr + id, ne00);
+                            id += rs;
+                        }
+                        id += rs * (ne01 - ir1);
                     }
                 }
             } else {
@@ -5116,7 +6249,8 @@ static void ggml_compute_forward_dup_f16(
 
                 for (int i03 = 0; i03 < ne03; i03++) {
                     for (int i02 = 0; i02 < ne02; i02++) {
-                        for (int i01 = 0; i01 < ne01; i01++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
                             for (int i00 = 0; i00 < ne00; i00++) {
                                 const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
 
@@ -5124,6 +6258,7 @@ static void ggml_compute_forward_dup_f16(
                                 id++;
                             }
                         }
+                        id += ne00 * (ne01 - ir1);
                     }
                 }
             } else if (dst->type == GGML_TYPE_F16) {
@@ -5132,7 +6267,8 @@ static void ggml_compute_forward_dup_f16(
 
                 for (int i03 = 0; i03 < ne03; i03++) {
                     for (int i02 = 0; i02 < ne02; i02++) {
-                        for (int i01 = 0; i01 < ne01; i01++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
                             for (int i00 = 0; i00 < ne00; i00++) {
                                 const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
 
@@ -5140,6 +6276,7 @@ static void ggml_compute_forward_dup_f16(
                                 id++;
                             }
                         }
+                        id += ne00 * (ne01 - ir1);
                     }
                 }
             } else {
@@ -5158,7 +6295,20 @@ static void ggml_compute_forward_dup_f16(
     if (dst->type == GGML_TYPE_F16) {
         for (int64_t i03 = 0; i03 < ne03; i03++) {
             for (int64_t i02 = 0; i02 < ne02; i02++) {
-                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                i10 += ne00 * ir0;
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
                     for (int64_t i00 = 0; i00 < ne00; i00++) {
                         const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
                               char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
@@ -5179,25 +6329,51 @@ static void ggml_compute_forward_dup_f16(
                         }
                     }
                 }
+                i10 += ne00 * (ne01 - ir1);
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
             }
         }
     } else if (dst->type == GGML_TYPE_F32) {
         for (int64_t i03 = 0; i03 < ne03; i03++) {
             for (int64_t i02 = 0; i02 < ne02; i02++) {
-                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                i10 += ne00 * ir0;
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
                     for (int64_t i00 = 0; i00 < ne00; i00++) {
                         const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
                               char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
 
                         *(float *) dst_ptr = GGML_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
 
-                        if (++i10 == ne00) {
+                        if (++i10 == ne0) {
                             i10 = 0;
-                            if (++i11 == ne01) {
+                            if (++i11 == ne1) {
                                 i11 = 0;
-                                if (++i12 == ne02) {
+                                if (++i12 == ne2) {
                                     i12 = 0;
-                                    if (++i13 == ne03) {
+                                    if (++i13 == ne3) {
                                         i13 = 0;
                                     }
                                 }
@@ -5205,6 +6381,19 @@ static void ggml_compute_forward_dup_f16(
                         }
                     }
                 }
+                i10 += ne00 * (ne01 - ir1);
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
             }
         }
     } else {
@@ -5216,7 +6405,6 @@ static void ggml_compute_forward_dup_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
-    GGML_ASSERT(params->ith == 0);
     GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
@@ -5228,6 +6416,11 @@ static void ggml_compute_forward_dup_f32(
     const int64_t ne02 = src0->ne[2];
     const int64_t ne03 = src0->ne[3];
 
+    const int64_t ne0 = dst->ne[0];
+    const int64_t ne1 = dst->ne[1];
+    const int64_t ne2 = dst->ne[2];
+    const int64_t ne3 = dst->ne[3];
+
     const size_t nb00 = src0->nb[0];
     const size_t nb01 = src0->nb[1];
     const size_t nb02 = src0->nb[2];
@@ -5238,19 +6431,40 @@ static void ggml_compute_forward_dup_f32(
     const size_t nb2 = dst->nb[2];
     const size_t nb3 = dst->nb[3];
 
+    const int ith = params->ith; // thread index
+    const int nth = params->nth; // number of threads
+
     if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
-        memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
+        // parallelize by elements
+        const int ne = ggml_nelements(dst);
+        const int dr = (ne + nth - 1) / nth;
+        const int ie0 = dr * ith;
+        const int ie1 = MIN(ie0 + dr, ne);
+
+        memcpy(
+            ((char *)  dst->data + ie0*nb0),
+            ((char *) src0->data + ie0*nb00),
+            (ie1 - ie0) * GGML_TYPE_SIZE[src0->type]);
+
         return;
     }
 
+    // parallelize by rows
+    const int nr = ne01;
+    // number of rows per thread
+    const int dr = (nr + nth - 1) / nth;
+    // row range for this thread
+    const int ir0 = dr * ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
     if (src0->type == dst->type &&
-        src0->ne[0] == dst->ne[0] &&
-        src0->nb[0] == GGML_TYPE_SIZE[src0->type] && dst->nb[0] == GGML_TYPE_SIZE[dst->type]) {
+        ne00 == ne0 &&
+        nb00 == GGML_TYPE_SIZE[src0->type] && nb0 == GGML_TYPE_SIZE[dst->type]) {
         // copy by rows
         const size_t rs = ne00*nb00;
         for (int64_t i03 = 0; i03 < ne03; i03++) {
             for (int64_t i02 = 0; i02 < ne02; i02++) {
-                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
                     memcpy(
                         ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
                         ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
@@ -5263,21 +6477,21 @@ static void ggml_compute_forward_dup_f32(
 
     if (ggml_is_contiguous(dst)) {
         // TODO: simplify
-        if (src0->nb[0] == sizeof(float)) {
+        if (nb00 == sizeof(float)) {
             if (dst->type == GGML_TYPE_F32) {
                 size_t id = 0;
-                const size_t rs = ne00*nb00;
+                const size_t rs = ne00 * nb00;
+                char * dst_ptr = (char *) dst->data;
 
                 for (int i03 = 0; i03 < ne03; i03++) {
                     for (int i02 = 0; i02 < ne02; i02++) {
-                        for (int i01 = 0; i01 < ne01; i01++) {
+                        id += rs * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
                             const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
-                            char * dst_ptr = (char *) dst->data + id*rs;
-
-                            memcpy(dst_ptr, src0_ptr, rs);
-
-                            id++;
+                            memcpy(dst_ptr + id, src0_ptr, rs);
+                            id += rs;
                         }
+                        id += rs * (ne01 - ir1);
                     }
                 }
             } else if (dst->type == GGML_TYPE_F16) {
@@ -5286,7 +6500,8 @@ static void ggml_compute_forward_dup_f32(
 
                 for (int i03 = 0; i03 < ne03; i03++) {
                     for (int i02 = 0; i02 < ne02; i02++) {
-                        for (int i01 = 0; i01 < ne01; i01++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
                             for (int i00 = 0; i00 < ne00; i00++) {
                                 const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
 
@@ -5294,6 +6509,25 @@ static void ggml_compute_forward_dup_f32(
                                 id++;
                             }
                         }
+                        id += ne00 * (ne01 - ir1);
+                    }
+                }
+            } else if (ggml_is_quantized(dst->type)) {
+                quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
+
+                size_t id = 0;
+                size_t rs = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
+                char * dst_ptr = (char *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += rs * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+                            quantize_row_q(src0_ptr, dst_ptr + id, ne00);
+                            id += rs;
+                        }
+                        id += rs * (ne01 - ir1);
                     }
                 }
             } else {
@@ -5308,7 +6542,8 @@ static void ggml_compute_forward_dup_f32(
 
                 for (int i03 = 0; i03 < ne03; i03++) {
                     for (int i02 = 0; i02 < ne02; i02++) {
-                        for (int i01 = 0; i01 < ne01; i01++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
                             for (int i00 = 0; i00 < ne00; i00++) {
                                 const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
 
@@ -5316,6 +6551,7 @@ static void ggml_compute_forward_dup_f32(
                                 id++;
                             }
                         }
+                        id += ne00 * (ne01 - ir1);
                     }
                 }
             } else if (dst->type == GGML_TYPE_F16) {
@@ -5324,7 +6560,8 @@ static void ggml_compute_forward_dup_f32(
 
                 for (int i03 = 0; i03 < ne03; i03++) {
                     for (int i02 = 0; i02 < ne02; i02++) {
-                        for (int i01 = 0; i01 < ne01; i01++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
                             for (int i00 = 0; i00 < ne00; i00++) {
                                 const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
 
@@ -5332,6 +6569,7 @@ static void ggml_compute_forward_dup_f32(
                                 id++;
                             }
                         }
+                        id += ne00 * (ne01 - ir1);
                     }
                 }
             } else {
@@ -5343,6 +6581,7 @@ static void ggml_compute_forward_dup_f32(
     }
 
     // dst counters
+
     int64_t i10 = 0;
     int64_t i11 = 0;
     int64_t i12 = 0;
@@ -5351,20 +6590,33 @@ static void ggml_compute_forward_dup_f32(
     if (dst->type == GGML_TYPE_F32) {
         for (int64_t i03 = 0; i03 < ne03; i03++) {
             for (int64_t i02 = 0; i02 < ne02; i02++) {
-                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                i10 += ne00 * ir0;
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
                     for (int64_t i00 = 0; i00 < ne00; i00++) {
                         const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
                               char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
 
                         memcpy(dst_ptr, src0_ptr, sizeof(float));
 
-                        if (++i10 == dst->ne[0]) {
+                        if (++i10 == ne0) {
                             i10 = 0;
-                            if (++i11 == dst->ne[1]) {
+                            if (++i11 == ne1) {
                                 i11 = 0;
-                                if (++i12 == dst->ne[2]) {
+                                if (++i12 == ne2) {
                                     i12 = 0;
-                                    if (++i13 == dst->ne[3]) {
+                                    if (++i13 == ne3) {
                                         i13 = 0;
                                     }
                                 }
@@ -5372,25 +6624,51 @@ static void ggml_compute_forward_dup_f32(
                         }
                     }
                 }
+                i10 += ne00 * (ne01 - ir1);
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
             }
         }
     } else if (dst->type == GGML_TYPE_F16) {
         for (int64_t i03 = 0; i03 < ne03; i03++) {
             for (int64_t i02 = 0; i02 < ne02; i02++) {
-                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                i10 += ne00 * ir0;
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
                     for (int64_t i00 = 0; i00 < ne00; i00++) {
                         const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
                               char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
 
                         *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
 
-                        if (++i10 == dst->ne[0]) {
+                        if (++i10 == ne0) {
                             i10 = 0;
-                            if (++i11 == dst->ne[1]) {
+                            if (++i11 == ne1) {
                                 i11 = 0;
-                                if (++i12 == dst->ne[2]) {
+                                if (++i12 == ne2) {
                                     i12 = 0;
-                                    if (++i13 == dst->ne[3]) {
+                                    if (++i13 == ne3) {
                                         i13 = 0;
                                     }
                                 }
@@ -5398,6 +6676,19 @@ static void ggml_compute_forward_dup_f32(
                         }
                     }
                 }
+                i10 += ne00 * (ne01 - ir1);
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
             }
         }
     } else {
@@ -5418,21 +6709,122 @@ static void ggml_compute_forward_dup(
             {
                 ggml_compute_forward_dup_f32(params, src0, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
     }
 }
 
-// ggml_compute_forward_add
-
-static void ggml_compute_forward_add_f32(
+// ggml_compute_forward_add
+
+static void ggml_compute_forward_add_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    const size_t nb00 = src0->nb[0];
+    const size_t nb01 = src0->nb[1];
+
+    const size_t nb10 = src1->nb[0];
+    const size_t nb11 = src1->nb[1];
+
+    const size_t nb0 = dst->nb[0];
+    const size_t nb1 = dst->nb[1];
+
+    GGML_ASSERT( nb0 == sizeof(float));
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    if (nb10 == sizeof(float)) {
+        for (int j = ith; j < n; j += nth) {
+#ifdef GGML_USE_ACCELERATE
+            vDSP_vadd(
+                    (float *) ((char *) src0->data + j*nb01), 1,
+                    (float *) ((char *) src1->data + j*nb11), 1,
+                    (float *) ((char *) dst->data  + j*nb1),  1, nc);
+#else
+            ggml_vec_add_f32(nc,
+                    (float *) ((char *) dst->data  + j*nb1),
+                    (float *) ((char *) src0->data + j*nb01),
+                    (float *) ((char *) src1->data + j*nb11));
+#endif
+        }
+    } else {
+        // src1 is not contiguous
+        for (int j = ith; j < n; j += nth) {
+            float * dst_ptr  = (float *) ((char *) dst->data  + j*nb1);
+            float * src0_ptr = (float *) ((char *) src0->data + j*nb01);
+            for (int i = 0; i < nc; i++) {
+                float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
+
+                dst_ptr[i] = src0_ptr[i] + *src1_ptr;
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_add_f16_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    const size_t nb00 = src0->nb[0];
+    const size_t nb01 = src0->nb[1];
+
+    const size_t nb10 = src1->nb[0];
+    const size_t nb11 = src1->nb[1];
+
+    const size_t nb0 = dst->nb[0];
+    const size_t nb1 = dst->nb[1];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F16);
+
+    GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+
+    if (nb10 == sizeof(float)) {
+        for (int j = ith; j < n; j += nth) {
+            ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + j*nb1);
+            ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
+            for (int i = 0; i < nc; i++) {
+                float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
+                dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + *src1_ptr);
+            }
+        }
+    }
+    else {
+        // src1 is not contiguous
+        GGML_ASSERT(false);
+    }
+}
+
+static void ggml_compute_forward_add_f16_f16(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -5458,35 +6850,135 @@ static void ggml_compute_forward_add_f32(
     const size_t nb0 = dst->nb[0];
     const size_t nb1 = dst->nb[1];
 
-    GGML_ASSERT( nb0 == sizeof(float));
-    GGML_ASSERT(nb00 == sizeof(float));
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F16);
+    GGML_ASSERT(dst->type == GGML_TYPE_F16);
 
-    if (nb10 == sizeof(float)) {
-        for (int j = ith; j < n; j += nth) {
-#ifdef GGML_USE_ACCELERATE
-            vDSP_vadd(
-                    (float *) ((char *) src0->data + j*nb01), 1,
-                    (float *) ((char *) src1->data + j*nb11), 1,
-                    (float *) ((char *) dst->data  + j*nb1),  1, nc);
-#else
-            ggml_vec_add_f32(nc,
-                    (float *) ((char *) dst->data  + j*nb1),
-                    (float *) ((char *) src0->data + j*nb01),
-                    (float *) ((char *) src1->data + j*nb11));
-#endif
-        }
-    } else {
-        // src1 is not contiguous
+    GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+
+    if (nb10 == sizeof(ggml_fp16_t)) {
         for (int j = ith; j < n; j += nth) {
-            float * dst_ptr  = (float *) ((char *) dst->data  + j*nb1);
-            float * src0_ptr = (float *) ((char *) src0->data + j*nb01);
+            ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + j*nb1);
+            ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
             for (int i = 0; i < nc; i++) {
-                float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
-
-                dst_ptr[i] = src0_ptr[i] + *src1_ptr;
+                ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + j*nb11 + i*nb10);
+                dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(*src1_ptr));
             }
         }
     }
+    else {
+        // src1 is not contiguous
+        GGML_ASSERT(false);
+    }
+}
+
+static void ggml_compute_forward_add_q_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    const int64_t ne03 = src0->ne[3];
+
+    //const int64_t ne10 = src1->ne[0];
+    //const int64_t ne11 = src1->ne[1];
+    const int64_t ne12 = src1->ne[2];
+    const int64_t ne13 = src1->ne[3];
+
+    //const int64_t ne0  = dst->ne[0];
+    //const int64_t ne1  = dst->ne[1];
+    const int64_t ne2  = dst->ne[2];
+    const int64_t ne3  = dst->ne[3];
+
+    const int nb00 = src0->nb[0];
+    const int nb01 = src0->nb[1];
+    const int nb02 = src0->nb[2];
+    const int nb03 = src0->nb[3];
+
+    const int nb10 = src1->nb[0];
+    const int nb11 = src1->nb[1];
+    const int nb12 = src1->nb[2];
+    const int nb13 = src1->nb[3];
+
+    const int nb0  = dst->nb[0];
+    const int nb1  = dst->nb[1];
+    const int nb2  = dst->nb[2];
+    const int nb3  = dst->nb[3];
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_ASSERT(ne02 == ne12);
+    GGML_ASSERT(ne03 == ne13);
+    GGML_ASSERT(ne2  == ne12);
+    GGML_ASSERT(ne3  == ne13);
+
+    const enum ggml_type type = src0->type;
+    dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
+    quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q;
+
+    // we don't support permuted src0 or src1
+    GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 <= nb1);
+    GGML_ASSERT(nb1 <= nb2);
+    GGML_ASSERT(nb2 <= nb3);
+
+    GGML_ASSERT(ggml_is_quantized(src0->type));
+    GGML_ASSERT(dst->type == src0->type);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+    // total rows in src0
+    const int nr = ne01*ne02*ne03;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    float * wdata = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 indices
+        const int i03 = ir/(ne02*ne01);
+        const int i02 = (ir - i03*ne02*ne01)/ne01;
+        const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+        // src1 and dst are same shape as src0 => same indices
+        const int i13 = i03;
+        const int i12 = i02;
+        const int i11 = i01;
+
+        const int i3 = i03;
+        const int i2 = i02;
+        const int i1 = i01;
+
+        void  * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
+        float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
+        void  * dst_row  = (void *) ((char *)  dst->data + ( i1*nb1  +  i2*nb2  +  i3*nb0));
+
+        assert(ne00 % 32 == 0);
+
+        // unquantize row from src0 to temp buffer
+        dequantize_row_q(src0_row, wdata, ne00);
+        // add src1
+        ggml_vec_acc_f32(ne00, wdata, src1_row);
+        // quantize row to dst
+        quantize_row_q(wdata, dst_row, ne00);
+    }
 }
 
 static void ggml_compute_forward_add(
@@ -5499,13 +6991,28 @@ static void ggml_compute_forward_add(
             {
                 ggml_compute_forward_add_f32(params, src0, src1, dst);
             } break;
+        case GGML_TYPE_F16:
+            {
+                if (src1->type == GGML_TYPE_F16) {
+                    ggml_compute_forward_add_f16_f16(params, src0, src1, dst);
+                }
+                else if (src1->type == GGML_TYPE_F32) {
+                    ggml_compute_forward_add_f16_f32(params, src0, src1, dst);
+                }
+                else {
+                    GGML_ASSERT(false);
+                }
+            } break;
         case GGML_TYPE_Q4_0:
         case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        case GGML_TYPE_Q4_2:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+            {
+                ggml_compute_forward_add_q_f32(params, src0, src1, dst);
+            } break;
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -5551,13 +7058,7 @@ static void ggml_compute_forward_sub(
             {
                 ggml_compute_forward_sub_f32(params, src0, src1, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -5603,13 +7104,7 @@ static void ggml_compute_forward_mul(
             {
                 ggml_compute_forward_mul_f32(params, src0, src1, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -5655,13 +7150,7 @@ static void ggml_compute_forward_div(
             {
                 ggml_compute_forward_div_f32(params, src0, src1, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -5703,13 +7192,7 @@ static void ggml_compute_forward_sqr(
             {
                 ggml_compute_forward_sqr_f32(params, src0, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -5751,13 +7234,7 @@ static void ggml_compute_forward_sqrt(
             {
                 ggml_compute_forward_sqrt_f32(params, src0, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -5789,15 +7266,20 @@ static void ggml_compute_forward_sum_f32(
     const size_t nb02 = src0->nb[2];
     const size_t nb03 = src0->nb[3];
 
+    ggml_float sum     = 0;
+    ggml_float row_sum = 0;
+
     for (int64_t i03 = 0; i03 < ne03; i03++) {
         for (int64_t i02 = 0; i02 < ne02; i02++) {
             for (int64_t i01 = 0; i01 < ne01; i01++) {
-                ggml_vec_sum_f32(ne00,
-                        (float *) (dst->data),
+                ggml_vec_sum_ggf(ne00,
+                        &row_sum,
                         (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
+                sum += row_sum;
             }
         }
     }
+    ((float *) dst->data)[0] = sum;
 }
 
 static void ggml_compute_forward_sum(
@@ -5809,13 +7291,7 @@ static void ggml_compute_forward_sum(
             {
                 ggml_compute_forward_sum_f32(params, src0, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -5886,13 +7362,7 @@ static void ggml_compute_forward_mean(
             {
                 ggml_compute_forward_mean_f32(params, src0, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -5950,13 +7420,7 @@ static void ggml_compute_forward_repeat(
             {
                 ggml_compute_forward_repeat_f32(params, src0, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -5998,13 +7462,7 @@ static void ggml_compute_forward_abs(
             {
                 ggml_compute_forward_abs_f32(params, src0, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -6046,13 +7504,7 @@ static void ggml_compute_forward_sgn(
             {
                 ggml_compute_forward_sgn_f32(params, src0, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -6094,13 +7546,7 @@ static void ggml_compute_forward_neg(
             {
                 ggml_compute_forward_neg_f32(params, src0, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -6142,13 +7588,7 @@ static void ggml_compute_forward_step(
             {
                 ggml_compute_forward_step_f32(params, src0, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -6190,13 +7630,7 @@ static void ggml_compute_forward_relu(
             {
                 ggml_compute_forward_relu_f32(params, src0, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -6255,13 +7689,7 @@ static void ggml_compute_forward_gelu(
             {
                 ggml_compute_forward_gelu_f32(params, src0, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -6317,18 +7745,12 @@ static void ggml_compute_forward_silu(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_silu_f32(params, src0, dst);
-            } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_silu_f32(params, src0, dst);
+            } break;
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -6408,13 +7830,7 @@ static void ggml_compute_forward_norm(
             {
                 ggml_compute_forward_norm_f32(params, src0, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -6488,13 +7904,7 @@ static void ggml_compute_forward_rms_norm(
             {
                 ggml_compute_forward_rms_norm_f32(params, src0, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -6504,7 +7914,7 @@ static void ggml_compute_forward_rms_norm(
 
 // ggml_compute_forward_mul_mat
 
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
 // helper function to determine if it is better to use BLAS or not
 // for large matrices, BLAS is faster
 static bool ggml_compute_forward_mul_mat_use_blas(
@@ -6520,8 +7930,12 @@ static bool ggml_compute_forward_mul_mat_use_blas(
     const int64_t ne1 = dst->ne[1];
 
     // TODO: find the optimal values for these
-    if (ggml_is_contiguous(src0) &&
-        ggml_is_contiguous(src1) && ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32))) {
+    if (
+#if !defined(GGML_USE_CUBLAS)
+        ggml_is_contiguous(src0) &&
+        ggml_is_contiguous(src1) &&
+#endif
+        ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32))) {
 
         /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/
         return true;
@@ -6529,6 +7943,7 @@ static bool ggml_compute_forward_mul_mat_use_blas(
 
     return false;
 }
+
 #endif
 
 static void ggml_compute_forward_mul_mat_f32(
@@ -6544,7 +7959,7 @@ static void ggml_compute_forward_mul_mat_f32(
     const int64_t ne02 = src0->ne[2];
     const int64_t ne03 = src0->ne[3];
 
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
     const int64_t ne10 = src1->ne[0];
 #endif
     const int64_t ne11 = src1->ne[1];
@@ -6601,7 +8016,7 @@ static void ggml_compute_forward_mul_mat_f32(
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
 
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
         if (params->ith != 0) {
             return;
@@ -6615,22 +8030,65 @@ static void ggml_compute_forward_mul_mat_f32(
             return;
         }
 
+#if defined(GGML_USE_CUBLAS)
+        const float alpha = 1.0f;
+        const float beta = 0.0f;
+        const int x_ne = ne01 * ne00;
+        const int y_ne = ne11 * ne10;
+        const int d_ne = ne11 * ne01;
+
+        size_t x_size, y_size, d_size;
+        float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
+        float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
+        float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
+#endif
+
         for (int64_t i03 = 0; i03 < ne03; i03++) {
             for (int64_t i02 = 0; i02 < ne02; i02++) {
+#if !defined(GGML_USE_CUBLAS)
                 const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
                 const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
-
+#endif
                 float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
 
+#if defined(GGML_USE_CUBLAS)
+                // copy data to device
+                CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_X, src0, i03, i02, g_cudaStream));
+                CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Y, src1, i03, i02, g_cudaStream));
+
+                // compute
+                CUBLAS_CHECK(
+                    cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
+                            ne01, ne11, ne10,
+                            &alpha, d_X, ne00,
+                                    d_Y, ne10,
+                            &beta,  d_D, ne01));
+
+                // copy data to host
+                CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
+#elif defined(GGML_USE_CLBLAST)
                 // zT = y * xT
+                ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
+                        ne11, ne01, ne10,
+                        1.0f,    y, ne10,
+                                 x, ne10,
+                        0.0f,    d, ne01,
+                        GGML_TYPE_F32);
+#else
                 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
                         ne11, ne01, ne10,
                         1.0f,    y, ne10,
                                  x, ne00,
                         0.0f,    d, ne01);
+#endif
             }
         }
-
+#if defined(GGML_USE_CUBLAS)
+        CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
+        ggml_cuda_pool_free(d_X, x_size);
+        ggml_cuda_pool_free(d_Y, y_size);
+        ggml_cuda_pool_free(d_D, d_size);
+#endif
         //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
 
         return;
@@ -6760,7 +8218,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
 
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
         GGML_ASSERT(nb10 == sizeof(float));
 
@@ -6776,10 +8234,37 @@ static void ggml_compute_forward_mul_mat_f16_f32(
             return;
         }
 
-        float * const wdata = params->wdata;
+#if defined(GGML_USE_CUBLAS)
+        const float alpha = 1.0f;
+        const float beta = 0.0f;
+        const int x_ne = ne01 * ne00;
+        const int y_ne = ne11 * ne10;
+        const int d_ne = ne11 * ne01;
 
+        size_t x_size, y_size, d_size;
+        ggml_fp16_t * d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
+        ggml_fp16_t * d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
+        float       * d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
+#else
+        float * const wdata = params->wdata;
+#endif
         for (int64_t i03 = 0; i03 < ne03; i03++) {
             for (int64_t i02 = 0; i02 < ne02; i02++) {
+#if defined(GGML_USE_CUBLAS)
+                // copy src0 while converting src1
+                CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_X, src0, i03, i02, g_cudaStream));
+
+                // with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16
+                ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + (ne11 * ne10) * (i03 * ne02 + i02);
+                {
+                    size_t id = 0;
+                    for (int64_t i01 = 0; i01 < ne11; ++i01) {
+                        for (int64_t i00 = 0; i00 < ne10; ++i00) {
+                            wdata[id++] = GGML_FP32_TO_FP16(*(float *) ((char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11 + i00*nb10));
+                        }
+                    }
+                }
+#else
                 {
                     size_t id = 0;
                     for (int64_t i01 = 0; i01 < ne01; ++i01) {
@@ -6788,7 +8273,41 @@ static void ggml_compute_forward_mul_mat_f16_f32(
                         }
                     }
                 }
+#endif
 
+#if defined(GGML_USE_CUBLAS)
+                const ggml_fp16_t * y = (ggml_fp16_t *) wdata;
+                float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
+
+                // copy data to device
+                CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
+
+                // compute
+                CUBLAS_CHECK(
+                    cublasGemmEx(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
+                            ne01, ne11, ne10,
+                            &alpha, d_X, CUDA_R_16F, ne00,
+                                    d_Y, CUDA_R_16F, ne10,
+                            &beta,  d_D, CUDA_R_32F, ne01,
+                            CUBLAS_COMPUTE_32F,
+                            CUBLAS_GEMM_DEFAULT));
+
+                // copy data to host
+                CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
+#elif defined(GGML_USE_CLBLAST)
+                const float * x = wdata;
+                const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
+
+                float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
+
+                // zT = y * xT
+                ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
+                        ne11, ne01, ne10,
+                        1.0f,    y, ne10,
+                                 x, ne10,
+                        0.0f,    d, ne01,
+                        GGML_TYPE_F32);
+#else
                 const float * x = wdata;
                 const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
 
@@ -6800,9 +8319,16 @@ static void ggml_compute_forward_mul_mat_f16_f32(
                         1.0f,    y, ne10,
                                  x, ne00,
                         0.0f,    d, ne01);
+#endif
             }
         }
 
+#if defined(GGML_USE_CUBLAS)
+        CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
+        ggml_cuda_pool_free(d_X, x_size);
+        ggml_cuda_pool_free(d_Y, y_size);
+        ggml_cuda_pool_free(d_D, d_size);
+#endif
         /*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
 
         return;
@@ -6886,27 +8412,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
     //}
 }
 
-static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
-    [GGML_TYPE_Q4_0] = {
-        .dequantize_row_q         = dequantize_row_q4_0,
-        .quantize_row_q           = quantize_row_q4_0,
-        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
-        .vec_dot_q                = ggml_vec_dot_q4_0,
-    },
-    [GGML_TYPE_Q4_1] = {
-        .dequantize_row_q         = dequantize_row_q4_1,
-        .quantize_row_q           = quantize_row_q4_1,
-        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
-        .vec_dot_q                = ggml_vec_dot_q4_1,
-    },
-};
-
-// For internal test use
-quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
-    GGML_ASSERT(i < GGML_TYPE_COUNT);
-    return quantize_fns[i];
-}
-
 static void ggml_compute_forward_mul_mat_q_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
@@ -6954,8 +8459,9 @@ static void ggml_compute_forward_mul_mat_q_f32(
     GGML_ASSERT(ne3  == ne13);
 
     const enum ggml_type type = src0->type;
-    quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q;
-    vec_dot_q_t      const vec_dot_q      = quantize_fns[type].vec_dot_q;
+    quantize_row_q_t const quantize_row_q_dot = quantize_fns[type].quantize_row_q_dot;
+    vec_dot_q_t      const vec_dot_q          = quantize_fns[type].vec_dot_q;
+    enum ggml_type   const vec_dot_type       = quantize_fns[type].vec_dot_type;
 
     // we don't support permuted src0 or src1
     GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
@@ -6975,7 +8481,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
 
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
         if (params->ith != 0) {
             return;
@@ -6989,11 +8495,42 @@ static void ggml_compute_forward_mul_mat_q_f32(
             return;
         }
 
+#if defined(GGML_USE_CUBLAS)
+        const float alpha = 1.0f;
+        const float beta = 0.0f;
+        const int x_ne = ne01 * ne00;
+        const int y_ne = ne11 * ne10;
+        const int d_ne = ne11 * ne01;
+
+        size_t x_size, y_size, d_size, q_size;
+        float * d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
+        float * d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
+        float * d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
+        void  * d_Q = ggml_cuda_pool_malloc(GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], &q_size);
+
+        const dequantize_row_q_cuda_t dequantize_row_q_cuda = ggml_get_dequantize_row_q_cuda(type);
+        GGML_ASSERT(dequantize_row_q_cuda != NULL);
+#else
         float * const wdata = params->wdata;
         dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
+#endif
 
         for (int64_t i03 = 0; i03 < ne03; i03++) {
             for (int64_t i02 = 0; i02 < ne02; i02++) {
+                const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
+
+                float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
+
+#if defined(GGML_USE_CUBLAS)
+                // copy and dequantize on device
+                CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, src0, i03, i02, g_cudaStream2));
+
+                dequantize_row_q_cuda(d_Q, d_X, x_ne, g_cudaStream2);
+                CUDA_CHECK(cudaGetLastError());
+                CUDA_CHECK(cudaEventRecord(g_cudaEvent, g_cudaStream2));
+#elif defined(GGML_USE_CLBLAST)
+                const void* x = (char *) src0->data + i03*nb03 + i02*nb02;
+#else
                 {
                     size_t id = 0;
                     for (int64_t i01 = 0; i01 < ne01; ++i01) {
@@ -7001,21 +8538,51 @@ static void ggml_compute_forward_mul_mat_q_f32(
                         id += ne00;
                     }
                 }
-
                 const float * x = wdata;
-                const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
+#endif
 
-                float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
+#if defined(GGML_USE_CUBLAS)
+                // copy data to device
+                CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Y, src1, i03, i02, g_cudaStream));
 
+                // wait for dequantization
+                CUDA_CHECK(cudaStreamWaitEvent(g_cudaStream, g_cudaEvent, 0));
+
+                // compute
+                CUBLAS_CHECK(
+                    cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
+                            ne01, ne11, ne10,
+                            &alpha, d_X, ne00,
+                                    d_Y, ne10,
+                            &beta,  d_D, ne01));
+
+                // copy data to host
+                CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
+#elif defined(GGML_USE_CLBLAST)
                 // zT = y * xT
+                ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
+                        ne11, ne01, ne10,
+                        1.0f,    y, ne10,
+                                 x, ne10,
+                        0.0f,    d, ne01,
+                        type);
+#else
                 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
                         ne11, ne01, ne10,
                         1.0f,    y, ne10,
                                  x, ne00,
                         0.0f,    d, ne01);
+#endif
             }
         }
 
+#if defined(GGML_USE_CUBLAS)
+        CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
+        ggml_cuda_pool_free(d_X, x_size);
+        ggml_cuda_pool_free(d_Y, y_size);
+        ggml_cuda_pool_free(d_D, d_size);
+        ggml_cuda_pool_free(d_Q, q_size);
+#endif
         //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
 
         return;
@@ -7024,12 +8591,12 @@ static void ggml_compute_forward_mul_mat_q_f32(
 
     if (params->type == GGML_TASK_INIT) {
         char * wdata = params->wdata;
-        const size_t row_size = ne10*GGML_TYPE_SIZE[type]/GGML_BLCK_SIZE[type];
+        const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
 
         for (int64_t i13 = 0; i13 < ne13; ++i13) {
             for (int64_t i12 = 0; i12 < ne12; ++i12) {
                 for (int64_t i11 = 0; i11 < ne11; ++i11) {
-                    quantize_row_q((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
+                    quantize_row_q_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
                     wdata += row_size;
                 }
             }
@@ -7055,7 +8622,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
     const int ir1 = MIN(ir0 + dr, nr);
 
     void * wdata = params->wdata;
-    const size_t row_size = ne00*GGML_TYPE_SIZE[type]/GGML_BLCK_SIZE[type];
+    const size_t row_size = ne00*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
 
     for (int ir = ir0; ir < ir1; ++ir) {
         // src0 indices
@@ -7103,6 +8670,11 @@ static void ggml_compute_forward_mul_mat(
     switch (src0->type) {
         case GGML_TYPE_Q4_0:
         case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q4_2:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+        case GGML_TYPE_Q8_1:
             {
                 ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
             } break;
@@ -7114,42 +8686,11 @@ static void ggml_compute_forward_mul_mat(
             {
                 ggml_compute_forward_mul_mat_f32(params, src0, src1, dst);
             } break;
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
     }
-
-#if 0
-    if (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_Q4_1) {
-        static int first = 8;
-        printf("src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]);
-        printf("src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]);
-        printf("dst:  ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
-        if (first) {
-            --first;
-        } else {
-            for (int k = 0; k < dst->ne[1]; ++k) {
-                for (int j = 0; j < dst->ne[0]/16; ++j) {
-                    for (int i = 0; i < 16; ++i) {
-                        printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
-                    }
-                    printf("\n");
-                }
-                printf("\n");
-            }
-            printf("\n");
-            exit(0);
-        }
-    } else {
-        printf("aaaa src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]);
-        printf("aaaa src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]);
-        printf("aaaa dst:  ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
-    }
-#endif
 }
 
 // ggml_compute_forward_scale
@@ -7199,13 +8740,7 @@ static void ggml_compute_forward_scale(
             {
                 ggml_compute_forward_scale_f32(params, src0, src1, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -7366,6 +8901,11 @@ static void ggml_compute_forward_get_rows(
     switch (src0->type) {
         case GGML_TYPE_Q4_0:
         case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q4_2:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+        case GGML_TYPE_Q8_1:
             {
                 ggml_compute_forward_get_rows_q(params, src0, src1, dst);
             } break;
@@ -7377,10 +8917,7 @@ static void ggml_compute_forward_get_rows(
             {
                 ggml_compute_forward_get_rows_f32(params, src0, src1, dst);
             } break;
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -7453,13 +8990,7 @@ static void ggml_compute_forward_diag_mask_inf(
             {
                 ggml_compute_forward_diag_mask_inf_f32(params, src0, src1, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -7472,87 +9003,237 @@ static void ggml_compute_forward_soft_max_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
-    GGML_ASSERT(ggml_is_contiguous(src0));
-    GGML_ASSERT(ggml_is_contiguous(dst));
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    // TODO: handle transposed/permuted matrices
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        float *p = (float *)((char *) dst->data + i1*dst->nb[1]);
+
+#ifndef NDEBUG
+        for (int i = 0; i < nc; ++i) {
+            //printf("p[%d] = %f\n", i, p[i]);
+            assert(!isnan(p[i]));
+        }
+#endif
+
+        float max = -INFINITY;
+        ggml_vec_max_f32(nc, &max, p);
+
+        ggml_float sum = 0.0;
+
+        uint16_t scvt;
+        for (int i = 0; i < nc; i++) {
+            //printf("p[%3d] = %8.4f\n", i, p[i]);
+            if (p[i] == -INFINITY) {
+                p[i] = 0.0f;
+            } else {
+                //const float val = (p[i] == -INFINITY) ? 0.0 : exp(p[i] - max);
+                ggml_fp16_t s = GGML_FP32_TO_FP16(p[i] - max);
+                memcpy(&scvt, &s, sizeof(scvt));
+                const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
+                sum += (ggml_float)val;
+                p[i] = val;
+            }
+        }
+
+        assert(sum > 0.0);
+
+        sum = 1.0/sum;
+        ggml_vec_scale_f32(nc, p, sum);
+
+#ifndef NDEBUG
+        for (int i = 0; i < nc; ++i) {
+            assert(!isnan(p[i]));
+            assert(!isinf(p[i]));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_soft_max(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_soft_max_f32(params, src0, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_alibi
+
+static void ggml_compute_forward_alibi_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(src1->type == GGML_TYPE_I32);
+    assert(ggml_nelements(src1) == 2);
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int n_past = ((int32_t *) src1->data)[0];
+    const int n_head = ((int32_t *) src1->data)[1];
+
+    const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
+    const int ne1 = src0->ne[1]; // seq_len_without_past
+    //const int ne2 = src0->ne[2]; // n_head -> this is k
+    //const int ne3 = src0->ne[3]; // 1 -> bsz
+
+    const int n  = ggml_nrows(src0);
+    const int ne2_ne3 = n/ne1; // ne2*ne3
+
+    const int nb0 = src0->nb[0];
+    const int nb1 = src0->nb[1];
+    const int nb2 = src0->nb[2];
+    //const int nb3 = src0->nb[3];
+
+    assert(nb0 == sizeof(float));
+    assert(ne1+n_past == ne0);
+
+    // add alibi to src0 (KQ_scaled)
+    const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
+
+    const float m0 = powf(2.0f, -8.0f / n_heads_log2_floor);
+    const float m1 = powf(2.0f, -4.0f / n_heads_log2_floor);
+
+    for (int i = 0; i < ne0; i++) {
+        for (int j = 0; j < ne1; j++) {
+            for (int k = 0; k < ne2_ne3; k++) {
+                float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
+                float *      pdst = (float *)((char *)  dst->data + i*nb0 + j*nb1 + k*nb2);
+
+                // TODO: k*nb2 or k*nb3
+
+                float m_k;
+
+                if (k < n_heads_log2_floor) {
+                    m_k = powf(m0, k + 1);
+                } else {
+                    m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
+                }
+
+                pdst[0] = (j+1) * m_k + src[0];
+            }
+        }
+    }
+}
+
+
+static void ggml_compute_forward_alibi_f16(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(src1->type == GGML_TYPE_I32);
+    assert(ggml_nelements(src1) == 2);
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
     }
 
-    // TODO: handle transposed/permuted matrices
-
-    const int ith = params->ith;
-    const int nth = params->nth;
+    const int n_past = ((int32_t *) src1->data)[0];
+    const int n_head = ((int32_t *) src1->data)[1];
 
-    const int nc = src0->ne[0];
-    const int nr = ggml_nrows(src0);
+    const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
+    const int ne1 = src0->ne[1]; // seq_len_without_past
+    //const int ne2 = src0->ne[2]; // n_head -> this is k
+    //const int ne3 = src0->ne[3]; // 1 -> bsz
 
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
+    const int n  = ggml_nrows(src0);
+    const int ne2_ne3 = n/ne1; // ne2*ne3
 
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
+    const int nb0 = src0->nb[0];
+    const int nb1 = src0->nb[1];
+    const int nb2 = src0->nb[2];
+    //const int nb3 = src0->nb[3];
 
-    for (int i1 = ir0; i1 < ir1; i1++) {
-        float *p = (float *)((char *) dst->data + i1*dst->nb[1]);
+    assert(nb0 == sizeof(ggml_fp16_t));
+    assert(ne1+n_past == ne0);
 
-#ifndef NDEBUG
-        for (int i = 0; i < nc; ++i) {
-            //printf("p[%d] = %f\n", i, p[i]);
-            assert(!isnan(p[i]));
-        }
-#endif
+    // add alibi to src0 (KQ_scaled)
+    const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
 
-        float max = -INFINITY;
-        ggml_vec_max_f32(nc, &max, p);
+    const float m0 = powf(2.0f, -8.0f / n_heads_log2_floor);
+    const float m1 = powf(2.0f, -4.0f / n_heads_log2_floor);
 
-        ggml_float sum = 0.0;
+    for (int i = 0; i < ne0; i++) {
+        for (int j = 0; j < ne1; j++) {
+            for (int k = 0; k < ne2_ne3; k++) {
+                ggml_fp16_t * const src  = (ggml_fp16_t *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
+                      float *      pdst  =       (float *)((char *)  dst->data + i*nb0 + j*nb1 + k*nb2);
 
-        uint16_t scvt;
-        for (int i = 0; i < nc; i++) {
-            if (p[i] == -INFINITY) {
-                p[i] = 0.0f;
-            } else {
-                //const float val = (p[i] == -INFINITY) ? 0.0 : exp(p[i] - max);
-                ggml_fp16_t s = GGML_FP32_TO_FP16(p[i] - max);
-                memcpy(&scvt, &s, sizeof(scvt));
-                const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
-                sum += (ggml_float)val;
-                p[i] = val;
-            }
-        }
+                // TODO: k*nb2 or k*nb3
 
-        assert(sum > 0.0);
+                float m_k;
 
-        sum = 1.0/sum;
-        ggml_vec_scale_f32(nc, p, sum);
+                if (k < n_heads_log2_floor) {
+                    m_k = powf(m0, k + 1);
+                } else {
+                    m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
+                }
 
-#ifndef NDEBUG
-        for (int i = 0; i < nc; ++i) {
-            assert(!isnan(p[i]));
-            assert(!isinf(p[i]));
+                // we return F32
+                pdst[0] = (j+1) * m_k + GGML_FP16_TO_FP32(src[0]);
+            }
         }
-#endif
     }
 }
 
-static void ggml_compute_forward_soft_max(
+static void ggml_compute_forward_alibi(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
     switch (src0->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_alibi_f16(params, src0, src1, dst);
+            } break;
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_soft_max_f32(params, src0, dst);
+                ggml_compute_forward_alibi_f32(params, src0, src1, dst);
             } break;
         case GGML_TYPE_Q4_0:
         case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q4_2:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+        case GGML_TYPE_Q8_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
         case GGML_TYPE_COUNT:
             {
                 GGML_ASSERT(false);
@@ -7610,9 +9291,11 @@ static void ggml_compute_forward_rope_f32(
 
     const float theta_scale = powf(10000.0, -2.0f/n_dims);
 
+    const bool is_neox = mode & 2;
+
     for (int64_t i3 = 0; i3 < ne3; i3++) {
-        for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
-            const int p = (mode == 0 ? n_past + i2 : i2);
+        for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
+            const int p = ((mode & 1) == 0 ? n_past + i2 : i2);
             for (int64_t i1 = 0; i1 < ne1; i1++) {
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
@@ -7625,14 +9308,25 @@ static void ggml_compute_forward_rope_f32(
 
                     theta *= theta_scale;
 
-                    const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-                          float * dst_data  = (float *)((char *)  dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+                    if (!is_neox) {
+                        const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+                              float * dst_data  = (float *)((char *)  dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+                        const float x0 = src[0];
+                        const float x1 = src[1];
 
-                    const float x0 = src[0];
-                    const float x1 = src[1];
+                        dst_data[0] = x0*cos_theta - x1*sin_theta;
+                        dst_data[1] = x0*sin_theta + x1*cos_theta;
+                    } else {
+                        const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
+                              float * dst_data  = (float *)((char *)  dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
 
-                    dst_data[0] = x0*cos_theta - x1*sin_theta;
-                    dst_data[1] = x0*sin_theta + x1*cos_theta;
+                        const float x0 = src[0];
+                        const float x1 = src[n_dims/2];
+
+                        dst_data[0]        = x0*cos_theta - x1*sin_theta;
+                        dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
+                    }
                 }
             }
         }
@@ -7687,9 +9381,11 @@ static void ggml_compute_forward_rope_f16(
 
     const float theta_scale = powf(10000.0, -2.0f/n_dims);
 
+    const bool is_neox = mode & 2;
+
     for (int64_t i3 = 0; i3 < ne3; i3++) {
-        for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
-            const int p = (mode == 0 ? n_past + i2 : i2);
+        for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
+            const int p = ((mode & 1) == 0 ? n_past + i2 : i2);
             for (int64_t i1 = 0; i1 < ne1; i1++) {
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
@@ -7702,14 +9398,25 @@ static void ggml_compute_forward_rope_f16(
 
                     theta *= theta_scale;
 
-                    const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-                          ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+                    if (!is_neox) {
+                        const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+                              ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+                        const float x0 = GGML_FP16_TO_FP32(src[0]);
+                        const float x1 = GGML_FP16_TO_FP32(src[1]);
 
-                    const float x0 = ggml_fp16_to_fp32(src[0]);
-                    const float x1 = ggml_fp16_to_fp32(src[1]);
+                        dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
+                        dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
+                    } else {
+                        const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
+                              ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
 
-                    dst_data[0] = ggml_fp32_to_fp16(x0*cos_theta - x1*sin_theta);
-                    dst_data[1] = ggml_fp32_to_fp16(x0*sin_theta + x1*cos_theta);
+                        const float x0 = GGML_FP16_TO_FP32(src[0]);
+                        const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
+
+                        dst_data[0]        = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
+                        dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
+                    }
                 }
             }
         }
@@ -7730,12 +9437,7 @@ static void ggml_compute_forward_rope(
             {
                 ggml_compute_forward_rope_f32(params, src0, src1, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -7998,12 +9700,7 @@ static void ggml_compute_forward_conv_1d_1s(
             {
                 ggml_compute_forward_conv_1d_1s_f32(params, src0, src1, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -8266,12 +9963,7 @@ static void ggml_compute_forward_conv_1d_2s(
             {
                 ggml_compute_forward_conv_1d_2s_f32(params, src0, src1, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -8751,12 +10443,7 @@ static void ggml_compute_forward_flash_attn(
             {
                 ggml_compute_forward_flash_attn_f32(params, q, k, v, masked, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -8962,12 +10649,7 @@ static void ggml_compute_forward_flash_ff(
             {
                 GGML_ASSERT(false); // TODO
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -9011,13 +10693,7 @@ static void ggml_compute_forward_map_unary(
             {
                 ggml_compute_forward_map_unary_f32(params, src0, dst, fun);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -9066,13 +10742,7 @@ static void ggml_compute_forward_map_binary(
             {
                 ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -9209,6 +10879,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor);
             } break;
+        case GGML_OP_ALIBI:
+            {
+                ggml_compute_forward_alibi(params, tensor->src0, tensor->src1, tensor);
+            } break;
         case GGML_OP_CONV_1D_1S:
             {
                 ggml_compute_forward_conv_1d_1s(params, tensor->src0, tensor->src1, tensor);
@@ -9411,6 +11085,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
+        case GGML_OP_ALIBI:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
         case GGML_OP_SILU:
             {
                 GGML_ASSERT(false); // TODO: not implemented
@@ -9822,13 +11500,29 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
             struct ggml_tensor * node = cgraph->nodes[i];
 
             switch (node->op) {
+                case GGML_OP_CPY:
                 case GGML_OP_DUP:
                     {
-                        node->n_tasks = 1;
+                        node->n_tasks = n_threads;
+
+                        size_t cur = 0;
+                        if (ggml_is_quantized(node->type)) {
+                            cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0] * n_threads;
+                        }
+
+                        work_size = MAX(work_size, cur);
                     } break;
                 case GGML_OP_ADD:
                     {
                         node->n_tasks = n_threads;
+
+                        size_t cur = 0;
+
+                        if (ggml_is_quantized(node->src0->type)) {
+                            cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_threads;
+                        }
+
+                        work_size = MAX(work_size, cur);
                     } break;
                 case GGML_OP_SUB:
                 case GGML_OP_MUL:
@@ -9873,14 +11567,17 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                         size_t cur = 0;
 
                         if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) {
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
                             if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
                                 node->n_tasks = 1; // TODO: this actually is doing nothing
                                                    //       the threads are still spinning
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+                                // here we need memory just for single 2D matrix from src0
                                 cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
-                                //printf("src0: ne0 = %d, ne1 = %d, ne = %d\n", node->src0->ne[0], node->src0->ne[1], node->src0->ne[0]*node->src0->ne[1]);
-                                //printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]);
-                                //printf("cur = %zu\n", cur);
+#else
+                                // with GPU, we need memory for the full 3D / 4D data
+                                cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*MAX(ggml_nelements(node->src1), ggml_nelements(node->src0));
+#endif
                             } else {
                                 cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1);
                             }
@@ -9889,15 +11586,21 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
 #endif
                         } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
                             cur = 0;
-                        } else if (quantize_fns[node->src0->type].vec_dot_q && node->src1->type == GGML_TYPE_F32) {
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
+                            if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
+                                node->n_tasks = 1;
+                            }
+#endif
+                        } else if (ggml_is_quantized(node->src0->type) && node->src1->type == GGML_TYPE_F32) {
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
                             if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
                                 node->n_tasks = 1;
                                 cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
                             } else
 #endif
                             {
-                                cur = GGML_TYPE_SIZE[node->src0->type]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[node->src0->type];
+                                const enum ggml_type type_q = quantize_fns[node->src0->type].vec_dot_type;
+                                cur = GGML_TYPE_SIZE[type_q]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[type_q];
                             }
                         } else {
                             GGML_ASSERT(false);
@@ -9909,7 +11612,6 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                     {
                         node->n_tasks = n_threads;
                     } break;
-                case GGML_OP_CPY:
                 case GGML_OP_CONT:
                 case GGML_OP_RESHAPE:
                 case GGML_OP_VIEW:
@@ -9928,6 +11630,10 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                     {
                         node->n_tasks = n_threads;
                     } break;
+                case GGML_OP_ALIBI:
+                    {
+                        node->n_tasks = 1; //TODO
+                    } break;
                 case GGML_OP_CONV_1D_1S:
                 case GGML_OP_CONV_1D_2S:
                     {
@@ -10226,9 +11932,9 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
     for (int i = 0; i < cgraph->n_nodes; i++) {
         struct ggml_tensor * node = cgraph->nodes[i];
 
-        perf_total_per_op_us[node->op] += node->perf_time_us;
+        perf_total_per_op_us[node->op] += MAX(1, node->perf_time_us);
 
-        GGML_PRINT(" - %3d: [ %" PRId64 ", %" PRId64 ", %" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
+        GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
                 i,
                 node->ne[0], node->ne[1], node->ne[2],
                 GGML_OP_LABEL[node->op], node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs,
@@ -10242,13 +11948,17 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
     for (int i = 0; i < cgraph->n_leafs; i++) {
         struct ggml_tensor * node = cgraph->leafs[i];
 
-        GGML_PRINT(" - %3d: [ %" PRId64 ", %" PRId64 "] %8s\n",
+        GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s\n",
                 i,
                 node->ne[0], node->ne[1],
                 GGML_OP_LABEL[node->op]);
     }
 
     for (int i = 0; i < GGML_OP_COUNT; i++) {
+        if (perf_total_per_op_us[i] == 0) {
+            continue;
+        }
+
         GGML_PRINT("perf_total_per_op_us[%16s] = %7.3f ms\n", GGML_OP_LABEL[i], (double) perf_total_per_op_us[i] / 1000.0);
     }
 
@@ -11072,17 +12782,17 @@ enum ggml_opt_result ggml_opt(
 ////////////////////////////////////////////////////////////////////////////////
 
 size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist) {
-    assert(k % QK == 0);
-    const int nb = k / QK;
+    assert(k % QK4_0 == 0);
+    const int nb = k / QK4_0;
 
     for (int j = 0; j < n; j += k) {
-        block_q4_0 * restrict y = (block_q4_0 *)dst + j/QK;
+        block_q4_0 * restrict y = (block_q4_0 *)dst + j/QK4_0;
 
         quantize_row_q4_0_reference(src + j, y, k);
 
         for (int i = 0; i < nb; i++) {
-            for (int l = 0; l < QK; l += 2) {
-                const uint8_t vi0 = y[i].qs[l/2] & 0xF;
+            for (int l = 0; l < QK4_0; l += 2) {
+                const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
                 const uint8_t vi1 = y[i].qs[l/2] >> 4;
 
                 hist[vi0]++;
@@ -11091,21 +12801,44 @@ size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t *
         }
     }
 
-    return (n/QK*sizeof(block_q4_0));
+    return (n/QK4_0*sizeof(block_q4_0));
 }
 
 size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist) {
-    assert(k % QK == 0);
-    const int nb = k / QK;
+    assert(k % QK4_1 == 0);
+    const int nb = k / QK4_1;
 
     for (int j = 0; j < n; j += k) {
-        block_q4_1 * restrict y = (block_q4_1 *)dst + j/QK;
+        block_q4_1 * restrict y = (block_q4_1 *)dst + j/QK4_1;
 
         quantize_row_q4_1_reference(src + j, y, k);
 
         for (int i = 0; i < nb; i++) {
-            for (int l = 0; l < QK; l += 2) {
-                const uint8_t vi0 = y[i].qs[l/2] & 0xF;
+            for (int l = 0; l < QK4_1; l += 2) {
+                const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
+                const uint8_t vi1 = y[i].qs[l/2] >> 4;
+
+                hist[vi0]++;
+                hist[vi1]++;
+            }
+        }
+    }
+
+    return (n/QK4_1*sizeof(block_q4_1));
+}
+
+size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t * hist) {
+    assert(k % QK4_2 == 0);
+    const int nb = k / QK4_2;
+
+    for (int j = 0; j < n; j += k) {
+        block_q4_2 * restrict y = (block_q4_2 *)dst + j/QK4_2;
+
+        quantize_row_q4_2_reference(src + j, y, k);
+
+        for (int i = 0; i < nb; i++) {
+            for (int l = 0; l < QK4_2; l += 2) {
+                const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
                 const uint8_t vi1 = y[i].qs[l/2] >> 4;
 
                 hist[vi0]++;
@@ -11114,7 +12847,133 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t *
         }
     }
 
-    return (n/QK*sizeof(block_q4_1));
+    return (n/QK4_2*sizeof(block_q4_2));
+}
+
+size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist) {
+    assert(k % QK5_0 == 0);
+    const int nb = k / QK5_0;
+
+    for (int j = 0; j < n; j += k) {
+        block_q5_0 * restrict y = (block_q5_0 *)dst + j/QK5_0;
+
+        quantize_row_q5_0_reference(src + j, y, k);
+
+        for (int i = 0; i < nb; i++) {
+            uint32_t qh;
+            memcpy(&qh, &y[i].qh, sizeof(qh));
+
+            for (int l = 0; l < QK5_0; l += 2) {
+                const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
+                const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
+
+                // cast to 16 bins
+                const uint8_t vi0 = ((y[i].qs[l/2] & 0x0F) | vh0) / 2;
+                const uint8_t vi1 = ((y[i].qs[l/2] >>   4) | vh1) / 2;
+
+                hist[vi0]++;
+                hist[vi1]++;
+            }
+        }
+    }
+
+    return (n/QK5_0*sizeof(block_q5_0));
+}
+
+size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist) {
+    assert(k % QK5_1 == 0);
+    const int nb = k / QK5_1;
+
+    for (int j = 0; j < n; j += k) {
+        block_q5_1 * restrict y = (block_q5_1 *)dst + j/QK5_1;
+
+        quantize_row_q5_1_reference(src + j, y, k);
+
+        for (int i = 0; i < nb; i++) {
+            uint32_t qh;
+            memcpy(&qh, &y[i].qh, sizeof(qh));
+
+            for (int l = 0; l < QK5_1; l += 2) {
+                const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
+                const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
+
+                // cast to 16 bins
+                const uint8_t vi0 = ((y[i].qs[l/2] & 0x0F) | vh0) / 2;
+                const uint8_t vi1 = ((y[i].qs[l/2] >>   4) | vh1) / 2;
+
+                hist[vi0]++;
+                hist[vi1]++;
+            }
+        }
+    }
+
+    return (n/QK5_1*sizeof(block_q5_1));
+}
+
+size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist) {
+    assert(k % QK8_0 == 0);
+    const int nb = k / QK8_0;
+
+    for (int j = 0; j < n; j += k) {
+        block_q8_0 * restrict y = (block_q8_0 *)dst + j/QK8_0;
+
+        quantize_row_q8_0_reference(src + j, y, k);
+
+        for (int i = 0; i < nb; i++) {
+            for (int l = 0; l < QK8_0; ++l) {
+                const int8_t vi = y[i].qs[l];
+
+                hist[vi/16 + 8]++;
+            }
+        }
+    }
+
+    return (n/QK8_0*sizeof(block_q8_0));
+}
+
+size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist) {
+    size_t result = 0;
+    switch (type) {
+        case GGML_TYPE_Q4_0:
+            {
+                GGML_ASSERT(start % QK4_0 == 0);
+                block_q4_0 * block = (block_q4_0*)dst + start / QK4_0;
+                result = ggml_quantize_q4_0(src + start, block, n, n, hist);
+            } break;
+        case GGML_TYPE_Q4_1:
+            {
+                GGML_ASSERT(start % QK4_1 == 0);
+                block_q4_1 * block = (block_q4_1*)dst + start / QK4_1;
+                result = ggml_quantize_q4_1(src + start, block, n, n, hist);
+            } break;
+        case GGML_TYPE_Q4_2:
+            {
+                GGML_ASSERT(start % QK4_2 == 0);
+                block_q4_2 * block = (block_q4_2*)dst + start / QK4_2;
+                result = ggml_quantize_q4_2(src + start, block, n, n, hist);
+            } break;
+        case GGML_TYPE_Q5_0:
+            {
+                GGML_ASSERT(start % QK5_0 == 0);
+                block_q5_0 * block = (block_q5_0*)dst + start / QK5_0;
+                result = ggml_quantize_q5_0(src + start, block, n, n, hist);
+            } break;
+        case GGML_TYPE_Q5_1:
+            {
+                GGML_ASSERT(start % QK5_1 == 0);
+                block_q5_1 * block = (block_q5_1*)dst + start / QK5_1;
+                result = ggml_quantize_q5_1(src + start, block, n, n, hist);
+            } break;
+        case GGML_TYPE_Q8_0:
+            {
+                GGML_ASSERT(start % QK8_0 == 0);
+                block_q8_0 * block = (block_q8_0*)dst + start / QK8_0;
+                result = ggml_quantize_q8_0(src + start, block, n, n, hist);
+            } break;
+        default:
+            assert(false);
+    }
+    return result;
 }
 
 ////////////////////////////////////////////////////////////////////////////////
@@ -11143,6 +13002,22 @@ int ggml_cpu_has_avx512(void) {
 #endif
 }
 
+int ggml_cpu_has_avx512_vbmi(void) {
+#if defined(__AVX512VBMI__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_avx512_vnni(void) {
+#if defined(__AVX512VNNI__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
 int ggml_cpu_has_fma(void) {
 #if defined(__FMA__)
     return 1;
@@ -11192,13 +13067,33 @@ int ggml_cpu_has_wasm_simd(void) {
 }
 
 int ggml_cpu_has_blas(void) {
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_cublas(void) {
+#if defined(GGML_USE_CUBLAS)
     return 1;
 #else
     return 0;
 #endif
 }
 
+int ggml_cpu_has_clblast(void) {
+#if defined(GGML_USE_CLBLAST)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_gpublas(void) {
+    return ggml_cpu_has_cublas() || ggml_cpu_has_clblast();
+}
+
 int ggml_cpu_has_sse3(void) {
 #if defined(__SSE3__)
     return 1;
diff --git a/ggml.h b/ggml.h
index bdff0b4de454e1caee160c2e8391479f91c3217e..38ae9a6eeeb71b29cb40a8e318ef3319334c3b0c 100644 (file)
--- a/ggml.h
+++ b/ggml.h
 //
 //
 
-#ifdef  __cplusplus
-extern "C" {
+#ifdef GGML_SHARED
+#    if defined(_WIN32) && !defined(__MINGW32__)
+#        ifdef GGML_BUILD
+#            define GGML_API __declspec(dllexport)
+#        else
+#            define GGML_API __declspec(dllimport)
+#        endif
+#    else
+#        define GGML_API __attribute__ ((visibility ("default")))
+#    endif
+#else
+#    define GGML_API
 #endif
 
 #include <stdint.h>
 #include <stddef.h>
 #include <stdbool.h>
 
+#define GGML_FILE_MAGIC   0x67676d6c // "ggml"
+#define GGML_FILE_VERSION 1
+
 #define GGML_MAX_DIMS          4
 #define GGML_MAX_NODES         4096
 #define GGML_MAX_PARAMS        16
@@ -184,660 +197,704 @@ extern "C" {
 #define GGML_MAX_OPT           4
 #define GGML_DEFAULT_N_THREADS 4
 
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
 #ifdef __ARM_NEON
-// we use the built-in 16-bit float type
-typedef __fp16 ggml_fp16_t;
+    // we use the built-in 16-bit float type
+    typedef __fp16 ggml_fp16_t;
 #else
-typedef uint16_t ggml_fp16_t;
+    typedef uint16_t ggml_fp16_t;
 #endif
 
-// convert FP16 <-> FP32
-float       ggml_fp16_to_fp32(ggml_fp16_t x);
-ggml_fp16_t ggml_fp32_to_fp16(float x);
-
-struct ggml_object;
-struct ggml_context;
-
-enum ggml_type {
-    // explicitly numbered values are used in llama.cpp files
-    GGML_TYPE_F32  = 0,
-    GGML_TYPE_F16  = 1,
-    GGML_TYPE_Q4_0 = 2,
-    GGML_TYPE_Q4_1 = 3,
-    GGML_TYPE_I8,
-    GGML_TYPE_I16,
-    GGML_TYPE_I32,
-    GGML_TYPE_COUNT,
-};
-
-// available tensor operations:
-enum ggml_op {
-    GGML_OP_NONE = 0,
-
-    GGML_OP_DUP,
-    GGML_OP_ADD,
-    GGML_OP_SUB,
-    GGML_OP_MUL,
-    GGML_OP_DIV,
-    GGML_OP_SQR,
-    GGML_OP_SQRT,
-    GGML_OP_SUM,
-    GGML_OP_MEAN,
-    GGML_OP_REPEAT,
-    GGML_OP_ABS,
-    GGML_OP_SGN,
-    GGML_OP_NEG,
-    GGML_OP_STEP,
-    GGML_OP_RELU,
-    GGML_OP_GELU,
-    GGML_OP_SILU,
-    GGML_OP_NORM, // normalize
-    GGML_OP_RMS_NORM,
-
-    GGML_OP_MUL_MAT,
-
-    GGML_OP_SCALE,
-    GGML_OP_CPY,
-    GGML_OP_CONT,
-    GGML_OP_RESHAPE,
-    GGML_OP_VIEW,
-    GGML_OP_PERMUTE,
-    GGML_OP_TRANSPOSE,
-    GGML_OP_GET_ROWS,
-    GGML_OP_DIAG_MASK_INF,
-    GGML_OP_SOFT_MAX,
-    GGML_OP_ROPE,
-    GGML_OP_CONV_1D_1S,
-    GGML_OP_CONV_1D_2S,
-
-    GGML_OP_FLASH_ATTN,
-    GGML_OP_FLASH_FF,
-
-    GGML_OP_MAP_UNARY,
-    GGML_OP_MAP_BINARY,
-
-    GGML_OP_COUNT,
-};
-
-
-// ggml object
-struct ggml_object {
-    size_t offs;
-    size_t size;
-
-    struct ggml_object * next;
-
-    char padding[8];
-};
-
-static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);
-
-// n-dimensional tensor
-struct ggml_tensor {
-    enum ggml_type type;
-
-    int    n_dims;
-    int64_t ne[GGML_MAX_DIMS]; // number of elements
-    size_t  nb[GGML_MAX_DIMS]; // stride in bytes:
-                               // nb[0] = sizeof(type)
-                               // nb[1] = nb[0]   * ne[0] + padding
-                               // nb[i] = nb[i-1] * ne[i-1]
-
-    // compute data
-    enum ggml_op op;
-
-    bool is_param;
-
-    struct ggml_tensor * grad;
-    struct ggml_tensor * src0;
-    struct ggml_tensor * src1;
-    struct ggml_tensor * opt[GGML_MAX_OPT];
-
-    // thread scheduling
-    int n_tasks;
-
-    // performance
-    int     perf_runs;
-    int64_t perf_cycles;
-    int64_t perf_time_us;
-
-    void * data;
-    char padding[8];
-};
-
-// computation graph
-struct ggml_cgraph {
-    int n_nodes;
-    int n_leafs;
-    int n_threads;
-
-    size_t work_size;
-    struct ggml_tensor * work;
-
-    struct ggml_tensor * nodes[GGML_MAX_NODES];
-    struct ggml_tensor * grads[GGML_MAX_NODES];
-    struct ggml_tensor * leafs[GGML_MAX_NODES];
-
-    // performance
-    int     perf_runs;
-    int64_t perf_cycles;
-    int64_t perf_time_us;
-};
-
-// scratch buffer
-struct ggml_scratch {
-    size_t offs;
-    size_t size;
-    void * data;
-};
-
-struct ggml_init_params {
-    // memory pool
-    size_t mem_size;   // bytes
-    void * mem_buffer; // if NULL, memory will be allocated internally
-    bool   no_alloc;   // don't allocate memory for the tensor data
-};
-
-void    ggml_time_init(void); // call this once at the beginning of the program
-int64_t ggml_time_ms(void);
-int64_t ggml_time_us(void);
-int64_t ggml_cycles(void);
-int64_t ggml_cycles_per_ms(void);
+    // convert FP16 <-> FP32
+    GGML_API float       ggml_fp16_to_fp32(ggml_fp16_t x);
+    GGML_API ggml_fp16_t ggml_fp32_to_fp16(float x);
+
+    struct ggml_object;
+    struct ggml_context;
+
+    enum ggml_type {
+        GGML_TYPE_F32  = 0,
+        GGML_TYPE_F16  = 1,
+        GGML_TYPE_Q4_0 = 2,
+        GGML_TYPE_Q4_1 = 3,
+        GGML_TYPE_Q4_2 = 4,
+        // GGML_TYPE_Q4_3 (5) support has been removed
+        GGML_TYPE_Q5_0 = 6,
+        GGML_TYPE_Q5_1 = 7,
+        GGML_TYPE_Q8_0 = 8,
+        GGML_TYPE_Q8_1 = 9,
+        GGML_TYPE_I8,
+        GGML_TYPE_I16,
+        GGML_TYPE_I32,
+        GGML_TYPE_COUNT,
+    };
+
+    // available tensor operations:
+    enum ggml_op {
+        GGML_OP_NONE = 0,
+
+        GGML_OP_DUP,
+        GGML_OP_ADD,
+        GGML_OP_SUB,
+        GGML_OP_MUL,
+        GGML_OP_DIV,
+        GGML_OP_SQR,
+        GGML_OP_SQRT,
+        GGML_OP_SUM,
+        GGML_OP_MEAN,
+        GGML_OP_REPEAT,
+        GGML_OP_ABS,
+        GGML_OP_SGN,
+        GGML_OP_NEG,
+        GGML_OP_STEP,
+        GGML_OP_RELU,
+        GGML_OP_GELU,
+        GGML_OP_SILU,
+        GGML_OP_NORM, // normalize
+        GGML_OP_RMS_NORM,
+
+        GGML_OP_MUL_MAT,
+
+        GGML_OP_SCALE,
+        GGML_OP_CPY,
+        GGML_OP_CONT,
+        GGML_OP_RESHAPE,
+        GGML_OP_VIEW,
+        GGML_OP_PERMUTE,
+        GGML_OP_TRANSPOSE,
+        GGML_OP_GET_ROWS,
+        GGML_OP_DIAG_MASK_INF,
+        GGML_OP_SOFT_MAX,
+        GGML_OP_ROPE,
+        GGML_OP_ALIBI,
+        GGML_OP_CONV_1D_1S,
+        GGML_OP_CONV_1D_2S,
+
+        GGML_OP_FLASH_ATTN,
+        GGML_OP_FLASH_FF,
+
+        GGML_OP_MAP_UNARY,
+        GGML_OP_MAP_BINARY,
+
+        GGML_OP_COUNT,
+    };
+
+
+    // ggml object
+    struct ggml_object {
+        size_t offs;
+        size_t size;
+
+        struct ggml_object * next;
+
+        char padding[8];
+    };
+
+    static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);
+
+    // n-dimensional tensor
+    struct ggml_tensor {
+        enum ggml_type type;
+
+        int     n_dims;
+        int64_t ne[GGML_MAX_DIMS]; // number of elements
+        size_t  nb[GGML_MAX_DIMS]; // stride in bytes:
+                                   // nb[0] = sizeof(type)
+                                   // nb[1] = nb[0]   * ne[0] + padding
+                                   // nb[i] = nb[i-1] * ne[i-1]
+
+        // compute data
+        enum ggml_op op;
+
+        bool is_param;
+
+        struct ggml_tensor * grad;
+        struct ggml_tensor * src0;
+        struct ggml_tensor * src1;
+        struct ggml_tensor * opt[GGML_MAX_OPT];
+
+        // thread scheduling
+        int n_tasks;
+
+        // performance
+        int     perf_runs;
+        int64_t perf_cycles;
+        int64_t perf_time_us;
+
+        void * data;
+        char padding[8];
+    };
+
+    // computation graph
+    struct ggml_cgraph {
+        int n_nodes;
+        int n_leafs;
+        int n_threads;
+
+        size_t work_size;
+        struct ggml_tensor * work;
+
+        struct ggml_tensor * nodes[GGML_MAX_NODES];
+        struct ggml_tensor * grads[GGML_MAX_NODES];
+        struct ggml_tensor * leafs[GGML_MAX_NODES];
+
+        // performance
+        int     perf_runs;
+        int64_t perf_cycles;
+        int64_t perf_time_us;
+    };
+
+    // scratch buffer
+    struct ggml_scratch {
+        size_t offs;
+        size_t size;
+        void * data;
+    };
 
-void ggml_print_object (const struct ggml_object * obj);
-void ggml_print_objects(const struct ggml_context * ctx);
+    struct ggml_init_params {
+        // memory pool
+        size_t mem_size;   // bytes
+        void * mem_buffer; // if NULL, memory will be allocated internally
+        bool   no_alloc;   // don't allocate memory for the tensor data
+    };
 
-int64_t ggml_nelements(const struct ggml_tensor * tensor);
-size_t  ggml_nbytes   (const struct ggml_tensor * tensor);
+    // misc
 
-int    ggml_blck_size (enum ggml_type type);
-size_t ggml_type_size (enum ggml_type type); // size in bytes for all elements in a block
-float  ggml_type_sizef(enum ggml_type type); // ggml_type_size()/ggml_blck_size() as float
+    GGML_API void    ggml_time_init(void); // call this once at the beginning of the program
+    GGML_API int64_t ggml_time_ms(void);
+    GGML_API int64_t ggml_time_us(void);
+    GGML_API int64_t ggml_cycles(void);
+    GGML_API int64_t ggml_cycles_per_ms(void);
 
-size_t ggml_element_size(const struct ggml_tensor * tensor);
+    GGML_API void    ggml_print_object (const struct ggml_object * obj);
+    GGML_API void    ggml_print_objects(const struct ggml_context * ctx);
 
-struct ggml_context * ggml_init(struct ggml_init_params params);
-void ggml_free(struct ggml_context * ctx);
+    GGML_API int64_t ggml_nelements(const struct ggml_tensor * tensor);
+    GGML_API size_t  ggml_nbytes   (const struct ggml_tensor * tensor);
 
-size_t ggml_used_mem(const struct ggml_context * ctx);
-
-size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch);
-
-struct ggml_tensor * ggml_new_tensor(
-        struct ggml_context * ctx,
-        enum   ggml_type type,
-        int    n_dims,
-        const int64_t *ne);
-
-struct ggml_tensor * ggml_new_tensor_1d(
-        struct ggml_context * ctx,
-        enum   ggml_type type,
-        int64_t ne0);
-
-struct ggml_tensor * ggml_new_tensor_2d(
-        struct ggml_context * ctx,
-        enum   ggml_type type,
-        int64_t ne0,
-        int64_t ne1);
-
-struct ggml_tensor * ggml_new_tensor_3d(
-        struct ggml_context * ctx,
-        enum   ggml_type type,
-        int64_t ne0,
-        int64_t ne1,
-        int64_t ne2);
-
-struct ggml_tensor * ggml_new_tensor_4d(
-        struct ggml_context * ctx,
-        enum   ggml_type type,
-        int64_t ne0,
-        int64_t ne1,
-        int64_t ne2,
-        int64_t ne3);
-
-struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value);
-struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value);
-
-struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src);
-struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, const struct ggml_tensor * src);
-
-struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);
-struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value);
-struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);
-
-int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i);
-void    ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value);
-
-float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i);
-void  ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value);
-
- void * ggml_get_data    (const struct ggml_tensor * tensor);
-float * ggml_get_data_f32(const struct ggml_tensor * tensor);
-
-//
-// operations on tensors with backpropagation
-//
-
-struct ggml_tensor * ggml_dup(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a);
-
-struct ggml_tensor * ggml_add(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        struct ggml_tensor  * b);
-
-struct ggml_tensor * ggml_sub(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        struct ggml_tensor  * b);
-
-struct ggml_tensor * ggml_mul(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        struct ggml_tensor  * b);
-
-struct ggml_tensor * ggml_div(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        struct ggml_tensor  * b);
-
-struct ggml_tensor * ggml_sqr(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a);
-
-struct ggml_tensor * ggml_sqrt(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a);
-
-// return scalar
-// TODO: compute sum along rows
-struct ggml_tensor * ggml_sum(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a);
-
-// mean along rows
-struct ggml_tensor * ggml_mean(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a);
-
-// if a is the same shape as b, and a is not parameter, return a
-// otherwise, return a new tensor: repeat(a) to fit in b
-struct ggml_tensor * ggml_repeat(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        struct ggml_tensor  * b);
-
-struct ggml_tensor * ggml_abs(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a);
-
-struct ggml_tensor * ggml_sgn(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a);
-
-struct ggml_tensor * ggml_neg(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a);
-
-struct ggml_tensor * ggml_step(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a);
-
-struct ggml_tensor * ggml_relu(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a);
-
-// TODO: double-check this computation is correct
-struct ggml_tensor * ggml_gelu(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a);
-
-struct ggml_tensor * ggml_silu(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a);
-
-// normalize along rows
-// TODO: eps is hardcoded to 1e-5 for now
-struct ggml_tensor * ggml_norm(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a);
-
-struct ggml_tensor * ggml_rms_norm(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a);
-
-// A: m rows, n columns
-// B: p rows, n columns (i.e. we transpose it internally)
-// result is m columns, p rows
-struct ggml_tensor * ggml_mul_mat(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        struct ggml_tensor  * b);
-
-//
-// operations on tensors without backpropagation
-//
-
-// in-place, returns view(a)
-struct ggml_tensor * ggml_scale(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        struct ggml_tensor  * b);
-
-// a -> b, return view(b)
-struct ggml_tensor * ggml_cpy(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        struct ggml_tensor  * b);
-
-// make contiguous
-struct ggml_tensor * ggml_cont(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a);
-
-// return view(a), b specifies the new shape
-// TODO: when we start computing gradient, make a copy instead of view
-struct ggml_tensor * ggml_reshape(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        struct ggml_tensor  * b);
-
-// return view(a)
-// TODO: when we start computing gradient, make a copy instead of view
-struct ggml_tensor * ggml_reshape_2d(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        int64_t               ne0,
-        int64_t               ne1);
-
-// return view(a)
-// TODO: when we start computing gradient, make a copy instead of view
-struct ggml_tensor * ggml_reshape_3d(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        int64_t               ne0,
-        int64_t               ne1,
-        int64_t               ne2);
-
-// offset in bytes
-struct ggml_tensor * ggml_view_1d(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        int64_t               ne0,
-        size_t                offset);
-
-struct ggml_tensor * ggml_view_2d(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        int64_t               ne0,
-        int64_t               ne1,
-        size_t                nb1, // row stride in bytes
-        size_t                offset);
-
-struct ggml_tensor * ggml_view_3d(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        int64_t               ne0,
-        int64_t               ne1,
-        int64_t               ne2,
-        size_t                nb1, // row   stride in bytes
-        size_t                nb2, // slice stride in bytes
-        size_t                offset);
-
-struct ggml_tensor * ggml_permute(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        int                   axis0,
-        int                   axis1,
-        int                   axis2,
-        int                   axis3);
-
-// alias for ggml_permute(ctx, a, 1, 0, 2, 3)
-struct ggml_tensor * ggml_transpose(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a);
-
-struct ggml_tensor * ggml_get_rows(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        struct ggml_tensor  * b);
-
-// set elements above the diagonal to -INF
-// in-place, returns view(a)
-struct ggml_tensor * ggml_diag_mask_inf(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        int                   n_past);
-
-// in-place, returns view(a)
-struct ggml_tensor * ggml_soft_max(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a);
-
-// rotary position embedding
-// in-place, returns view(a)
-// if mode == 1, skip n_past elements
-// TODO: avoid creating a new tensor every time
-struct ggml_tensor * ggml_rope(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        int                   n_past,
-        int                   n_dims,
-        int                   mode);
-
-// padding = 1
-// TODO: we don't support extra parameters for now
-//       that's why we are hard-coding the stride, padding, and dilation
-//       not great ..
-struct ggml_tensor * ggml_conv_1d_1s(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        struct ggml_tensor  * b);
-
-struct ggml_tensor * ggml_conv_1d_2s(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        struct ggml_tensor  * b);
-
-struct ggml_tensor * ggml_flash_attn(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * q,
-        struct ggml_tensor  * k,
-        struct ggml_tensor  * v,
-        bool                  masked);
-
-struct ggml_tensor * ggml_flash_ff(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        struct ggml_tensor  * b0,
-        struct ggml_tensor  * b1,
-        struct ggml_tensor  * c0,
-        struct ggml_tensor  * c1);
-
-// Mapping operations
-typedef void (*ggml_unary_op_f32_t)(const int, float *, const float *);
-typedef void (*ggml_binary_op_f32_t)(const int, float *, const float *, const float *);
-
-struct ggml_tensor * ggml_map_unary_f32(
-        struct ggml_context        * ctx,
-        struct ggml_tensor         * a,
-        const  ggml_unary_op_f32_t fun);
-
-struct ggml_tensor * ggml_map_binary_f32(
-        struct ggml_context         * ctx,
-        struct ggml_tensor          * a,
-        struct ggml_tensor          * b,
-        const  ggml_binary_op_f32_t fun);
-
-//
-// automatic differentiation
-//
-
-void ggml_set_param(
-        struct ggml_context * ctx,
-        struct ggml_tensor * tensor);
-
-void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
-
-struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
-struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
-
-void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph);
-void ggml_graph_reset  (struct ggml_cgraph * cgraph);
-
-// print info and performance information for the graph
-void ggml_graph_print(const struct ggml_cgraph * cgraph);
-
-// dump the graph into a file using the dot format
-void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename);
-
-//
-// optimization
-//
-
-// optimization methods
-enum ggml_opt_type {
-    GGML_OPT_ADAM,
-    GGML_OPT_LBFGS,
-};
-
-// linesearch methods
-enum ggml_linesearch {
-    GGML_LINESEARCH_DEFAULT = 1,
-
-    GGML_LINESEARCH_BACKTRACKING_ARMIJO       = 0,
-    GGML_LINESEARCH_BACKTRACKING_WOLFE        = 1,
-    GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE = 2,
-};
-
-// optimization return values
-enum ggml_opt_result {
-    GGML_OPT_OK = 0,
-    GGML_OPT_DID_NOT_CONVERGE,
-    GGML_OPT_NO_CONTEXT,
-    GGML_OPT_INVALID_WOLFE,
-    GGML_OPT_FAIL,
+    GGML_API int     ggml_blck_size (enum ggml_type type);
+    GGML_API size_t  ggml_type_size (enum ggml_type type); // size in bytes for all elements in a block
+    GGML_API float   ggml_type_sizef(enum ggml_type type); // ggml_type_size()/ggml_blck_size() as float
 
-    GGML_LINESEARCH_FAIL = -128,
-    GGML_LINESEARCH_MINIMUM_STEP,
-    GGML_LINESEARCH_MAXIMUM_STEP,
-    GGML_LINESEARCH_MAXIMUM_ITERATIONS,
-    GGML_LINESEARCH_INVALID_PARAMETERS,
-};
+    GGML_API const char * ggml_type_name(enum ggml_type type);
 
-// optimization parameters
-//
-//   see ggml.c (ggml_opt_default_params) for default values
-//
-struct ggml_opt_params {
-    enum ggml_opt_type type;
+    GGML_API size_t  ggml_element_size(const struct ggml_tensor * tensor);
+
+    GGML_API bool    ggml_is_quantized(enum ggml_type type);
+
+    // main
+
+    GGML_API struct ggml_context * ggml_init(struct ggml_init_params params);
+    GGML_API void    ggml_free(struct ggml_context * ctx);
+
+    GGML_API size_t  ggml_used_mem(const struct ggml_context * ctx);
+
+    GGML_API size_t  ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch);
+
+    GGML_API struct ggml_tensor * ggml_new_tensor(
+            struct ggml_context * ctx,
+            enum   ggml_type type,
+            int    n_dims,
+            const int64_t *ne);
+
+    GGML_API struct ggml_tensor * ggml_new_tensor_1d(
+            struct ggml_context * ctx,
+            enum   ggml_type type,
+            int64_t ne0);
+
+    GGML_API struct ggml_tensor * ggml_new_tensor_2d(
+            struct ggml_context * ctx,
+            enum   ggml_type type,
+            int64_t ne0,
+            int64_t ne1);
+
+    GGML_API struct ggml_tensor * ggml_new_tensor_3d(
+            struct ggml_context * ctx,
+            enum   ggml_type type,
+            int64_t ne0,
+            int64_t ne1,
+            int64_t ne2);
+
+    GGML_API struct ggml_tensor * ggml_new_tensor_4d(
+            struct ggml_context * ctx,
+            enum   ggml_type type,
+            int64_t ne0,
+            int64_t ne1,
+            int64_t ne2,
+            int64_t ne3);
+
+    GGML_API struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value);
+    GGML_API struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value);
+
+    GGML_API struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src);
+    GGML_API struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, const struct ggml_tensor * src);
+
+    GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);
+    GGML_API struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value);
+    GGML_API struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);
 
-    int n_threads;
+    GGML_API int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i);
+    GGML_API void    ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value);
+
+    GGML_API float   ggml_get_f32_1d(const struct ggml_tensor * tensor, int i);
+    GGML_API void    ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value);
+
+    GGML_API void *  ggml_get_data    (const struct ggml_tensor * tensor);
+    GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
 
-    // delta-based convergence test
     //
-    //   if past == 0 - disabled
-    //   if past > 0:
-    //     stop if |f(x) - f(x_past)| < delta * max(1, |f(x)|)
+    // operations on tensors with backpropagation
     //
-    int past;
-    float delta;
 
-    // maximum number of iterations without improvement
+    GGML_API struct ggml_tensor * ggml_dup(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_add(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_add_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_sub(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_mul(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_div(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_sqr(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_sqrt(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // return scalar
+    // TODO: compute sum along rows
+    GGML_API struct ggml_tensor * ggml_sum(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // mean along rows
+    GGML_API struct ggml_tensor * ggml_mean(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // if a is the same shape as b, and a is not parameter, return a
+    // otherwise, return a new tensor: repeat(a) to fit in b
+    GGML_API struct ggml_tensor * ggml_repeat(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_abs(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_sgn(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_neg(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_step(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_relu(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // TODO: double-check this computation is correct
+    GGML_API struct ggml_tensor * ggml_gelu(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_silu(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // normalize along rows
+    // TODO: eps is hardcoded to 1e-5 for now
+    GGML_API struct ggml_tensor * ggml_norm(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_rms_norm(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // A: m rows, n columns
+    // B: p rows, n columns (i.e. we transpose it internally)
+    // result is m columns, p rows
+    GGML_API struct ggml_tensor * ggml_mul_mat(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
     //
-    //   if 0 - disabled
-    //   if > 0:
-    //     assume convergence if no cost improvement in this number of iterations
+    // operations on tensors without backpropagation
     //
-    int max_no_improvement;
 
-    bool print_forward_graph;
-    bool print_backward_graph;
+    // in-place, returns view(a)
+    GGML_API struct ggml_tensor * ggml_scale(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    // a -> b, return view(b)
+    GGML_API struct ggml_tensor * ggml_cpy(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    // make contiguous
+    GGML_API struct ggml_tensor * ggml_cont(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // return view(a), b specifies the new shape
+    // TODO: when we start computing gradient, make a copy instead of view
+    GGML_API struct ggml_tensor * ggml_reshape(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    // return view(a)
+    // TODO: when we start computing gradient, make a copy instead of view
+    GGML_API struct ggml_tensor * ggml_reshape_2d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0,
+            int64_t               ne1);
+
+    // return view(a)
+    // TODO: when we start computing gradient, make a copy instead of view
+    GGML_API struct ggml_tensor * ggml_reshape_3d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0,
+            int64_t               ne1,
+            int64_t               ne2);
+
+    // offset in bytes
+    GGML_API struct ggml_tensor * ggml_view_1d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0,
+            size_t                offset);
+
+    GGML_API struct ggml_tensor * ggml_view_2d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0,
+            int64_t               ne1,
+            size_t                nb1, // row stride in bytes
+            size_t                offset);
+
+    GGML_API struct ggml_tensor * ggml_view_3d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0,
+            int64_t               ne1,
+            int64_t               ne2,
+            size_t                nb1, // row   stride in bytes
+            size_t                nb2, // slice stride in bytes
+            size_t                offset);
+
+    GGML_API struct ggml_tensor * ggml_permute(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   axis0,
+            int                   axis1,
+            int                   axis2,
+            int                   axis3);
+
+    // alias for ggml_permute(ctx, a, 1, 0, 2, 3)
+    GGML_API struct ggml_tensor * ggml_transpose(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_get_rows(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    // set elements above the diagonal to -INF
+    // in-place, returns view(a)
+    GGML_API struct ggml_tensor * ggml_diag_mask_inf(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   n_past);
+
+    // in-place, returns view(a)
+    GGML_API struct ggml_tensor * ggml_soft_max(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // rotary position embedding
+    // in-place, returns view(a)
+    // if mode & 1 == 1, skip n_past elements
+    // if mode & 2 == 1, GPT-NeoX style
+    // TODO: avoid creating a new tensor every time
+    GGML_API struct ggml_tensor * ggml_rope(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   n_past,
+            int                   n_dims,
+            int                   mode);
+
+    // alibi position embedding
+    // in-place, returns view(a)
+    struct ggml_tensor * ggml_alibi(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   n_past,
+            int                   n_head);
+
+    // padding = 1
+    // TODO: we don't support extra parameters for now
+    //       that's why we are hard-coding the stride, padding, and dilation
+    //       not great ..
+    GGML_API struct ggml_tensor * ggml_conv_1d_1s(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_conv_1d_2s(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_flash_attn(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * q,
+            struct ggml_tensor  * k,
+            struct ggml_tensor  * v,
+            bool                  masked);
+
+    GGML_API struct ggml_tensor * ggml_flash_ff(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b0,
+            struct ggml_tensor  * b1,
+            struct ggml_tensor  * c0,
+            struct ggml_tensor  * c1);
+
+    // Mapping operations
+    GGML_API typedef void (*ggml_unary_op_f32_t)(const int, float *, const float *);
+    GGML_API typedef void (*ggml_binary_op_f32_t)(const int, float *, const float *, const float *);
+
+    GGML_API struct ggml_tensor * ggml_map_unary_f32(
+            struct ggml_context        * ctx,
+            struct ggml_tensor         * a,
+            const  ggml_unary_op_f32_t fun);
+
+    GGML_API struct ggml_tensor * ggml_map_binary_f32(
+            struct ggml_context         * ctx,
+            struct ggml_tensor          * a,
+            struct ggml_tensor          * b,
+            const  ggml_binary_op_f32_t fun);
 
-    // ADAM parameters
-    struct {
-        int n_iter;
+    //
+    // automatic differentiation
+    //
 
-        float alpha; // learning rate
-        float beta1;
-        float beta2;
-        float eps;   // epsilon for numerical stability
-        float eps_f; // epsilon for convergence test
-        float eps_g; // epsilon for convergence test
-    } adam;
+    GGML_API void ggml_set_param(
+            struct ggml_context * ctx,
+            struct ggml_tensor * tensor);
 
-    // LBFGS parameters
-    struct {
-        int m; // number of corrections to approximate the inv. Hessian
-        int n_iter;
-        int max_linesearch;
+    GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
 
-        float eps;      // convergence tolerance
-        float ftol;     // line search tolerance
-        float wolfe;
-        float min_step;
-        float max_step;
+    GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
+    GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
 
-        enum ggml_linesearch linesearch;
-    } lbfgs;
-};
+    GGML_API void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph);
+    GGML_API void ggml_graph_reset  (struct ggml_cgraph * cgraph);
 
-struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type);
+    // print info and performance information for the graph
+    GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph);
 
-// optimize the function defined by the tensor f
-enum ggml_opt_result ggml_opt(
-        struct ggml_context * ctx,
-        struct ggml_opt_params params,
-        struct ggml_tensor * f);
+    // dump the graph into a file using the dot format
+    GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename);
 
-//
-// quantization
-//
+    //
+    // optimization
+    //
 
-size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
-size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
+    // optimization methods
+    enum ggml_opt_type {
+        GGML_OPT_ADAM,
+        GGML_OPT_LBFGS,
+    };
+
+    // linesearch methods
+    enum ggml_linesearch {
+        GGML_LINESEARCH_DEFAULT = 1,
+
+        GGML_LINESEARCH_BACKTRACKING_ARMIJO       = 0,
+        GGML_LINESEARCH_BACKTRACKING_WOLFE        = 1,
+        GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE = 2,
+    };
+
+    // optimization return values
+    enum ggml_opt_result {
+        GGML_OPT_OK = 0,
+        GGML_OPT_DID_NOT_CONVERGE,
+        GGML_OPT_NO_CONTEXT,
+        GGML_OPT_INVALID_WOLFE,
+        GGML_OPT_FAIL,
+
+        GGML_LINESEARCH_FAIL = -128,
+        GGML_LINESEARCH_MINIMUM_STEP,
+        GGML_LINESEARCH_MAXIMUM_STEP,
+        GGML_LINESEARCH_MAXIMUM_ITERATIONS,
+        GGML_LINESEARCH_INVALID_PARAMETERS,
+    };
+
+    // optimization parameters
+    //
+    //   see ggml.c (ggml_opt_default_params) for default values
+    //
+    struct ggml_opt_params {
+        enum ggml_opt_type type;
+
+        int n_threads;
+
+        // delta-based convergence test
+        //
+        //   if past == 0 - disabled
+        //   if past > 0:
+        //     stop if |f(x) - f(x_past)| < delta * max(1, |f(x)|)
+        //
+        int past;
+        float delta;
+
+        // maximum number of iterations without improvement
+        //
+        //   if 0 - disabled
+        //   if > 0:
+        //     assume convergence if no cost improvement in this number of iterations
+        //
+        int max_no_improvement;
+
+        bool print_forward_graph;
+        bool print_backward_graph;
+
+        // ADAM parameters
+        struct {
+            int n_iter;
+
+            float alpha; // learning rate
+            float beta1;
+            float beta2;
+            float eps;   // epsilon for numerical stability
+            float eps_f; // epsilon for convergence test
+            float eps_g; // epsilon for convergence test
+        } adam;
+
+        // LBFGS parameters
+        struct {
+            int m; // number of corrections to approximate the inv. Hessian
+            int n_iter;
+            int max_linesearch;
+
+            float eps;      // convergence tolerance
+            float ftol;     // line search tolerance
+            float wolfe;
+            float min_step;
+            float max_step;
+
+            enum ggml_linesearch linesearch;
+        } lbfgs;
+    };
+
+    GGML_API struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type);
+
+    // optimize the function defined by the tensor f
+    GGML_API enum ggml_opt_result ggml_opt(
+            struct ggml_context * ctx,
+            struct ggml_opt_params params,
+            struct ggml_tensor * f);
 
-//
-// system info
-//
+    //
+    // quantization
+    //
 
-int ggml_cpu_has_avx(void);
-int ggml_cpu_has_avx2(void);
-int ggml_cpu_has_avx512(void);
-int ggml_cpu_has_fma(void);
-int ggml_cpu_has_neon(void);
-int ggml_cpu_has_arm_fma(void);
-int ggml_cpu_has_f16c(void);
-int ggml_cpu_has_fp16_va(void);
-int ggml_cpu_has_wasm_simd(void);
-int ggml_cpu_has_blas(void);
-int ggml_cpu_has_sse3(void);
-int ggml_cpu_has_vsx(void);
+    GGML_API size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
+    GGML_API size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
+    GGML_API size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t * hist);
+    GGML_API size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist);
+    GGML_API size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist);
+    GGML_API size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist);
 
+    GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
 
-//
-// Internal types and functions exposed for tests and benchmarks
-//
+    //
+    // system info
+    //
+
+    GGML_API int ggml_cpu_has_avx        (void);
+    GGML_API int ggml_cpu_has_avx2       (void);
+    GGML_API int ggml_cpu_has_avx512     (void);
+    GGML_API int ggml_cpu_has_avx512_vbmi(void);
+    GGML_API int ggml_cpu_has_avx512_vnni(void);
+    GGML_API int ggml_cpu_has_fma        (void);
+    GGML_API int ggml_cpu_has_neon       (void);
+    GGML_API int ggml_cpu_has_arm_fma    (void);
+    GGML_API int ggml_cpu_has_f16c       (void);
+    GGML_API int ggml_cpu_has_fp16_va    (void);
+    GGML_API int ggml_cpu_has_wasm_simd  (void);
+    GGML_API int ggml_cpu_has_blas       (void);
+    GGML_API int ggml_cpu_has_cublas     (void);
+    GGML_API int ggml_cpu_has_clblast    (void);
+    GGML_API int ggml_cpu_has_gpublas    (void);
+    GGML_API int ggml_cpu_has_sse3       (void);
+    GGML_API int ggml_cpu_has_vsx        (void);
+
+    //
+    // Internal types and functions exposed for tests and benchmarks
+    //
 
 #ifdef  __cplusplus
-// restrict not standard in C++
+    // restrict not standard in C++
 #define GGML_RESTRICT
 #else
 #define GGML_RESTRICT restrict
 #endif
-typedef void (*dequantize_row_q_t)(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-typedef void (*quantize_row_q_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
-typedef void (*vec_dot_q_t)(const int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT x, const void * GGML_RESTRICT y);
-
-typedef struct {
-    dequantize_row_q_t dequantize_row_q;
-    quantize_row_q_t   quantize_row_q;
-    quantize_row_q_t   quantize_row_q_reference;
-    vec_dot_q_t        vec_dot_q;
-} quantize_fns_t;
-
-quantize_fns_t ggml_internal_get_quantize_fn(size_t i);
+    typedef void (*dequantize_row_q_t)(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
+    typedef void (*quantize_row_q_t)  (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
+    typedef void (*vec_dot_q_t)       (const int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT x, const void * GGML_RESTRICT y);
+
+    typedef struct {
+        dequantize_row_q_t dequantize_row_q;
+        quantize_row_q_t   quantize_row_q;
+        quantize_row_q_t   quantize_row_q_reference;
+        quantize_row_q_t   quantize_row_q_dot;
+        vec_dot_q_t        vec_dot_q;
+        enum ggml_type     vec_dot_type;
+    } quantize_fns_t;
+
+    quantize_fns_t ggml_internal_get_quantize_fn(size_t i);
 
 #ifdef  __cplusplus
 }