]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : sync llama.cpp
authorGeorgi Gerganov <redacted>
Wed, 19 Apr 2023 17:20:23 +0000 (20:20 +0300)
committerGeorgi Gerganov <redacted>
Wed, 19 Apr 2023 17:20:35 +0000 (20:20 +0300)
include/ggml/ggml.h
src/ggml.c

index 241e96a1975b1c935fa96fabd2585b31f905e2b0..570147fc246587603042beed51a0cdb61bae17f1 100644 (file)
@@ -204,7 +204,8 @@ enum ggml_type {
     GGML_TYPE_F16  = 1,
     GGML_TYPE_Q4_0 = 2,
     GGML_TYPE_Q4_1 = 3,
-    GGML_TYPE_Q8_0 = 4,
+    GGML_TYPE_Q4_2 = 4,
+    GGML_TYPE_Q8_0 = 5,
     GGML_TYPE_I8,
     GGML_TYPE_I16,
     GGML_TYPE_I32,
@@ -430,6 +431,12 @@ struct ggml_tensor * ggml_add(
         struct ggml_tensor  * a,
         struct ggml_tensor  * b);
 
+
+struct ggml_tensor * ggml_add_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b);
+
 struct ggml_tensor * ggml_sub(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
@@ -800,6 +807,7 @@ enum ggml_opt_result ggml_opt(
 
 size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
 size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
+size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t * hist);
 
 //
 // system info
@@ -808,6 +816,8 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t *
 int ggml_cpu_has_avx(void);
 int ggml_cpu_has_avx2(void);
 int ggml_cpu_has_avx512(void);
+int ggml_cpu_has_avx512_vbmi(void);
+int ggml_cpu_has_avx512_vnni(void);
 int ggml_cpu_has_fma(void);
 int ggml_cpu_has_neon(void);
 int ggml_cpu_has_arm_fma(void);
@@ -815,6 +825,7 @@ int ggml_cpu_has_f16c(void);
 int ggml_cpu_has_fp16_va(void);
 int ggml_cpu_has_wasm_simd(void);
 int ggml_cpu_has_blas(void);
+int ggml_cpu_has_cublas(void);
 int ggml_cpu_has_sse3(void);
 int ggml_cpu_has_vsx(void);
 
index ccad76e8c018e31352d4c599277fccc41adaf1b3..3b38eaad3673612fe62cbcd11c4b9588d944062c 100644 (file)
@@ -142,10 +142,46 @@ inline static void* ggml_aligned_malloc(size_t size) {
         } \
     } while (0)
 
-#ifdef GGML_USE_ACCELERATE
+#if defined(GGML_USE_ACCELERATE)
 #include <Accelerate/Accelerate.h>
-#elif GGML_USE_OPENBLAS
+#elif defined(GGML_USE_OPENBLAS)
 #include <cblas.h>
+#elif defined(GGML_USE_CUBLAS)
+#include <cublas_v2.h>
+#include <cuda_runtime.h>
+#define CUDA_CHECK(err)                                                                            \
+    do {                                                                                           \
+        cudaError_t err_ = (err);                                                                  \
+        if (err_ != cudaSuccess) {                                                                 \
+            printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__,                       \
+                cudaGetErrorString(err_));                                                         \
+            exit(1);                                                                               \
+        }                                                                                          \
+    } while (0)
+
+#define CUBLAS_CHECK(err)                                                                          \
+    do {                                                                                           \
+        cublasStatus_t err_ = (err);                                                               \
+        if (err_ != CUBLAS_STATUS_SUCCESS) {                                                       \
+            printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__);                        \
+            exit(1);                                                                               \
+        }                                                                                          \
+    } while (0)
+
+static cublasHandle_t cublasH = NULL;
+static cudaStream_t cudaStream = NULL;
+static void init_cublas(void) {
+    if (cublasH == NULL) {
+        // create cublas handle, bind a stream
+        CUBLAS_CHECK(cublasCreate(&cublasH));
+
+        CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking));
+        CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream));
+
+        // configure logging to stdout
+        // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
+    }
+}
 #endif
 
 #undef MIN
@@ -514,6 +550,18 @@ inline static uint16_t vaddvq_u8(uint8x16_t v) {
         (uint16_t)vgetq_lane_u8(v, 14) + (uint16_t)vgetq_lane_u8(v, 15);
 }
 
+inline static int16_t vaddvq_s8(int8x16_t v) {
+    return
+        (int16_t)vgetq_lane_s8(v, 0)  + (int16_t)vgetq_lane_s8(v, 1)  +
+        (int16_t)vgetq_lane_s8(v, 2)  + (int16_t)vgetq_lane_s8(v, 3)  +
+        (int16_t)vgetq_lane_s8(v, 4)  + (int16_t)vgetq_lane_s8(v, 5)  +
+        (int16_t)vgetq_lane_s8(v, 6)  + (int16_t)vgetq_lane_s8(v, 7)  +
+        (int16_t)vgetq_lane_s8(v, 8)  + (int16_t)vgetq_lane_s8(v, 9)  +
+        (int16_t)vgetq_lane_s8(v, 10) + (int16_t)vgetq_lane_s8(v, 11) +
+        (int16_t)vgetq_lane_s8(v, 12) + (int16_t)vgetq_lane_s8(v, 13) +
+        (int16_t)vgetq_lane_s8(v, 14) + (int16_t)vgetq_lane_s8(v, 15);
+}
+
 inline static int32_t vaddvq_s16(int16x8_t v) {
     return
         (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
@@ -585,6 +633,13 @@ typedef struct {
 } block_q4_1;
 static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
 
+#define QK4_2 16
+typedef struct {
+    ggml_fp16_t d;         // delta
+    uint8_t qs[QK4_2 / 2]; // nibbles / quants
+} block_q4_2;
+static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding");
+
 #define QK8_0 32
 typedef struct {
     float   d;          // delta
@@ -1045,6 +1100,49 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
 #endif
 }
 
+// reference implementation for deterministic creation of model files
+static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * restrict y, int k) {
+    assert(k % QK4_2 == 0);
+
+    const int nb = k / QK4_2;
+
+    for (int i = 0; i < nb; i++) {
+        float amax = 0.0f; // absolute max
+
+        for (int l = 0; l < QK4_2; l++) {
+            const float v = x[i*QK4_2 + l];
+            amax = MAX(amax, fabsf(v));
+        }
+
+        const float d = amax / ((1 << 3) - 1);
+
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+
+        for (int l = 0; l < QK4_2; l += 2) {
+            const float v0 = x[i*QK4_2 + l + 0]*id;
+            const float v1 = x[i*QK4_2 + l + 1]*id;
+
+            const uint8_t vi0 = (uint8_t)(v0 + 8.5f);
+            const uint8_t vi1 = (uint8_t)(v1 + 8.5f);
+
+            assert(vi0 < 16);
+            assert(vi1 < 16);
+
+            y[i].qs[l/2] = vi0 | (vi1 << 4);
+        }
+    }
+}
+
+static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int k) {
+    assert(k % QK4_2 == 0);
+
+    block_q4_2 * restrict y = vy;
+
+    quantize_row_q4_2_reference(x, y, k);
+}
+
 // reference implementation for deterministic creation of model files
 static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) {
     assert(k % QK8_0 == 0);
@@ -1064,7 +1162,7 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
         y[i].d = d;
 
         for (int l = 0; l < QK8_0; ++l) {
-            const float   v  = x[i*QK8_0 + l]*id;
+            const float v = x[i*QK8_0 + l]*id;
             y[i].qs[l] = roundf(v);
         }
     }
@@ -1420,6 +1518,77 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
 #endif
 }
 
+static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, int k) {
+    assert(k % QK4_2 == 0);
+    const int nb = k / QK4_2;
+
+    const block_q4_2 * restrict x = vx;
+
+    for (int i = 0; i < nb; i++) {
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict pp = x[i].qs;
+
+        for (int l = 0; l < QK4_2; l += 2) {
+            const uint8_t vi = pp[l/2];
+
+            const int8_t vi0 = vi & 0xf;
+            const int8_t vi1 = vi >> 4;
+
+            const float v0 = (vi0 - 8)*d;
+            const float v1 = (vi1 - 8)*d;
+
+            y[i*QK4_2 + l + 0] = v0;
+            y[i*QK4_2 + l + 1] = v1;
+
+            assert(!isnan(y[i*QK4_2 + l + 0]));
+            assert(!isnan(y[i*QK4_2 + l + 1]));
+        }
+    }
+}
+
+static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+
+static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
+    [GGML_TYPE_Q4_0] = {
+        .dequantize_row_q         = dequantize_row_q4_0,
+        .quantize_row_q           = quantize_row_q4_0,
+        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
+        .quantize_row_q_dot       = quantize_row_q8_0,
+        .vec_dot_q                = ggml_vec_dot_q4_0_q8_0,
+    },
+    [GGML_TYPE_Q4_1] = {
+        .dequantize_row_q         = dequantize_row_q4_1,
+        .quantize_row_q           = quantize_row_q4_1,
+        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
+        .quantize_row_q_dot       = quantize_row_q8_0,
+        .vec_dot_q                = ggml_vec_dot_q4_1_q8_0,
+    },
+    [GGML_TYPE_Q4_2] = {
+        .dequantize_row_q         = dequantize_row_q4_2,
+        .quantize_row_q           = quantize_row_q4_2,
+        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_reference,
+        .quantize_row_q_dot       = quantize_row_q8_0,
+        .vec_dot_q                = ggml_vec_dot_q4_2_q8_0,
+    },
+    [GGML_TYPE_Q8_0] = {
+        .dequantize_row_q         = NULL,   // TODO
+        .quantize_row_q           = quantize_row_q8_0,
+        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_0_reference,
+        .quantize_row_q_dot       = quantize_row_q8_0,
+        .vec_dot_q                = NULL,   // TODO
+    },
+};
+
+// For internal test use
+quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
+    GGML_ASSERT(i < GGML_TYPE_COUNT);
+    return quantize_fns[i];
+}
+
+
 //
 // simd mappings
 //
@@ -1976,37 +2145,6 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
     *s = sumf;
 }
 
-#if __AVX512F__ && QK4_0 == 32
-static inline __m512 dot_q4_0_oneblock_avx512(
-    __m512 acc,
-    const block_q4_0 * restrict x,
-    const block_q4_0 * restrict y,
-    int i
-) {
-    // Compute combined scale for the block
-    __m512 d = _mm512_set1_ps( x[i].d * y[i].d );
-
-    __m256i bx = bytesFromNibbles( x[i].qs );
-    __m256i by = bytesFromNibbles( y[i].qs );
-
-    // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
-    const __m256i off = _mm256_set1_epi8( 8 );
-    bx = _mm256_sub_epi8( bx, off );
-    by = _mm256_sub_epi8( by, off );
-
-    // Sign-extend 16 signed bytes into int16_t
-    __m512i x32 = _mm512_cvtepi8_epi16( bx );
-    __m512i y32 = _mm512_cvtepi8_epi16( by );
-    // Compute products of int16_t integers, add pairwise
-    __m512i i64 = _mm512_madd_epi16( x32, y32 );
-
-    // Convert int32_t to float
-    __m512 p = _mm512_cvtepi32_ps( i64 );
-    // Apply the scale, and accumulate
-    return _mm512_fmadd_ps( d, p, acc );
-}
-#endif
-
 inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
     ggml_float sumf = 0.0;
 
@@ -2043,67 +2181,64 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
     *s = sumf;
 }
 
-static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
-    const int nb = n / QK4_0;
+static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    const int nb = n / QK8_0;
 
-    assert(n % QK4_0 == 0);
+    assert(n % QK8_0 == 0);
     assert(nb % 2 == 0);
 
     const block_q4_0 * restrict x = vx;
-    const block_q4_0 * restrict y = vy;
+    const block_q8_0 * restrict y = vy;
 
     float sumf = 0.0;
 
 #if defined(__ARM_NEON)
-    float sum0 = 0.0f;
-    float sum1 = 0.0f;
+    float32x4_t sumv0 = vdupq_n_f32(0.0f);
+    float32x4_t sumv1 = vdupq_n_f32(0.0f);
 
     for (int i = 0; i < nb; i += 2) {
         const block_q4_0 * restrict x0 = &x[i + 0];
-        const block_q4_0 * restrict y0 = &y[i + 0];
         const block_q4_0 * restrict x1 = &x[i + 1];
-        const block_q4_0 * restrict y1 = &y[i + 1];
+        const block_q8_0 * restrict y0 = &y[i + 0];
+        const block_q8_0 * restrict y1 = &y[i + 1];
 
-        const uint8x16_t m4b = vdupq_n_u8(0xf);
-        const int8x16_t  s8b = vdupq_n_s8(0x8);
+        const uint8x16_t m4b   = vdupq_n_u8(0xf);
+        const int8x16_t  s8b   = vdupq_n_s8(0x8);
 
         const uint8x16_t v0_0 = vld1q_u8(x0->qs);
-        const uint8x16_t v1_0 = vld1q_u8(y0->qs);
         const uint8x16_t v0_1 = vld1q_u8(x1->qs);
-        const uint8x16_t v1_1 = vld1q_u8(y1->qs);
 
         // 4-bit -> 8-bit
-        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
-        const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b));
+        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
         const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
-        const int8x16_t v1_0h = vreinterpretq_s8_u8(vshrq_n_u8(v1_0, 4));
-
-        const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
-        const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b));
+        const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
         const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
-        const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4));
 
         // sub 8
         const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
-        const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
         const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
-        const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
-
         const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
-        const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
         const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
-        const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
+
+        // load y
+        const int8x16_t v1_0l = vld1q_s8(y0->qs);
+        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
+        const int8x16_t v1_1l = vld1q_s8(y1->qs);
+        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
+
+        // interleave
+        const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
+        const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
+        const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
+        const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
 
 #if defined(__ARM_FEATURE_DOTPROD)
         // dot product into int32x4_t
-        int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
-        int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
+        const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs);
+        const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs);
 
-        p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
-        p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
-
-        sum0 += x0->d*y0->d*vaddvq_s32(p_0);
-        sum1 += x1->d*y1->d*vaddvq_s32(p_1);
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
 #else
         const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
         const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
@@ -2115,115 +2250,51 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
         const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
         const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
 
-        const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h);
-        const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h);
-
-        const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h);
-        const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
-
-        const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
-        const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
+        const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
+        const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
+        const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
+        const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
 
-        sum0 += x0->d*y0->d*vaddvq_s16(p_0);
-        sum1 += x1->d*y1->d*vaddvq_s16(p_1);
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d);
 #endif
     }
 
-    sumf = sum0 + sum1;
-#elif defined(__AVX512F__)
+    sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
+#elif defined(__AVX2__)
     // Initialize accumulator with zeros
-    __m512 acc0 = _mm512_setzero_ps();
-    __m512 acc1 = _mm512_setzero_ps();
+    __m256 acc = _mm256_setzero_ps();
 
-    const int superblock_size = 8;
-    const int superblock_count = nb / superblock_size;
+    // Main loop
+    for (int i = 0; i < nb; ++i) {
+        /* Compute combined scale for the block */
+        const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
 
-    for (int superblock_ix = 0; superblock_ix < superblock_count; superblock_ix += 1) {
-        int i = superblock_ix * superblock_size;
+        __m256i bx = bytesFromNibbles(x[i].qs);
 
-        acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+0 );
-        acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+1 );
-        acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+2 );
-        acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+3 );
-        acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+4 );
-        acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+5 );
-        acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+6 );
-        acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+7 );
-    }
+        // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
+        const __m256i off = _mm256_set1_epi8( 8 );
+        bx = _mm256_sub_epi8( bx, off );
 
-    // Remainders
-    for (int i = superblock_count * superblock_size; i < nb; ++i) {
-        acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i );
-    }
+        __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
 
-    // Horizontal sum of all lanes of the accumulator
-    sumf = _mm512_reduce_add_ps( acc0 ) + _mm512_reduce_add_ps( acc1 );
-#elif defined(__AVX2__)
-    // Initialize accumulator with zeros
-    __m256 acc = _mm256_setzero_ps();
+        // Get absolute values of x vectors
+        const __m256i ax = _mm256_sign_epi8(bx, bx);
 
-    /* Prepare the constants we will need during execution */
-    const __m256i lowMask = _mm256_set1_epi8( 0xF );
-    const __m256i offset_8 = _mm256_set1_epi16( 8 );
+        // Sign the values of the y vectors
+        const __m256i sy = _mm256_sign_epi8(by, bx);
 
-#define UNROLL_COUNT 8
-    // make sure we only unroll multiples of the block count
-    assert(nb % UNROLL_COUNT == 0);
+        // Perform multiplication and create 16-bit values
+        const __m256i dot = _mm256_maddubs_epi16(ax, sy);
 
-    // Main loop
-    for (int i = 0; i < nb; i+=UNROLL_COUNT) {
-        // This loop will be unrolled by the compiler
-        for (int u=0;u<UNROLL_COUNT;u++)  {
-            /* Compute combined scale for the block */
-            const __m256 scale = _mm256_mul_ps(
-                    _mm256_broadcast_ss( &x[i+u].d ),
-                    _mm256_broadcast_ss( &y[i+u].d ) );
-
-            /* get input from x
-               Input: 32 Nibbles (16 bytes) at *x[i+u]
-               Output: 2 vectors with 16 values of type int16_t (x_high_q, x_low_q) */
-
-            /* Load 16 bytes from memory */
-            const __m128i tmp_x = _mm_loadu_si128( ( const __m128i* ) x[i+u].qs);
-            /* Expand bytes into uint16_t values */
-            const __m256i bytes_x = _mm256_cvtepu8_epi16(tmp_x);
-            /* Unpack values into individual bytes */
-            __m256i x_low_q = _mm256_and_si256( lowMask, bytes_x );
-            const __m256i pre_shift_x_high_q = _mm256_andnot_si256( lowMask, bytes_x );
-            __m256i x_high_q = _mm256_srli_epi16( pre_shift_x_high_q, 4 );
-            /* Now we have two vectors with bytes in [ 0 .. 15 ] interval.  Offset them into [ -8 .. +7 ] interval.  */
-            x_high_q = _mm256_sub_epi16( x_high_q, offset_8 );
-            x_low_q = _mm256_sub_epi16( x_low_q, offset_8 );
-
-            /* get input from y
-               Input: 32 Nibbles (16 bytes) at *y[i+u]
-               Output: 2 vectors with 16 values of type int16_t (y_high_q, y_low_q) */
-
-            /* Load 16 bytes from memory */
-            const __m128i tmp_y = _mm_loadu_si128( (const __m128i* ) y[i+u].qs);
-            /* Expand bytes into uint16_t values */
-            const __m256i bytes_y = _mm256_cvtepu8_epi16(tmp_y);
-            /* Unpack values into individual bytes */
-            const __m256i pre_shift_y_high_q = _mm256_andnot_si256( lowMask, bytes_y );
-            __m256i y_high_q = _mm256_srli_epi16( pre_shift_y_high_q, 4 );
-            __m256i y_low_q = _mm256_and_si256( lowMask, bytes_y );
-            /* Now we have two vectors with bytes in [ 0 .. 15 ] interval.  Offset them into [ -8 .. +7 ] interval.  */
-            y_high_q = _mm256_sub_epi16( y_high_q, offset_8 );
-            y_low_q = _mm256_sub_epi16( y_low_q, offset_8 );
-
-            /* Compute products of int16_t integers, add pairwise, store as int32_t */
-            __m256i xy_high_q = _mm256_madd_epi16( x_high_q, y_high_q );
-            __m256i xy_low_q = _mm256_madd_epi16( x_low_q, y_low_q );
-
-            /* Accumulate the products of int32_t integers -> we now have a vector of 8 int_32t */
-            __m256i xy_q = _mm256_add_epi32( xy_high_q, xy_low_q );
-
-            /* Convert to vectore of 8 int32_t to 8 floats */
-            __m256 q = _mm256_cvtepi32_ps( xy_q );
-
-            /* Multiply q with scale and accumulate */
-            acc = _mm256_fmadd_ps( scale, q, acc );
-        }
+        const __m256i ones = _mm256_set1_epi16(1);
+        __m256i xy_q = _mm256_madd_epi16(ones, dot);
+
+        /* Convert to vectore of 8 int32_t to 8 floats */
+        __m256 q = _mm256_cvtepi32_ps( xy_q );
+
+        /* Multiply q with scale and accumulate */
+        acc = _mm256_fmadd_ps( d, q, acc );
     }
 
     // Return horizontal sum of the acc vector
@@ -2246,12 +2317,11 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
         for (int j = 0; j < 2; ++j) {
             // Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
             __m128i bx = bytesFromNibbles( x[i].qs + 8*j );
-            __m128i by = bytesFromNibbles( y[i].qs + 8*j );
+            __m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));
 
             // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
             const __m128i off = _mm_set1_epi8( 8 );
             bx = _mm_sub_epi8( bx, off );
-            by = _mm_sub_epi8( by, off );
 
             // Get absolute values of x vectors
             const __m128i ax = _mm_sign_epi8(bx, bx);
@@ -2279,86 +2349,6 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
     res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
 
     sumf = _mm_cvtss_f32( res );
-#elif defined(__wasm_simd128__)
-    // wasm simd
-    float sum0 = 0.0f;
-    float sum1 = 0.0f;
-
-    for (int i = 0; i < nb; i += 2) {
-        const block_q4_0 * restrict x0 = &x[i + 0];
-        const block_q4_0 * restrict y0 = &y[i + 0];
-        const block_q4_0 * restrict x1 = &x[i + 1];
-        const block_q4_0 * restrict y1 = &y[i + 1];
-
-        const v128_t m4b = wasm_u8x16_splat(0xf);
-        const v128_t s8b = wasm_i8x16_splat(0x8);
-
-        const v128_t v0_0 = wasm_v128_load(x0->qs);
-        const v128_t v0_1 = wasm_v128_load(y0->qs);
-        const v128_t v1_0 = wasm_v128_load(x1->qs);
-        const v128_t v1_1 = wasm_v128_load(y1->qs);
-
-        // 4-bit -> 8-bit
-        const v128_t v0_0l = wasm_v128_and(v0_0, m4b);
-        const v128_t v1_0l = wasm_v128_and(v1_0, m4b);
-
-        const v128_t v0_0h = wasm_u8x16_shr(v0_0, 4);
-        const v128_t v1_0h = wasm_u8x16_shr(v1_0, 4);
-
-        const v128_t v0_1l = wasm_v128_and(v0_1, m4b);
-        const v128_t v1_1l = wasm_v128_and(v1_1, m4b);
-
-        const v128_t v0_1h = wasm_u8x16_shr(v0_1, 4);
-        const v128_t v1_1h = wasm_u8x16_shr(v1_1, 4);
-
-        // sub 8
-        const v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b);
-        const v128_t v1_0ls = wasm_i8x16_sub(v1_0l, s8b);
-
-        const v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b);
-        const v128_t v1_0hs = wasm_i8x16_sub(v1_0h, s8b);
-
-        const v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b);
-        const v128_t v1_1ls = wasm_i8x16_sub(v1_1l, s8b);
-
-        const v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b);
-        const v128_t v1_1hs = wasm_i8x16_sub(v1_1h, s8b);
-
-        // dot product into int16x8_t
-        const v128_t pl0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0ls), wasm_i16x8_extend_low_i8x16(v1_0ls));
-        const v128_t pl0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0ls), wasm_i16x8_extend_high_i8x16(v1_0ls));
-
-        const v128_t ph0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0hs), wasm_i16x8_extend_low_i8x16(v1_0hs));
-        const v128_t ph0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0hs), wasm_i16x8_extend_high_i8x16(v1_0hs));
-
-        const v128_t pl1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1ls), wasm_i16x8_extend_low_i8x16(v1_1ls));
-        const v128_t pl1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1ls), wasm_i16x8_extend_high_i8x16(v1_1ls));
-
-        const v128_t ph1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1hs), wasm_i16x8_extend_low_i8x16(v1_1hs));
-        const v128_t ph1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1hs), wasm_i16x8_extend_high_i8x16(v1_1hs));
-
-        const v128_t pl_0 = wasm_i16x8_add(pl0l, pl0h);
-        const v128_t ph_0 = wasm_i16x8_add(ph0l, ph0h);
-
-        const v128_t pl_1 = wasm_i16x8_add(pl1l, pl1h);
-        const v128_t ph_1 = wasm_i16x8_add(ph1l, ph1h);
-
-        const v128_t p_0 = wasm_i16x8_add(pl_0, ph_0);
-        const v128_t p_1 = wasm_i16x8_add(pl_1, ph_1);
-
-        sum0 += x0->d * y0->d * (
-                wasm_i16x8_extract_lane(p_0, 0) + wasm_i16x8_extract_lane(p_0, 1) +
-                wasm_i16x8_extract_lane(p_0, 2) + wasm_i16x8_extract_lane(p_0, 3) +
-                wasm_i16x8_extract_lane(p_0, 4) + wasm_i16x8_extract_lane(p_0, 5) +
-                wasm_i16x8_extract_lane(p_0, 6) + wasm_i16x8_extract_lane(p_0, 7));
-        sum1 += x1->d * y1->d * (
-                wasm_i16x8_extract_lane(p_1, 0) + wasm_i16x8_extract_lane(p_1, 1) +
-                wasm_i16x8_extract_lane(p_1, 2) + wasm_i16x8_extract_lane(p_1, 3) +
-                wasm_i16x8_extract_lane(p_1, 4) + wasm_i16x8_extract_lane(p_1, 5) +
-                wasm_i16x8_extract_lane(p_1, 6) + wasm_i16x8_extract_lane(p_1, 7));
-    }
-
-    sumf = sum0 + sum1;
 #else
     // scalar
     for (int i = 0; i < nb; i++) {
@@ -2366,202 +2356,187 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
         const float d1 = y[i].d;
 
         const uint8_t * restrict p0 = x[i].qs;
-        const uint8_t * restrict p1 = y[i].qs;
+        const  int8_t * restrict p1 = y[i].qs;
 
         int sumi = 0;
-        for (int j = 0; j < QK4_0/2; j++) {
+        for (int j = 0; j < QK8_0/2; j++) {
             const uint8_t v0 = p0[j];
-            const uint8_t v1 = p1[j];
 
-            const int8_t i0 = (int8_t) (v0 & 0xf) - 8;
-            const int8_t i1 = (int8_t) (v0 >> 4)  - 8;
+            const int i0 = (int8_t) (v0 & 0xf) - 8;
+            const int i1 = (int8_t) (v0 >> 4)  - 8;
 
-            const int8_t i2 = (int8_t) (v1 & 0xf) - 8;
-            const int8_t i3 = (int8_t) (v1 >> 4)  - 8;
+            const int i2 = p1[2*j + 0];
+            const int i3 = p1[2*j + 1];
 
             sumi += i0*i2 + i1*i3;
         }
-        sumf += d0 * d1 * sumi;
+        sumf += d0*d1*sumi;
     }
 #endif
 
     *s = sumf;
 }
 
-static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
-    const int nb = n / QK4_1;
+static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    const int nb = n / QK8_0;
+
+    assert(n % QK8_0 == 0);
+    assert(nb % 2 == 0);
 
     const block_q4_1 * restrict x = vx;
-    const block_q4_1 * restrict y = vy;
+    const block_q8_0 * restrict y = vy;
 
     float sumf = 0.0;
 
-#if defined(__AVX2__)
-    // Initialize accumulator with zeros
-    __m256 acc = _mm256_setzero_ps();
-    // Accumulator for constant offsets
-    float acc_offset = 0.0f;
-
-    // Main loop
-    for (int i = 0; i < nb; ++i) {
-        const float * d0 = &x[i].d;
-        const float * d1 = &y[i].d;
-
-        const float * m0 = &x[i].m;
-        const float * m1 = &y[i].m;
-
-        const __m256 d0v = _mm256_broadcast_ss( d0 );
-        const __m256 d1v = _mm256_broadcast_ss( d1 );
-        const __m256 m0v = _mm256_broadcast_ss( m0 );
-        const __m256 m1v = _mm256_broadcast_ss( m1 );
-
-        // Compute combined scale for the block
-        const __m256 scale_01 = _mm256_mul_ps( d0v, d1v );
-
-        // Compute cross scales for the block
-        const __m256 scale_0 = _mm256_mul_ps( d0v, m1v );
-        const __m256 scale_1 = _mm256_mul_ps( m0v, d1v );
-        const __m256 cross_scales = _mm256_blend_ps( scale_0, scale_1, 0xAA /* 0b10101010 */ );
-
-        // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
-        __m256i bx = bytesFromNibbles( x[i].qs );
-        __m256i by = bytesFromNibbles( y[i].qs );
-
-        // Now we have a vector with bytes in [ 0 .. 15 ] interval.
-
-        // Sign-extend first 16 signed bytes into int16_t
-        __m256i x16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( bx ) );
-        __m256i y16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
-        // Compute products of int16_t integers, add pairwise
-        __m256i i32 = _mm256_madd_epi16( x16, y16 );
-
-        // Sign-extend last 16 signed bytes into int16_t vectors
-        __m256i x16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( bx, 1 ) );
-        __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
-        // Accumulate products of int16_t integers
-        i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16_h, y16_h ) );
-
-        // compute sums of unsigned bytes in bx, by in blocks of 8.
-        // This results in a layout like X100 0000 X200 0000 X300 0000 X400 0000,
-        // which we then interleave as X100 Y100 X200 Y200 X300 Y300 X400 Y400.
-        // so if we then cast to 8 singles, we get 8 floats like [ x0_7, y0_7, x8_15, y8_15, x16_23, y16_23, x24_31, y24_31 ]
-        __m256i xsumi = _mm256_sad_epu8( bx, _mm256_setzero_si256() );
-        __m256i ysumi = _mm256_sad_epu8( by, _mm256_setzero_si256() );
-        __m256i sumsi = _mm256_or_si256( xsumi, _mm256_slli_si256( ysumi, 4 ) );
-        __m256  sums  = _mm256_cvtepi32_ps( sumsi );
-
-        // Convert int32_t to float
-        __m256 p = _mm256_cvtepi32_ps( i32 );
-        // Apply the scale, and accumulate
-        // acc += d0*d1*x*y + d0*m1*x + d1*m0*y
-        acc = _mm256_fmadd_ps( scale_01, p, acc );
-        acc = _mm256_fmadd_ps( cross_scales, sums, acc );
-        // acc_offset += m0*m1 (for each entry in the block)
-        acc_offset += (*m0)*(*m1);
-    }
-
-    // Return horizontal sum of the acc vector
-    __m128 res = _mm256_extractf128_ps( acc, 1 );
-    res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
-    res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
-    res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
-
-    sumf = _mm_cvtss_f32( res ) + acc_offset * QK4_1;
-#elif defined(__ARM_NEON)
-    float sum00 = 0.0f;
-    float sum01 = 0.0f;
-    float sum10 = 0.0f;
-    float sum11 = 0.0f;
+    // TODO: add AVX / WASM SIMD / etc
+#if defined(__ARM_NEON)
+    float32x4_t sumv0 = vdupq_n_f32(0.0f);
+    float32x4_t sumv1 = vdupq_n_f32(0.0f);
 
     for (int i = 0; i < nb; i += 2) {
         const block_q4_1 * restrict x0 = &x[i + 0];
-        const block_q4_1 * restrict y0 = &y[i + 0];
         const block_q4_1 * restrict x1 = &x[i + 1];
-        const block_q4_1 * restrict y1 = &y[i + 1];
+        const block_q8_0 * restrict y0 = &y[i + 0];
+        const block_q8_0 * restrict y1 = &y[i + 1];
 
         const uint8x16_t m4b = vdupq_n_u8(0xf);
 
         const uint8x16_t v0_0 = vld1q_u8(x0->qs);
-        const uint8x16_t v1_0 = vld1q_u8(y0->qs);
         const uint8x16_t v0_1 = vld1q_u8(x1->qs);
-        const uint8x16_t v1_1 = vld1q_u8(y1->qs);
 
         // 4-bit -> 8-bit
-        const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
-        const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
-        const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
-        const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
+        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
+        const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+        const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
+        const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+        // load y
+        const int8x16_t v1_0l = vld1q_s8(y0->qs);
+        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
+        const int8x16_t v1_1l = vld1q_s8(y1->qs);
+        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
+
+        // interleave
+        const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
+        const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
+        const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
+        const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
 
-        const uint8x16_t v0_1l = vandq_u8(v0_1, m4b);
-        const uint8x16_t v1_1l = vandq_u8(v1_1, m4b);
-        const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4);
-        const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
+        const int16x8_t s0i = vaddq_s16(
+                        vaddq_s16(vmovl_s8(vget_low_s8(v1_0ls)), vmovl_s8(vget_high_s8(v1_0ls))),
+                        vaddq_s16(vmovl_s8(vget_low_s8(v1_0hs)), vmovl_s8(vget_high_s8(v1_0hs))));
 
-        sum00 += x0->m*y0->m;
-        sum01 += y0->m*x0->d*((uint16_t)vaddvq_u8(v0_0l) + (uint16_t)vaddvq_u8(v0_0h));
-        sum10 += x0->m*y0->d*((uint16_t)vaddvq_u8(v1_0l) + (uint16_t)vaddvq_u8(v1_0h));
+        const int16x8_t s1i = vaddq_s16(
+                        vaddq_s16(vmovl_s8(vget_low_s8(v1_1ls)), vmovl_s8(vget_high_s8(v1_1ls))),
+                        vaddq_s16(vmovl_s8(vget_low_s8(v1_1hs)), vmovl_s8(vget_high_s8(v1_1hs))));
 
-        sum00 += x1->m*y1->m;
-        sum01 += y1->m*x1->d*((uint16_t)vaddvq_u8(v0_1l) + (uint16_t)vaddvq_u8(v0_1h));
-        sum10 += x1->m*y1->d*((uint16_t)vaddvq_u8(v1_1l) + (uint16_t)vaddvq_u8(v1_1h));
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s0i), vget_high_s16(s0i))), x0->m*y0->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s1i), vget_high_s16(s1i))), x1->m*y1->d);
 
 #if defined(__ARM_FEATURE_DOTPROD)
         // dot product into int32x4_t
-        uint32x4_t p_0 = vdotq_u32(vdupq_n_u32(0), v0_0l, v1_0l);
-        uint32x4_t p_1 = vdotq_u32(vdupq_n_u32(0), v0_1l, v1_1l);
-
-        p_0 = vdotq_u32(p_0, v0_0h, v1_0h);
-        p_1 = vdotq_u32(p_1, v0_1h, v1_1h);
+        const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs);
+        const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs);
 
-        sum11 += x0->d*y0->d*vaddvq_u32(p_0);
-        sum11 += x1->d*y1->d*vaddvq_u32(p_1);
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
 #else
-        const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
-        const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
-        const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
-        const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
+        const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0ls));
+        const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0ls));
+        const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0hs));
+        const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0hs));
+
+        const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1ls));
+        const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1ls));
+        const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1hs));
+        const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1hs));
+
+        const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
+        const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
+        const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
+        const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d);
+#endif
+    }
 
-        const uint16x8_t pl1l = vmull_u8(vget_low_u8 (v0_1l), vget_low_u8 (v1_1l));
-        const uint16x8_t pl1h = vmull_u8(vget_high_u8(v0_1l), vget_high_u8(v1_1l));
-        const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h));
-        const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h));
+    sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
+#elif defined(__AVX2__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
 
-        const uint16x8_t pl_0 = vaddq_u16(pl0l, pl0h);
-        const uint16x8_t ph_0 = vaddq_u16(ph0l, ph0h);
+    // Main loop
+    for (int i = 0; i < nb; ++i) {
+        const float * d0 = &x[i].d;
+        const float * d1 = &y[i].d;
+        const float * m0 = &x[i].m;
 
-        const uint16x8_t pl_1 = vaddq_u16(pl1l, pl1h);
-        const uint16x8_t ph_1 = vaddq_u16(ph1l, ph1h);
+        const __m256 d0v = _mm256_broadcast_ss( d0 );
+        const __m256 d1v = _mm256_broadcast_ss( d1 );
+        const __m256 m0v = _mm256_broadcast_ss( m0 );
 
-        const uint16x8_t p_0 = vaddq_u16(pl_0, ph_0);
-        const uint16x8_t p_1 = vaddq_u16(pl_1, ph_1);
+        // Compute combined scales
+        const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
+        const __m256 d1m0 = _mm256_mul_ps( d1v, m0v );
 
-        sum11 += x0->d*y0->d*vaddvq_u16(p_0);
-        sum11 += x1->d*y1->d*vaddvq_u16(p_1);
-#endif
+        // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
+        const __m256i bx = bytesFromNibbles( x[i].qs );
+        const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
+
+        // Get absolute values of x vectors
+        const __m256i ax = _mm256_sign_epi8( bx, bx );
+
+        // Sign the values of the y vectors
+        const __m256i sy = _mm256_sign_epi8( by, bx );
+
+        // Perform multiplication and create 16-bit values
+        const __m256i dot = _mm256_maddubs_epi16( ax, sy );
+        const __m256i ones = _mm256_set1_epi16( 1 );
+        const __m256i xy_q = _mm256_madd_epi16( ones, dot );
+
+        // Convert to vector of 8 int32_t to 8 floats
+        const __m256 xy = _mm256_cvtepi32_ps( xy_q );
+
+        // Accumulate d0*d1*x*y
+        acc = _mm256_fmadd_ps( d0d1, xy, acc );
+
+        // Compute sum of y values
+        const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
+        const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
+        const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones );
+        const __m256 ysum = _mm256_cvtepi32_ps( ysumi );
+
+        // Accumulate d1*m0*y
+        acc = _mm256_fmadd_ps( d1m0, ysum, acc );
     }
 
-    sumf = QK4_1*sum00 + sum01 + sum10 + sum11;
+    // Return horizontal sum of the acc vector
+    __m128 res = _mm256_extractf128_ps( acc, 1 );
+    res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
+    res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
+    res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
+
+    sumf = _mm_cvtss_f32( res );
 #else
     // scalar
     for (int i = 0; i < nb; i++) {
         const float d0 = x[i].d;
-        const float d1 = y[i].d;
-
         const float m0 = x[i].m;
-        const float m1 = y[i].m;
+        const float d1 = y[i].d;
 
         const uint8_t * restrict p0 = x[i].qs;
-        const uint8_t * restrict p1 = y[i].qs;
+        const  int8_t * restrict p1 = y[i].qs;
 
-        for (int j = 0; j < QK4_1/2; j++) {
+        // TODO: this is very slow ..
+        for (int j = 0; j < QK8_0/2; j++) {
             const uint8_t v0 = p0[j];
-            const uint8_t v1 = p1[j];
 
             const float f0 = d0*(v0 & 0xf) + m0;
             const float f1 = d0*(v0 >> 4)  + m0;
 
-            const float f2 = d1*(v1 & 0xf) + m1;
-            const float f3 = d1*(v1 >> 4)  + m1;
+            const float f2 = d1*p1[2*j + 0];
+            const float f3 = d1*p1[2*j + 1];
 
             sumf += f0*f2 + f1*f3;
         }
@@ -2571,32 +2546,35 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
     *s = sumf;
 }
 
-static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
     const int nb = n / QK8_0;
 
     assert(n % QK8_0 == 0);
     assert(nb % 2 == 0);
+    assert(QK8_0 == 2*QK4_2);
 
-    const block_q4_0 * restrict x = vx;
+    const block_q4_2 * restrict x = vx;
     const block_q8_0 * restrict y = vy;
 
     float sumf = 0.0;
 
 #if defined(__ARM_NEON)
-    float sum0 = 0.0f;
-    float sum1 = 0.0f;
+    float32x4_t sumv0 = vdupq_n_f32(0.0f);
+    float32x4_t sumv1 = vdupq_n_f32(0.0f);
 
     for (int i = 0; i < nb; i += 2) {
-        const block_q4_0 * restrict x0 = &x[i + 0];
-        const block_q4_0 * restrict x1 = &x[i + 1];
+        const block_q4_2 * restrict x0_0 = &x[2*(i + 0) + 0];
+        const block_q4_2 * restrict x0_1 = &x[2*(i + 0) + 1];
+        const block_q4_2 * restrict x1_0 = &x[2*(i + 1) + 0];
+        const block_q4_2 * restrict x1_1 = &x[2*(i + 1) + 1];
         const block_q8_0 * restrict y0 = &y[i + 0];
         const block_q8_0 * restrict y1 = &y[i + 1];
 
         const uint8x16_t m4b   = vdupq_n_u8(0xf);
         const int8x16_t  s8b   = vdupq_n_s8(0x8);
 
-        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
-        const uint8x16_t v0_1 = vld1q_u8(x1->qs);
+        const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
+        const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs));
 
         // 4-bit -> 8-bit
         const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
@@ -2610,164 +2588,88 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
         const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
         const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
 
-        // load y
-        const int8x16_t v1_0l = vld1q_s8(y0->qs);
-        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
-        const int8x16_t v1_1l = vld1q_s8(y1->qs);
-        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
-
-        // interleave
-        const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
-        const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
-        const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
-        const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
-
-#if defined(__ARM_FEATURE_DOTPROD)
-        // dot product into int32x4_t
-        int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
-        int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
-
-        p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
-        p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
-
-        sum0 += x0->d*y0->d*vaddvq_s32(p_0);
-        sum1 += x1->d*y1->d*vaddvq_s32(p_1);
-#else
-        const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
-        const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
-        const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
-        const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
-
-        const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
-        const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
-        const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
-        const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
-
-        const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h);
-        const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h);
-
-        const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h);
-        const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
-
-        const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
-        const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
-
-        sum0 += x0->d*y0->d*vaddvq_s16(p_0);
-        sum1 += x1->d*y1->d*vaddvq_s16(p_1);
-#endif
-    }
-
-    sumf = sum0 + sum1;
-#elif defined(__AVX2__)
-    // Initialize accumulator with zeros
-    __m256 acc = _mm256_setzero_ps();
-
-    // Main loop
-    for (int i = 0; i < nb; ++i) {
-        /* Compute combined scale for the block */
-        const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
-
-        __m256i bx = bytesFromNibbles(x[i].qs);
-
-        // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
-        const __m256i off = _mm256_set1_epi8( 8 );
-        bx = _mm256_sub_epi8( bx, off );
-
-        __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
-
-        // Get absolute values of x vectors
-        const __m256i ax = _mm256_sign_epi8(bx, bx);
-
-        // Sign the values of the y vectors
-        const __m256i sy = _mm256_sign_epi8(by, bx);
-
-        // Perform multiplication and create 16-bit values
-        const __m256i dot = _mm256_maddubs_epi16(ax, sy);
-
-        const __m256i ones = _mm256_set1_epi16(1);
-        __m256i xy_q = _mm256_madd_epi16(ones, dot);
-
-        /* Convert to vectore of 8 int32_t to 8 floats */
-        __m256 q = _mm256_cvtepi32_ps( xy_q );
-
-        /* Multiply q with scale and accumulate */
-        acc = _mm256_fmadd_ps( d, q, acc );
-    }
-
-    // Return horizontal sum of the acc vector
-    __m128 res = _mm256_extractf128_ps( acc, 1 );
-    res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
-    res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
-    res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
-
-    sumf = _mm_cvtss_f32( res );
-#elif defined(__AVX__)
-    // Initialize accumulator with zeros
-    __m256 acc = _mm256_setzero_ps();
-
-    // Main loop
-    for (int i = 0; i < nb; ++i) {
-        // Compute combined scale for the block
-        const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
-
-        __m128i i32[2];
-        for (int j = 0; j < 2; ++j) {
-            // Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
-            __m128i bx = bytesFromNibbles( x[i].qs + 8*j );
-            __m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));
-
-            // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
-            const __m128i off = _mm_set1_epi8( 8 );
-            bx = _mm_sub_epi8( bx, off );
-
-            // Get absolute values of x vectors
-            const __m128i ax = _mm_sign_epi8(bx, bx);
-
-            // Sign the values of the y vectors
-            const __m128i sy = _mm_sign_epi8(by, bx);
-
-            // Perform multiplication and create 16-bit values
-            const __m128i dot = _mm_maddubs_epi16(ax, sy);
+        // interleave
+        const int8x16_t v0_0lz = vzip1q_s8(v0_0ls, v0_0hs);
+        const int8x16_t v0_0hz = vzip2q_s8(v0_0ls, v0_0hs);
+        const int8x16_t v0_1lz = vzip1q_s8(v0_1ls, v0_1hs);
+        const int8x16_t v0_1hz = vzip2q_s8(v0_1ls, v0_1hs);
 
-            const __m128i ones = _mm_set1_epi16(1);
-            i32[j] = _mm_madd_epi16(ones, dot);
-        }
+        // load y
+        const int8x16_t v1_0l = vld1q_s8(y0->qs);
+        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
+        const int8x16_t v1_1l = vld1q_s8(y1->qs);
+        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
 
-        // Convert int32_t to float
-        __m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] ));
-        // Apply the scale, and accumulate
-        acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
-    }
+#if defined(__ARM_FEATURE_DOTPROD)
+        sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
+                vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), GGML_FP16_TO_FP32(x0_0->d)),
+                vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
 
-    // Return horizontal sum of the acc vector
-    __m128 res = _mm256_extractf128_ps( acc, 1 );
-    res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
-    res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
-    res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
+        sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
+                vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), GGML_FP16_TO_FP32(x1_0->d)),
+                vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
+#else
+        const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
+        const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
+        const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
+        const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
+
+        const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
+        const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
+        const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
+        const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
+
+        const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
+        const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
+        const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
+        const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
+
+        sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
+                vmulq_n_f32(vcvtq_f32_s32(pl0), GGML_FP16_TO_FP32(x0_0->d)),
+                vmulq_n_f32(vcvtq_f32_s32(ph0), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
+
+        sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
+                vmulq_n_f32(vcvtq_f32_s32(pl1), GGML_FP16_TO_FP32(x1_0->d)),
+                vmulq_n_f32(vcvtq_f32_s32(ph1), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
+#endif
+    }
 
-    sumf = _mm_cvtss_f32( res );
+    sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
 #else
     // scalar
     for (int i = 0; i < nb; i++) {
-        const float d0 = x[i].d;
-        const float d1 = y[i].d;
+        const uint8_t * restrict x0 = x[2*i + 0].qs;
+        const uint8_t * restrict x1 = x[2*i + 1].qs;
+        const  int8_t * restrict y0 = y[i].qs;
 
-        const uint8_t * restrict p0 = x[i].qs;
-        const  int8_t * restrict p1 = y[i].qs;
+        const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
+        const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
 
-        int sumi = 0;
-        for (int j = 0; j < QK8_0/2; j++) {
-            const uint8_t v0 = p0[j];
+        int sumi_0 = 0;
+        int sumi_1 = 0;
 
-            const int i0 = (int8_t) (v0 & 0xf) - 8;
-            const int i1 = (int8_t) (v0 >> 4)  - 8;
+        for (int j = 0; j < QK8_0/4; j++) {
+            const uint8_t v0 = x0[j];
+            const uint8_t v1 = x1[j];
 
-            const int i2 = p1[2*j + 0];
-            const int i3 = p1[2*j + 1];
+            const int i0_0 = (int8_t) (v0 & 0xf) - 8;
+            const int i1_0 = (int8_t) (v0 >> 4)  - 8;
 
-            sumi += i0*i2 + i1*i3;
+            const int i0_1 = (int8_t) (v1 & 0xf) - 8;
+            const int i1_1 = (int8_t) (v1 >> 4)  - 8;
+
+            const int i2_0 = y0[2*j + 0];
+            const int i3_0 = y0[2*j + 1];
+
+            const int i2_1 = y0[2*(j + QK8_0/4) + 0];
+            const int i3_1 = y0[2*(j + QK8_0/4) + 1];
+
+            sumi_0 += i0_0*i2_0 + i1_0*i3_0;
+            sumi_1 += i0_1*i2_1 + i1_1*i3_1;
         }
-        sumf += d0*d1*sumi;
+
+        sumf += (d0 * y[i].d) * sumi_0;
+        sumf += (d1 * y[i].d) * sumi_1;
     }
 #endif
 
@@ -3020,24 +2922,26 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
     [GGML_TYPE_F16]  = 1,
     [GGML_TYPE_Q4_0] = QK4_0,
     [GGML_TYPE_Q4_1] = QK4_1,
+    [GGML_TYPE_Q4_2] = QK4_2,
     [GGML_TYPE_Q8_0] = QK8_0,
     [GGML_TYPE_I8]   = 1,
     [GGML_TYPE_I16]  = 1,
     [GGML_TYPE_I32]  = 1,
 };
-static_assert(GGML_TYPE_COUNT == 8, "GGML_BLCK_SIZE is outdated");
+static_assert(GGML_TYPE_COUNT == 9, "GGML_BLCK_SIZE is outdated");
 
 static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
     [GGML_TYPE_F32]  = sizeof(float),
     [GGML_TYPE_F16]  = sizeof(ggml_fp16_t),
     [GGML_TYPE_Q4_0] = sizeof(block_q4_0),
     [GGML_TYPE_Q4_1] = sizeof(block_q4_1),
+    [GGML_TYPE_Q4_2] = sizeof(block_q4_2),
     [GGML_TYPE_Q8_0] = sizeof(block_q8_0),
     [GGML_TYPE_I8]   = sizeof(int8_t),
     [GGML_TYPE_I16]  = sizeof(int16_t),
     [GGML_TYPE_I32]  = sizeof(int32_t),
 };
-static_assert(GGML_TYPE_COUNT == 8, "GGML_TYPE_SIZE is outdated");
+static_assert(GGML_TYPE_COUNT == 9, "GGML_TYPE_SIZE is outdated");
 
 
 static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
@@ -3045,12 +2949,26 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
     [GGML_TYPE_F16]  = "f16",
     [GGML_TYPE_Q4_0] = "q4_0",
     [GGML_TYPE_Q4_1] = "q4_1",
+    [GGML_TYPE_Q4_2] = "q4_2",
     [GGML_TYPE_Q8_0] = "q8_0",
     [GGML_TYPE_I8]   = "i8",
     [GGML_TYPE_I16]  = "i16",
     [GGML_TYPE_I32]  = "i32",
 };
-static_assert(GGML_TYPE_COUNT == 8, "GGML_TYPE_NAME is outdated");
+static_assert(GGML_TYPE_COUNT == 9, "GGML_TYPE_NAME is outdated");
+
+static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
+    [GGML_TYPE_F32]  = false,
+    [GGML_TYPE_F16]  = false,
+    [GGML_TYPE_Q4_0] = true,
+    [GGML_TYPE_Q4_1] = true,
+    [GGML_TYPE_Q4_2] = true,
+    [GGML_TYPE_Q8_0] = true,
+    [GGML_TYPE_I8]   = false,
+    [GGML_TYPE_I16]  = false,
+    [GGML_TYPE_I32]  = false,
+};
+static_assert(GGML_TYPE_COUNT == 9, "GGML_IS_QUANTIZED is outdated");
 
 static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
     "NONE",
@@ -3312,6 +3230,10 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
         (t0->ne[3] == t1->ne[3]);
 }
 
+static inline bool ggml_is_quantized(enum ggml_type type) {
+    return GGML_IS_QUANTIZED[type];
+}
+
 static inline bool ggml_is_transposed(const struct ggml_tensor * tensor) {
     return tensor->nb[0] > tensor->nb[1];
 }
@@ -3422,6 +3344,11 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
             GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
         }
 
+        // initialize cuBLAS
+        #if defined(GGML_USE_CUBLAS)
+        init_cublas();
+        #endif
+
         is_first_call = false;
     }
 
@@ -5352,7 +5279,6 @@ static void ggml_compute_forward_dup_f16(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
-    GGML_ASSERT(params->ith == 0);
     GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
@@ -5364,6 +5290,11 @@ static void ggml_compute_forward_dup_f16(
     const int64_t ne02 = src0->ne[2];
     const int64_t ne03 = src0->ne[3];
 
+    const int64_t ne0 = dst->ne[0];
+    const int64_t ne1 = dst->ne[1];
+    const int64_t ne2 = dst->ne[2];
+    const int64_t ne3 = dst->ne[3];
+
     const size_t nb00 = src0->nb[0];
     const size_t nb01 = src0->nb[1];
     const size_t nb02 = src0->nb[2];
@@ -5374,19 +5305,40 @@ static void ggml_compute_forward_dup_f16(
     const size_t nb2 = dst->nb[2];
     const size_t nb3 = dst->nb[3];
 
+    const int ith = params->ith; // thread index
+    const int nth = params->nth; // number of threads
+
     if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
-        memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
+        // parallelize by elements
+        const int ne = ggml_nelements(dst);
+        const int dr = (ne + nth - 1) / nth;
+        const int ie0 = dr * ith;
+        const int ie1 = MIN(ie0 + dr, ne);
+
+        memcpy(
+            ((char *)  dst->data + ie0*nb0),
+            ((char *) src0->data + ie0*nb00),
+            (ie1 - ie0) * GGML_TYPE_SIZE[src0->type]);
+
         return;
     }
 
+    // parallelize by rows
+    const int nr = ne01;
+    // number of rows per thread
+    const int dr = (nr + nth - 1) / nth;
+    // row range for this thread
+    const int ir0 = dr * ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
     if (src0->type == dst->type &&
-        src0->ne[0] == dst->ne[0] &&
-        src0->nb[0] == GGML_TYPE_SIZE[src0->type] && dst->nb[0] == GGML_TYPE_SIZE[dst->type]) {
+        ne00 == ne0 &&
+        nb00 == GGML_TYPE_SIZE[src0->type] && nb0 == GGML_TYPE_SIZE[dst->type]) {
         // copy by rows
         const size_t rs = ne00*nb00;
         for (int64_t i03 = 0; i03 < ne03; i03++) {
             for (int64_t i02 = 0; i02 < ne02; i02++) {
-                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
                     memcpy(
                         ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
                         ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
@@ -5400,21 +5352,21 @@ static void ggml_compute_forward_dup_f16(
     // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
 
     if (ggml_is_contiguous(dst)) {
-        if (src0->nb[0] == sizeof(ggml_fp16_t)) {
+        if (nb00 == sizeof(ggml_fp16_t)) {
             if (dst->type == GGML_TYPE_F16) {
                 size_t id = 0;
-                const size_t rs = ne00*nb00;
+                const size_t rs = ne00 * nb00;
+                char * dst_ptr = (char *) dst->data;
 
                 for (int i03 = 0; i03 < ne03; i03++) {
                     for (int i02 = 0; i02 < ne02; i02++) {
-                        for (int i01 = 0; i01 < ne01; i01++) {
+                        id += rs * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
                             const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
-                            char * dst_ptr = (char *) dst->data + id*rs;
-
-                            memcpy(dst_ptr, src0_ptr, rs);
-
-                            id++;
+                            memcpy(dst_ptr + id, src0_ptr, rs);
+                            id += rs;
                         }
+                        id += rs * (ne01 - ir1);
                     }
                 }
             } else if (dst->type == GGML_TYPE_F32) {
@@ -5423,14 +5375,39 @@ static void ggml_compute_forward_dup_f16(
 
                 for (int i03 = 0; i03 < ne03; i03++) {
                     for (int i02 = 0; i02 < ne02; i02++) {
-                        for (int i01 = 0; i01 < ne01; i01++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
                             for (int i00 = 0; i00 < ne00; i00++) {
-                                const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-
-                                dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
+                                dst_ptr[id] = GGML_FP16_TO_FP32(src0_ptr[i00]);
                                 id++;
                             }
                         }
+                        id += ne00 * (ne01 - ir1);
+                    }
+                }
+            } else if (ggml_is_quantized(dst->type)) {
+                quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
+                float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
+
+                size_t id = 0;
+                size_t rs = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
+                char * dst_ptr = (char *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += rs * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                src0_f32[i00] = GGML_FP16_TO_FP32(src0_ptr[i00]);
+                            }
+
+                            quantize_row_q(src0_f32, dst_ptr + id, ne00);
+                            id += rs;
+                        }
+                        id += rs * (ne01 - ir1);
                     }
                 }
             } else {
@@ -5445,7 +5422,8 @@ static void ggml_compute_forward_dup_f16(
 
                 for (int i03 = 0; i03 < ne03; i03++) {
                     for (int i02 = 0; i02 < ne02; i02++) {
-                        for (int i01 = 0; i01 < ne01; i01++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
                             for (int i00 = 0; i00 < ne00; i00++) {
                                 const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
 
@@ -5453,6 +5431,7 @@ static void ggml_compute_forward_dup_f16(
                                 id++;
                             }
                         }
+                        id += ne00 * (ne01 - ir1);
                     }
                 }
             } else if (dst->type == GGML_TYPE_F16) {
@@ -5461,7 +5440,8 @@ static void ggml_compute_forward_dup_f16(
 
                 for (int i03 = 0; i03 < ne03; i03++) {
                     for (int i02 = 0; i02 < ne02; i02++) {
-                        for (int i01 = 0; i01 < ne01; i01++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
                             for (int i00 = 0; i00 < ne00; i00++) {
                                 const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
 
@@ -5469,6 +5449,7 @@ static void ggml_compute_forward_dup_f16(
                                 id++;
                             }
                         }
+                        id += ne00 * (ne01 - ir1);
                     }
                 }
             } else {
@@ -5487,7 +5468,20 @@ static void ggml_compute_forward_dup_f16(
     if (dst->type == GGML_TYPE_F16) {
         for (int64_t i03 = 0; i03 < ne03; i03++) {
             for (int64_t i02 = 0; i02 < ne02; i02++) {
-                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                i10 += ne00 * ir0;
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
                     for (int64_t i00 = 0; i00 < ne00; i00++) {
                         const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
                               char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
@@ -5508,25 +5502,51 @@ static void ggml_compute_forward_dup_f16(
                         }
                     }
                 }
+                i10 += ne00 * (ne01 - ir1);
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
             }
         }
     } else if (dst->type == GGML_TYPE_F32) {
         for (int64_t i03 = 0; i03 < ne03; i03++) {
             for (int64_t i02 = 0; i02 < ne02; i02++) {
-                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                i10 += ne00 * ir0;
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
                     for (int64_t i00 = 0; i00 < ne00; i00++) {
                         const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
                               char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
 
                         *(float *) dst_ptr = GGML_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
 
-                        if (++i10 == ne00) {
+                        if (++i10 == ne0) {
                             i10 = 0;
-                            if (++i11 == ne01) {
+                            if (++i11 == ne1) {
                                 i11 = 0;
-                                if (++i12 == ne02) {
+                                if (++i12 == ne2) {
                                     i12 = 0;
-                                    if (++i13 == ne03) {
+                                    if (++i13 == ne3) {
                                         i13 = 0;
                                     }
                                 }
@@ -5534,6 +5554,19 @@ static void ggml_compute_forward_dup_f16(
                         }
                     }
                 }
+                i10 += ne00 * (ne01 - ir1);
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
             }
         }
     } else {
@@ -5545,7 +5578,6 @@ static void ggml_compute_forward_dup_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
-    GGML_ASSERT(params->ith == 0);
     GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
@@ -5557,6 +5589,11 @@ static void ggml_compute_forward_dup_f32(
     const int64_t ne02 = src0->ne[2];
     const int64_t ne03 = src0->ne[3];
 
+    const int64_t ne0 = dst->ne[0];
+    const int64_t ne1 = dst->ne[1];
+    const int64_t ne2 = dst->ne[2];
+    const int64_t ne3 = dst->ne[3];
+
     const size_t nb00 = src0->nb[0];
     const size_t nb01 = src0->nb[1];
     const size_t nb02 = src0->nb[2];
@@ -5567,19 +5604,40 @@ static void ggml_compute_forward_dup_f32(
     const size_t nb2 = dst->nb[2];
     const size_t nb3 = dst->nb[3];
 
+    const int ith = params->ith; // thread index
+    const int nth = params->nth; // number of threads
+
     if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
-        memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
+        // parallelize by elements
+        const int ne = ggml_nelements(dst);
+        const int dr = (ne + nth - 1) / nth;
+        const int ie0 = dr * ith;
+        const int ie1 = MIN(ie0 + dr, ne);
+
+        memcpy(
+            ((char *)  dst->data + ie0*nb0),
+            ((char *) src0->data + ie0*nb00),
+            (ie1 - ie0) * GGML_TYPE_SIZE[src0->type]);
+
         return;
     }
 
+    // parallelize by rows
+    const int nr = ne01;
+    // number of rows per thread
+    const int dr = (nr + nth - 1) / nth;
+    // row range for this thread
+    const int ir0 = dr * ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
     if (src0->type == dst->type &&
-        src0->ne[0] == dst->ne[0] &&
-        src0->nb[0] == GGML_TYPE_SIZE[src0->type] && dst->nb[0] == GGML_TYPE_SIZE[dst->type]) {
+        ne00 == ne0 &&
+        nb00 == GGML_TYPE_SIZE[src0->type] && nb0 == GGML_TYPE_SIZE[dst->type]) {
         // copy by rows
         const size_t rs = ne00*nb00;
         for (int64_t i03 = 0; i03 < ne03; i03++) {
             for (int64_t i02 = 0; i02 < ne02; i02++) {
-                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
                     memcpy(
                         ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
                         ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
@@ -5592,21 +5650,21 @@ static void ggml_compute_forward_dup_f32(
 
     if (ggml_is_contiguous(dst)) {
         // TODO: simplify
-        if (src0->nb[0] == sizeof(float)) {
+        if (nb00 == sizeof(float)) {
             if (dst->type == GGML_TYPE_F32) {
                 size_t id = 0;
-                const size_t rs = ne00*nb00;
+                const size_t rs = ne00 * nb00;
+                char * dst_ptr = (char *) dst->data;
 
                 for (int i03 = 0; i03 < ne03; i03++) {
                     for (int i02 = 0; i02 < ne02; i02++) {
-                        for (int i01 = 0; i01 < ne01; i01++) {
+                        id += rs * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
                             const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
-                            char * dst_ptr = (char *) dst->data + id*rs;
-
-                            memcpy(dst_ptr, src0_ptr, rs);
-
-                            id++;
+                            memcpy(dst_ptr + id, src0_ptr, rs);
+                            id += rs;
                         }
+                        id += rs * (ne01 - ir1);
                     }
                 }
             } else if (dst->type == GGML_TYPE_F16) {
@@ -5615,7 +5673,8 @@ static void ggml_compute_forward_dup_f32(
 
                 for (int i03 = 0; i03 < ne03; i03++) {
                     for (int i02 = 0; i02 < ne02; i02++) {
-                        for (int i01 = 0; i01 < ne01; i01++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
                             for (int i00 = 0; i00 < ne00; i00++) {
                                 const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
 
@@ -5623,6 +5682,25 @@ static void ggml_compute_forward_dup_f32(
                                 id++;
                             }
                         }
+                        id += ne00 * (ne01 - ir1);
+                    }
+                }
+            } else if (ggml_is_quantized(dst->type)) {
+                quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
+
+                size_t id = 0;
+                size_t rs = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
+                char * dst_ptr = (char *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += rs * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+                            quantize_row_q(src0_ptr, dst_ptr + id, ne00);
+                            id += rs;
+                        }
+                        id += rs * (ne01 - ir1);
                     }
                 }
             } else {
@@ -5637,7 +5715,8 @@ static void ggml_compute_forward_dup_f32(
 
                 for (int i03 = 0; i03 < ne03; i03++) {
                     for (int i02 = 0; i02 < ne02; i02++) {
-                        for (int i01 = 0; i01 < ne01; i01++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
                             for (int i00 = 0; i00 < ne00; i00++) {
                                 const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
 
@@ -5645,6 +5724,7 @@ static void ggml_compute_forward_dup_f32(
                                 id++;
                             }
                         }
+                        id += ne00 * (ne01 - ir1);
                     }
                 }
             } else if (dst->type == GGML_TYPE_F16) {
@@ -5653,7 +5733,8 @@ static void ggml_compute_forward_dup_f32(
 
                 for (int i03 = 0; i03 < ne03; i03++) {
                     for (int i02 = 0; i02 < ne02; i02++) {
-                        for (int i01 = 0; i01 < ne01; i01++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
                             for (int i00 = 0; i00 < ne00; i00++) {
                                 const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
 
@@ -5661,6 +5742,7 @@ static void ggml_compute_forward_dup_f32(
                                 id++;
                             }
                         }
+                        id += ne00 * (ne01 - ir1);
                     }
                 }
             } else {
@@ -5672,6 +5754,7 @@ static void ggml_compute_forward_dup_f32(
     }
 
     // dst counters
+
     int64_t i10 = 0;
     int64_t i11 = 0;
     int64_t i12 = 0;
@@ -5680,20 +5763,34 @@ static void ggml_compute_forward_dup_f32(
     if (dst->type == GGML_TYPE_F32) {
         for (int64_t i03 = 0; i03 < ne03; i03++) {
             for (int64_t i02 = 0; i02 < ne02; i02++) {
-                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                i10 += ne00 * ir0;
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    i11++;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
                     for (int64_t i00 = 0; i00 < ne00; i00++) {
                         const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
                               char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
 
                         memcpy(dst_ptr, src0_ptr, sizeof(float));
 
-                        if (++i10 == dst->ne[0]) {
+                        if (++i10 == ne0) {
                             i10 = 0;
-                            if (++i11 == dst->ne[1]) {
+                            if (++i11 == ne1) {
                                 i11 = 0;
-                                if (++i12 == dst->ne[2]) {
+                                if (++i12 == ne2) {
                                     i12 = 0;
-                                    if (++i13 == dst->ne[3]) {
+                                    if (++i13 == ne3) {
                                         i13 = 0;
                                     }
                                 }
@@ -5701,25 +5798,51 @@ static void ggml_compute_forward_dup_f32(
                         }
                     }
                 }
+                i10 += ne00 * (ne01 - ir1);
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
             }
         }
     } else if (dst->type == GGML_TYPE_F16) {
         for (int64_t i03 = 0; i03 < ne03; i03++) {
             for (int64_t i02 = 0; i02 < ne02; i02++) {
-                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                i10 += ne00 * ir0;
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
                     for (int64_t i00 = 0; i00 < ne00; i00++) {
                         const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
                               char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
 
                         *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
 
-                        if (++i10 == dst->ne[0]) {
+                        if (++i10 == ne0) {
                             i10 = 0;
-                            if (++i11 == dst->ne[1]) {
+                            if (++i11 == ne1) {
                                 i11 = 0;
-                                if (++i12 == dst->ne[2]) {
+                                if (++i12 == ne2) {
                                     i12 = 0;
-                                    if (++i13 == dst->ne[3]) {
+                                    if (++i13 == ne3) {
                                         i13 = 0;
                                     }
                                 }
@@ -5727,36 +5850,155 @@ static void ggml_compute_forward_dup_f32(
                         }
                     }
                 }
+                i10 += ne00 * (ne01 - ir1);
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    } else {
+        GGML_ASSERT(false); // TODO: implement
+    }
+}
+
+static void ggml_compute_forward_dup(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_dup_f16(params, src0, dst);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_dup_f32(params, src0, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_add
+
+static void ggml_compute_forward_add_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    const size_t nb00 = src0->nb[0];
+    const size_t nb01 = src0->nb[1];
+
+    const size_t nb10 = src1->nb[0];
+    const size_t nb11 = src1->nb[1];
+
+    const size_t nb0 = dst->nb[0];
+    const size_t nb1 = dst->nb[1];
+
+    GGML_ASSERT( nb0 == sizeof(float));
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    if (nb10 == sizeof(float)) {
+        for (int j = ith; j < n; j += nth) {
+#ifdef GGML_USE_ACCELERATE
+            vDSP_vadd(
+                    (float *) ((char *) src0->data + j*nb01), 1,
+                    (float *) ((char *) src1->data + j*nb11), 1,
+                    (float *) ((char *) dst->data  + j*nb1),  1, nc);
+#else
+            ggml_vec_add_f32(nc,
+                    (float *) ((char *) dst->data  + j*nb1),
+                    (float *) ((char *) src0->data + j*nb01),
+                    (float *) ((char *) src1->data + j*nb11));
+#endif
+        }
+    } else {
+        // src1 is not contiguous
+        for (int j = ith; j < n; j += nth) {
+            float * dst_ptr  = (float *) ((char *) dst->data  + j*nb1);
+            float * src0_ptr = (float *) ((char *) src0->data + j*nb01);
+            for (int i = 0; i < nc; i++) {
+                float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
+
+                dst_ptr[i] = src0_ptr[i] + *src1_ptr;
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_add_f16_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    const size_t nb00 = src0->nb[0];
+    const size_t nb01 = src0->nb[1];
+
+    const size_t nb10 = src1->nb[0];
+    const size_t nb11 = src1->nb[1];
+
+    const size_t nb0 = dst->nb[0];
+    const size_t nb1 = dst->nb[1];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F16);
+
+    GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+
+    if (nb10 == sizeof(float)) {
+        for (int j = ith; j < n; j += nth) {
+            ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + j*nb1);
+            ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
+            for (int i = 0; i < nc; i++) {
+                float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
+                dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + *src1_ptr);
             }
         }
-    } else {
-        GGML_ASSERT(false); // TODO: implement
     }
-}
-
-static void ggml_compute_forward_dup(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * src0,
-        struct ggml_tensor * dst) {
-    switch (src0->type) {
-        case GGML_TYPE_F16:
-            {
-                ggml_compute_forward_dup_f16(params, src0, dst);
-            } break;
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_dup_f32(params, src0, dst);
-            } break;
-        default:
-            {
-                GGML_ASSERT(false);
-            } break;
+    else {
+        // src1 is not contiguous
+        GGML_ASSERT(false);
     }
 }
 
-// ggml_compute_forward_add
-
-static void ggml_compute_forward_add_f32(
+static void ggml_compute_forward_add_f16_f16(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -5782,35 +6024,135 @@ static void ggml_compute_forward_add_f32(
     const size_t nb0 = dst->nb[0];
     const size_t nb1 = dst->nb[1];
 
-    GGML_ASSERT( nb0 == sizeof(float));
-    GGML_ASSERT(nb00 == sizeof(float));
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F16);
+    GGML_ASSERT(dst->type == GGML_TYPE_F16);
 
-    if (nb10 == sizeof(float)) {
-        for (int j = ith; j < n; j += nth) {
-#ifdef GGML_USE_ACCELERATE
-            vDSP_vadd(
-                    (float *) ((char *) src0->data + j*nb01), 1,
-                    (float *) ((char *) src1->data + j*nb11), 1,
-                    (float *) ((char *) dst->data  + j*nb1),  1, nc);
-#else
-            ggml_vec_add_f32(nc,
-                    (float *) ((char *) dst->data  + j*nb1),
-                    (float *) ((char *) src0->data + j*nb01),
-                    (float *) ((char *) src1->data + j*nb11));
-#endif
-        }
-    } else {
-        // src1 is not contiguous
+    GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+
+    if (nb10 == sizeof(ggml_fp16_t)) {
         for (int j = ith; j < n; j += nth) {
-            float * dst_ptr  = (float *) ((char *) dst->data  + j*nb1);
-            float * src0_ptr = (float *) ((char *) src0->data + j*nb01);
+            ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + j*nb1);
+            ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
             for (int i = 0; i < nc; i++) {
-                float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
-
-                dst_ptr[i] = src0_ptr[i] + *src1_ptr;
+                ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + j*nb11 + i*nb10);
+                dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(*src1_ptr));
             }
         }
     }
+    else {
+        // src1 is not contiguous
+        GGML_ASSERT(false);
+    }
+}
+
+static void ggml_compute_forward_add_q_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    const int64_t ne03 = src0->ne[3];
+
+    //const int64_t ne10 = src1->ne[0];
+    //const int64_t ne11 = src1->ne[1];
+    const int64_t ne12 = src1->ne[2];
+    const int64_t ne13 = src1->ne[3];
+
+    //const int64_t ne0  = dst->ne[0];
+    //const int64_t ne1  = dst->ne[1];
+    const int64_t ne2  = dst->ne[2];
+    const int64_t ne3  = dst->ne[3];
+
+    const int nb00 = src0->nb[0];
+    const int nb01 = src0->nb[1];
+    const int nb02 = src0->nb[2];
+    const int nb03 = src0->nb[3];
+
+    const int nb10 = src1->nb[0];
+    const int nb11 = src1->nb[1];
+    const int nb12 = src1->nb[2];
+    const int nb13 = src1->nb[3];
+
+    const int nb0  = dst->nb[0];
+    const int nb1  = dst->nb[1];
+    const int nb2  = dst->nb[2];
+    const int nb3  = dst->nb[3];
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_ASSERT(ne02 == ne12);
+    GGML_ASSERT(ne03 == ne13);
+    GGML_ASSERT(ne2  == ne12);
+    GGML_ASSERT(ne3  == ne13);
+
+    const enum ggml_type type = src0->type;
+    dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
+    quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q;
+
+    // we don't support permuted src0 or src1
+    GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 <= nb1);
+    GGML_ASSERT(nb1 <= nb2);
+    GGML_ASSERT(nb2 <= nb3);
+
+    GGML_ASSERT(ggml_is_quantized(src0->type));
+    GGML_ASSERT(dst->type == src0->type);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+    // total rows in src0
+    const int nr = ne01*ne02*ne03;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    float * wdata = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 indices
+        const int i03 = ir/(ne02*ne01);
+        const int i02 = (ir - i03*ne02*ne01)/ne01;
+        const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+        // src1 and dst are same shape as src0 => same indices
+        const int i13 = i03;
+        const int i12 = i02;
+        const int i11 = i01;
+
+        const int i3 = i03;
+        const int i2 = i02;
+        const int i1 = i01;
+
+        void  * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
+        float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
+        void  * dst_row  = (void *) ((char *)  dst->data + ( i1*nb1  +  i2*nb2  +  i3*nb0));
+
+        assert(ne00 % 32 == 0);
+
+        // unquantize row from src0 to temp buffer
+        dequantize_row_q(src0_row, wdata, ne00);
+        // add src1
+        ggml_vec_acc_f32(ne00, wdata, src1_row);
+        // quantize row to dst
+        quantize_row_q(wdata, dst_row, ne00);
+    }
 }
 
 static void ggml_compute_forward_add(
@@ -5823,6 +6165,24 @@ static void ggml_compute_forward_add(
             {
                 ggml_compute_forward_add_f32(params, src0, src1, dst);
             } break;
+        case GGML_TYPE_F16:
+            {
+                if (src1->type == GGML_TYPE_F16) {
+                    ggml_compute_forward_add_f16_f16(params, src0, src1, dst);
+                }
+                else if (src1->type == GGML_TYPE_F32) {
+                    ggml_compute_forward_add_f16_f32(params, src0, src1, dst);
+                }
+                else {
+                    GGML_ASSERT(false);
+                }
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q4_2:
+            {
+                ggml_compute_forward_add_q_f32(params, src0, src1, dst);
+            } break;
         default:
             {
                 GGML_ASSERT(false);
@@ -6720,7 +7080,7 @@ static void ggml_compute_forward_rms_norm(
 
 // ggml_compute_forward_mul_mat
 
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
 // helper function to determine if it is better to use BLAS or not
 // for large matrices, BLAS is faster
 static bool ggml_compute_forward_mul_mat_use_blas(
@@ -6760,7 +7120,7 @@ static void ggml_compute_forward_mul_mat_f32(
     const int64_t ne02 = src0->ne[2];
     const int64_t ne03 = src0->ne[3];
 
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
     const int64_t ne10 = src1->ne[0];
 #endif
     const int64_t ne11 = src1->ne[1];
@@ -6817,7 +7177,7 @@ static void ggml_compute_forward_mul_mat_f32(
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
 
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
         if (params->ith != 0) {
             return;
@@ -6831,6 +7191,21 @@ static void ggml_compute_forward_mul_mat_f32(
             return;
         }
 
+#if defined(GGML_USE_CUBLAS)
+        float *d_X = NULL;
+        float *d_Y = NULL;
+        float *d_D = NULL;
+        const float alpha = 1.0f;
+        const float beta = 0.0f;
+        const int x_ne = ne01 * ne10;
+        const int y_ne = ne11 * ne10;
+        const int d_ne = ne11 * ne01;
+
+        CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
+        CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
+        CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
+#endif
+
         for (int64_t i03 = 0; i03 < ne03; i03++) {
             for (int64_t i02 = 0; i02 < ne02; i02++) {
                 const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
@@ -6838,15 +7213,37 @@ static void ggml_compute_forward_mul_mat_f32(
 
                 float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
 
+#if defined(GGML_USE_CUBLAS)
+                // copy data to device
+                CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream));
+                CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
+
+                // compute
+                CUBLAS_CHECK(
+                    cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
+                            ne01, ne11, ne10,
+                            &alpha, d_X, ne00,
+                                    d_Y, ne10,
+                            &beta,  d_D, ne01));
+
+                // copy data to host
+                CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
+                CUDA_CHECK(cudaStreamSynchronize(cudaStream));
+#else
                 // zT = y * xT
                 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
                         ne11, ne01, ne10,
                         1.0f,    y, ne10,
                                  x, ne00,
                         0.0f,    d, ne01);
+#endif
             }
         }
-
+#if defined(GGML_USE_CUBLAS)
+        CUDA_CHECK(cudaFree(d_X));
+        CUDA_CHECK(cudaFree(d_Y));
+        CUDA_CHECK(cudaFree(d_D));
+#endif
         //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
 
         return;
@@ -6976,7 +7373,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
 
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
         GGML_ASSERT(nb10 == sizeof(float));
 
@@ -6992,10 +7389,37 @@ static void ggml_compute_forward_mul_mat_f16_f32(
             return;
         }
 
-        float * const wdata = params->wdata;
+#if defined(GGML_USE_CUBLAS)
+        ggml_fp16_t * const wdata = params->wdata;
 
+        float *d_X = NULL;
+        float *d_Y = NULL;
+        float *d_D = NULL;
+        const float alpha = 1.0f;
+        const float beta = 0.0f;
+        const int x_ne = ne01 * ne10;
+        const int y_ne = ne11 * ne10;
+        const int d_ne = ne11 * ne01;
+
+        CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(ggml_fp16_t) * x_ne));
+        CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
+        CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
+#else
+        float * const wdata = params->wdata;
+#endif
         for (int64_t i03 = 0; i03 < ne03; i03++) {
             for (int64_t i02 = 0; i02 < ne02; i02++) {
+#if defined(GGML_USE_CUBLAS)
+                // with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16
+                {
+                    size_t id = 0;
+                    for (int64_t i01 = 0; i01 < ne11; ++i01) {
+                        for (int64_t i00 = 0; i00 < ne10; ++i00) {
+                            wdata[id++] = GGML_FP32_TO_FP16(*(float *) ((char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11 + i00*nb10));
+                        }
+                    }
+                }
+#else
                 {
                     size_t id = 0;
                     for (int64_t i01 = 0; i01 < ne01; ++i01) {
@@ -7004,7 +7428,32 @@ static void ggml_compute_forward_mul_mat_f16_f32(
                         }
                     }
                 }
+#endif
+
+#if defined(GGML_USE_CUBLAS)
+                const ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + i02*nb02 + i03*nb03);
+                const ggml_fp16_t * y = (ggml_fp16_t *) wdata;
 
+                float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
+
+                // copy data to device
+                CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, cudaStream));
+                CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, cudaStream));
+
+                // compute
+                CUBLAS_CHECK(
+                    cublasGemmEx(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
+                            ne01, ne11, ne10,
+                            &alpha, d_X, CUDA_R_16F, ne00,
+                                    d_Y, CUDA_R_16F, ne10,
+                            &beta,  d_D, CUDA_R_32F, ne01,
+                            CUBLAS_COMPUTE_32F,
+                            CUBLAS_GEMM_DEFAULT));
+
+                // copy data to host
+                CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
+                CUDA_CHECK(cudaStreamSynchronize(cudaStream));
+#else
                 const float * x = wdata;
                 const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
 
@@ -7016,9 +7465,15 @@ static void ggml_compute_forward_mul_mat_f16_f32(
                         1.0f,    y, ne10,
                                  x, ne00,
                         0.0f,    d, ne01);
+#endif
             }
         }
 
+#if defined(GGML_USE_CUBLAS)
+        CUDA_CHECK(cudaFree(d_X));
+        CUDA_CHECK(cudaFree(d_Y));
+        CUDA_CHECK(cudaFree(d_D));
+#endif
         /*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
 
         return;
@@ -7102,30 +7557,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
     //}
 }
 
-static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
-    [GGML_TYPE_Q4_0] = {
-        .dequantize_row_q         = dequantize_row_q4_0,
-        .quantize_row_q           = quantize_row_q4_0,
-        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
-        .quantize_row_q_dot       = quantize_row_q8_0,
-        .vec_dot_q                = ggml_vec_dot_q4_0_q8_0,
-    },
-    [GGML_TYPE_Q4_1] = {
-        .dequantize_row_q         = dequantize_row_q4_1,
-        .quantize_row_q           = quantize_row_q4_1,
-        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
-        .quantize_row_q_dot       = quantize_row_q4_1,
-        .vec_dot_q                = ggml_vec_dot_q4_1,
-    },
-    // TODO: GGML_TYPE_Q8_0
-};
-
-// For internal test use
-quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
-    GGML_ASSERT(i < GGML_TYPE_COUNT);
-    return quantize_fns[i];
-}
-
 static void ggml_compute_forward_mul_mat_q_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
@@ -7194,7 +7625,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
 
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
         if (params->ith != 0) {
             return;
@@ -7211,6 +7642,21 @@ static void ggml_compute_forward_mul_mat_q_f32(
         float * const wdata = params->wdata;
         dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
 
+#if defined(GGML_USE_CUBLAS)
+        float *d_X = NULL;
+        float *d_Y = NULL;
+        float *d_D = NULL;
+        const float alpha = 1.0f;
+        const float beta = 0.0f;
+        const int x_ne = ne01 * ne10;
+        const int y_ne = ne11 * ne10;
+        const int d_ne = ne11 * ne01;
+
+        CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
+        CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
+        CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
+#endif
+
         for (int64_t i03 = 0; i03 < ne03; i03++) {
             for (int64_t i02 = 0; i02 < ne02; i02++) {
                 {
@@ -7226,15 +7672,38 @@ static void ggml_compute_forward_mul_mat_q_f32(
 
                 float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
 
+#if defined(GGML_USE_CUBLAS)
+                // copy data to device
+                CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream));
+                CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
+
+                // compute
+                CUBLAS_CHECK(
+                    cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
+                            ne01, ne11, ne10,
+                            &alpha, d_X, ne00,
+                                    d_Y, ne10,
+                            &beta,  d_D, ne01));
+
+                // copy data to host
+                CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
+                CUDA_CHECK(cudaStreamSynchronize(cudaStream));
+#else
                 // zT = y * xT
                 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
                         ne11, ne01, ne10,
                         1.0f,    y, ne10,
                                  x, ne00,
                         0.0f,    d, ne01);
+#endif
             }
         }
 
+#if defined(GGML_USE_CUBLAS)
+        CUDA_CHECK(cudaFree(d_X));
+        CUDA_CHECK(cudaFree(d_Y));
+        CUDA_CHECK(cudaFree(d_D));
+#endif
         //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
 
         return;
@@ -7322,6 +7791,7 @@ static void ggml_compute_forward_mul_mat(
     switch (src0->type) {
         case GGML_TYPE_Q4_0:
         case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q4_2:
         case GGML_TYPE_Q8_0:
             {
                 ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
@@ -7577,6 +8047,7 @@ static void ggml_compute_forward_get_rows(
     switch (src0->type) {
         case GGML_TYPE_Q4_0:
         case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q4_2:
         case GGML_TYPE_Q8_0:
             {
                 ggml_compute_forward_get_rows_q(params, src0, src1, dst);
@@ -7902,11 +8373,11 @@ static void ggml_compute_forward_rope_f16(
                     const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
                           ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
 
-                    const float x0 = ggml_fp16_to_fp32(src[0]);
-                    const float x1 = ggml_fp16_to_fp32(src[1]);
+                    const float x0 = GGML_FP16_TO_FP32(src[0]);
+                    const float x1 = GGML_FP16_TO_FP32(src[1]);
 
-                    dst_data[0] = ggml_fp32_to_fp16(x0*cos_theta - x1*sin_theta);
-                    dst_data[1] = ggml_fp32_to_fp16(x0*sin_theta + x1*cos_theta);
+                    dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
+                    dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
                 }
             }
         }
@@ -9982,13 +10453,29 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
             struct ggml_tensor * node = cgraph->nodes[i];
 
             switch (node->op) {
+                case GGML_OP_CPY:
                 case GGML_OP_DUP:
                     {
-                        node->n_tasks = 1;
+                        node->n_tasks = n_threads;
+
+                        size_t cur = 0;
+                        if (ggml_is_quantized(node->type)) {
+                            cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0] * n_threads;
+                        }
+
+                        work_size = MAX(work_size, cur);
                     } break;
                 case GGML_OP_ADD:
                     {
                         node->n_tasks = n_threads;
+
+                        size_t cur = 0;
+
+                        if (ggml_is_quantized(node->src0->type)) {
+                            cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_threads;
+                        }
+
+                        work_size = MAX(work_size, cur);
                     } break;
                 case GGML_OP_SUB:
                 case GGML_OP_MUL:
@@ -10033,7 +10520,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                         size_t cur = 0;
 
                         if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) {
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
                             if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
                                 node->n_tasks = 1; // TODO: this actually is doing nothing
                                                    //       the threads are still spinning
@@ -10049,8 +10536,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
 #endif
                         } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
                             cur = 0;
-                        } else if (quantize_fns[node->src0->type].vec_dot_q && node->src1->type == GGML_TYPE_F32) {
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+                        } else if (ggml_is_quantized(node->src0->type) && node->src1->type == GGML_TYPE_F32) {
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
                             if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
                                 node->n_tasks = 1;
                                 cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
@@ -10069,7 +10556,6 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                     {
                         node->n_tasks = n_threads;
                     } break;
-                case GGML_OP_CPY:
                 case GGML_OP_CONT:
                 case GGML_OP_RESHAPE:
                 case GGML_OP_VIEW:
@@ -11277,6 +11763,29 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t *
     return (n/QK4_1*sizeof(block_q4_1));
 }
 
+size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t * hist) {
+    assert(k % QK4_2 == 0);
+    const int nb = k / QK4_2;
+
+    for (int j = 0; j < n; j += k) {
+        block_q4_2 * restrict y = (block_q4_2 *)dst + j/QK4_2;
+
+        quantize_row_q4_2_reference(src + j, y, k);
+
+        for (int i = 0; i < nb; i++) {
+            for (int l = 0; l < QK4_2; l += 2) {
+                const uint8_t vi0 = y[i].qs[l/2] & 0xF;
+                const uint8_t vi1 = y[i].qs[l/2] >> 4;
+
+                hist[vi0]++;
+                hist[vi1]++;
+            }
+        }
+    }
+
+    return (n/QK4_2*sizeof(block_q4_2));
+}
+
 ////////////////////////////////////////////////////////////////////////////////
 
 int ggml_cpu_has_avx(void) {
@@ -11303,6 +11812,22 @@ int ggml_cpu_has_avx512(void) {
 #endif
 }
 
+int ggml_cpu_has_avx512_vbmi(void) {
+#if defined(__AVX512VBMI__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_avx512_vnni(void) {
+#if defined(__AVX512VNNI__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
 int ggml_cpu_has_fma(void) {
 #if defined(__FMA__)
     return 1;
@@ -11352,7 +11877,15 @@ int ggml_cpu_has_wasm_simd(void) {
 }
 
 int ggml_cpu_has_blas(void) {
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_cublas(void) {
+#if defined(GGML_USE_CUBLAS)
     return 1;
 #else
     return 0;