]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
sync : llama.cpp (CUDA opts, ggml-quants, YARN, etc.) (#601)
authorGeorgi Gerganov <redacted>
Fri, 3 Nov 2023 08:08:17 +0000 (10:08 +0200)
committerGitHub <redacted>
Fri, 3 Nov 2023 08:08:17 +0000 (10:08 +0200)
ggml-ci

include/ggml/ggml.h
scripts/sync-llama.sh
src/CMakeLists.txt
src/ggml-cuda.cu
src/ggml-impl.h
src/ggml-metal.m
src/ggml-metal.metal
src/ggml-quants.c [new file with mode: 0644]
src/ggml-quants.h [new file with mode: 0644]
src/ggml.c
tests/test-quantize-fns.cpp

index 253e4e27f68704ee0a15ad27720254fc8083cc68..e56a8337a507747eae914cd355c08d791c68ba4c 100644 (file)
 #define GGML_MAX_CONTEXTS       64
 #define GGML_MAX_SRC            6
 #define GGML_MAX_NAME           64
-#define GGML_MAX_OP_PARAMS      32
+#define GGML_MAX_OP_PARAMS      64
 #define GGML_DEFAULT_N_THREADS  4
 #define GGML_DEFAULT_GRAPH_SIZE 2048
 #if UINTPTR_MAX == 0xFFFFFFFF
@@ -1330,8 +1330,13 @@ extern "C" {
             int                   n_dims,
             int                   mode,
             int                   n_ctx,
+            int                   n_orig_ctx,
             float                 freq_base,
-            float                 freq_scale);
+            float                 freq_scale,
+            float                 ext_factor,
+            float                 attn_factor,
+            float                 beta_fast,
+            float                 beta_slow);
 
     // in-place, returns view(a)
     GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
@@ -1341,8 +1346,17 @@ extern "C" {
             int                   n_dims,
             int                   mode,
             int                   n_ctx,
+            int                   n_orig_ctx,
             float                 freq_base,
-            float                 freq_scale);
+            float                 freq_scale,
+            float                 ext_factor,
+            float                 attn_factor,
+            float                 beta_fast,
+            float                 beta_slow);
+
+    // compute correction dims for YaRN RoPE scaling
+    void ggml_rope_yarn_corr_dims(
+        int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]);
 
     // xPos RoPE, in-place, returns view(a)
     GGML_API struct ggml_tensor * ggml_rope_xpos_inplace(
@@ -1941,12 +1955,19 @@ extern "C" {
     // quantization
     //
 
+    // TODO: these would probably get removed in favor of the more general ggml_quantize_chunk
     GGML_API size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
     GGML_API size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
     GGML_API size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist);
     GGML_API size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist);
     GGML_API size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist);
 
+    GGML_API size_t ggml_quantize_q2_K(const float * src, void * dst, int n, int k, int64_t * hist);
+    GGML_API size_t ggml_quantize_q3_K(const float * src, void * dst, int n, int k, int64_t * hist);
+    GGML_API size_t ggml_quantize_q4_K(const float * src, void * dst, int n, int k, int64_t * hist);
+    GGML_API size_t ggml_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist);
+    GGML_API size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist);
+
     GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
 
     //
index b9b7aed14f1a52da3cf3a05c1c50f9574bbe51dc..b8c649ba058e3e41271ebebc0ae6aa39b01260c9 100755 (executable)
@@ -1,18 +1,24 @@
 #!/bin/bash
 
-cp -rpv ../llama.cpp/ggml.c           src/ggml.c
-cp -rpv ../llama.cpp/ggml-alloc.c     src/ggml-alloc.c
-cp -rpv ../llama.cpp/ggml-backend.c   src/ggml-backend.c
-cp -rpv ../llama.cpp/ggml-cuda.h      src/ggml-cuda.h
-cp -rpv ../llama.cpp/ggml-cuda.cu     src/ggml-cuda.cu
-cp -rpv ../llama.cpp/ggml-opencl.h    src/ggml-opencl.h
-cp -rpv ../llama.cpp/ggml-opencl.cpp  src/ggml-opencl.cpp
-cp -rpv ../llama.cpp/ggml-metal.h     src/ggml-metal.h
-cp -rpv ../llama.cpp/ggml-metal.m     src/ggml-metal.m
-cp -rpv ../llama.cpp/ggml-metal.metal src/ggml-metal.metal
-cp -rpv ../llama.cpp/ggml.h           include/ggml/ggml.h
-cp -rpv ../llama.cpp/ggml-alloc.h     include/ggml/ggml-alloc.h
-cp -rpv ../llama.cpp/ggml-backend.h   include/ggml/ggml-backend.h
+cp -rpv ../llama.cpp/ggml.c              src/ggml.c
+cp -rpv ../llama.cpp/ggml-alloc.c        src/ggml-alloc.c
+cp -rpv ../llama.cpp/ggml-backend-impl.c src/ggml-backend-impl.c
+cp -rpv ../llama.cpp/ggml-backend.c      src/ggml-backend.c
+cp -rpv ../llama.cpp/ggml-cuda.cu        src/ggml-cuda.cu
+cp -rpv ../llama.cpp/ggml-cuda.h         src/ggml-cuda.h
+cp -rpv ../llama.cpp/ggml-impl.h         src/ggml-impl.h
+cp -rpv ../llama.cpp/ggml-metal.h        src/ggml-metal.h
+cp -rpv ../llama.cpp/ggml-metal.m        src/ggml-metal.m
+cp -rpv ../llama.cpp/ggml-metal.metal    src/ggml-metal.metal
+#cp -rpv ../llama.cpp/ggml-mpi.h          src/ggml-mpi.h
+#cp -rpv ../llama.cpp/ggml-mpi.c          src/ggml-mpi.c
+cp -rpv ../llama.cpp/ggml-opencl.cpp     src/ggml-opencl.cpp
+cp -rpv ../llama.cpp/ggml-opencl.h       src/ggml-opencl.h
+cp -rpv ../llama.cpp/ggml-quants.c       src/ggml-quants.c
+cp -rpv ../llama.cpp/ggml-quants.h       src/ggml-quants.h
+cp -rpv ../llama.cpp/ggml.h              include/ggml/ggml.h
+cp -rpv ../llama.cpp/ggml-alloc.h        include/ggml/ggml-alloc.h
+cp -rpv ../llama.cpp/ggml-backend.h      include/ggml/ggml-backend.h
 
 cp -rpv ../llama.cpp/tests/test-opt.cpp           tests/test-opt.cpp
 cp -rpv ../llama.cpp/tests/test-grad0.cpp         tests/test-grad0.cpp
index 192e01eae8f0b87159383700c2873dc30188dc7b..d65a58e548c183e00816fdb23585e2ac8a748e9f 100644 (file)
@@ -253,6 +253,7 @@ add_library(${TARGET}
     ggml.c
     ggml-alloc.c
     ggml-backend.c
+    ggml-quants.c
     ggml-impl.h
     ggml-backend-impl.h
     ../include/ggml/ggml.h
index 92eb547f6114aa68e0c8f7907c3660e54a4bea8a..adc34aab76f85e530b5c028fc23fc660d5d96fd3 100644 (file)
 #define CC_OFFSET_AMD 1000000
 #define CC_RDNA2      (CC_OFFSET_AMD + 1030)
 
+// define this if you want to always fallback to MMQ kernels and not use cuBLAS for matrix multiplication
+// on modern hardware, using cuBLAS is recommended as it utilizes F16 tensor cores which are very performant
+// for large computational tasks. the drawback is that this requires some extra amount of VRAM:
+// -  7B quantum model: +100-200 MB
+// - 13B quantum model: +200-400 MB
+//
+//#define GGML_CUDA_FORCE_MMQ
+
+// TODO: improve this to be correct for more hardware
+//       for example, currently fails for GeForce GTX 1660 which is TURING arch (> VOLTA) but does not have tensor cores
+//       probably other such cases, and not sure what happens on AMD hardware
+#if !defined(GGML_CUDA_FORCE_MMQ)
+#define CUDA_USE_TENSOR_CORES
+#endif
+
+// max batch size to use MMQ kernels when tensor cores are available
+#define MMQ_MAX_BATCH_SIZE 32
+
 #if defined(GGML_USE_HIPBLAS)
 #define __CUDA_ARCH__ 1300
 
@@ -164,11 +182,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
     do {                                                                                \
         cudaError_t err_ = (err);                                                       \
         if (err_ != cudaSuccess) {                                                      \
-            int id;                                                                     \
-            cudaGetDevice(&id);                                                         \
+            int dev_id;                                                                     \
+            cudaGetDevice(&dev_id);                                                         \
             fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
                 cudaGetErrorString(err_));                                              \
-            fprintf(stderr, "current device: %d\n", id);                                \
+            fprintf(stderr, "current device: %d\n", dev_id);                                \
             exit(1);                                                                    \
         }                                                                               \
     } while (0)
@@ -178,11 +196,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
     do {                                                                                \
         cublasStatus_t err_ = (err);                                                    \
         if (err_ != CUBLAS_STATUS_SUCCESS) {                                            \
-            int id;                                                                     \
-            cudaGetDevice(&id);                                                         \
+            int dev_id;                                                                     \
+            cudaGetDevice(&dev_id);                                                         \
             fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n",                         \
                     err_, __FILE__, __LINE__, cublasGetStatusString(err_));             \
-            fprintf(stderr, "current device: %d\n", id);                                \
+            fprintf(stderr, "current device: %d\n", dev_id);                                \
             exit(1);                                                                    \
         }                                                                               \
     } while (0)
@@ -448,6 +466,7 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUA
 
 #define MAX_STREAMS 8
 static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullptr };
+static cudaMemPool_t g_cudaMemPools[GGML_CUDA_MAX_DEVICES] = { nullptr };
 
 struct ggml_tensor_extra_gpu {
     void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
@@ -471,7 +490,6 @@ static int g_device_count = -1;
 static int g_main_device = 0;
 static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES];
 static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
-static bool g_mul_mat_q = true;
 
 static void * g_scratch_buffer = nullptr;
 static size_t g_scratch_size = 0; // disabled by default
@@ -497,6 +515,15 @@ static __global__ void add_f16_f32_f16(const half * x, const float * y, half * d
     dst[i] = __hadd(x[i], __float2half(y[i]));
 }
 
+static __global__ void add_f16_f32_f32(const half * x, const float * y, float * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = __half2float(x[i]) + y[i];
+}
+
 static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
@@ -3555,9 +3582,15 @@ static __device__ __forceinline__ void mul_mat_q(
 #define  MMQ_X_Q4_0_RDNA1  64
 #define  MMQ_Y_Q4_0_RDNA1  64
 #define NWARPS_Q4_0_RDNA1  8
+#if defined(CUDA_USE_TENSOR_CORES)
+#define  MMQ_X_Q4_0_AMPERE 4
+#define  MMQ_Y_Q4_0_AMPERE 32
+#define NWARPS_Q4_0_AMPERE 4
+#else
 #define  MMQ_X_Q4_0_AMPERE 64
 #define  MMQ_Y_Q4_0_AMPERE 128
 #define NWARPS_Q4_0_AMPERE 4
+#endif
 #define  MMQ_X_Q4_0_PASCAL 64
 #define  MMQ_Y_Q4_0_PASCAL 64
 #define NWARPS_Q4_0_PASCAL 8
@@ -3616,9 +3649,15 @@ template <bool need_check> static __global__ void
 #define  MMQ_X_Q4_1_RDNA1  64
 #define  MMQ_Y_Q4_1_RDNA1  64
 #define NWARPS_Q4_1_RDNA1  8
+#if defined(CUDA_USE_TENSOR_CORES)
+#define  MMQ_X_Q4_1_AMPERE 4
+#define  MMQ_Y_Q4_1_AMPERE 32
+#define NWARPS_Q4_1_AMPERE 4
+#else
 #define  MMQ_X_Q4_1_AMPERE 64
 #define  MMQ_Y_Q4_1_AMPERE 128
 #define NWARPS_Q4_1_AMPERE 4
+#endif
 #define  MMQ_X_Q4_1_PASCAL 64
 #define  MMQ_Y_Q4_1_PASCAL 64
 #define NWARPS_Q4_1_PASCAL 8
@@ -3679,9 +3718,15 @@ template <bool need_check> static __global__ void
 #define  MMQ_X_Q5_0_RDNA1  64
 #define  MMQ_Y_Q5_0_RDNA1  64
 #define NWARPS_Q5_0_RDNA1  8
+#if defined(CUDA_USE_TENSOR_CORES)
+#define  MMQ_X_Q5_0_AMPERE 4
+#define  MMQ_Y_Q5_0_AMPERE 32
+#define NWARPS_Q5_0_AMPERE 4
+#else
 #define  MMQ_X_Q5_0_AMPERE 128
 #define  MMQ_Y_Q5_0_AMPERE 64
 #define NWARPS_Q5_0_AMPERE 4
+#endif
 #define  MMQ_X_Q5_0_PASCAL 64
 #define  MMQ_Y_Q5_0_PASCAL 64
 #define NWARPS_Q5_0_PASCAL 8
@@ -3740,9 +3785,15 @@ template <bool need_check> static __global__ void
 #define  MMQ_X_Q5_1_RDNA1  64
 #define  MMQ_Y_Q5_1_RDNA1  64
 #define NWARPS_Q5_1_RDNA1  8
+#if defined(CUDA_USE_TENSOR_CORES)
+#define  MMQ_X_Q5_1_AMPERE 4
+#define  MMQ_Y_Q5_1_AMPERE 32
+#define NWARPS_Q5_1_AMPERE 4
+#else
 #define  MMQ_X_Q5_1_AMPERE 128
 #define  MMQ_Y_Q5_1_AMPERE 64
 #define NWARPS_Q5_1_AMPERE 4
+#endif
 #define  MMQ_X_Q5_1_PASCAL 64
 #define  MMQ_Y_Q5_1_PASCAL 64
 #define NWARPS_Q5_1_PASCAL 8
@@ -3801,9 +3852,15 @@ mul_mat_q5_1(
 #define  MMQ_X_Q8_0_RDNA1  64
 #define  MMQ_Y_Q8_0_RDNA1  64
 #define NWARPS_Q8_0_RDNA1  8
+#if defined(CUDA_USE_TENSOR_CORES)
+#define  MMQ_X_Q8_0_AMPERE 4
+#define  MMQ_Y_Q8_0_AMPERE 32
+#define NWARPS_Q8_0_AMPERE 4
+#else
 #define  MMQ_X_Q8_0_AMPERE 128
 #define  MMQ_Y_Q8_0_AMPERE 64
 #define NWARPS_Q8_0_AMPERE 4
+#endif
 #define  MMQ_X_Q8_0_PASCAL 64
 #define  MMQ_Y_Q8_0_PASCAL 64
 #define NWARPS_Q8_0_PASCAL 8
@@ -3862,9 +3919,15 @@ template <bool need_check> static __global__ void
 #define  MMQ_X_Q2_K_RDNA1  128
 #define  MMQ_Y_Q2_K_RDNA1  32
 #define NWARPS_Q2_K_RDNA1  8
+#if defined(CUDA_USE_TENSOR_CORES)
+#define  MMQ_X_Q2_K_AMPERE 4
+#define  MMQ_Y_Q2_K_AMPERE 32
+#define NWARPS_Q2_K_AMPERE 4
+#else
 #define  MMQ_X_Q2_K_AMPERE 64
 #define  MMQ_Y_Q2_K_AMPERE 128
 #define NWARPS_Q2_K_AMPERE 4
+#endif
 #define  MMQ_X_Q2_K_PASCAL 64
 #define  MMQ_Y_Q2_K_PASCAL 64
 #define NWARPS_Q2_K_PASCAL 8
@@ -3923,9 +3986,15 @@ mul_mat_q2_K(
 #define  MMQ_X_Q3_K_RDNA1  32
 #define  MMQ_Y_Q3_K_RDNA1  128
 #define NWARPS_Q3_K_RDNA1  8
+#if defined(CUDA_USE_TENSOR_CORES)
+#define  MMQ_X_Q3_K_AMPERE 4
+#define  MMQ_Y_Q3_K_AMPERE 32
+#define NWARPS_Q3_K_AMPERE 4
+#else
 #define  MMQ_X_Q3_K_AMPERE 128
 #define  MMQ_Y_Q3_K_AMPERE 128
 #define NWARPS_Q3_K_AMPERE 4
+#endif
 #define  MMQ_X_Q3_K_PASCAL 64
 #define  MMQ_Y_Q3_K_PASCAL 64
 #define NWARPS_Q3_K_PASCAL 8
@@ -3986,9 +4055,15 @@ template <bool need_check> static __global__ void
 #define  MMQ_X_Q4_K_RDNA1  32
 #define  MMQ_Y_Q4_K_RDNA1  64
 #define NWARPS_Q4_K_RDNA1  8
+#if defined(CUDA_USE_TENSOR_CORES)
+#define  MMQ_X_Q4_K_AMPERE 4
+#define  MMQ_Y_Q4_K_AMPERE 32
+#define NWARPS_Q4_K_AMPERE 4
+#else
 #define  MMQ_X_Q4_K_AMPERE 64
 #define  MMQ_Y_Q4_K_AMPERE 128
 #define NWARPS_Q4_K_AMPERE 4
+#endif
 #define  MMQ_X_Q4_K_PASCAL 64
 #define  MMQ_Y_Q4_K_PASCAL 64
 #define NWARPS_Q4_K_PASCAL 8
@@ -4049,9 +4124,15 @@ template <bool need_check> static __global__ void
 #define  MMQ_X_Q5_K_RDNA1  32
 #define  MMQ_Y_Q5_K_RDNA1  64
 #define NWARPS_Q5_K_RDNA1  8
+#if defined(CUDA_USE_TENSOR_CORES)
+#define  MMQ_X_Q5_K_AMPERE 4
+#define  MMQ_Y_Q5_K_AMPERE 32
+#define NWARPS_Q5_K_AMPERE 4
+#else
 #define  MMQ_X_Q5_K_AMPERE 64
 #define  MMQ_Y_Q5_K_AMPERE 128
 #define NWARPS_Q5_K_AMPERE 4
+#endif
 #define  MMQ_X_Q5_K_PASCAL 64
 #define  MMQ_Y_Q5_K_PASCAL 64
 #define NWARPS_Q5_K_PASCAL 8
@@ -4110,9 +4191,15 @@ mul_mat_q5_K(
 #define  MMQ_X_Q6_K_RDNA1  32
 #define  MMQ_Y_Q6_K_RDNA1  64
 #define NWARPS_Q6_K_RDNA1  8
+#if defined(CUDA_USE_TENSOR_CORES)
+#define  MMQ_X_Q6_K_AMPERE 4
+#define  MMQ_Y_Q6_K_AMPERE 32
+#define NWARPS_Q6_K_AMPERE 4
+#else
 #define  MMQ_X_Q6_K_AMPERE 64
 #define  MMQ_Y_Q6_K_AMPERE 64
 #define NWARPS_Q6_K_AMPERE 4
+#endif
 #define  MMQ_X_Q6_K_PASCAL 64
 #define  MMQ_Y_Q6_K_PASCAL 64
 #define NWARPS_Q6_K_PASCAL 8
@@ -4408,11 +4495,41 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
     cpy_1(cx + x_offset, cdst + dst_offset);
 }
 
-// rope == RoPE == rotary positional embedding
+static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
+    const float y = (i0 / 2 - low) / max(0.001f, high - low);
+    return 1.0f - min(1.0f, max(0.0f, y));
+}
 
+struct rope_corr_dims {
+    float v[4];
+};
+
+// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
+// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
+static __device__ void rope_yarn(
+    float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
+    float * cos_theta, float * sin_theta
+) {
+    // Get n-d rotational scaling corrected for extrapolation
+    float theta_interp = freq_scale * theta_extrap;
+    float theta = theta_interp;
+    if (ext_factor != 0.0f) {
+        float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
+        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
+
+        // Get n-d magnitude scaling corrected for interpolation
+        mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
+    }
+    *cos_theta = cosf(theta) * mscale;
+    *sin_theta = sinf(theta) * mscale;
+}
+
+// rope == RoPE == rotary positional embedding
 template<typename T, bool has_pos>
-static __global__ void rope(const T * x, T * dst, const int ncols, const int32_t * pos, const float freq_scale,
-                            const int p_delta_rows, const float theta_scale) {
+static __global__ void rope(
+    const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
+    float ext_factor, float attn_factor, rope_corr_dims corr_dims
+) {
     const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
     if (col >= ncols) {
@@ -4424,10 +4541,10 @@ static __global__ void rope(const T * x, T * dst, const int ncols, const int32_t
     const int i2 = row/p_delta_rows;
 
     const int p = has_pos ? pos[i2] : 0;
-    const float p0 = p*freq_scale;
-    const float theta = p0*powf(theta_scale, col/2);
-    const float sin_theta = sinf(theta);
-    const float cos_theta = cosf(theta);
+    const float theta_base = p*powf(freq_base, -float(col)/ncols);
+
+    float cos_theta, sin_theta;
+    rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta);
 
     const float x0 = x[i + 0];
     const float x1 = x[i + 1];
@@ -4437,8 +4554,10 @@ static __global__ void rope(const T * x, T * dst, const int ncols, const int32_t
 }
 
 template<typename T, bool has_pos>
-static __global__ void rope_neox(const T * x, T * dst, const int ncols, const int32_t * pos, const float freq_scale,
-                                 const int p_delta_rows, const float theta_scale) {
+static __global__ void rope_neox(
+    const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
+    float ext_factor, float attn_factor, rope_corr_dims corr_dims
+) {
     const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
     if (col >= ncols) {
@@ -4449,11 +4568,14 @@ static __global__ void rope_neox(const T * x, T * dst, const int ncols, const in
     const int i = row*ncols + col/2;
     const int i2 = row/p_delta_rows;
 
+    // simplified from `(ib * ncols + col) * (-1 / ncols)`, where ib is assumed to be zero
+    const float cur_rot = -float(col)/ncols;
+
     const int p = has_pos ? pos[i2] : 0;
-    const float p0 = p*freq_scale;
-    const float theta = p0*powf(theta_scale, col/2);
-    const float sin_theta = sinf(theta);
-    const float cos_theta = cosf(theta);
+    const float theta_base = p*powf(freq_base, cur_rot);
+
+    float cos_theta, sin_theta;
+    rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
 
     const float x0 = x[i + 0];
     const float x1 = x[i + ncols/2];
@@ -4462,8 +4584,10 @@ static __global__ void rope_neox(const T * x, T * dst, const int ncols, const in
     dst[i + ncols/2] = x0*sin_theta + x1*cos_theta;
 }
 
-static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale,
-                                    const int p_delta_rows, const float theta_scale, const int n_ctx) {
+static __global__ void rope_glm_f32(
+    const float * x, float * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
+    int n_ctx
+) {
     const int col = blockDim.x*blockIdx.x + threadIdx.x;
     const int half_n_dims = ncols/4;
 
@@ -4475,7 +4599,7 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
     const int i = row*ncols + col;
     const int i2 = row/p_delta_rows;
 
-    const float col_theta_scale = powf(theta_scale, col);
+    const float col_theta_scale = powf(freq_base, -2.0f*col/ncols);
      // FIXME: this is likely wrong
     const int p = pos != nullptr ? pos[i2] : 0;
 
@@ -4617,6 +4741,11 @@ static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, co
     add_f16_f32_f16<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
 }
 
+static void add_f16_f32_f32_cuda(const half * x, const float * y, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
+    add_f16_f32_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
+}
+
 static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
     const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE;
     mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
@@ -5494,40 +5623,54 @@ static void clamp_f32_cuda(const float * x, float * dst, const float min, const
 }
 
 template<typename T>
-static void rope_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
-                          const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
+static void rope_cuda(
+    const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
+) {
     GGML_ASSERT(ncols % 2 == 0);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
     const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
     const dim3 block_nums(nrows, num_blocks_x, 1);
     if (pos == nullptr) {
-        rope<T, false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
+        rope<T, false><<<block_nums, block_dims, 0, stream>>>(
+            x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
+        );
     } else {
-        rope<T, true><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
+        rope<T, true><<<block_nums, block_dims, 0, stream>>>(
+            x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
+        );
     }
 }
 
 template<typename T>
-static void rope_neox_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
-                          const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
+static void rope_neox_cuda(
+    const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
+) {
     GGML_ASSERT(ncols % 2 == 0);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
     const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
     const dim3 block_nums(nrows, num_blocks_x, 1);
     if (pos == nullptr) {
-        rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
+        rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
+            x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
+        );
     } else {
-        rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
+        rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
+            x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
+        );
     }
 }
 
-static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
-                              const int p_delta_rows, const float theta_scale, const int n_ctx, cudaStream_t stream) {
+static void rope_glm_f32_cuda(
+    const float * x, float * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, int n_ctx, cudaStream_t stream
+) {
     GGML_ASSERT(ncols % 4 == 0);
     const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1);
     const int num_blocks_x = (ncols + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE;
     const dim3 block_nums(num_blocks_x, nrows, 1);
-    rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale, n_ctx);
+    rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, n_ctx);
 }
 
 static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows,
@@ -5631,6 +5774,16 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
     return ptr;
 }
 
+static void * ggml_cuda_pool_malloc_async(size_t size, size_t * actual_size, int id, cudaStream_t stream) {
+    if (g_cudaMemPools[id] == nullptr) {
+        return ggml_cuda_pool_malloc(size, actual_size);
+    }
+    void *ptr;
+    CUDA_CHECK(cudaMallocFromPoolAsync(&ptr, size, g_cudaMemPools[id], stream));
+    *actual_size = size;
+    return ptr;
+}
+
 static void ggml_cuda_pool_free(void * ptr, size_t size) {
     scoped_spin_lock lock(g_cuda_pool_lock);
     int id;
@@ -5649,6 +5802,13 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) {
 }
 
 
+static void ggml_cuda_pool_free_async(void * ptr, size_t actual_size, int id, cudaStream_t stream) {
+    if (g_cudaMemPools[id] == nullptr) {
+        return ggml_cuda_pool_free(ptr, actual_size);
+    }
+    CUDA_CHECK(cudaFreeAsync(ptr, stream));
+}
+
 void ggml_init_cublas() {
     static bool initialized = false;
 
@@ -5664,6 +5824,16 @@ void ggml_init_cublas() {
         CUDA_CHECK(cudaGetDeviceCount(&g_device_count));
         GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES);
         int64_t total_vram = 0;
+#if defined(GGML_CUDA_FORCE_MMQ)
+        fprintf(stderr, "%s: GGML_CUDA_FORCE_MMQ:   yes\n", __func__);
+#else
+        fprintf(stderr, "%s: GGML_CUDA_FORCE_MMQ:   no\n", __func__);
+#endif
+#if defined(CUDA_USE_TENSOR_CORES)
+        fprintf(stderr, "%s: CUDA_USE_TENSOR_CORES: yes\n", __func__);
+#else
+        fprintf(stderr, "%s: CUDA_USE_TENSOR_CORES: no\n", __func__);
+#endif
         fprintf(stderr, "%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, g_device_count);
         for (int id = 0; id < g_device_count; ++id) {
             cudaDeviceProp prop;
@@ -5693,6 +5863,13 @@ void ggml_init_cublas() {
             // create cublas handle
             CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id]));
             CUBLAS_CHECK(cublasSetMathMode(g_cublas_handles[id], CUBLAS_TF32_TENSOR_OP_MATH));
+
+            // configure memory pool
+            cudaError_t err = cudaDeviceGetMemPool(&g_cudaMemPools[id], id);
+            if (err == cudaSuccess) {
+                size_t treshold = UINT64_MAX;
+                CUDA_CHECK(cudaMemPoolSetAttribute(g_cudaMemPools[id], cudaMemPoolAttrReleaseThreshold, &treshold));
+            }
         }
 
         // configure logging to stdout
@@ -5910,7 +6087,10 @@ inline void ggml_cuda_op_add(
         add_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(src0), ne10*ne11, main_stream);
     } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
         add_f16_f32_f16_cuda((const half *) src0_dd, src1_dd, (half *) dst_dd, ggml_nelements(src0), main_stream);
+    } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
+        add_f16_f32_f32_cuda((const half *) src0_dd, src1_dd, dst_dd, ggml_nelements(src0), main_stream);
     } else {
+        fprintf(stderr, "src0->type: %d  dst->type: %d\n", src0->type, dst->type);
         GGML_ASSERT(false);
     }
 
@@ -6255,16 +6435,15 @@ inline void ggml_cuda_op_mul_mat_cublas(
     const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
     const int64_t src1_padded_row_size, const cudaStream_t & stream) {
 
-    GGML_ASSERT(src0_dd_i != nullptr);
+    GGML_ASSERT(src0_dd_i  != nullptr);
     GGML_ASSERT(src1_ddf_i != nullptr);
-    GGML_ASSERT(dst_dd_i != nullptr);
-
+    GGML_ASSERT(dst_dd_i   != nullptr);
 
     const int64_t ne00 = src0->ne[0];
-
     const int64_t ne10 = src1->ne[0];
 
     const int64_t ne0 = dst->ne[0];
+
     const int64_t row_diff = row_high - row_low;
 
     int id;
@@ -6284,7 +6463,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
             const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
             GGML_ASSERT(to_fp16_cuda != nullptr);
             size_t ne = row_diff*ne00;
-            src0_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src0_as);
+            src0_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src0_as, id, stream);
             to_fp16_cuda(src0_dd_i, src0_as_f16, ne, stream);
         }
         const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16;
@@ -6295,13 +6474,12 @@ inline void ggml_cuda_op_mul_mat_cublas(
             const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
             GGML_ASSERT(to_fp16_cuda != nullptr);
             size_t ne = src1_ncols*ne10;
-            src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as);
+            src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src1_as, id, stream);
             to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
         }
         const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16;
-
-        size_t dst_as = 0;
-        half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as);
+        size_t dst_f16_as = 0;
+        half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(row_diff*src1_ncols * sizeof(half), &dst_f16_as, id, stream);
 
         const half alpha_f16 = 1.0f;
         const half beta_f16 = 0.0f;
@@ -6319,14 +6497,15 @@ inline void ggml_cuda_op_mul_mat_cublas(
         const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
         to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream);
 
-        ggml_cuda_pool_free(dst_f16, dst_as);
+        if (dst_f16_as != 0) {
+            ggml_cuda_pool_free_async(dst_f16, dst_f16_as, id, stream);
+        }
 
         if (src0_as != 0) {
-            ggml_cuda_pool_free(src0_as_f16, src0_as);
+            ggml_cuda_pool_free_async(src0_as_f16, src0_as, id, stream);
         }
-
         if (src1_as != 0) {
-            ggml_cuda_pool_free(src1_as_f16, src1_as);
+            ggml_cuda_pool_free_async(src1_as_f16, src1_as, id, stream);
         }
     }
     else {
@@ -6336,7 +6515,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
         if (src0->type != GGML_TYPE_F32) {
             const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
             GGML_ASSERT(to_fp32_cuda != nullptr);
-            src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as); // NOLINT
+            src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc_async(row_diff*ne00 * sizeof(float), &src0_as, id, stream); // NOLINT
             to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream);
         }
         const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32;
@@ -6349,11 +6528,11 @@ inline void ggml_cuda_op_mul_mat_cublas(
             cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
                     row_diff, src1_ncols, ne10,
                     &alpha, src0_ddf_i, ne00,
-                            src1_ddf_i,  ne10,
+                            src1_ddf_i, ne10,
                     &beta,  dst_dd_i,   ldc));
 
         if (src0_as != 0) {
-            ggml_cuda_pool_free(src0_ddq_as_f32, src0_as);
+            ggml_cuda_pool_free_async(src0_ddq_as_f32, src0_as, id, stream);
         }
     }
 
@@ -6375,17 +6554,20 @@ inline void ggml_cuda_op_rope(
     const int64_t ne2 = dst->ne[2];
     const int64_t nrows = ggml_nrows(src0);
 
-    //const int n_past = ((int32_t *) dst->op_params)[0];
-    const int n_dims = ((int32_t *) dst->op_params)[1];
-    const int mode   = ((int32_t *) dst->op_params)[2];
-    const int n_ctx  = ((int32_t *) dst->op_params)[3];
-    // RoPE alteration for extended context
-
-    float freq_base, freq_scale;
-    memcpy(&freq_base,  (int32_t *) dst->op_params + 4, sizeof(float));
-    memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
+    //const int n_past      = ((int32_t *) dst->op_params)[0];
+    const int n_dims      = ((int32_t *) dst->op_params)[1];
+    const int mode        = ((int32_t *) dst->op_params)[2];
+    const int n_ctx       = ((int32_t *) dst->op_params)[3];
+    const int n_orig_ctx  = ((int32_t *) dst->op_params)[4];
 
-    const float theta_scale = powf(freq_base, -2.0f/n_dims);
+    // RoPE alteration for extended context
+    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+    memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
+    memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
+    memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
+    memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
+    memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
+    memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
 
     const int32_t * pos = nullptr;
     if ((mode & 1) == 0) {
@@ -6397,24 +6579,39 @@ inline void ggml_cuda_op_rope(
     const bool is_neox = mode & 2;
     const bool is_glm  = mode & 4;
 
+    rope_corr_dims corr_dims;
+    ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
+
     // compute
     if (is_glm) {
         GGML_ASSERT(false);
-        rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, n_ctx, main_stream);
+        rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, main_stream);
     } else if (is_neox) {
         GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet");
         if (src0->type == GGML_TYPE_F32) {
-            rope_neox_cuda((const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
+            rope_neox_cuda(
+                (const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
+                attn_factor, corr_dims, main_stream
+            );
         } else if (src0->type == GGML_TYPE_F16) {
-            rope_neox_cuda((const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
+            rope_neox_cuda(
+                (const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
+                attn_factor, corr_dims, main_stream
+            );
         } else {
             GGML_ASSERT(false);
         }
     } else {
         if (src0->type == GGML_TYPE_F32) {
-            rope_cuda((const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
+            rope_cuda(
+                (const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
+                attn_factor, corr_dims, main_stream
+            );
         } else if (src0->type == GGML_TYPE_F16) {
-            rope_cuda((const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
+            rope_cuda(
+                (const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
+                attn_factor, corr_dims, main_stream
+            );
         } else {
             GGML_ASSERT(false);
         }
@@ -6525,8 +6722,10 @@ inline void ggml_cuda_op_clamp(
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
-    const float min = ((float *) dst->op_params)[0];
-    const float max = ((float *) dst->op_params)[1];
+    float min;
+    float max;
+    memcpy(&min, dst->op_params, sizeof(float));
+    memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
 
     clamp_f32_cuda(src0_dd, dst_dd, min, max, ggml_nelements(src0), main_stream);
     CUDA_CHECK(cudaGetLastError());
@@ -6756,21 +6955,22 @@ static void ggml_cuda_op_mul_mat(
             src0_dd[id] = (char *) src0_extra->data_device[id];
         } else {
             const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes(src0);
-            src0_dd[id] = (char *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_as[id]);
+            src0_dd[id] = (char *) ggml_cuda_pool_malloc_async(ggml_nbytes(src0), &src0_as[id], id, stream);
         }
 
         if (src1_on_device && src1_is_contiguous) {
             src1_ddf[id] = (float *) src1_extra->data_device[id];
         } else {
-            src1_ddf[id] = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src1), &src1_asf[id]);
+            src1_ddf[id] = (float *) ggml_cuda_pool_malloc_async(ggml_nbytes(src1), &src1_asf[id], id, stream);
         }
 
         if (convert_src1_to_q8_1) {
-            src1_ddq[id] = (char *) ggml_cuda_pool_malloc(nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs, &src1_asq[id]);
+            const size_t size_dst_ddq = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs;
+            src1_ddq[id] = (char *) ggml_cuda_pool_malloc_async(size_dst_ddq, &src1_asq[id], id, stream);
 
             if (src1_on_device && src1_is_contiguous) {
                 quantize_row_q8_1_cuda(src1_ddf[id], src1_ddq[id], ne10, nrows1, src1_padded_col_size, stream);
-                CUDA_CHECK(cudaGetLastError());
+                // CUDA_CHECK(cudaGetLastError());
             }
         }
 
@@ -6778,7 +6978,7 @@ static void ggml_cuda_op_mul_mat(
             dst_dd[id] = (float *) dst_extra->data_device[id];
         } else {
             const size_t size_dst_ddf = split ? (row_high[id]-row_low[id])*ne1*sizeof(float) : ggml_nbytes(dst);
-            dst_dd[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_as[id]);
+            dst_dd[id] = (float *) ggml_cuda_pool_malloc_async(size_dst_ddf, &dst_as[id], id,  stream);
         }
     }
 
@@ -6904,24 +7104,6 @@ static void ggml_cuda_op_mul_mat(
         }
     }
 
-    for (int64_t id = 0; id < g_device_count; ++id) {
-        CUDA_CHECK(ggml_cuda_set_device(id));
-
-        // free buffers again when done
-        if (src0_as[id] > 0) {
-            ggml_cuda_pool_free(src0_dd[id], src0_as[id]);
-        }
-        if (src1_asf[id] > 0) {
-            ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]);
-        }
-        if (src1_asq[id] > 0) {
-            ggml_cuda_pool_free(src1_ddq[id], src1_asq[id]);
-        }
-        if (dst_as[id] > 0) {
-            ggml_cuda_pool_free(dst_dd[id], dst_as[id]);
-        }
-    }
-
     // main device waits for all other devices to be finished
     if (split && g_device_count > 1) {
         int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE;
@@ -6939,6 +7121,21 @@ static void ggml_cuda_op_mul_mat(
         CUDA_CHECK(ggml_cuda_set_device(g_main_device));
         CUDA_CHECK(cudaDeviceSynchronize());
     }
+
+    for (int64_t id = 0; id < g_device_count; ++id) {
+        if (src0_as[id] > 0) {
+            ggml_cuda_pool_free_async(src0_dd[id], src0_as[id], id, g_cudaStreams[id][0]);
+        }
+        if (src1_asf[id] > 0) {
+            ggml_cuda_pool_free_async(src1_ddf[id], src1_asf[id], id, g_cudaStreams[id][0]);
+        }
+        if (src1_asq[id] > 0) {
+            ggml_cuda_pool_free_async(src1_ddq[id], src1_asq[id], id, g_cudaStreams[id][0]);
+        }
+        if (dst_as[id] > 0) {
+            ggml_cuda_pool_free_async(dst_dd[id], dst_as[id], id, g_cudaStreams[id][0]);
+        }
+    }
 }
 
 static void ggml_cuda_repeat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -7050,9 +7247,34 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
     ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
 }
 
-static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
+__global__ void k_compute_batched_ptrs(
+        const half * src0_as_f16, const half * src1_as_f16, half * dst_f16,
+        const void ** ptrs_src, void ** ptrs_dst,
+        int ne12, int ne13,
+        int ne23,
+        int nb02, int nb03,
+        int nb12, int nb13,
+        int nb2, int nb3,
+        int r2, int r3) {
+    int i13 = blockIdx.x * blockDim.x + threadIdx.x;
+    int i12 = blockIdx.y * blockDim.y + threadIdx.y;
+
+    if (i13 >= ne13 || i12 >= ne12) {
+        return;
+    }
+
+    int i03 = i13 / r3;
+    int i02 = i12 / r2;
+
+    ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02   + i03*nb03;
+    ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2;
+    ptrs_dst[0*ne23 + i12 + i13*ne12] = (      char *)     dst_f16 + i12* nb2/2 + i13* nb3/2;
+}
+
+static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     GGML_ASSERT(!ggml_is_transposed(src0));
     GGML_ASSERT(!ggml_is_transposed(src1));
+
     GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
     GGML_ASSERT(src0->type == GGML_TYPE_F16);
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
@@ -7100,11 +7322,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
     GGML_ASSERT(to_fp16_cuda != nullptr);
 
     size_t src1_as = 0;
-    half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
+    half * src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne1 * sizeof(half), &src1_as, id, main_stream);
     to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
 
     size_t dst_as = 0;
-    half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
+    half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &dst_as, id, main_stream);
 
     GGML_ASSERT(ne12 % ne02 == 0);
     GGML_ASSERT(ne13 % ne03 == 0);
@@ -7150,71 +7372,76 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
                 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
     } else {
         // use cublasGemmBatchedEx
-        // TODO: https://github.com/ggerganov/llama.cpp/pull/3749#discussion_r1369997000
         const int ne23 = ne12*ne13;
 
-        // TODO: avoid this alloc
-        void ** ptrs = (void **) malloc(3*ne23*sizeof(void *));
-
-        for (int i13 = 0; i13 < ne13; ++i13) {
-            for (int i12 = 0; i12 < ne12; ++i12) {
-                int i03 = i13 / r3;
-                int i02 = i12 / r2;
-
-                ptrs[0*ne23 + i12 + i13*ne12] = (char *) src0_as_f16 + i02*src0->nb[2]   + i03*src0->nb[3];
-                ptrs[1*ne23 + i12 + i13*ne12] = (char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2;
-                ptrs[2*ne23 + i12 + i13*ne12] = (char *)     dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2;
-            }
-        }
-
-        // allocate device memory for pointers
-        void ** ptrs_as = nullptr;
-        CUDA_CHECK(cudaMalloc(&ptrs_as, 3*ne23*sizeof(void *)));
+        const void ** ptrs_src = nullptr;
+              void ** ptrs_dst = nullptr;
 
-        // TODO: this does not work for some reason -- not sure why?
-        //size_t ptrs_s = 0;
-        //ptrs_as = (void **) ggml_cuda_pool_malloc(3*ne23*sizeof(void *), &ptrs_s);
+        size_t ptrs_src_s = 0;
+        size_t ptrs_dst_s = 0;
 
-        // copy pointers to device
-        CUDA_CHECK(cudaMemcpy(ptrs_as, ptrs, 3*ne23*sizeof(void *), cudaMemcpyHostToDevice));
-
-        free(ptrs);
+        ptrs_src = (const void **) ggml_cuda_pool_malloc_async(2*ne23*sizeof(void *), &ptrs_src_s, id, main_stream);
+        ptrs_dst = (      void **) ggml_cuda_pool_malloc_async(1*ne23*sizeof(void *), &ptrs_dst_s, id, main_stream);
 
+        dim3 block_dims(ne13, ne12);
+        k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
+                src0_as_f16, src1_as_f16, dst_f16,
+                ptrs_src, ptrs_dst,
+                ne12, ne13,
+                ne23,
+                nb02, nb03,
+                nb12, nb13,
+                dst->nb[2], dst->nb[3],
+                r2, r3);
+        CUDA_CHECK(cudaGetLastError());
         CUBLAS_CHECK(
         cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
                 ne01, ne11, ne10,
-                &alpha_f16, (const void **) (ptrs_as + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
-                            (const void **) (ptrs_as + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
-                &beta_f16,  (      void **) (ptrs_as + 2*ne23), CUDA_R_16F, ne01,
+                &alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
+                            (const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
+                &beta_f16,  (      void **) (ptrs_dst + 0*ne23), CUDA_R_16F, ne01,
                 ne23,
                 CUBLAS_COMPUTE_16F,
                 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
 
-        // free device memory for pointers
-        CUDA_CHECK(cudaFree(ptrs_as));
-        //ggml_cuda_pool_free(ptrs_as, ptrs_s);
+        if (ptrs_src_s != 0) {
+            ggml_cuda_pool_free_async(ptrs_src, ptrs_src_s, id, main_stream);
+        }
+        if (ptrs_dst_s != 0) {
+            ggml_cuda_pool_free_async(ptrs_dst, ptrs_dst_s, id, main_stream);
+        }
     }
 #endif
 
     const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
     to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
-
-    ggml_cuda_pool_free(src1_as_f16, src1_as);
-    ggml_cuda_pool_free(dst_f16, dst_as);
+    if (src1_as != 0) {
+        ggml_cuda_pool_free_async(src1_as_f16, src1_as, id, main_stream);
+    }
+    if (dst_as != 0) {
+        ggml_cuda_pool_free_async(dst_f16, dst_as, id, main_stream);
+    }
 }
 
 static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) &&
-        src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU;
+    const bool all_on_device =
+        (src0->backend == GGML_BACKEND_GPU) &&
+        (src1->backend == GGML_BACKEND_GPU) &&
+        ( dst->backend == GGML_BACKEND_GPU);
 
     int64_t min_compute_capability = INT_MAX;
     for (int64_t id = 0; id < g_device_count; ++id) {
-        if (min_compute_capability > g_compute_capabilities[id]
-                && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
+        if (min_compute_capability > g_compute_capabilities[id] && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
             min_compute_capability = g_compute_capabilities[id];
         }
     }
 
+#ifdef CUDA_USE_TENSOR_CORES
+    const bool use_tensor_cores = true;
+#else
+    const bool use_tensor_cores = false;
+#endif
+
     // debug helpers
     //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
     //printf("      %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
@@ -7223,19 +7450,19 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
     //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
     //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
 
-    if (all_on_device && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
-        // KQ
+    if (all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
+        // KQ single-batch
         ggml_cuda_mul_mat_vec_p021(src0, src1, dst);
-    } else if (all_on_device && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
-        // KQV
+    } else if (all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
+        // KQV single-batch
         ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
-    } else if (all_on_device && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
+    } else if (all_on_device && use_tensor_cores && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
+        // KQ + KQV multi-batch
         ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst);
     } else if (src0->type == GGML_TYPE_F32) {
         ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
     } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
         if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) {
-
 #ifdef GGML_CUDA_FORCE_DMMV
             const bool use_mul_mat_vec_q = false;
 #else
@@ -7248,7 +7475,15 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
                 ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
             }
         } else {
-            if (g_mul_mat_q && ggml_is_quantized(src0->type) && min_compute_capability >= MIN_CC_DP4A) {
+            bool use_mul_mat_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type);
+
+            // when tensor cores are available, use them for large batch size
+            // ref: https://github.com/ggerganov/llama.cpp/pull/3776
+            if (use_tensor_cores && min_compute_capability >= CC_VOLTA && src1->ne[1] > MMQ_MAX_BATCH_SIZE) {
+                use_mul_mat_q = false;
+            }
+
+            if (use_mul_mat_q) {
                 ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true);
             } else {
                 ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
@@ -7602,10 +7837,6 @@ void ggml_cuda_set_main_device(const int main_device) {
     }
 }
 
-void ggml_cuda_set_mul_mat_q(const bool mul_mat_q) {
-    g_mul_mat_q = mul_mat_q;
-}
-
 void ggml_cuda_set_scratch_size(const size_t scratch_size) {
     // this is a hack to not completely break llama.cpp when using multiple models or contexts simultaneously
     // it still won't always work as expected, but it's better than nothing
@@ -7852,6 +8083,8 @@ static struct ggml_backend_buffer_i cuda_backend_buffer_interface = {
 };
 
 static ggml_backend_buffer_t ggml_backend_cuda_alloc_buffer(ggml_backend_t backend, size_t size) {
+    ggml_cuda_set_device(g_main_device);
+
     ggml_backend_buffer_context_cuda * ctx = new ggml_backend_buffer_context_cuda;
 
     size = std::max(size, (size_t)1); // cudaMalloc returns null for size 0
index e5a2d9d108d6a30295100eabed4e1ea185c14864..d88f261449f058d3717ab7de619b14cb8834fa51 100644 (file)
 
 // GGML internal header
 
+#include <assert.h>
 #include <stddef.h>
 #include <stdbool.h>
+#include <string.h> // memcpy
+#include <math.h>   // fabsf
 
 #ifdef __cplusplus
 extern "C" {
 #endif
 
+// static_assert should be a #define, but if it's not,
+// fall back to the _Static_assert C11 keyword.
+// if C99 - static_assert is noop
+// ref: https://stackoverflow.com/a/53923785/4039976
+#ifndef static_assert
+#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L)
+#define static_assert(cond, msg) _Static_assert(cond, msg)
+#else
+#define static_assert(cond, msg) struct global_scope_noop_trick
+#endif
+#endif
+
+// __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
+#if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))
+#ifndef __FMA__
+#define __FMA__
+#endif
+#ifndef __F16C__
+#define __F16C__
+#endif
+#ifndef __SSE3__
+#define __SSE3__
+#endif
+#endif
+
+#undef MIN
+#undef MAX
+
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+
+// 16-bit float
+// on Arm, we use __fp16
+// on x86, we use uint16_t
+#if defined(__ARM_NEON) && !defined(_MSC_VER)
+
+// if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
+//
+//   $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
+//
+#include <arm_neon.h>
+
+#define GGML_COMPUTE_FP16_TO_FP32(x) ((float) (x))
+#define GGML_COMPUTE_FP32_TO_FP16(x) (x)
+
+#define GGML_FP16_TO_FP32(x) ((float) (x))
+#define GGML_FP32_TO_FP16(x) (x)
+
+#else
+
+#ifdef __wasm_simd128__
+#include <wasm_simd128.h>
+#else
+#ifdef __POWER9_VECTOR__
+#include <altivec.h>
+#undef bool
+#define bool _Bool
+#else
+#if defined(_MSC_VER) || defined(__MINGW32__)
+#include <intrin.h>
+#else
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__)
+#if !defined(__riscv)
+#include <immintrin.h>
+#endif
+#endif
+#endif
+#endif
+#endif
+
+#ifdef __riscv_v_intrinsic
+#include <riscv_vector.h>
+#endif
+
+#ifdef __F16C__
+
+#ifdef _MSC_VER
+#define GGML_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x)))
+#define GGML_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0)
+#else
+#define GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x)
+#define GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0)
+#endif
+
+#elif defined(__POWER9_VECTOR__)
+
+#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
+#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
+/* the inline asm below is about 12% faster than the lookup method */
+#define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)
+#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
+
+static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
+    register float f;
+    register double d;
+    __asm__(
+        "mtfprd %0,%2\n"
+        "xscvhpdp %0,%0\n"
+        "frsp %1,%0\n" :
+        /* temp */ "=d"(d),
+        /* out */  "=f"(f):
+        /* in */   "r"(h));
+    return f;
+}
+
+static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
+    register double d;
+    register ggml_fp16_t r;
+    __asm__( /* xscvdphp can work on double or single precision */
+        "xscvdphp %0,%2\n"
+        "mffprd %1,%0\n" :
+        /* temp */ "=d"(d),
+        /* out */  "=r"(r):
+        /* in */   "f"(f));
+    return r;
+}
+
+#else
+
+// FP16 <-> FP32
+// ref: https://github.com/Maratyszcza/FP16
+
+static inline float fp32_from_bits(uint32_t w) {
+    union {
+        uint32_t as_bits;
+        float as_value;
+    } fp32;
+    fp32.as_bits = w;
+    return fp32.as_value;
+}
+
+static inline uint32_t fp32_to_bits(float f) {
+    union {
+        float as_value;
+        uint32_t as_bits;
+    } fp32;
+    fp32.as_value = f;
+    return fp32.as_bits;
+}
+
+static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
+    const uint32_t w = (uint32_t) h << 16;
+    const uint32_t sign = w & UINT32_C(0x80000000);
+    const uint32_t two_w = w + w;
+
+    const uint32_t exp_offset = UINT32_C(0xE0) << 23;
+#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
+    const float exp_scale = 0x1.0p-112f;
+#else
+    const float exp_scale = fp32_from_bits(UINT32_C(0x7800000));
+#endif
+    const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
+
+    const uint32_t magic_mask = UINT32_C(126) << 23;
+    const float magic_bias = 0.5f;
+    const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
+
+    const uint32_t denormalized_cutoff = UINT32_C(1) << 27;
+    const uint32_t result = sign |
+        (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value));
+    return fp32_from_bits(result);
+}
+
+static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
+#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
+    const float scale_to_inf = 0x1.0p+112f;
+    const float scale_to_zero = 0x1.0p-110f;
+#else
+    const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000));
+    const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000));
+#endif
+    float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
+
+    const uint32_t w = fp32_to_bits(f);
+    const uint32_t shl1_w = w + w;
+    const uint32_t sign = w & UINT32_C(0x80000000);
+    uint32_t bias = shl1_w & UINT32_C(0xFF000000);
+    if (bias < UINT32_C(0x71000000)) {
+        bias = UINT32_C(0x71000000);
+    }
+
+    base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
+    const uint32_t bits = fp32_to_bits(base);
+    const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
+    const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
+    const uint32_t nonsign = exp_bits + mantissa_bits;
+    return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign);
+}
+
+#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
+#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
+
+#endif // __F16C__
+
+#endif // __ARM_NEON
+
+// precomputed f32 table for f16 (256 KB)
+// defined in ggml.c, initialized in ggml_init()
+extern float ggml_table_f32_f16[1 << 16];
+
+// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
+// so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
+// This is also true for POWER9.
+#if !defined(GGML_FP16_TO_FP32) || !defined(GGML_FP32_TO_FP16)
+
+inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
+    uint16_t s;
+    memcpy(&s, &f, sizeof(uint16_t));
+    return ggml_table_f32_f16[s];
+}
+
+#define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x)
+#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
+
+#endif
 
 #define GGML_HASHTABLE_FULL ((size_t)-1)
 #define GGML_HASHTABLE_ALREADY_EXISTS ((size_t)-2)
index 12f66237474bef71840d36c9b110c6392ed7337e..9136a7cf6a1fc9b6cbcb2daf3cd2551e82427513 100644 (file)
@@ -211,6 +211,10 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
             GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
 
             NSString * sourcePath = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
+            if (sourcePath == nil) {
+                GGML_METAL_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__);
+                sourcePath = @"ggml-metal.metal";
+            }
             GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [sourcePath UTF8String]);
             NSString * src = [NSString stringWithContentsOfFile:sourcePath encoding:NSUTF8StringEncoding error:&error];
             if (error) {
@@ -235,14 +239,17 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
     // load kernels
     {
         NSError * error = nil;
-#define GGML_METAL_ADD_KERNEL(name) \
-        ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
-        ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
+
+        /*
         GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
                 (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
                 (int) ctx->pipeline_##name.threadExecutionWidth); \
+        */
+#define GGML_METAL_ADD_KERNEL(name) \
+        ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
+        ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
         if (error) { \
-          GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
+            GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
             return NULL; \
         }
 
@@ -745,9 +752,8 @@ void ggml_metal_graph_compute(
                     case GGML_OP_TRANSPOSE:
                     case GGML_OP_PERMUTE:
                         {
-                            // noop
-                            continue;
-                        } break;
+                            // noop -> next node
+                        } continue;
                     default:
                         {
                         } break;
@@ -1002,11 +1008,15 @@ void ggml_metal_graph_compute(
                         } break;
                     case GGML_OP_SOFT_MAX:
                         {
-                            const int nth = MIN(32, ne00);
+                            int nth = 32; // SIMD width
 
                             if (ne00%4 == 0) {
                                 [encoder setComputePipelineState:ctx->pipeline_soft_max_4];
                             } else {
+                                do {
+                                    nth *= 2;
+                                } while (nth <= ne00 && nth <= 1024);
+                                nth /= 2;
                                 [encoder setComputePipelineState:ctx->pipeline_soft_max];
                             }
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -1014,8 +1024,9 @@ void ggml_metal_graph_compute(
                             [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
                             [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
                             [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+                            [encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
 
-                            [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                            [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
                     case GGML_OP_DIAG_MASK_INF:
                         {
@@ -1396,14 +1407,18 @@ void ggml_metal_graph_compute(
 
                             const int nth = MIN(1024, ne00);
 
-                            const int n_past = ((int32_t *) dst->op_params)[0];
-                            const int n_dims = ((int32_t *) dst->op_params)[1];
-                            const int mode   = ((int32_t *) dst->op_params)[2];
+                            const int n_past     = ((int32_t *) dst->op_params)[0];
+                            const int n_dims     = ((int32_t *) dst->op_params)[1];
+                            const int mode       = ((int32_t *) dst->op_params)[2];
+                            const int n_orig_ctx = ((int32_t *) dst->op_params)[3];
 
-                            float freq_base;
-                            float freq_scale;
-                            memcpy(&freq_base,  (int32_t *) dst->op_params + 4, sizeof(float));
-                            memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
+                            float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+                            memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
+                            memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
+                            memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
+                            memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
+                            memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
+                            memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
 
                             switch (src0->type) {
                                 case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break;
@@ -1411,30 +1426,35 @@ void ggml_metal_graph_compute(
                                 default: GGML_ASSERT(false);
                             };
 
-                            [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
-                            [encoder setBuffer:id_src1 offset:offs_src1        atIndex:1];
-                            [encoder setBuffer:id_dst  offset:offs_dst         atIndex:2];
-                            [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:3];
-                            [encoder setBytes:&ne01    length:sizeof( int64_t) atIndex:4];
-                            [encoder setBytes:&ne02    length:sizeof( int64_t) atIndex:5];
-                            [encoder setBytes:&ne03    length:sizeof( int64_t) atIndex:6];
-                            [encoder setBytes:&nb00    length:sizeof(uint64_t) atIndex:7];
-                            [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:8];
-                            [encoder setBytes:&nb02    length:sizeof(uint64_t) atIndex:9];
-                            [encoder setBytes:&nb03    length:sizeof(uint64_t) atIndex:10];
-                            [encoder setBytes:&ne0     length:sizeof( int64_t) atIndex:11];
-                            [encoder setBytes:&ne1     length:sizeof( int64_t) atIndex:12];
-                            [encoder setBytes:&ne2     length:sizeof( int64_t) atIndex:13];
-                            [encoder setBytes:&ne3     length:sizeof( int64_t) atIndex:14];
-                            [encoder setBytes:&nb0     length:sizeof(uint64_t) atIndex:15];
-                            [encoder setBytes:&nb1     length:sizeof(uint64_t) atIndex:16];
-                            [encoder setBytes:&nb2     length:sizeof(uint64_t) atIndex:17];
-                            [encoder setBytes:&nb3     length:sizeof(uint64_t) atIndex:18];
-                            [encoder setBytes:&n_past  length:sizeof(     int) atIndex:19];
-                            [encoder setBytes:&n_dims  length:sizeof(     int) atIndex:20];
-                            [encoder setBytes:&mode    length:sizeof(     int) atIndex:21];
-                            [encoder setBytes:&freq_base  length:sizeof(float) atIndex:22];
-                            [encoder setBytes:&freq_scale length:sizeof(float) atIndex:23];
+                            [encoder setBuffer:id_src0     offset:offs_src0        atIndex:0];
+                            [encoder setBuffer:id_src1     offset:offs_src1        atIndex:1];
+                            [encoder setBuffer:id_dst      offset:offs_dst         atIndex:2];
+                            [encoder setBytes:&ne00        length:sizeof( int64_t) atIndex:3];
+                            [encoder setBytes:&ne01        length:sizeof( int64_t) atIndex:4];
+                            [encoder setBytes:&ne02        length:sizeof( int64_t) atIndex:5];
+                            [encoder setBytes:&ne03        length:sizeof( int64_t) atIndex:6];
+                            [encoder setBytes:&nb00        length:sizeof(uint64_t) atIndex:7];
+                            [encoder setBytes:&nb01        length:sizeof(uint64_t) atIndex:8];
+                            [encoder setBytes:&nb02        length:sizeof(uint64_t) atIndex:9];
+                            [encoder setBytes:&nb03        length:sizeof(uint64_t) atIndex:10];
+                            [encoder setBytes:&ne0         length:sizeof( int64_t) atIndex:11];
+                            [encoder setBytes:&ne1         length:sizeof( int64_t) atIndex:12];
+                            [encoder setBytes:&ne2         length:sizeof( int64_t) atIndex:13];
+                            [encoder setBytes:&ne3         length:sizeof( int64_t) atIndex:14];
+                            [encoder setBytes:&nb0         length:sizeof(uint64_t) atIndex:15];
+                            [encoder setBytes:&nb1         length:sizeof(uint64_t) atIndex:16];
+                            [encoder setBytes:&nb2         length:sizeof(uint64_t) atIndex:17];
+                            [encoder setBytes:&nb3         length:sizeof(uint64_t) atIndex:18];
+                            [encoder setBytes:&n_past      length:sizeof(     int) atIndex:19];
+                            [encoder setBytes:&n_dims      length:sizeof(     int) atIndex:20];
+                            [encoder setBytes:&mode        length:sizeof(     int) atIndex:21];
+                            [encoder setBytes:&n_orig_ctx  length:sizeof(     int) atIndex:22];
+                            [encoder setBytes:&freq_base   length:sizeof(   float) atIndex:23];
+                            [encoder setBytes:&freq_scale  length:sizeof(   float) atIndex:24];
+                            [encoder setBytes:&ext_factor  length:sizeof(   float) atIndex:25];
+                            [encoder setBytes:&attn_factor length:sizeof(   float) atIndex:26];
+                            [encoder setBytes:&beta_fast   length:sizeof(   float) atIndex:27];
+                            [encoder setBytes:&beta_slow   length:sizeof(   float) atIndex:28];
 
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
index f4b460564453c5b3cfb1e0b545fea0d5e2c6182e..7c35f23a7612fd75362457b3fc9d137cd37e0bfa 100644 (file)
@@ -184,36 +184,73 @@ kernel void kernel_soft_max(
         constant   int64_t & ne00,
         constant   int64_t & ne01,
         constant   int64_t & ne02,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = tgpig[2];
-    const int64_t i02 = tgpig[1];
-    const int64_t i01 = tgpig[0];
+        threadgroup float  * buf [[threadgroup(0)]],
+        uint  tgpig[[threadgroup_position_in_grid]],
+        uint  tpitg[[thread_position_in_threadgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint    ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = (tgpig) / (ne02*ne01);
+    const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
+    const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
 
     device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
     device       float * pdst  = dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
 
     // parallel max
-    float lmax = tpitg[0] < ne00 ? psrc0[tpitg[0]] : -INFINITY;
-    for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
+    float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY;
+
+    for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) {
         lmax = MAX(lmax, psrc0[i00]);
     }
-    const float max = simd_max(lmax);
+
+    float max = simd_max(lmax);
+    if (tiisg == 0) {
+        buf[sgitg] = max;
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    // broadcast, simd group number is ntg / 32
+    for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
+       if (tpitg < i) {
+           buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
+       }
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    max = buf[0];
 
     // parallel sum
     float lsum = 0.0f;
-    for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
+    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
         const float exp_psrc0 = exp(psrc0[i00] - max);
         lsum += exp_psrc0;
         // Remember the result of exp here. exp is expensive, so we really do not
-        // whish to compute it twice.
+        // wish to compute it twice.
         pdst[i00] = exp_psrc0;
     }
 
-    const float sum = simd_sum(lsum);
+    float sum = simd_sum(lsum);
+    if (tiisg == 0) {
+        buf[sgitg] = sum;
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
 
-    for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
+    // broadcast, simd group number is ntg / 32
+    for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
+       if (tpitg < i) {
+           buf[tpitg] += buf[tpitg + i];
+       }
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    sum = buf[0];
+
+    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
         pdst[i00] /= sum;
     }
 }
@@ -224,37 +261,73 @@ kernel void kernel_soft_max_4(
         constant   int64_t & ne00,
         constant   int64_t & ne01,
         constant   int64_t & ne02,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = tgpig[2];
-    const int64_t i02 = tgpig[1];
-    const int64_t i01 = tgpig[0];
+        threadgroup float  * buf [[threadgroup(0)]],
+        uint  tgpig[[threadgroup_position_in_grid]],
+        uint  tpitg[[thread_position_in_threadgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint    ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = (tgpig) / (ne02*ne01);
+    const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
+    const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
 
     device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
     device       float4 * pdst4 = (device       float4 *)(dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
 
     // parallel max
-    float4 lmax4 = tpitg[0] < ne00/4 ? psrc4[tpitg[0]] : -INFINITY;
-    for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
+    float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY;
+
+    for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) {
         lmax4 = fmax(lmax4, psrc4[i00]);
     }
-    float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
 
-    const float max = simd_max(lmax);
+    const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
+    float max = simd_max(lmax);
+    if (tiisg == 0) {
+        buf[sgitg] = max;
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    // broadcast, simd group number is ntg / 32
+    for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
+       if (tpitg < i) {
+           buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
+       }
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    max = buf[0];
 
     // parallel sum
     float4 lsum4 = 0.0f;
-    for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
+    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
         const float4 exp_psrc4 = exp(psrc4[i00] - max);
         lsum4 += exp_psrc4;
         pdst4[i00] = exp_psrc4;
     }
-    float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
 
-    const float sum = simd_sum(lsum);
+    const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
+    float sum = simd_sum(lsum);
+    if (tiisg == 0) {
+        buf[sgitg] = sum;
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    // broadcast, simd group number is ntg / 32
+    for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
+       if (tpitg < i) {
+           buf[tpitg] += buf[tpitg + i];
+       }
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    sum = buf[0];
 
-    for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
+    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
         pdst4[i00] /= sum;
     }
 }
@@ -274,7 +347,7 @@ kernel void kernel_diag_mask_inf(
         dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
     } else {
         dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
-     }
+    }
 }
 
 kernel void kernel_diag_mask_inf_8(
@@ -988,6 +1061,45 @@ kernel void kernel_alibi_f32(
     }
 }
 
+static float rope_yarn_ramp(const float low, const float high, const int i0) {
+    const float y = (i0 / 2 - low) / max(0.001f, high - low);
+    return 1.0f - min(1.0f, max(0.0f, y));
+}
+
+// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
+// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
+static void rope_yarn(
+    float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
+    thread float * cos_theta, thread float * sin_theta
+) {
+    // Get n-d rotational scaling corrected for extrapolation
+    float theta_interp = freq_scale * theta_extrap;
+    float theta = theta_interp;
+    if (ext_factor != 0.0f) {
+        float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
+        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
+
+        // Get n-d magnitude scaling corrected for interpolation
+        mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
+    }
+    *cos_theta = cos(theta) * mscale;
+    *sin_theta = sin(theta) * mscale;
+}
+
+// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
+// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
+static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
+    return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base));
+}
+
+static void rope_yarn_corr_dims(
+    int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
+) {
+    // start and end correction dims
+    dims[0] = max(0.0f,         floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
+    dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
+}
+
 typedef void (rope_t)(
         device const    void * src0,
         device const int32_t * src1,
@@ -1011,8 +1123,13 @@ typedef void (rope_t)(
         constant         int & n_past,
         constant         int & n_dims,
         constant         int & mode,
+        constant         int & n_orig_ctx,
         constant       float & freq_base,
         constant       float & freq_scale,
+        constant       float & ext_factor,
+        constant       float & attn_factor,
+        constant       float & beta_fast,
+        constant       float & beta_slow,
         uint  tiitg[[thread_index_in_threadgroup]],
         uint3 tptg[[threads_per_threadgroup]],
         uint3 tgpig[[threadgroup_position_in_grid]]);
@@ -1041,8 +1158,13 @@ kernel void kernel_rope(
         constant         int & n_past,
         constant         int & n_dims,
         constant         int & mode,
+        constant         int & n_orig_ctx,
         constant       float & freq_base,
         constant       float & freq_scale,
+        constant       float & ext_factor,
+        constant       float & attn_factor,
+        constant       float & beta_fast,
+        constant       float & beta_slow,
         uint  tiitg[[thread_index_in_threadgroup]],
         uint3 tptg[[threads_per_threadgroup]],
         uint3 tgpig[[threadgroup_position_in_grid]]) {
@@ -1052,19 +1174,22 @@ kernel void kernel_rope(
 
     const bool is_neox = mode & 2;
 
+    float corr_dims[2];
+    rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
+
     device const int32_t * pos = src1;
 
     const int64_t p = pos[i2];
 
-    const float theta_0 = freq_scale * (float)p;
+    const float theta_0 = (float)p;
     const float inv_ndims = -1.f/n_dims;
 
     if (!is_neox) {
         for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
 
             const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
-            const float cos_theta = cos(theta);
-            const float sin_theta = sin(theta);
+            float cos_theta, sin_theta;
+            rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
 
             device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
             device       T * dst_data  = (device T *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
@@ -1079,9 +1204,12 @@ kernel void kernel_rope(
         for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
             for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
 
-                const float theta = theta_0 * pow(freq_base, inv_ndims*ic - ib);
-                const float cos_theta = cos(theta);
-                const float sin_theta = sin(theta);
+                // simplified from `(ib * n_dims + ic) * inv_ndims`
+                const float cur_rot = inv_ndims*ic - ib;
+
+                const float theta = theta_0 * pow(freq_base, cur_rot);
+                float cos_theta, sin_theta;
+                rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
 
                 const int64_t i0 = ib*n_dims + ic/2;
 
diff --git a/src/ggml-quants.c b/src/ggml-quants.c
new file mode 100644 (file)
index 0000000..740be6d
--- /dev/null
@@ -0,0 +1,7282 @@
+#include "ggml-quants.h"
+#include "ggml-impl.h"
+
+#include <math.h>
+#include <string.h>
+#include <assert.h>
+#include <float.h>
+
+#ifdef __ARM_NEON
+
+// if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
+//
+//   $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
+//
+#include <arm_neon.h>
+
+#if !defined(__aarch64__)
+inline static int32_t vaddvq_s16(int16x8_t v) {
+    return
+        (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
+        (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
+        (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
+        (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
+}
+
+inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
+    int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
+    int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
+    return vcombine_s16(a0, b0);
+}
+
+inline static int32_t vaddvq_s32(int32x4_t v) {
+    return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
+}
+#endif
+
+#else
+
+#ifdef __wasm_simd128__
+#include <wasm_simd128.h>
+#else
+#ifdef __POWER9_VECTOR__
+#include <altivec.h>
+#undef bool
+#define bool _Bool
+#else
+#if defined(_MSC_VER) || defined(__MINGW32__)
+#include <intrin.h>
+#else
+#if !defined(__riscv) && !defined(__s390__)
+#include <immintrin.h>
+#endif
+#endif
+#endif
+#endif
+#endif
+
+#ifdef __riscv_v_intrinsic
+#include <riscv_vector.h>
+#endif
+
+#undef MIN
+#undef MAX
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+
+#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
+
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
+// multiply int8_t, add results pairwise twice
+static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
+    // Get absolute values of x vectors
+    const __m128i ax = _mm_sign_epi8(x, x);
+    // Sign the values of the y vectors
+    const __m128i sy = _mm_sign_epi8(y, x);
+    // Perform multiplication and create 16-bit values
+    const __m128i dot = _mm_maddubs_epi16(ax, sy);
+    const __m128i ones = _mm_set1_epi16(1);
+    return _mm_madd_epi16(ones, dot);
+}
+
+#if __AVX__ || __AVX2__ || __AVX512F__
+// horizontally add 8 floats
+static inline float hsum_float_8(const __m256 x) {
+    __m128 res = _mm256_extractf128_ps(x, 1);
+    res = _mm_add_ps(res, _mm256_castps256_ps128(x));
+    res = _mm_add_ps(res, _mm_movehl_ps(res, res));
+    res = _mm_add_ss(res, _mm_movehdup_ps(res));
+    return _mm_cvtss_f32(res);
+}
+
+// horizontally add 8 int32_t
+static inline int hsum_i32_8(const __m256i a) {
+    const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
+    const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
+    const __m128i sum64 = _mm_add_epi32(hi64, sum128);
+    const __m128i hi32  = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
+    return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
+}
+
+// horizontally add 4 int32_t
+static inline int hsum_i32_4(const __m128i a) {
+    const __m128i hi64 = _mm_unpackhi_epi64(a, a);
+    const __m128i sum64 = _mm_add_epi32(hi64, a);
+    const __m128i hi32  = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
+    return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
+}
+
+#if defined(__AVX2__) || defined(__AVX512F__)
+// spread 32 bits to 32 bytes { 0x00, 0xFF }
+static inline __m256i bytes_from_bits_32(const uint8_t * x) {
+    uint32_t x32;
+    memcpy(&x32, x, sizeof(uint32_t));
+    const __m256i shuf_mask = _mm256_set_epi64x(
+            0x0303030303030303, 0x0202020202020202,
+            0x0101010101010101, 0x0000000000000000);
+    __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask);
+    const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);
+    bytes = _mm256_or_si256(bytes, bit_mask);
+    return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1));
+}
+
+// Unpack 32 4-bit fields into 32 bytes
+// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
+static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
+{
+    const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi);
+    const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp);
+    const __m256i lowMask = _mm256_set1_epi8( 0xF );
+    return _mm256_and_si256(lowMask, bytes);
+}
+
+// add int16_t pairwise and return as float vector
+static inline __m256 sum_i16_pairs_float(const __m256i x) {
+    const __m256i ones = _mm256_set1_epi16(1);
+    const __m256i summed_pairs = _mm256_madd_epi16(ones, x);
+    return _mm256_cvtepi32_ps(summed_pairs);
+}
+
+static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
+#if __AVXVNNI__
+    const __m256i zero = _mm256_setzero_si256();
+    const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
+    return _mm256_cvtepi32_ps(summed_pairs);
+#else
+    // Perform multiplication and create 16-bit values
+    const __m256i dot = _mm256_maddubs_epi16(ax, sy);
+    return sum_i16_pairs_float(dot);
+#endif
+}
+
+// multiply int8_t, add results pairwise twice and return as float vector
+static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
+#if __AVXVNNIINT8__
+    const __m256i zero = _mm256_setzero_si256();
+    const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y);
+    return _mm256_cvtepi32_ps(summed_pairs);
+#else
+    // Get absolute values of x vectors
+    const __m256i ax = _mm256_sign_epi8(x, x);
+    // Sign the values of the y vectors
+    const __m256i sy = _mm256_sign_epi8(y, x);
+    return mul_sum_us8_pairs_float(ax, sy);
+#endif
+}
+
+static inline __m128i packNibbles( __m256i bytes )
+{
+    // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
+#if __AVX512F__
+    const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4);   // 0000_0000_abcd_0000
+    bytes = _mm256_or_si256(bytes, bytes_srli_4);               // 0000_abcd_abcd_efgh
+    return _mm256_cvtepi16_epi8(bytes);                         // abcd_efgh
+#else
+    const __m256i lowByte = _mm256_set1_epi16( 0xFF );
+    __m256i high = _mm256_andnot_si256( lowByte, bytes );
+    __m256i low = _mm256_and_si256( lowByte, bytes );
+    high = _mm256_srli_epi16( high, 4 );
+    bytes = _mm256_or_si256( low, high );
+
+    // Compress uint16_t lanes into bytes
+    __m128i r0 = _mm256_castsi256_si128( bytes );
+    __m128i r1 = _mm256_extracti128_si256( bytes, 1 );
+    return _mm_packus_epi16( r0, r1 );
+#endif
+}
+#elif defined(__AVX__)
+// spread 32 bits to 32 bytes { 0x00, 0xFF }
+static inline __m256i bytes_from_bits_32(const uint8_t * x) {
+    uint32_t x32;
+    memcpy(&x32, x, sizeof(uint32_t));
+    const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
+    const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202);
+    __m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl);
+    __m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh);
+    const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe);
+    bytesl = _mm_or_si128(bytesl, bit_mask);
+    bytesh = _mm_or_si128(bytesh, bit_mask);
+    bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1));
+    bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1));
+    return MM256_SET_M128I(bytesh, bytesl);
+}
+
+// Unpack 32 4-bit fields into 32 bytes
+// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
+static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
+{
+    // Load 16 bytes from memory
+    __m128i tmpl = _mm_loadu_si128((const __m128i *)rsi);
+    __m128i tmph = _mm_srli_epi16(tmpl, 4);
+    const __m128i lowMask = _mm_set1_epi8(0xF);
+    tmpl = _mm_and_si128(lowMask, tmpl);
+    tmph = _mm_and_si128(lowMask, tmph);
+    return MM256_SET_M128I(tmph, tmpl);
+}
+
+// add int16_t pairwise and return as float vector
+static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
+    const __m128i ones = _mm_set1_epi16(1);
+    const __m128i summed_pairsl = _mm_madd_epi16(ones, xl);
+    const __m128i summed_pairsh = _mm_madd_epi16(ones, xh);
+    const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl);
+    return _mm256_cvtepi32_ps(summed_pairs);
+}
+
+static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
+    const __m128i axl = _mm256_castsi256_si128(ax);
+    const __m128i axh = _mm256_extractf128_si256(ax, 1);
+    const __m128i syl = _mm256_castsi256_si128(sy);
+    const __m128i syh = _mm256_extractf128_si256(sy, 1);
+    // Perform multiplication and create 16-bit values
+    const __m128i dotl = _mm_maddubs_epi16(axl, syl);
+    const __m128i doth = _mm_maddubs_epi16(axh, syh);
+    return sum_i16_pairs_float(doth, dotl);
+}
+
+// multiply int8_t, add results pairwise twice and return as float vector
+static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
+    const __m128i xl = _mm256_castsi256_si128(x);
+    const __m128i xh = _mm256_extractf128_si256(x, 1);
+    const __m128i yl = _mm256_castsi256_si128(y);
+    const __m128i yh = _mm256_extractf128_si256(y, 1);
+    // Get absolute values of x vectors
+    const __m128i axl = _mm_sign_epi8(xl, xl);
+    const __m128i axh = _mm_sign_epi8(xh, xh);
+    // Sign the values of the y vectors
+    const __m128i syl = _mm_sign_epi8(yl, xl);
+    const __m128i syh = _mm_sign_epi8(yh, xh);
+    // Perform multiplication and create 16-bit values
+    const __m128i dotl = _mm_maddubs_epi16(axl, syl);
+    const __m128i doth = _mm_maddubs_epi16(axh, syh);
+    return sum_i16_pairs_float(doth, dotl);
+}
+
+static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
+{
+    // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
+    const __m128i lowByte = _mm_set1_epi16( 0xFF );
+    __m128i high = _mm_andnot_si128( lowByte, bytes1 );
+    __m128i low = _mm_and_si128( lowByte, bytes1 );
+    high = _mm_srli_epi16( high, 4 );
+    bytes1 = _mm_or_si128( low, high );
+    high = _mm_andnot_si128( lowByte, bytes2 );
+    low = _mm_and_si128( lowByte, bytes2 );
+    high = _mm_srli_epi16( high, 4 );
+    bytes2 = _mm_or_si128( low, high );
+
+    return _mm_packus_epi16( bytes1, bytes2);
+}
+#endif
+#elif defined(__SSSE3__)
+// horizontally add 4x4 floats
+static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) {
+    __m128 res_0 =_mm_hadd_ps(a, b);
+    __m128 res_1 =_mm_hadd_ps(c, d);
+    __m128 res =_mm_hadd_ps(res_0, res_1);
+    res =_mm_hadd_ps(res, res);
+    res =_mm_hadd_ps(res, res);
+
+    return _mm_cvtss_f32(res);
+}
+#endif // __AVX__ || __AVX2__ || __AVX512F__
+#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
+
+#if defined(__ARM_NEON)
+
+#if !defined(__aarch64__)
+
+inline static int32_t vaddvq_s32(int32x4_t v) {
+    return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
+}
+
+inline static float vaddvq_f32(float32x4_t v) {
+    return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
+}
+
+inline static float vmaxvq_f32(float32x4_t v) {
+    return
+        MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
+            MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
+}
+
+inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
+    int32x4_t res;
+
+    res[0] = roundf(vgetq_lane_f32(v, 0));
+    res[1] = roundf(vgetq_lane_f32(v, 1));
+    res[2] = roundf(vgetq_lane_f32(v, 2));
+    res[3] = roundf(vgetq_lane_f32(v, 3));
+
+    return res;
+}
+
+#endif
+#endif
+
+#if defined(__ARM_NEON) || defined(__wasm_simd128__)
+#define B1(c,s,n)  0x ## n ## c ,  0x ## n ## s
+#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
+#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)
+#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)
+#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)
+#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)
+#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)
+#define B8(c,s  ) B7(c,s,     c), B7(c,s,     s)
+
+// precomputed tables for expanding 8bits to 8 bytes:
+static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4
+static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
+#endif
+
+// reference implementation for deterministic creation of model files
+void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
+    static const int qk = QK4_0;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        float amax = 0.0f; // absolute max
+        float max  = 0.0f;
+
+        for (int j = 0; j < qk; j++) {
+            const float v = x[i*qk + j];
+            if (amax < fabsf(v)) {
+                amax = fabsf(v);
+                max  = v;
+            }
+        }
+
+        const float d  = max / -8;
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+
+        for (int j = 0; j < qk/2; ++j) {
+            const float x0 = x[i*qk + 0    + j]*id;
+            const float x1 = x[i*qk + qk/2 + j]*id;
+
+            const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
+            const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
+
+            y[i].qs[j]  = xi0;
+            y[i].qs[j] |= xi1 << 4;
+        }
+    }
+}
+
+void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
+    quantize_row_q4_0_reference(x, y, k);
+}
+
+void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k) {
+    const int qk = QK4_1;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        float min = FLT_MAX;
+        float max = -FLT_MAX;
+
+        for (int j = 0; j < qk; j++) {
+            const float v = x[i*qk + j];
+
+            if (v < min) min = v;
+            if (v > max) max = v;
+        }
+
+        const float d  = (max - min) / ((1 << 4) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+        y[i].m = GGML_FP32_TO_FP16(min);
+
+        for (int j = 0; j < qk/2; ++j) {
+            const float x0 = (x[i*qk + 0    + j] - min)*id;
+            const float x1 = (x[i*qk + qk/2 + j] - min)*id;
+
+            const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
+            const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
+
+            y[i].qs[j]  = xi0;
+            y[i].qs[j] |= xi1 << 4;
+        }
+    }
+}
+
+void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
+    quantize_row_q4_1_reference(x, y, k);
+}
+
+void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) {
+    static const int qk = QK5_0;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        float amax = 0.0f; // absolute max
+        float max  = 0.0f;
+
+        for (int j = 0; j < qk; j++) {
+            const float v = x[i*qk + j];
+            if (amax < fabsf(v)) {
+                amax = fabsf(v);
+                max  = v;
+            }
+        }
+
+        const float d  = max / -16;
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+
+        uint32_t qh = 0;
+
+        for (int j = 0; j < qk/2; ++j) {
+            const float x0 = x[i*qk + 0    + j]*id;
+            const float x1 = x[i*qk + qk/2 + j]*id;
+
+            const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
+            const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
+
+            y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
+
+            // get the 5-th bit and store it in qh at the right position
+            qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+            qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);
+        }
+
+        memcpy(&y[i].qh, &qh, sizeof(qh));
+    }
+}
+
+void quantize_row_q5_0(const float * restrict x, void * restrict y, int k) {
+    quantize_row_q5_0_reference(x, y, k);
+}
+
+void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) {
+    const int qk = QK5_1;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        float min = FLT_MAX;
+        float max = -FLT_MAX;
+
+        for (int j = 0; j < qk; j++) {
+            const float v = x[i*qk + j];
+
+            if (v < min) min = v;
+            if (v > max) max = v;
+        }
+
+        const float d  = (max - min) / ((1 << 5) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+        y[i].m = GGML_FP32_TO_FP16(min);
+
+        uint32_t qh = 0;
+
+        for (int j = 0; j < qk/2; ++j) {
+            const float x0 = (x[i*qk + 0    + j] - min)*id;
+            const float x1 = (x[i*qk + qk/2 + j] - min)*id;
+
+            const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
+            const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
+
+            y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
+
+            // get the 5-th bit and store it in qh at the right position
+            qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+            qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);
+        }
+
+        memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
+    }
+}
+
+void quantize_row_q5_1(const float * restrict x, void * restrict y, int k) {
+    quantize_row_q5_1_reference(x, y, k);
+}
+
+// reference implementation for deterministic creation of model files
+void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) {
+    assert(k % QK8_0 == 0);
+    const int nb = k / QK8_0;
+
+    for (int i = 0; i < nb; i++) {
+        float amax = 0.0f; // absolute max
+
+        for (int j = 0; j < QK8_0; j++) {
+            const float v = x[i*QK8_0 + j];
+            amax = MAX(amax, fabsf(v));
+        }
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+
+        for (int j = 0; j < QK8_0; ++j) {
+            const float x0 = x[i*QK8_0 + j]*id;
+
+            y[i].qs[j] = roundf(x0);
+        }
+    }
+}
+
+void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
+    assert(QK8_0 == 32);
+    assert(k % QK8_0 == 0);
+    const int nb = k / QK8_0;
+
+    block_q8_0 * restrict y = vy;
+
+#if defined(__ARM_NEON)
+    for (int i = 0; i < nb; i++) {
+        float32x4_t srcv [8];
+        float32x4_t asrcv[8];
+        float32x4_t amaxv[8];
+
+        for (int j = 0; j < 8; j++) srcv[j]  = vld1q_f32(x + i*32 + 4*j);
+        for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
+
+        for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
+        for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
+        for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
+
+        const float amax = vmaxvq_f32(amaxv[0]);
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+
+        for (int j = 0; j < 8; j++) {
+            const float32x4_t v  = vmulq_n_f32(srcv[j], id);
+            const int32x4_t   vi = vcvtnq_s32_f32(v);
+
+            y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
+            y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
+            y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
+            y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
+        }
+    }
+#elif defined(__wasm_simd128__)
+    for (int i = 0; i < nb; i++) {
+        v128_t srcv [8];
+        v128_t asrcv[8];
+        v128_t amaxv[8];
+
+        for (int j = 0; j < 8; j++) srcv[j]  = wasm_v128_load(x + i*32 + 4*j);
+        for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);
+
+        for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);
+        for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);
+        for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);
+
+        const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),
+                                   wasm_f32x4_extract_lane(amaxv[0], 1)),
+                               MAX(wasm_f32x4_extract_lane(amaxv[0], 2),
+                                   wasm_f32x4_extract_lane(amaxv[0], 3)));
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+
+        for (int j = 0; j < 8; j++) {
+            const v128_t v  = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));
+            const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);
+
+            y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);
+            y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);
+            y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);
+            y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);
+        }
+    }
+#elif defined(__AVX2__) || defined(__AVX__)
+    for (int i = 0; i < nb; i++) {
+        // Load elements into 4 AVX vectors
+        __m256 v0 = _mm256_loadu_ps( x );
+        __m256 v1 = _mm256_loadu_ps( x + 8 );
+        __m256 v2 = _mm256_loadu_ps( x + 16 );
+        __m256 v3 = _mm256_loadu_ps( x + 24 );
+        x += 32;
+
+        // Compute max(abs(e)) for the block
+        const __m256 signBit = _mm256_set1_ps( -0.0f );
+        __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
+
+        __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
+        max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
+        max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
+        const float maxScalar = _mm_cvtss_f32( max4 );
+
+        // Quantize these floats
+        const float d = maxScalar / 127.f;
+        y[i].d = GGML_FP32_TO_FP16(d);
+        const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
+        const __m256 mul = _mm256_set1_ps( id );
+
+        // Apply the multiplier
+        v0 = _mm256_mul_ps( v0, mul );
+        v1 = _mm256_mul_ps( v1, mul );
+        v2 = _mm256_mul_ps( v2, mul );
+        v3 = _mm256_mul_ps( v3, mul );
+
+        // Round to nearest integer
+        v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
+        v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
+        v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
+        v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
+
+        // Convert floats to integers
+        __m256i i0 = _mm256_cvtps_epi32( v0 );
+        __m256i i1 = _mm256_cvtps_epi32( v1 );
+        __m256i i2 = _mm256_cvtps_epi32( v2 );
+        __m256i i3 = _mm256_cvtps_epi32( v3 );
+
+#if defined(__AVX2__)
+        // Convert int32 to int16
+        i0 = _mm256_packs_epi32( i0, i1 );     // 0, 1, 2, 3,  8, 9, 10, 11,  4, 5, 6, 7, 12, 13, 14, 15
+        i2 = _mm256_packs_epi32( i2, i3 );     // 16, 17, 18, 19,  24, 25, 26, 27,  20, 21, 22, 23, 28, 29, 30, 31
+                                            // Convert int16 to int8
+        i0 = _mm256_packs_epi16( i0, i2 );     // 0, 1, 2, 3,  8, 9, 10, 11,  16, 17, 18, 19,  24, 25, 26, 27,  4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
+
+        // We got our precious signed bytes, but the order is now wrong
+        // These AVX2 pack instructions process 16-byte pieces independently
+        // The following instruction is fixing the order
+        const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
+        i0 = _mm256_permutevar8x32_epi32( i0, perm );
+
+        _mm256_storeu_si256((__m256i *)y[i].qs, i0);
+#else
+        // Since we don't have in AVX some necessary functions,
+        // we split the registers in half and call AVX2 analogs from SSE
+        __m128i ni0 = _mm256_castsi256_si128( i0 );
+        __m128i ni1 = _mm256_extractf128_si256( i0, 1);
+        __m128i ni2 = _mm256_castsi256_si128( i1 );
+        __m128i ni3 = _mm256_extractf128_si256( i1, 1);
+        __m128i ni4 = _mm256_castsi256_si128( i2 );
+        __m128i ni5 = _mm256_extractf128_si256( i2, 1);
+        __m128i ni6 = _mm256_castsi256_si128( i3 );
+        __m128i ni7 = _mm256_extractf128_si256( i3, 1);
+
+        // Convert int32 to int16
+        ni0 = _mm_packs_epi32( ni0, ni1 );
+        ni2 = _mm_packs_epi32( ni2, ni3 );
+        ni4 = _mm_packs_epi32( ni4, ni5 );
+        ni6 = _mm_packs_epi32( ni6, ni7 );
+        // Convert int16 to int8
+        ni0 = _mm_packs_epi16( ni0, ni2 );
+        ni4 = _mm_packs_epi16( ni4, ni6 );
+
+        _mm_storeu_si128((__m128i *)(y[i].qs +  0), ni0);
+        _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
+#endif
+    }
+#elif defined(__riscv_v_intrinsic)
+
+    size_t vl = __riscv_vsetvl_e32m4(QK8_0);
+
+    for (int i = 0; i < nb; i++) {
+        // load elements
+        vfloat32m4_t v_x   = __riscv_vle32_v_f32m4(x+i*QK8_0, vl);
+
+        vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
+        vfloat32m1_t tmp   = __riscv_vfmv_v_f_f32m1(0.0f, vl);
+        vfloat32m1_t vmax  = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
+        float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+
+        vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
+
+        // convert to integer
+        vint16m2_t   vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
+        vint8m1_t    vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
+
+        // store result
+        __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
+    }
+#else
+    GGML_UNUSED(nb);
+    // scalar
+    quantize_row_q8_0_reference(x, y, k);
+#endif
+}
+
+// reference implementation for deterministic creation of model files
+void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k) {
+    assert(QK8_1 == 32);
+    assert(k % QK8_1 == 0);
+    const int nb = k / QK8_1;
+
+    for (int i = 0; i < nb; i++) {
+        float amax = 0.0f; // absolute max
+
+        for (int j = 0; j < QK8_1; j++) {
+            const float v = x[i*QK8_1 + j];
+            amax = MAX(amax, fabsf(v));
+        }
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = d;
+
+        int sum = 0;
+
+        for (int j = 0; j < QK8_1/2; ++j) {
+            const float v0 = x[i*QK8_1           + j]*id;
+            const float v1 = x[i*QK8_1 + QK8_1/2 + j]*id;
+
+            y[i].qs[          j] = roundf(v0);
+            y[i].qs[QK8_1/2 + j] = roundf(v1);
+
+            sum += y[i].qs[          j];
+            sum += y[i].qs[QK8_1/2 + j];
+        }
+
+        y[i].s = sum*d;
+    }
+}
+
+void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) {
+    assert(k % QK8_1 == 0);
+    const int nb = k / QK8_1;
+
+    block_q8_1 * restrict y = vy;
+
+#if defined(__ARM_NEON)
+    for (int i = 0; i < nb; i++) {
+        float32x4_t srcv [8];
+        float32x4_t asrcv[8];
+        float32x4_t amaxv[8];
+
+        for (int j = 0; j < 8; j++) srcv[j]  = vld1q_f32(x + i*32 + 4*j);
+        for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
+
+        for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
+        for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
+        for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
+
+        const float amax = vmaxvq_f32(amaxv[0]);
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = d;
+
+        int32x4_t accv = vdupq_n_s32(0);
+
+        for (int j = 0; j < 8; j++) {
+            const float32x4_t v  = vmulq_n_f32(srcv[j], id);
+            const int32x4_t   vi = vcvtnq_s32_f32(v);
+
+            y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
+            y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
+            y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
+            y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
+
+            accv = vaddq_s32(accv, vi);
+        }
+
+        y[i].s = d * vaddvq_s32(accv);
+    }
+#elif defined(__wasm_simd128__)
+    for (int i = 0; i < nb; i++) {
+        v128_t srcv [8];
+        v128_t asrcv[8];
+        v128_t amaxv[8];
+
+        for (int j = 0; j < 8; j++) srcv[j]  = wasm_v128_load(x + i*32 + 4*j);
+        for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);
+
+        for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);
+        for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);
+        for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);
+
+        const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),
+                                   wasm_f32x4_extract_lane(amaxv[0], 1)),
+                               MAX(wasm_f32x4_extract_lane(amaxv[0], 2),
+                                   wasm_f32x4_extract_lane(amaxv[0], 3)));
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = d;
+
+        v128_t accv = wasm_i32x4_splat(0);
+
+        for (int j = 0; j < 8; j++) {
+            const v128_t v  = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));
+            const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);
+
+            y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);
+            y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);
+            y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);
+            y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);
+
+            accv = wasm_i32x4_add(accv, vi);
+        }
+
+        y[i].s = d * (wasm_i32x4_extract_lane(accv, 0) +
+                      wasm_i32x4_extract_lane(accv, 1) +
+                      wasm_i32x4_extract_lane(accv, 2) +
+                      wasm_i32x4_extract_lane(accv, 3));
+    }
+#elif defined(__AVX2__) || defined(__AVX__)
+    for (int i = 0; i < nb; i++) {
+        // Load elements into 4 AVX vectors
+        __m256 v0 = _mm256_loadu_ps( x );
+        __m256 v1 = _mm256_loadu_ps( x + 8 );
+        __m256 v2 = _mm256_loadu_ps( x + 16 );
+        __m256 v3 = _mm256_loadu_ps( x + 24 );
+        x += 32;
+
+        // Compute max(abs(e)) for the block
+        const __m256 signBit = _mm256_set1_ps( -0.0f );
+        __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
+
+        __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
+        max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
+        max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
+        const float maxScalar = _mm_cvtss_f32( max4 );
+
+        // Quantize these floats
+        const float d = maxScalar / 127.f;
+        y[i].d = d;
+        const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
+        const __m256 mul = _mm256_set1_ps( id );
+
+        // Apply the multiplier
+        v0 = _mm256_mul_ps( v0, mul );
+        v1 = _mm256_mul_ps( v1, mul );
+        v2 = _mm256_mul_ps( v2, mul );
+        v3 = _mm256_mul_ps( v3, mul );
+
+        // Round to nearest integer
+        v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
+        v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
+        v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
+        v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
+
+        // Convert floats to integers
+        __m256i i0 = _mm256_cvtps_epi32( v0 );
+        __m256i i1 = _mm256_cvtps_epi32( v1 );
+        __m256i i2 = _mm256_cvtps_epi32( v2 );
+        __m256i i3 = _mm256_cvtps_epi32( v3 );
+
+#if defined(__AVX2__)
+        // Compute the sum of the quants and set y[i].s
+        y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
+
+        // Convert int32 to int16
+        i0 = _mm256_packs_epi32( i0, i1 );     // 0, 1, 2, 3,  8, 9, 10, 11,  4, 5, 6, 7, 12, 13, 14, 15
+        i2 = _mm256_packs_epi32( i2, i3 );     // 16, 17, 18, 19,  24, 25, 26, 27,  20, 21, 22, 23, 28, 29, 30, 31
+                                            // Convert int16 to int8
+        i0 = _mm256_packs_epi16( i0, i2 );     // 0, 1, 2, 3,  8, 9, 10, 11,  16, 17, 18, 19,  24, 25, 26, 27,  4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
+
+        // We got our precious signed bytes, but the order is now wrong
+        // These AVX2 pack instructions process 16-byte pieces independently
+        // The following instruction is fixing the order
+        const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
+        i0 = _mm256_permutevar8x32_epi32( i0, perm );
+
+        _mm256_storeu_si256((__m256i *)y[i].qs, i0);
+#else
+        // Since we don't have in AVX some necessary functions,
+        // we split the registers in half and call AVX2 analogs from SSE
+        __m128i ni0 = _mm256_castsi256_si128( i0 );
+        __m128i ni1 = _mm256_extractf128_si256( i0, 1);
+        __m128i ni2 = _mm256_castsi256_si128( i1 );
+        __m128i ni3 = _mm256_extractf128_si256( i1, 1);
+        __m128i ni4 = _mm256_castsi256_si128( i2 );
+        __m128i ni5 = _mm256_extractf128_si256( i2, 1);
+        __m128i ni6 = _mm256_castsi256_si128( i3 );
+        __m128i ni7 = _mm256_extractf128_si256( i3, 1);
+
+        // Compute the sum of the quants and set y[i].s
+        const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3));
+        const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7));
+        y[i].s = d * hsum_i32_4(_mm_add_epi32(s0, s1));
+
+        // Convert int32 to int16
+        ni0 = _mm_packs_epi32( ni0, ni1 );
+        ni2 = _mm_packs_epi32( ni2, ni3 );
+        ni4 = _mm_packs_epi32( ni4, ni5 );
+        ni6 = _mm_packs_epi32( ni6, ni7 );
+        // Convert int16 to int8
+        ni0 = _mm_packs_epi16( ni0, ni2 );
+        ni4 = _mm_packs_epi16( ni4, ni6 );
+
+        _mm_storeu_si128((__m128i *)(y[i].qs +  0), ni0);
+        _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
+#endif
+    }
+#elif defined(__riscv_v_intrinsic)
+
+    size_t vl = __riscv_vsetvl_e32m4(QK8_1);
+
+    for (int i = 0; i < nb; i++) {
+        // load elements
+        vfloat32m4_t v_x   = __riscv_vle32_v_f32m4(x+i*QK8_1, vl);
+
+        vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
+        vfloat32m1_t tmp   = __riscv_vfmv_v_f_f32m1(0.0, vl);
+        vfloat32m1_t vmax  = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
+        float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
+
+        const float d  = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = d;
+
+        vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
+
+        // convert to integer
+        vint16m2_t   vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
+        vint8m1_t    vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
+
+        // store result
+        __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
+
+        // compute sum for y[i].s
+        vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl);
+        vint16m1_t vwrs = __riscv_vwredsum_vs_i8m1_i16m1(vs, tmp2, vl);
+
+        // set y[i].s
+        int sum = __riscv_vmv_x_s_i16m1_i16(vwrs);
+        y[i].s = sum*d;
+    }
+#else
+    GGML_UNUSED(nb);
+    // scalar
+    quantize_row_q8_1_reference(x, y, k);
+#endif
+}
+
+void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k) {
+    static const int qk = QK4_0;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+
+        for (int j = 0; j < qk/2; ++j) {
+            const int x0 = (x[i].qs[j] & 0x0F) - 8;
+            const int x1 = (x[i].qs[j] >>   4) - 8;
+
+            y[i*qk + j + 0   ] = x0*d;
+            y[i*qk + j + qk/2] = x1*d;
+        }
+    }
+}
+
+void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int k) {
+    static const int qk = QK4_1;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+        const float m = GGML_FP16_TO_FP32(x[i].m);
+
+        for (int j = 0; j < qk/2; ++j) {
+            const int x0 = (x[i].qs[j] & 0x0F);
+            const int x1 = (x[i].qs[j] >>   4);
+
+            y[i*qk + j + 0   ] = x0*d + m;
+            y[i*qk + j + qk/2] = x1*d + m;
+        }
+    }
+}
+
+void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int k) {
+    static const int qk = QK5_0;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+
+        uint32_t qh;
+        memcpy(&qh, x[i].qh, sizeof(qh));
+
+        for (int j = 0; j < qk/2; ++j) {
+            const uint8_t xh_0 = ((qh >> (j +  0)) << 4) & 0x10;
+            const uint8_t xh_1 = ((qh >> (j + 12))     ) & 0x10;
+
+            const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16;
+            const int32_t x1 = ((x[i].qs[j] >>   4) | xh_1) - 16;
+
+            y[i*qk + j + 0   ] = x0*d;
+            y[i*qk + j + qk/2] = x1*d;
+        }
+    }
+}
+
+void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int k) {
+    static const int qk = QK5_1;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+        const float m = GGML_FP16_TO_FP32(x[i].m);
+
+        uint32_t qh;
+        memcpy(&qh, x[i].qh, sizeof(qh));
+
+        for (int j = 0; j < qk/2; ++j) {
+            const uint8_t xh_0 = ((qh >> (j +  0)) << 4) & 0x10;
+            const uint8_t xh_1 = ((qh >> (j + 12))     ) & 0x10;
+
+            const int x0 = (x[i].qs[j] & 0x0F) | xh_0;
+            const int x1 = (x[i].qs[j] >>   4) | xh_1;
+
+            y[i*qk + j + 0   ] = x0*d + m;
+            y[i*qk + j + qk/2] = x1*d + m;
+        }
+    }
+}
+
+void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int k) {
+    static const int qk = QK8_0;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+
+        for (int j = 0; j < qk; ++j) {
+            y[i*qk + j] = x[i].qs[j]*d;
+        }
+    }
+}
+
+//
+// 2-6 bit quantization in super-blocks
+//
+
+//
+// ===================== Helper functions
+//
+static inline int nearest_int(float fval) {
+    assert(fval <= 4194303.f);
+    float val = fval + 12582912.f;
+    int i; memcpy(&i, &val, sizeof(int));
+    return (i & 0x007fffff) - 0x00400000;
+}
+
+static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, int rmse_type) {
+    float max = 0;
+    float amax = 0;
+    for (int i = 0; i < n; ++i) {
+        float ax = fabsf(x[i]);
+        if (ax > amax) { amax = ax; max = x[i]; }
+    }
+    if (amax < 1e-30f) { // all zero
+        for (int i = 0; i < n; ++i) {
+            L[i] = 0;
+        }
+        return 0.f;
+    }
+    float iscale = -nmax / max;
+    if (rmse_type == 0) {
+        for (int i = 0; i < n; ++i) {
+            int l = nearest_int(iscale * x[i]);
+            L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
+        }
+        return 1/iscale;
+    }
+    bool return_early = false;
+    if (rmse_type < 0) {
+        rmse_type = -rmse_type;
+        return_early = true;
+    }
+    int weight_type = rmse_type%2;
+    float sumlx = 0;
+    float suml2 = 0;
+    for (int i = 0; i < n; ++i) {
+        int l = nearest_int(iscale * x[i]);
+        l = MAX(-nmax, MIN(nmax-1, l));
+        L[i] = l + nmax;
+        float w = weight_type == 1 ? x[i] * x[i] : 1;
+        sumlx += w*x[i]*l;
+        suml2 += w*l*l;
+    }
+    float scale = sumlx/suml2;
+    if (return_early) return suml2 > 0 ? 0.5f*(scale + 1/iscale) : 1/iscale;
+    float best = scale * sumlx;
+    for (int is = -9; is <= 9; ++is) {
+        if (is == 0) {
+            continue;
+        }
+        iscale = -(nmax + 0.1f*is) / max;
+        sumlx = suml2 = 0;
+        for (int i = 0; i < n; ++i) {
+            int l = nearest_int(iscale * x[i]);
+            l = MAX(-nmax, MIN(nmax-1, l));
+            float w = weight_type == 1 ? x[i] * x[i] : 1;
+            sumlx += w*x[i]*l;
+            suml2 += w*l*l;
+        }
+        if (suml2 > 0 && sumlx*sumlx > best*suml2) {
+            for (int i = 0; i < n; ++i) {
+                int l = nearest_int(iscale * x[i]);
+                L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
+            }
+            scale = sumlx/suml2; best = scale*sumlx;
+        }
+    }
+    return scale;
+}
+
+static float make_q3_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, bool do_rmse) {
+    float max = 0;
+    float amax = 0;
+    for (int i = 0; i < n; ++i) {
+        float ax = fabsf(x[i]);
+        if (ax > amax) { amax = ax; max = x[i]; }
+    }
+    if (!amax) { // all zero
+        for (int i = 0; i < n; ++i) { L[i] = 0; }
+        return 0.f;
+    }
+    float iscale = -nmax / max;
+    if (do_rmse) {
+        float sumlx = 0;
+        float suml2 = 0;
+        for (int i = 0; i < n; ++i) {
+            int l = nearest_int(iscale * x[i]);
+            l = MAX(-nmax, MIN(nmax-1, l));
+            L[i] = l;
+            float w = x[i]*x[i];
+            sumlx += w*x[i]*l;
+            suml2 += w*l*l;
+        }
+        for (int itry = 0; itry < 5; ++itry) {
+            int n_changed = 0;
+            for (int i = 0; i < n; ++i) {
+                float w = x[i]*x[i];
+                float slx = sumlx - w*x[i]*L[i];
+                if (slx > 0) {
+                    float sl2 = suml2 - w*L[i]*L[i];
+                    int new_l = nearest_int(x[i] * sl2 / slx);
+                    new_l = MAX(-nmax, MIN(nmax-1, new_l));
+                    if (new_l != L[i]) {
+                        slx += w*x[i]*new_l;
+                        sl2 += w*new_l*new_l;
+                        if (sl2 > 0 && slx*slx*suml2 > sumlx*sumlx*sl2) {
+                            L[i] = new_l; sumlx = slx; suml2 = sl2;
+                            ++n_changed;
+                        }
+                    }
+                }
+            }
+            if (!n_changed) {
+                break;
+            }
+        }
+        for (int i = 0; i < n; ++i) {
+            L[i] += nmax;
+        }
+        return sumlx / suml2;
+    }
+    for (int i = 0; i < n; ++i) {
+        int l = nearest_int(iscale * x[i]);
+        l = MAX(-nmax, MIN(nmax-1, l));
+        L[i] = l + nmax;
+    }
+    return 1/iscale;
+}
+
+static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min,
+        int ntry, float alpha) {
+    float min = x[0];
+    float max = x[0];
+    for (int i = 1; i < n; ++i) {
+        if (x[i] < min) min = x[i];
+        if (x[i] > max) max = x[i];
+    }
+    if (max == min) {
+        for (int i = 0; i < n; ++i) L[i] = 0;
+        *the_min = 0;
+        return 0.f;
+    }
+    if (min > 0) min = 0;
+    float iscale = nmax/(max - min);
+    float scale = 1/iscale;
+    for (int itry = 0; itry < ntry; ++itry) {
+        float sumlx = 0; int suml2 = 0;
+        bool did_change = false;
+        for (int i = 0; i < n; ++i) {
+            int l = nearest_int(iscale*(x[i] - min));
+            l = MAX(0, MIN(nmax, l));
+            if (l != L[i]) {
+                L[i] = l;
+                did_change = true;
+            }
+            sumlx += (x[i] - min)*l;
+            suml2 += l*l;
+        }
+        scale = sumlx/suml2;
+        float sum = 0;
+        for (int i = 0; i < n; ++i) {
+            sum += x[i] - scale*L[i];
+        }
+        min = alpha*min + (1 - alpha)*sum/n;
+        if (min > 0) min = 0;
+        iscale = 1/scale;
+        if (!did_change) break;
+    }
+    *the_min = -min;
+    return scale;
+}
+
+static float make_qkx2_quants(int n, int nmax, const float * restrict x, const float * restrict weights,
+        uint8_t * restrict L, float * restrict the_min, uint8_t * restrict Laux,
+        float rmin, float rdelta, int nstep, bool use_mad) {
+    float min = x[0];
+    float max = x[0];
+    float sum_w = weights[0];
+    float sum_x = sum_w * x[0];
+    for (int i = 1; i < n; ++i) {
+        if (x[i] < min) min = x[i];
+        if (x[i] > max) max = x[i];
+        float w = weights[i];
+        sum_w += w;
+        sum_x += w * x[i];
+    }
+    if (min > 0) min = 0;
+    if (max == min) {
+        for (int i = 0; i < n; ++i) L[i] = 0;
+        *the_min = -min;
+        return 0.f;
+    }
+    float iscale = nmax/(max - min);
+    float scale = 1/iscale;
+    float best_mad = 0;
+    for (int i = 0; i < n; ++i) {
+        int l = nearest_int(iscale*(x[i] - min));
+        L[i] = MAX(0, MIN(nmax, l));
+        float diff = scale * L[i] + min - x[i];
+        diff = use_mad ? fabsf(diff) : diff * diff;
+        float w = weights[i];
+        best_mad += w * diff;
+    }
+    if (nstep < 1) {
+        *the_min = -min;
+        return scale;
+    }
+    for (int is = 0; is <= nstep; ++is) {
+        iscale = (rmin + rdelta*is + nmax)/(max - min);
+        float sum_l = 0, sum_l2 = 0, sum_xl = 0;
+        for (int i = 0; i < n; ++i) {
+            int l = nearest_int(iscale*(x[i] - min));
+            l = MAX(0, MIN(nmax, l));
+            Laux[i] = l;
+            float w = weights[i];
+            sum_l += w*l;
+            sum_l2 += w*l*l;
+            sum_xl += w*l*x[i];
+        }
+        float D = sum_w * sum_l2 - sum_l * sum_l;
+        if (D > 0) {
+            float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D;
+            float this_min   = (sum_l2 * sum_x - sum_l * sum_xl)/D;
+            if (this_min > 0) {
+                this_min = 0;
+                this_scale = sum_xl / sum_l2;
+            }
+            float mad = 0;
+            for (int i = 0; i < n; ++i) {
+                float diff = this_scale * Laux[i] + this_min - x[i];
+                diff = use_mad ? fabsf(diff) : diff * diff;
+                float w = weights[i];
+                mad += w * diff;
+            }
+            if (mad < best_mad) {
+                for (int i = 0; i < n; ++i) {
+                    L[i] = Laux[i];
+                }
+                best_mad = mad;
+                scale = this_scale;
+                min = this_min;
+            }
+        }
+    }
+    *the_min = -min;
+    return scale;
+}
+
+#if QK_K == 256
+static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
+    if (j < 4) {
+        *d = q[j] & 63; *m = q[j + 4] & 63;
+    } else {
+        *d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
+        *m = (q[j+4] >>  4) | ((q[j-0] >> 6) << 4);
+    }
+}
+#endif
+
+//========================- 2-bit (de)-quantization
+
+void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    uint8_t L[QK_K];
+    uint8_t Laux[16];
+    float   weights[16];
+    float mins[QK_K/16];
+    float scales[QK_K/16];
+
+    const float q4scale = 15.f;
+
+    for (int i = 0; i < nb; i++) {
+        float max_scale = 0; // as we are deducting the min, scales are always positive
+        float max_min = 0;
+        for (int j = 0; j < QK_K/16; ++j) {
+            for (int l = 0; l < 16; ++l) weights[l] = fabsf(x[16*j + l]);
+            scales[j] = make_qkx2_quants(16, 3, x + 16*j, weights, L + 16*j, &mins[j], Laux, -0.5f, 0.1f, 15, true);
+            float scale = scales[j];
+            if (scale > max_scale) {
+                max_scale = scale;
+            }
+            float min = mins[j];
+            if (min > max_min) {
+                max_min = min;
+            }
+        }
+
+        if (max_scale > 0) {
+            float iscale = q4scale/max_scale;
+            for (int j = 0; j < QK_K/16; ++j) {
+                int l = nearest_int(iscale*scales[j]);
+                y[i].scales[j] = l;
+            }
+            y[i].d = GGML_FP32_TO_FP16(max_scale/q4scale);
+        } else {
+            for (int j = 0; j < QK_K/16; ++j) y[i].scales[j] = 0;
+            y[i].d = GGML_FP32_TO_FP16(0.f);
+        }
+        if (max_min > 0) {
+            float iscale = q4scale/max_min;
+            for (int j = 0; j < QK_K/16; ++j) {
+                int l = nearest_int(iscale*mins[j]);
+                y[i].scales[j] |= (l << 4);
+            }
+            y[i].dmin = GGML_FP32_TO_FP16(max_min/q4scale);
+        } else {
+            y[i].dmin = GGML_FP32_TO_FP16(0.f);
+        }
+        for (int j = 0; j < QK_K/16; ++j) {
+            const float d = GGML_FP16_TO_FP32(y[i].d) * (y[i].scales[j] & 0xF);
+            if (!d) continue;
+            const float dm = GGML_FP16_TO_FP32(y[i].dmin) * (y[i].scales[j] >> 4);
+            for (int ii = 0; ii < 16; ++ii) {
+                int l = nearest_int((x[16*j + ii] + dm)/d);
+                l = MAX(0, MIN(3, l));
+                L[16*j + ii] = l;
+            }
+        }
+
+#if QK_K == 256
+        for (int j = 0; j < QK_K; j += 128) {
+            for (int l = 0; l < 32; ++l) {
+                y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
+            }
+        }
+#else
+        for (int l = 0; l < 16; ++l) {
+            y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6);
+        }
+#endif
+
+        x += QK_K;
+
+    }
+}
+
+void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    for (int i = 0; i < nb; i++) {
+
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+        const float min = GGML_FP16_TO_FP32(x[i].dmin);
+
+        const uint8_t * q = x[i].qs;
+
+#if QK_K == 256
+        int is = 0;
+        float dl, ml;
+        for (int n = 0; n < QK_K; n += 128) {
+            int shift = 0;
+            for (int j = 0; j < 4; ++j) {
+
+                uint8_t sc = x[i].scales[is++];
+                dl = d * (sc & 0xF); ml = min * (sc >> 4);
+                for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml;
+
+                sc = x[i].scales[is++];
+                dl = d * (sc & 0xF); ml = min * (sc >> 4);
+                for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml;
+
+                shift += 2;
+            }
+            q += 32;
+        }
+#else
+        float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4);
+        float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4);
+        float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4);
+        float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4);
+        for (int l = 0; l < 16; ++l) {
+            y[l+ 0] = dl1 * ((int8_t)((q[l] >> 0) & 3)) - ml1;
+            y[l+16] = dl2 * ((int8_t)((q[l] >> 2) & 3)) - ml2;
+            y[l+32] = dl3 * ((int8_t)((q[l] >> 4) & 3)) - ml3;
+            y[l+48] = dl4 * ((int8_t)((q[l] >> 6) & 3)) - ml4;
+        }
+        y += QK_K;
+#endif
+    }
+}
+
+void quantize_row_q2_K(const float * restrict x, void * restrict vy, int k) {
+    quantize_row_q2_K_reference(x, vy, k);
+}
+
+size_t ggml_quantize_q2_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
+    (void)hist; // TODO: collect histograms
+
+    for (int j = 0; j < n; j += k) {
+        block_q2_K * restrict y = (block_q2_K *)dst + j/QK_K;
+        quantize_row_q2_K_reference(src + j, y, k);
+    }
+    return (n/QK_K*sizeof(block_q2_K));
+}
+
+//========================= 3-bit (de)-quantization
+
+void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    int8_t L[QK_K];
+    float scales[QK_K / 16];
+
+    for (int i = 0; i < nb; i++) {
+
+        float max_scale = 0;
+        float amax = 0;
+        for (int j = 0; j < QK_K/16; ++j) {
+            scales[j] = make_q3_quants(16, 4, x + 16*j, L + 16*j, true);
+            float scale = fabsf(scales[j]);
+            if (scale > amax) {
+                amax = scale; max_scale = scales[j];
+            }
+        }
+
+#if QK_K == 256
+        memset(y[i].scales, 0, 12);
+        if (max_scale) {
+            float iscale = -32.f/max_scale;
+            for (int j = 0; j < QK_K/16; ++j) {
+                int8_t l = nearest_int(iscale*scales[j]);
+                l = MAX(-32, MIN(31, l)) + 32;
+                if (j < 8) {
+                    y[i].scales[j] = l & 0xF;
+                } else {
+                    y[i].scales[j-8] |= ((l & 0xF) << 4);
+                }
+                l >>= 4;
+                y[i].scales[j%4 + 8] |= (l << (2*(j/4)));
+            }
+            y[i].d = GGML_FP32_TO_FP16(1/iscale);
+        } else {
+            y[i].d = GGML_FP32_TO_FP16(0.f);
+        }
+
+        int8_t sc;
+        for (int j = 0; j < QK_K/16; ++j) {
+            sc = j < 8 ? y[i].scales[j] & 0xF : y[i].scales[j-8] >> 4;
+            sc = (sc | (((y[i].scales[8 + j%4] >> (2*(j/4))) & 3) << 4)) - 32;
+            float d = GGML_FP16_TO_FP32(y[i].d) * sc;
+            if (!d) {
+                continue;
+            }
+            for (int ii = 0; ii < 16; ++ii) {
+                int l = nearest_int(x[16*j + ii]/d);
+                l = MAX(-4, MIN(3, l));
+                L[16*j + ii] = l + 4;
+            }
+        }
+#else
+        if (max_scale) {
+            float iscale = -8.f/max_scale;
+            for (int j = 0; j < QK_K/16; j+=2) {
+                int l1 = nearest_int(iscale*scales[j]);
+                l1 = 8 + MAX(-8, MIN(7, l1));
+                int l2 = nearest_int(iscale*scales[j+1]);
+                l2 = 8 + MAX(-8, MIN(7, l2));
+                y[i].scales[j/2] = l1 | (l2 << 4);
+            }
+            y[i].d = GGML_FP32_TO_FP16(1/iscale);
+        } else {
+            for (int j = 0; j < QK_K/16; j+=2) {
+                y[i].scales[j/2] = 0;
+            }
+            y[i].d = GGML_FP32_TO_FP16(0.f);
+        }
+        for (int j = 0; j < QK_K/16; ++j) {
+            int s = j%2 == 0 ? y[i].scales[j/2] & 0xF : y[i].scales[j/2] >> 4;
+            float d = GGML_FP16_TO_FP32(y[i].d) * (s - 8);
+            if (!d) {
+                continue;
+            }
+            for (int ii = 0; ii < 16; ++ii) {
+                int l = nearest_int(x[16*j + ii]/d);
+                l = MAX(-4, MIN(3, l));
+                L[16*j + ii] = l + 4;
+            }
+        }
+#endif
+
+        memset(y[i].hmask, 0, QK_K/8);
+        // We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc.
+        int m = 0;
+        uint8_t hm = 1;
+        for (int j = 0; j < QK_K; ++j) {
+            if (L[j] > 3) {
+                y[i].hmask[m] |= hm;
+                L[j] -= 4;
+            }
+            if (++m == QK_K/8) {
+                m = 0; hm <<= 1;
+            }
+        }
+#if QK_K == 256
+        for (int j = 0; j < QK_K; j += 128) {
+            for (int l = 0; l < 32; ++l) {
+                y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
+            }
+        }
+#else
+        for (int l = 0; l < 16; ++l) {
+            y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6);
+        }
+#endif
+
+        x += QK_K;
+    }
+}
+
+#if QK_K == 256
+void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    const uint32_t kmask1 = 0x03030303;
+    const uint32_t kmask2 = 0x0f0f0f0f;
+
+    uint32_t aux[4];
+    const int8_t * scales = (const int8_t*)aux;
+
+    for (int i = 0; i < nb; i++) {
+
+        const float d_all = GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict q = x[i].qs;
+        const uint8_t * restrict hm = x[i].hmask;
+        uint8_t m = 1;
+
+        memcpy(aux, x[i].scales, 12);
+        uint32_t tmp = aux[2];
+        aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
+        aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
+        aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
+        aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
+
+        int is = 0;
+        float dl;
+        for (int n = 0; n < QK_K; n += 128) {
+            int shift = 0;
+            for (int j = 0; j < 4; ++j) {
+
+                dl = d_all * (scales[is++] - 32);
+                for (int l = 0; l < 16; ++l) {
+                    *y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4));
+                }
+
+                dl = d_all * (scales[is++] - 32);
+                for (int l = 0; l < 16; ++l) {
+                    *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4));
+                }
+
+                shift += 2;
+                m <<= 1;
+            }
+            q += 32;
+        }
+
+    }
+}
+#else
+void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) {
+    assert(k % QK_K == 0);
+    assert(QK_K == 64);
+    const int nb = k / QK_K;
+
+    for (int i = 0; i < nb; i++) {
+
+        const float d_all = GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict q = x[i].qs;
+        const uint8_t * restrict hm = x[i].hmask;
+
+        const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
+        const float d2 = d_all * ((x[i].scales[0] >>  4) - 8);
+        const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
+        const float d4 = d_all * ((x[i].scales[1] >>  4) - 8);
+
+        for (int l=0; l<8; ++l) {
+            uint8_t h = hm[l];
+            y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4));
+            y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4));
+            y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4));
+            y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4));
+            y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4));
+            y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4));
+            y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4));
+            y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4));
+        }
+        y += QK_K;
+    }
+}
+#endif
+
+void quantize_row_q3_K(const float * restrict x, void * restrict vy, int k) {
+    quantize_row_q3_K_reference(x, vy, k);
+}
+
+size_t ggml_quantize_q3_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
+    (void)hist; // TODO: collect histograms
+
+    for (int j = 0; j < n; j += k) {
+        block_q3_K * restrict y = (block_q3_K *)dst + j/QK_K;
+        quantize_row_q3_K_reference(src + j, y, k);
+    }
+    return (n/QK_K*sizeof(block_q3_K));
+}
+
+// ====================== 4-bit (de)-quantization
+
+void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    uint8_t L[QK_K];
+    uint8_t Laux[32];
+    float   weights[32];
+    float mins[QK_K/32];
+    float scales[QK_K/32];
+
+    for (int i = 0; i < nb; i++) {
+
+        float max_scale = 0; // as we are deducting the min, scales are always positive
+        float max_min = 0;
+        for (int j = 0; j < QK_K/32; ++j) {
+            //scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
+            float sum_x2 = 0;
+            for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l];
+            float av_x = sqrtf(sum_x2/32);
+            for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
+            scales[j] = make_qkx2_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -1.f, 0.1f, 20, false);
+            float scale = scales[j];
+            if (scale > max_scale) {
+                max_scale = scale;
+            }
+            float min = mins[j];
+            if (min > max_min) {
+                max_min = min;
+            }
+        }
+
+#if QK_K == 256
+        float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
+        float inv_min   = max_min   > 0 ? 63.f/max_min   : 0.f;
+        for (int j = 0; j < QK_K/32; ++j) {
+            uint8_t ls = nearest_int(inv_scale*scales[j]);
+            uint8_t lm = nearest_int(inv_min*mins[j]);
+            ls = MIN(63, ls);
+            lm = MIN(63, lm);
+            if (j < 4) {
+                y[i].scales[j] = ls;
+                y[i].scales[j+4] = lm;
+            } else {
+                y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
+                y[i].scales[j-4] |= ((ls >> 4) << 6);
+                y[i].scales[j-0] |= ((lm >> 4) << 6);
+            }
+        }
+        y[i].d = GGML_FP32_TO_FP16(max_scale/63.f);
+        y[i].dmin = GGML_FP32_TO_FP16(max_min/63.f);
+
+        uint8_t sc, m;
+        for (int j = 0; j < QK_K/32; ++j) {
+            get_scale_min_k4(j, y[i].scales, &sc, &m);
+            const float d = GGML_FP16_TO_FP32(y[i].d) * sc;
+            if (!d) continue;
+            const float dm = GGML_FP16_TO_FP32(y[i].dmin) * m;
+            for (int ii = 0; ii < 32; ++ii) {
+                int l = nearest_int((x[32*j + ii] + dm)/d);
+                l = MAX(0, MIN(15, l));
+                L[32*j + ii] = l;
+            }
+        }
+#else
+        const float s_factor = 15.f;
+        float inv_scale = max_scale > 0 ? s_factor/max_scale : 0.f;
+        float inv_min   = max_min   > 0 ? s_factor/max_min   : 0.f;
+        int d1 = nearest_int(inv_scale*scales[0]);
+        int m1 = nearest_int(inv_min*mins[0]);
+        int d2 = nearest_int(inv_scale*scales[1]);
+        int m2 = nearest_int(inv_min*mins[1]);
+        y[i].scales[0] = d1 | (m1 << 4);
+        y[i].scales[1] = d2 | (m2 << 4);
+        y[i].d[0] = GGML_FP32_TO_FP16(max_scale/s_factor);
+        y[i].d[1] = GGML_FP32_TO_FP16(max_min/s_factor);
+
+        float sumlx = 0;
+        int   suml2 = 0;
+        for (int j = 0; j < QK_K/32; ++j) {
+            const uint8_t sd = y[i].scales[j] & 0xF;
+            const uint8_t sm = y[i].scales[j] >>  4;
+            const float d = GGML_FP16_TO_FP32(y[i].d[0]) * sd;
+            if (!d) continue;
+            const float m = GGML_FP16_TO_FP32(y[i].d[1]) * sm;
+            for (int ii = 0; ii < 32; ++ii) {
+                int l = nearest_int((x[32*j + ii] + m)/d);
+                l = MAX(0, MIN(15, l));
+                L[32*j + ii] = l;
+                sumlx += (x[32*j + ii] + m)*l*sd;
+                suml2 += l*l*sd*sd;
+            }
+        }
+        if (suml2) {
+            y[i].d[0] = GGML_FP32_TO_FP16(sumlx/suml2);
+        }
+#endif
+        uint8_t * q = y[i].qs;
+        for (int j = 0; j < QK_K; j += 64) {
+            for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4);
+            q += 32;
+        }
+
+        x += QK_K;
+
+    }
+}
+
+void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    for (int i = 0; i < nb; i++) {
+
+        const uint8_t * q = x[i].qs;
+
+#if QK_K == 256
+
+        const float d   = GGML_FP16_TO_FP32(x[i].d);
+        const float min = GGML_FP16_TO_FP32(x[i].dmin);
+
+        int is = 0;
+        uint8_t sc, m;
+        for (int j = 0; j < QK_K; j += 64) {
+            get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
+            const float d1 = d * sc; const float m1 = min * m;
+            get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
+            const float d2 = d * sc; const float m2 = min * m;
+            for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
+            for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l]  >> 4) - m2;
+            q += 32; is += 2;
+        }
+#else
+        const float dall = GGML_FP16_TO_FP32(x[i].d[0]);
+        const float mall = GGML_FP16_TO_FP32(x[i].d[1]);
+        const float d1 = dall * (x[i].scales[0] & 0xF), m1 = mall * (x[i].scales[0] >> 4);
+        const float d2 = dall * (x[i].scales[1] & 0xF), m2 = mall * (x[i].scales[1] >> 4);
+        for (int l = 0; l < 32; ++l) {
+            y[l+ 0] = d1 * (q[l] & 0xF) - m1;
+            y[l+32] = d2 * (q[l] >>  4) - m2;
+        }
+        y += QK_K;
+#endif
+
+    }
+}
+
+void quantize_row_q4_K(const float * restrict x, void * restrict vy, int k) {
+    assert(k % QK_K == 0);
+    block_q4_K * restrict y = vy;
+    quantize_row_q4_K_reference(x, y, k);
+}
+
+size_t ggml_quantize_q4_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
+    assert(k % QK_K == 0);
+    (void)hist; // TODO: collect histograms
+
+    for (int j = 0; j < n; j += k) {
+        block_q4_K * restrict y = (block_q4_K *)dst + j/QK_K;
+        quantize_row_q4_K_reference(src + j, y, k);
+    }
+    return (n/QK_K*sizeof(block_q4_K));
+}
+
+// ====================== 5-bit (de)-quantization
+
+void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+#if QK_K == 256
+    uint8_t L[QK_K];
+    float mins[QK_K/32];
+    float scales[QK_K/32];
+    float weights[32];
+    uint8_t Laux[32];
+#else
+    int8_t L[QK_K];
+    float scales[QK_K/16];
+#endif
+
+    for (int i = 0; i < nb; i++) {
+
+#if QK_K == 256
+
+        float max_scale = 0; // as we are deducting the min, scales are always positive
+        float max_min = 0;
+        for (int j = 0; j < QK_K/32; ++j) {
+            //scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
+            float sum_x2 = 0;
+            for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l];
+            float av_x = sqrtf(sum_x2/32);
+            for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
+            scales[j] = make_qkx2_quants(32, 31, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.5f, 0.1f, 15, false);
+            float scale = scales[j];
+            if (scale > max_scale) {
+                max_scale = scale;
+            }
+            float min = mins[j];
+            if (min > max_min) {
+                max_min = min;
+            }
+        }
+
+        float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
+        float inv_min   = max_min   > 0 ? 63.f/max_min   : 0.f;
+        for (int j = 0; j < QK_K/32; ++j) {
+            uint8_t ls = nearest_int(inv_scale*scales[j]);
+            uint8_t lm = nearest_int(inv_min*mins[j]);
+            ls = MIN(63, ls);
+            lm = MIN(63, lm);
+            if (j < 4) {
+                y[i].scales[j] = ls;
+                y[i].scales[j+4] = lm;
+            } else {
+                y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
+                y[i].scales[j-4] |= ((ls >> 4) << 6);
+                y[i].scales[j-0] |= ((lm >> 4) << 6);
+            }
+        }
+        y[i].d = GGML_FP32_TO_FP16(max_scale/63.f);
+        y[i].dmin = GGML_FP32_TO_FP16(max_min/63.f);
+
+        uint8_t sc, m;
+        for (int j = 0; j < QK_K/32; ++j) {
+            get_scale_min_k4(j, y[i].scales, &sc, &m);
+            const float d = GGML_FP16_TO_FP32(y[i].d) * sc;
+            if (!d) continue;
+            const float dm = GGML_FP16_TO_FP32(y[i].dmin) * m;
+            for (int ii = 0; ii < 32; ++ii) {
+                int l = nearest_int((x[32*j + ii] + dm)/d);
+                l = MAX(0, MIN(31, l));
+                L[32*j + ii] = l;
+            }
+        }
+
+        uint8_t * restrict qh = y[i].qh;
+        uint8_t * restrict ql = y[i].qs;
+        memset(qh, 0, QK_K/8);
+
+        uint8_t m1 = 1, m2 = 2;
+        for (int n = 0; n < QK_K; n += 64) {
+            for (int j = 0; j < 32; ++j) {
+                int l1 = L[n + j];
+                if (l1 > 15) {
+                    l1 -= 16; qh[j] |= m1;
+                }
+                int l2 = L[n + j + 32];
+                if (l2 > 15) {
+                    l2 -= 16; qh[j] |= m2;
+                }
+                ql[j] = l1 | (l2 << 4);
+            }
+            m1 <<= 2; m2 <<= 2;
+            ql += 32;
+        }
+#else
+        float max_scale = 0, amax = 0;
+        for (int j = 0; j < QK_K/16; ++j) {
+            scales[j] = make_qx_quants(16, 16, x + 16*j, L + 16*j, 1);
+            float abs_scale = fabsf(scales[j]);
+            if (abs_scale > amax) {
+                amax = abs_scale;
+                max_scale = scales[j];
+            }
+        }
+
+        float iscale = -128.f/max_scale;
+        for (int j = 0; j < QK_K/16; ++j) {
+            int l = nearest_int(iscale*scales[j]);
+            y[i].scales[j] = MAX(-128, MIN(127, l));
+        }
+        y[i].d = GGML_FP32_TO_FP16(1/iscale);
+
+        for (int j = 0; j < QK_K/16; ++j) {
+            const float d = GGML_FP16_TO_FP32(y[i].d) * y[i].scales[j];
+            if (!d) continue;
+            for (int ii = 0; ii < 16; ++ii) {
+                int l = nearest_int(x[16*j + ii]/d);
+                l = MAX(-16, MIN(15, l));
+                L[16*j + ii] = l + 16;
+            }
+        }
+
+        uint8_t * restrict qh = y[i].qh;
+        uint8_t * restrict ql = y[i].qs;
+        memset(qh, 0, QK_K/8);
+
+        for (int j = 0; j < 32; ++j) {
+            int jm = j%8;
+            int is = j/8;
+            int l1 = L[j];
+            if (l1 > 15) {
+                l1 -= 16; qh[jm] |= (1 << is);
+            }
+            int l2 = L[j + 32];
+            if (l2 > 15) {
+                l2 -= 16; qh[jm] |= (1 << (4 + is));
+            }
+            ql[j] = l1 | (l2 << 4);
+        }
+#endif
+
+        x += QK_K;
+
+    }
+}
+
+void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    for (int i = 0; i < nb; i++) {
+
+        const uint8_t * ql = x[i].qs;
+        const uint8_t * qh = x[i].qh;
+
+#if QK_K == 256
+
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+        const float min = GGML_FP16_TO_FP32(x[i].dmin);
+
+        int is = 0;
+        uint8_t sc, m;
+        uint8_t u1 = 1, u2 = 2;
+        for (int j = 0; j < QK_K; j += 64) {
+            get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
+            const float d1 = d * sc; const float m1 = min * m;
+            get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
+            const float d2 = d * sc; const float m2 = min * m;
+            for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1;
+            for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l]  >> 4) + (qh[l] & u2 ? 16 : 0)) - m2;
+            ql += 32; is += 2;
+            u1 <<= 2; u2 <<= 2;
+        }
+#else
+        float d = GGML_FP16_TO_FP32(x[i].d);
+        const int8_t * restrict s = x[i].scales;
+        for (int l = 0; l < 8; ++l) {
+            y[l+ 0] = d * s[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16));
+            y[l+ 8] = d * s[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16));
+            y[l+16] = d * s[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16));
+            y[l+24] = d * s[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16));
+            y[l+32] = d * s[2] * ((ql[l+ 0] >>  4) - (qh[l] & 0x10 ? 0 : 16));
+            y[l+40] = d * s[2] * ((ql[l+ 8] >>  4) - (qh[l] & 0x20 ? 0 : 16));
+            y[l+48] = d * s[3] * ((ql[l+16] >>  4) - (qh[l] & 0x40 ? 0 : 16));
+            y[l+56] = d * s[3] * ((ql[l+24] >>  4) - (qh[l] & 0x80 ? 0 : 16));
+        }
+        y += QK_K;
+#endif
+    }
+}
+
+void quantize_row_q5_K(const float * restrict x, void * restrict vy, int k) {
+    assert(k % QK_K == 0);
+    block_q5_K * restrict y = vy;
+    quantize_row_q5_K_reference(x, y, k);
+}
+
+size_t ggml_quantize_q5_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
+    assert(k % QK_K == 0);
+    (void)hist; // TODO: collect histograms
+
+    for (int j = 0; j < n; j += k) {
+        block_q5_K * restrict y = (block_q5_K *)dst + j/QK_K;
+        quantize_row_q5_K_reference(src + j, y, k);
+    }
+    return (n/QK_K*sizeof(block_q5_K));
+}
+
+// ====================== 6-bit (de)-quantization
+
+void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    int8_t L[QK_K];
+    float   scales[QK_K/16];
+
+    for (int i = 0; i < nb; i++) {
+
+        float max_scale = 0;
+        float max_abs_scale = 0;
+
+        for (int ib = 0; ib < QK_K/16; ++ib) {
+
+            const float scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1);
+            scales[ib] = scale;
+
+            const float abs_scale = fabsf(scale);
+            if (abs_scale > max_abs_scale) {
+                max_abs_scale = abs_scale;
+                max_scale = scale;
+            }
+
+        }
+
+        if (!max_abs_scale) {
+            memset(&y[i], 0, sizeof(block_q6_K));
+            y[i].d = GGML_FP32_TO_FP16(0.f);
+            x += QK_K;
+            continue;
+        }
+
+        float iscale = -128.f/max_scale;
+        y[i].d = GGML_FP32_TO_FP16(1/iscale);
+        for (int ib = 0; ib < QK_K/16; ++ib) {
+            y[i].scales[ib] = MIN(127, nearest_int(iscale*scales[ib]));
+        }
+
+        for (int j = 0; j < QK_K/16; ++j) {
+            float d = GGML_FP16_TO_FP32(y[i].d) * y[i].scales[j];
+            if (!d) {
+                continue;
+            }
+            for (int ii = 0; ii < 16; ++ii) {
+                int l = nearest_int(x[16*j + ii]/d);
+                l = MAX(-32, MIN(31, l));
+                L[16*j + ii] = l + 32;
+            }
+        }
+
+        uint8_t * restrict ql = y[i].ql;
+        uint8_t * restrict qh = y[i].qh;
+#if QK_K == 256
+        for (int j = 0; j < QK_K; j += 128) {
+            for (int l = 0; l < 32; ++l) {
+                const uint8_t q1 = L[j + l +  0] & 0xF;
+                const uint8_t q2 = L[j + l + 32] & 0xF;
+                const uint8_t q3 = L[j + l + 64] & 0xF;
+                const uint8_t q4 = L[j + l + 96] & 0xF;
+                ql[l+ 0] = q1 | (q3 << 4);
+                ql[l+32] = q2 | (q4 << 4);
+                qh[l] = (L[j + l] >> 4) | ((L[j + l + 32] >> 4) << 2) | ((L[j + l + 64] >> 4) << 4) | ((L[j + l + 96] >> 4) << 6);
+            }
+            ql += 64;
+            qh += 32;
+        }
+#else
+        for (int l = 0; l < 32; ++l) {
+            const uint8_t q1 = L[l +  0] & 0xF;
+            const uint8_t q2 = L[l + 32] & 0xF;
+            ql[l] = q1 | (q2 << 4);
+        }
+        for (int l = 0; l < 16; ++l) {
+            qh[l] = (L[l] >> 4) | ((L[l + 16] >> 4) << 2) | ((L[l + 32] >> 4) << 4) | ((L[l + 48] >> 4) << 6);
+        }
+#endif
+
+        x += QK_K;
+
+    }
+}
+
+void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    for (int i = 0; i < nb; i++) {
+
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict ql = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        const int8_t  * restrict sc = x[i].scales;
+
+#if QK_K == 256
+        for (int n = 0; n < QK_K; n += 128) {
+            for (int l = 0; l < 32; ++l) {
+                int is = l/16;
+                const int8_t q1 = (int8_t)((ql[l +  0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
+                const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
+                const int8_t q3 = (int8_t)((ql[l +  0]  >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
+                const int8_t q4 = (int8_t)((ql[l + 32]  >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
+                y[l +  0] = d * sc[is + 0] * q1;
+                y[l + 32] = d * sc[is + 2] * q2;
+                y[l + 64] = d * sc[is + 4] * q3;
+                y[l + 96] = d * sc[is + 6] * q4;
+            }
+            y  += 128;
+            ql += 64;
+            qh += 32;
+            sc += 8;
+        }
+#else
+        for (int l = 0; l < 16; ++l) {
+            const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
+            const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
+            const int8_t q3 = (int8_t)((ql[l+ 0]  >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
+            const int8_t q4 = (int8_t)((ql[l+16]  >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
+            y[l+ 0] = d * sc[0] * q1;
+            y[l+16] = d * sc[1] * q2;
+            y[l+32] = d * sc[2] * q3;
+            y[l+48] = d * sc[3] * q4;
+        }
+        y  += 64;
+#endif
+
+    }
+}
+
+void quantize_row_q6_K(const float * restrict x, void * restrict vy, int k) {
+    assert(k % QK_K == 0);
+    block_q6_K * restrict y = vy;
+    quantize_row_q6_K_reference(x, y, k);
+}
+
+size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist) {
+    assert(k % QK_K == 0);
+    (void)hist; // TODO: collect histograms
+
+    for (int j = 0; j < n; j += k) {
+        block_q6_K * restrict y = (block_q6_K *)dst + j/QK_K;
+        quantize_row_q6_K_reference(src + j, y, k);
+    }
+    return (n/QK_K*sizeof(block_q6_K));
+}
+
+//===================================== Q8_K ==============================================
+
+void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    for (int i = 0; i < nb; i++) {
+
+        float max = 0;
+        float amax = 0;
+        for (int j = 0; j < QK_K; ++j) {
+            float ax = fabsf(x[j]);
+            if (ax > amax) {
+                amax = ax; max = x[j];
+            }
+        }
+        if (!amax) {
+            y[i].d = 0;
+            memset(y[i].qs, 0, QK_K);
+            x += QK_K;
+            continue;
+        }
+        const float iscale = -128.f/max;
+        for (int j = 0; j < QK_K; ++j) {
+            int v = nearest_int(iscale*x[j]);
+            y[i].qs[j] = MIN(127, v);
+        }
+        for (int j = 0; j < QK_K/16; ++j) {
+            int sum = 0;
+            for (int ii = 0; ii < 16; ++ii) {
+                sum += y[i].qs[j*16 + ii];
+            }
+            y[i].bsums[j] = sum;
+        }
+        y[i].d = 1/iscale;
+        x += QK_K;
+    }
+}
+
+void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    for (int i = 0; i < nb; i++) {
+        for (int j = 0; j < QK_K; ++j) {
+            *y++ = x[i].d * x[i].qs[j];
+        }
+    }
+}
+
+void quantize_row_q8_K(const float * restrict x, void * restrict y, int k) {
+    quantize_row_q8_K_reference(x, y, k);
+}
+
+//===================================== Dot ptoducts =================================
+
+//
+// Helper functions
+//
+#if __AVX__ || __AVX2__ || __AVX512F__
+
+// shuffles to pick the required scales in dot products
+static inline __m256i get_scale_shuffle_q3k(int i) {
+    static const uint8_t k_shuffle[128] = {
+         0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,     2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
+         4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,     6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
+         8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,    10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
+        12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,    14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
+    };
+    return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
+}
+static inline __m256i get_scale_shuffle_k4(int i) {
+    static const uint8_t k_shuffle[256] = {
+         0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
+         2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
+         4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,
+         6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
+         8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,
+        10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
+        12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,
+        14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15
+    };
+    return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
+}
+static inline __m128i get_scale_shuffle(int i) {
+    static const uint8_t k_shuffle[128] = {
+         0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
+         2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
+         4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5,
+         6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7,
+         8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9,
+        10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11,
+        12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13,
+        14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15
+    };
+    return _mm_loadu_si128((const __m128i*)k_shuffle + i);
+}
+#endif
+
+void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+
+    assert(n % qk == 0);
+
+    const block_q4_0 * restrict x = vx;
+    const block_q8_0 * restrict y = vy;
+
+#if defined(__ARM_NEON)
+    float32x4_t sumv0 = vdupq_n_f32(0.0f);
+    float32x4_t sumv1 = vdupq_n_f32(0.0f);
+
+    assert(nb % 2 == 0); // TODO: handle odd nb
+
+    for (int i = 0; i < nb; i += 2) {
+        const block_q4_0 * restrict x0 = &x[i + 0];
+        const block_q4_0 * restrict x1 = &x[i + 1];
+        const block_q8_0 * restrict y0 = &y[i + 0];
+        const block_q8_0 * restrict y1 = &y[i + 1];
+
+        const uint8x16_t m4b = vdupq_n_u8(0x0F);
+        const int8x16_t  s8b = vdupq_n_s8(0x8);
+
+        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
+        const uint8x16_t v0_1 = vld1q_u8(x1->qs);
+
+        // 4-bit -> 8-bit
+        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
+        const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+        const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
+        const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+        // sub 8
+        const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
+        const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
+        const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
+        const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
+
+        // load y
+        const int8x16_t v1_0l = vld1q_s8(y0->qs);
+        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
+        const int8x16_t v1_1l = vld1q_s8(y1->qs);
+        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+        // dot product into int32x4_t
+        const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
+        const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+#else
+        const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l));
+        const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
+        const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0h));
+        const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
+
+        const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1l));
+        const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l));
+        const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1h));
+        const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h));
+
+        const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
+        const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
+        const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
+        const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+#endif
+    }
+
+    *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
+#elif defined(__AVX2__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+
+    // Main loop
+    for (int i = 0; i < nb; ++i) {
+        /* Compute combined scale for the block */
+        const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
+
+        __m256i bx = bytes_from_nibbles_32(x[i].qs);
+
+        // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
+        const __m256i off = _mm256_set1_epi8( 8 );
+        bx = _mm256_sub_epi8( bx, off );
+
+        __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
+
+        const __m256 q = mul_sum_i8_pairs_float(bx, by);
+
+        /* Multiply q with scale and accumulate */
+        acc = _mm256_fmadd_ps( d, q, acc );
+    }
+
+    *s = hsum_float_8(acc);
+#elif defined(__AVX__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+
+    // Main loop
+    for (int i = 0; i < nb; ++i) {
+        // Compute combined scale for the block
+        const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
+
+        const __m128i lowMask = _mm_set1_epi8(0xF);
+        const __m128i off = _mm_set1_epi8(8);
+
+        const __m128i tmp = _mm_loadu_si128((const __m128i *)x[i].qs);
+
+        __m128i bx = _mm_and_si128(lowMask, tmp);
+        __m128i by = _mm_loadu_si128((const __m128i *)y[i].qs);
+        bx = _mm_sub_epi8(bx, off);
+        const __m128i i32_0 = mul_sum_i8_pairs(bx, by);
+
+        bx = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4));
+        by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
+        bx = _mm_sub_epi8(bx, off);
+        const __m128i i32_1 = mul_sum_i8_pairs(bx, by);
+
+        // Convert int32_t to float
+        __m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1));
+
+        // Apply the scale, and accumulate
+        acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
+    }
+
+    *s = hsum_float_8(acc);
+#elif defined(__SSSE3__)
+    // set constants
+    const __m128i lowMask = _mm_set1_epi8(0xF);
+    const __m128i off = _mm_set1_epi8(8);
+
+    // Initialize accumulator with zeros
+    __m128 acc_0 = _mm_setzero_ps();
+    __m128 acc_1 = _mm_setzero_ps();
+    __m128 acc_2 = _mm_setzero_ps();
+    __m128 acc_3 = _mm_setzero_ps();
+
+    // First round without accumulation
+    {
+        _mm_prefetch(&x[0] + sizeof(block_q4_0), _MM_HINT_T0);
+        _mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0);
+
+        // Compute combined scale for the block 0 and 1
+        const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[0].d) * GGML_FP16_TO_FP32(y[0].d) );
+
+        const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[0].qs);
+
+        __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
+        __m128i by_0 = _mm_loadu_si128((const __m128i *)y[0].qs);
+        bx_0 = _mm_sub_epi8(bx_0, off);
+        const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
+
+        __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
+        __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[0].qs + 16));
+        bx_1 = _mm_sub_epi8(bx_1, off);
+        const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
+
+        _mm_prefetch(&x[1] + sizeof(block_q4_0), _MM_HINT_T0);
+        _mm_prefetch(&y[1] + sizeof(block_q8_0), _MM_HINT_T0);
+
+        // Compute combined scale for the block 2 and 3
+        const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[1].d) * GGML_FP16_TO_FP32(y[1].d) );
+
+        const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[1].qs);
+
+        __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
+        __m128i by_2 = _mm_loadu_si128((const __m128i *)y[1].qs);
+        bx_2 = _mm_sub_epi8(bx_2, off);
+        const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
+
+        __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
+        __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[1].qs + 16));
+        bx_3 = _mm_sub_epi8(bx_3, off);
+        const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
+
+        // Convert int32_t to float
+        __m128 p0 = _mm_cvtepi32_ps(i32_0);
+        __m128 p1 = _mm_cvtepi32_ps(i32_1);
+        __m128 p2 = _mm_cvtepi32_ps(i32_2);
+        __m128 p3 = _mm_cvtepi32_ps(i32_3);
+
+        // Apply the scale
+        acc_0 = _mm_mul_ps( d_0_1, p0 );
+        acc_1 = _mm_mul_ps( d_0_1, p1 );
+        acc_2 = _mm_mul_ps( d_2_3, p2 );
+        acc_3 = _mm_mul_ps( d_2_3, p3 );
+    }
+
+    assert(nb % 2 == 0); // TODO: handle odd nb
+
+    // Main loop
+    for (int i = 2; i < nb; i+=2) {
+        _mm_prefetch(&x[i] + sizeof(block_q4_0), _MM_HINT_T0);
+        _mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0);
+
+        // Compute combined scale for the block 0 and 1
+        const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
+
+        const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[i].qs);
+
+        __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
+        __m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs);
+        bx_0 = _mm_sub_epi8(bx_0, off);
+        const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
+
+        __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
+        __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
+        bx_1 = _mm_sub_epi8(bx_1, off);
+        const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
+
+        _mm_prefetch(&x[i] + 2 * sizeof(block_q4_0), _MM_HINT_T0);
+        _mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
+
+        // Compute combined scale for the block 2 and 3
+        const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i + 1].d) * GGML_FP16_TO_FP32(y[i + 1].d) );
+
+        const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[i + 1].qs);
+
+        __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
+        __m128i by_2 = _mm_loadu_si128((const __m128i *)y[i + 1].qs);
+        bx_2 = _mm_sub_epi8(bx_2, off);
+        const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
+
+        __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
+        __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[i + 1].qs + 16));
+        bx_3 = _mm_sub_epi8(bx_3, off);
+        const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
+
+        // Convert int32_t to float
+        __m128 p0 = _mm_cvtepi32_ps(i32_0);
+        __m128 p1 = _mm_cvtepi32_ps(i32_1);
+        __m128 p2 = _mm_cvtepi32_ps(i32_2);
+        __m128 p3 = _mm_cvtepi32_ps(i32_3);
+
+        // Apply the scale
+        __m128 p0_d = _mm_mul_ps( d_0_1, p0 );
+        __m128 p1_d = _mm_mul_ps( d_0_1, p1 );
+        __m128 p2_d = _mm_mul_ps( d_2_3, p2 );
+        __m128 p3_d = _mm_mul_ps( d_2_3, p3 );
+
+        // Acummulate
+        acc_0 = _mm_add_ps(p0_d, acc_0);
+        acc_1 = _mm_add_ps(p1_d, acc_1);
+        acc_2 = _mm_add_ps(p2_d, acc_2);
+        acc_3 = _mm_add_ps(p3_d, acc_3);
+    }
+
+    *s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
+#elif defined(__riscv_v_intrinsic)
+    float sumf = 0.0;
+
+    size_t vl = __riscv_vsetvl_e8m1(qk/2);
+
+    for (int i = 0; i < nb; i++) {
+        // load elements
+        vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
+
+        vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
+        vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
+
+        // mask and store lower part of x, and then upper part
+        vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
+        vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
+
+        vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
+        vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
+
+        // subtract offset
+        vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 8, vl);
+        vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 8, vl);
+
+        vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
+        vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
+
+        vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
+
+        vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
+        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
+
+        int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
+
+        sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d);
+    }
+
+    *s = sumf;
+#else
+    // scalar
+    float sumf = 0.0;
+
+    for (int i = 0; i < nb; i++) {
+        int sumi = 0;
+
+        for (int j = 0; j < qk/2; ++j) {
+            const int v0 = (x[i].qs[j] & 0x0F) - 8;
+            const int v1 = (x[i].qs[j] >>   4) - 8;
+
+            sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
+        }
+
+        sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d);
+    }
+
+    *s = sumf;
+#endif
+}
+
+void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    const int qk = QK8_1;
+    const int nb = n / qk;
+
+    assert(n % qk == 0);
+
+    const block_q4_1 * restrict x = vx;
+    const block_q8_1 * restrict y = vy;
+
+    // TODO: add WASM SIMD
+#if defined(__ARM_NEON)
+    float32x4_t sumv0 = vdupq_n_f32(0.0f);
+    float32x4_t sumv1 = vdupq_n_f32(0.0f);
+
+    float summs = 0;
+
+    assert(nb % 2 == 0); // TODO: handle odd nb
+
+    for (int i = 0; i < nb; i += 2) {
+        const block_q4_1 * restrict x0 = &x[i + 0];
+        const block_q4_1 * restrict x1 = &x[i + 1];
+        const block_q8_1 * restrict y0 = &y[i + 0];
+        const block_q8_1 * restrict y1 = &y[i + 1];
+
+        summs += GGML_FP16_TO_FP32(x0->m) * y0->s + GGML_FP16_TO_FP32(x1->m) * y1->s;
+
+        const uint8x16_t m4b = vdupq_n_u8(0x0F);
+
+        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
+        const uint8x16_t v0_1 = vld1q_u8(x1->qs);
+
+        // 4-bit -> 8-bit
+        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
+        const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+        const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
+        const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+        // load y
+        const int8x16_t v1_0l = vld1q_s8(y0->qs);
+        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
+        const int8x16_t v1_1l = vld1q_s8(y1->qs);
+        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+        // dot product into int32x4_t
+        const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
+        const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*y0->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*y1->d);
+#else
+        const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0l));
+        const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0l));
+        const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0h));
+        const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0h));
+
+        const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1l));
+        const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1l));
+        const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1h));
+        const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1h));
+
+        const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
+        const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
+        const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
+        const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d);
+#endif
+    }
+
+    *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
+#elif defined(__AVX2__) || defined(__AVX__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+
+    float summs = 0;
+
+    // Main loop
+    for (int i = 0; i < nb; ++i) {
+        const float d0 = GGML_FP16_TO_FP32(x[i].d);
+        const float d1 = y[i].d;
+
+        summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
+
+        const __m256 d0v = _mm256_set1_ps( d0 );
+        const __m256 d1v = _mm256_set1_ps( d1 );
+
+        // Compute combined scales
+        const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
+
+        // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
+        const __m256i bx = bytes_from_nibbles_32(x[i].qs);
+        const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
+
+        const __m256 xy = mul_sum_us8_pairs_float(bx, by);
+
+        // Accumulate d0*d1*x*y
+#if defined(__AVX2__)
+        acc = _mm256_fmadd_ps( d0d1, xy, acc );
+#else
+        acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc );
+#endif
+    }
+
+    *s = hsum_float_8(acc) + summs;
+#elif defined(__riscv_v_intrinsic)
+    float sumf = 0.0;
+
+    size_t vl = __riscv_vsetvl_e8m1(qk/2);
+
+    for (int i = 0; i < nb; i++) {
+        // load elements
+        vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
+
+        vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
+        vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
+
+        // mask and store lower part of x, and then upper part
+        vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
+        vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
+
+        vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
+        vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
+
+        vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
+        vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
+
+        vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
+
+        vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
+        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
+
+        int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
+
+        sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
+    }
+
+    *s = sumf;
+#else
+    // scalar
+    float sumf = 0.0;
+
+    for (int i = 0; i < nb; i++) {
+        int sumi = 0;
+
+        for (int j = 0; j < qk/2; ++j) {
+            const int v0 = (x[i].qs[j] & 0x0F);
+            const int v1 = (x[i].qs[j] >>   4);
+
+            sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
+        }
+
+        sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
+    }
+
+    *s = sumf;
+#endif
+}
+
+void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+
+    assert(n % qk == 0);
+    assert(qk == QK5_0);
+
+    const block_q5_0 * restrict x = vx;
+    const block_q8_0 * restrict y = vy;
+
+#if defined(__ARM_NEON)
+    float32x4_t sumv0 = vdupq_n_f32(0.0f);
+    float32x4_t sumv1 = vdupq_n_f32(0.0f);
+
+    uint32_t qh0;
+    uint32_t qh1;
+
+    uint64_t tmp0[4];
+    uint64_t tmp1[4];
+
+    assert(nb % 2 == 0); // TODO: handle odd nb
+
+    for (int i = 0; i < nb; i += 2) {
+        const block_q5_0 * restrict x0 = &x[i];
+        const block_q5_0 * restrict x1 = &x[i + 1];
+        const block_q8_0 * restrict y0 = &y[i];
+        const block_q8_0 * restrict y1 = &y[i + 1];
+
+        const uint8x16_t m4b = vdupq_n_u8(0x0F);
+
+        // extract the 5th bit via lookup table ((!b) << 4)
+        memcpy(&qh0, x0->qh, sizeof(qh0));
+        memcpy(&qh1, x1->qh, sizeof(qh1));
+
+        tmp0[0] = table_b2b_1[(qh0 >>  0) & 0xFF];
+        tmp0[1] = table_b2b_1[(qh0 >>  8) & 0xFF];
+        tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF];
+        tmp0[3] = table_b2b_1[(qh0 >> 24)       ];
+
+        tmp1[0] = table_b2b_1[(qh1 >>  0) & 0xFF];
+        tmp1[1] = table_b2b_1[(qh1 >>  8) & 0xFF];
+        tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF];
+        tmp1[3] = table_b2b_1[(qh1 >> 24)       ];
+
+        const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0));
+        const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2));
+        const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0));
+        const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2));
+
+        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
+        const uint8x16_t v0_1 = vld1q_u8(x1->qs);
+
+        // 4-bit -> 8-bit
+        int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
+        int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+        int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
+        int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+        // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero)
+        const int8x16_t v0_0lf = vsubq_s8(v0_0l, qhl0);
+        const int8x16_t v0_0hf = vsubq_s8(v0_0h, qhh0);
+        const int8x16_t v0_1lf = vsubq_s8(v0_1l, qhl1);
+        const int8x16_t v0_1hf = vsubq_s8(v0_1h, qhh1);
+
+        // load y
+        const int8x16_t v1_0l = vld1q_s8(y0->qs);
+        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
+        const int8x16_t v1_1l = vld1q_s8(y1->qs);
+        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
+                        vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
+                        vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
+                        vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
+                        vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+#else
+        const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
+        const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
+        const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h));
+        const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h));
+
+        const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l));
+        const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l));
+        const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h));
+        const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h));
+
+        const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
+        const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
+        const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
+        const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+#endif
+    }
+
+    *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
+#elif defined(__wasm_simd128__)
+    v128_t sumv = wasm_f32x4_splat(0.0f);
+
+    uint32_t qh;
+    uint64_t tmp[4];
+
+    // TODO: check if unrolling this is better
+    for (int i = 0; i < nb; ++i) {
+        const block_q5_0 * restrict x0 = &x[i];
+        const block_q8_0 * restrict y0 = &y[i];
+
+        const v128_t m4b  = wasm_i8x16_splat(0x0F);
+
+        // extract the 5th bit
+        memcpy(&qh, x0->qh, sizeof(qh));
+
+        tmp[0] = table_b2b_1[(qh >>  0) & 0xFF];
+        tmp[1] = table_b2b_1[(qh >>  8) & 0xFF];
+        tmp[2] = table_b2b_1[(qh >> 16) & 0xFF];
+        tmp[3] = table_b2b_1[(qh >> 24)       ];
+
+        const v128_t qhl = wasm_v128_load(tmp + 0);
+        const v128_t qhh = wasm_v128_load(tmp + 2);
+
+        const v128_t v0 = wasm_v128_load(x0->qs);
+
+        // 4-bit -> 8-bit
+        const v128_t v0l = wasm_v128_and (v0, m4b);
+        const v128_t v0h = wasm_u8x16_shr(v0, 4);
+
+        // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero)
+        const v128_t v0lf = wasm_i8x16_sub(v0l, qhl);
+        const v128_t v0hf = wasm_i8x16_sub(v0h, qhh);
+
+        // load y
+        const v128_t v1l = wasm_v128_load(y0->qs);
+        const v128_t v1h = wasm_v128_load(y0->qs + 16);
+
+        // int8x16 -> int16x8
+        const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
+        const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
+        const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
+        const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
+
+        const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
+        const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
+        const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
+        const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
+
+        // dot product
+        sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(
+                        wasm_i32x4_add(
+                            wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
+                                           wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
+                            wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
+                                           wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
+                    wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d))));
+    }
+
+    *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
+         wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
+#elif defined(__AVX2__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+
+    // Main loop
+    for (int i = 0; i < nb; i++) {
+        /* Compute combined scale for the block */
+        const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
+
+        __m256i bx = bytes_from_nibbles_32(x[i].qs);
+        __m256i bxhi = bytes_from_bits_32(x[i].qh);
+        bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0));
+        bx = _mm256_or_si256(bx, bxhi);
+
+        __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
+
+        const __m256 q = mul_sum_i8_pairs_float(bx, by);
+
+        /* Multiply q with scale and accumulate */
+        acc = _mm256_fmadd_ps(d, q, acc);
+    }
+
+    *s = hsum_float_8(acc);
+#elif defined(__AVX__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+    __m128i mask = _mm_set1_epi8((char)0xF0);
+
+    // Main loop
+    for (int i = 0; i < nb; i++) {
+        /* Compute combined scale for the block */
+        const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
+
+        __m256i bx = bytes_from_nibbles_32(x[i].qs);
+        const __m256i bxhi = bytes_from_bits_32(x[i].qh);
+        __m128i bxhil = _mm256_castsi256_si128(bxhi);
+        __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
+        bxhil = _mm_andnot_si128(bxhil, mask);
+        bxhih = _mm_andnot_si128(bxhih, mask);
+        __m128i bxl = _mm256_castsi256_si128(bx);
+        __m128i bxh = _mm256_extractf128_si256(bx, 1);
+        bxl = _mm_or_si128(bxl, bxhil);
+        bxh = _mm_or_si128(bxh, bxhih);
+        bx = MM256_SET_M128I(bxh, bxl);
+
+        const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
+
+        const __m256 q = mul_sum_i8_pairs_float(bx, by);
+
+        /* Multiply q with scale and accumulate */
+        acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc);
+    }
+
+    *s = hsum_float_8(acc);
+#elif defined(__riscv_v_intrinsic)
+    float sumf = 0.0;
+
+    uint32_t qh;
+
+    size_t vl = __riscv_vsetvl_e8m1(qk/2);
+
+    // These tempory registers are for masking and shift operations
+    vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
+    vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl);
+
+    vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl);
+    vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
+
+    for (int i = 0; i < nb; i++) {
+        memcpy(&qh, x[i].qh, sizeof(uint32_t));
+
+        // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
+        vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(vt_2, qh, vl);
+        vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(xha_0, vt_1, vl);
+        vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
+
+        // ((qh & (1u << (j + 16))) >> (j + 12));
+        vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(vt_3, qh, vl);
+        vuint32m2_t xhl_1 = __riscv_vsrl_vv_u32m2(xha_1, vt_4, vl);
+
+        // narrowing
+        vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xhl_0, vl);
+        vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
+
+        vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xhl_1, vl);
+        vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
+
+        // load
+        vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
+
+        vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
+        vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
+
+        vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
+        vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
+
+        vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
+        vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
+
+        vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
+        vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
+
+        vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 16, vl);
+        vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 16, vl);
+
+        vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
+        vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
+
+        vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
+
+        vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
+        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
+
+        int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
+
+        sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi;
+    }
+
+    *s = sumf;
+#else
+    // scalar
+    float sumf = 0.0;
+
+    for (int i = 0; i < nb; i++) {
+        uint32_t qh;
+        memcpy(&qh, x[i].qh, sizeof(qh));
+
+        int sumi = 0;
+
+        for (int j = 0; j < qk/2; ++j) {
+            const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
+            const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
+
+            const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16;
+            const int32_t x1 = ((x[i].qs[j] >>   4) | xh_1) - 16;
+
+            sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]);
+        }
+
+        sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi;
+    }
+
+    *s = sumf;
+#endif
+}
+
+void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    const int qk = QK8_1;
+    const int nb = n / qk;
+
+    assert(n % qk == 0);
+    assert(qk == QK5_1);
+
+    const block_q5_1 * restrict x = vx;
+    const block_q8_1 * restrict y = vy;
+
+#if defined(__ARM_NEON)
+    float32x4_t sumv0 = vdupq_n_f32(0.0f);
+    float32x4_t sumv1 = vdupq_n_f32(0.0f);
+
+    float summs0 = 0.0f;
+    float summs1 = 0.0f;
+
+    uint32_t qh0;
+    uint32_t qh1;
+
+    uint64_t tmp0[4];
+    uint64_t tmp1[4];
+
+    assert(nb % 2 == 0); // TODO: handle odd nb
+
+    for (int i = 0; i < nb; i += 2) {
+        const block_q5_1 * restrict x0 = &x[i];
+        const block_q5_1 * restrict x1 = &x[i + 1];
+        const block_q8_1 * restrict y0 = &y[i];
+        const block_q8_1 * restrict y1 = &y[i + 1];
+
+        const uint8x16_t m4b = vdupq_n_u8(0x0F);
+
+        summs0 += GGML_FP16_TO_FP32(x0->m) * y0->s;
+        summs1 += GGML_FP16_TO_FP32(x1->m) * y1->s;
+
+        // extract the 5th bit via lookup table ((b) << 4)
+        memcpy(&qh0, x0->qh, sizeof(qh0));
+        memcpy(&qh1, x1->qh, sizeof(qh1));
+
+        tmp0[0] = table_b2b_0[(qh0 >>  0) & 0xFF];
+        tmp0[1] = table_b2b_0[(qh0 >>  8) & 0xFF];
+        tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF];
+        tmp0[3] = table_b2b_0[(qh0 >> 24)       ];
+
+        tmp1[0] = table_b2b_0[(qh1 >>  0) & 0xFF];
+        tmp1[1] = table_b2b_0[(qh1 >>  8) & 0xFF];
+        tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF];
+        tmp1[3] = table_b2b_0[(qh1 >> 24)       ];
+
+        const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0));
+        const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2));
+        const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0));
+        const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2));
+
+        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
+        const uint8x16_t v0_1 = vld1q_u8(x1->qs);
+
+        // 4-bit -> 8-bit
+        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
+        const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+        const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
+        const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+        // add high bit
+        const int8x16_t v0_0lf = vorrq_s8(v0_0l, qhl0);
+        const int8x16_t v0_0hf = vorrq_s8(v0_0h, qhh0);
+        const int8x16_t v0_1lf = vorrq_s8(v0_1l, qhl1);
+        const int8x16_t v0_1hf = vorrq_s8(v0_1h, qhh1);
+
+        // load y
+        const int8x16_t v1_0l = vld1q_s8(y0->qs);
+        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
+        const int8x16_t v1_1l = vld1q_s8(y1->qs);
+        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
+                        vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
+                        vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*y0->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
+                        vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
+                        vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*y1->d);
+#else
+        const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
+        const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
+        const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h));
+        const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h));
+
+        const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l));
+        const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l));
+        const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h));
+        const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h));
+
+        const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
+        const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
+        const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
+        const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d);
+#endif
+    }
+
+    *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
+#elif defined(__wasm_simd128__)
+    v128_t sumv = wasm_f32x4_splat(0.0f);
+
+    float summs = 0.0f;
+
+    uint32_t qh;
+    uint64_t tmp[4];
+
+    // TODO: check if unrolling this is better
+    for (int i = 0; i < nb; ++i) {
+        const block_q5_1 * restrict x0 = &x[i];
+        const block_q8_1 * restrict y0 = &y[i];
+
+        summs += GGML_FP16_TO_FP32(x0->m) * y0->s;
+
+        const v128_t m4b = wasm_i8x16_splat(0x0F);
+
+        // extract the 5th bit
+        memcpy(&qh, x0->qh, sizeof(qh));
+
+        tmp[0] = table_b2b_0[(qh >>  0) & 0xFF];
+        tmp[1] = table_b2b_0[(qh >>  8) & 0xFF];
+        tmp[2] = table_b2b_0[(qh >> 16) & 0xFF];
+        tmp[3] = table_b2b_0[(qh >> 24)       ];
+
+        const v128_t qhl = wasm_v128_load(tmp + 0);
+        const v128_t qhh = wasm_v128_load(tmp + 2);
+
+        const v128_t v0 = wasm_v128_load(x0->qs);
+
+        // 4-bit -> 8-bit
+        const v128_t v0l = wasm_v128_and (v0, m4b);
+        const v128_t v0h = wasm_u8x16_shr(v0, 4);
+
+        // add high bit
+        const v128_t v0lf = wasm_v128_or(v0l, qhl);
+        const v128_t v0hf = wasm_v128_or(v0h, qhh);
+
+        // load y
+        const v128_t v1l = wasm_v128_load(y0->qs);
+        const v128_t v1h = wasm_v128_load(y0->qs + 16);
+
+        // int8x16 -> int16x8
+        const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
+        const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
+        const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
+        const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
+
+        const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
+        const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
+        const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
+        const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
+
+        // dot product
+        sumv = wasm_f32x4_add(sumv,
+                wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add(
+                            wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
+                                           wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
+                            wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
+                                           wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
+                    wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * y0->d)));
+    }
+
+    *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
+         wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs;
+#elif defined(__AVX2__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+
+    float summs = 0.0f;
+
+    // Main loop
+    for (int i = 0; i < nb; i++) {
+        const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
+
+        summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
+
+        __m256i bx = bytes_from_nibbles_32(x[i].qs);
+        __m256i bxhi = bytes_from_bits_32(x[i].qh);
+        bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10));
+        bx = _mm256_or_si256(bx, bxhi);
+
+        const __m256 dy = _mm256_set1_ps(y[i].d);
+        const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
+
+        const __m256 q = mul_sum_us8_pairs_float(bx, by);
+
+        acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
+    }
+
+    *s = hsum_float_8(acc) + summs;
+#elif defined(__AVX__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+    __m128i mask = _mm_set1_epi8(0x10);
+
+    float summs = 0.0f;
+
+    // Main loop
+    for (int i = 0; i < nb; i++) {
+        const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
+
+        summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
+
+        __m256i bx = bytes_from_nibbles_32(x[i].qs);
+        const __m256i bxhi = bytes_from_bits_32(x[i].qh);
+        __m128i bxhil = _mm256_castsi256_si128(bxhi);
+        __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
+        bxhil = _mm_and_si128(bxhil, mask);
+        bxhih = _mm_and_si128(bxhih, mask);
+        __m128i bxl = _mm256_castsi256_si128(bx);
+        __m128i bxh = _mm256_extractf128_si256(bx, 1);
+        bxl = _mm_or_si128(bxl, bxhil);
+        bxh = _mm_or_si128(bxh, bxhih);
+        bx = MM256_SET_M128I(bxh, bxl);
+
+        const __m256 dy = _mm256_set1_ps(y[i].d);
+        const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
+
+        const __m256 q = mul_sum_us8_pairs_float(bx, by);
+
+        acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);
+    }
+
+    *s = hsum_float_8(acc) + summs;
+#elif defined(__riscv_v_intrinsic)
+    float sumf = 0.0;
+
+    uint32_t qh;
+
+    size_t vl = __riscv_vsetvl_e8m1(qk/2);
+
+    // temporary registers for shift operations
+    vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
+    vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
+
+    for (int i = 0; i < nb; i++) {
+        memcpy(&qh, x[i].qh, sizeof(uint32_t));
+
+        // load qh
+        vuint32m2_t vqh = __riscv_vmv_v_x_u32m2(qh, vl);
+
+        // ((qh >> (j +  0)) << 4) & 0x10;
+        vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(vqh, vt_1, vl);
+        vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
+        vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(xhl_0, 0x10, vl);
+
+        // ((qh >> (j + 12))     ) & 0x10;
+        vuint32m2_t xhr_1 = __riscv_vsrl_vv_u32m2(vqh, vt_2, vl);
+        vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(xhr_1, 0x10, vl);
+
+        // narrowing
+        vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xha_0, vl);
+        vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
+
+        vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xha_1, vl);
+        vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
+
+        // load
+        vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
+
+        vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
+        vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
+
+        vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
+        vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
+
+        vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
+        vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
+
+        vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
+        vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
+
+        vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
+        vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
+
+        vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
+
+        vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
+        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
+
+        int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
+
+        sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
+    }
+
+    *s = sumf;
+#else
+    // scalar
+    float sumf = 0.0;
+
+    for (int i = 0; i < nb; i++) {
+        uint32_t qh;
+        memcpy(&qh, x[i].qh, sizeof(qh));
+
+        int sumi = 0;
+
+        for (int j = 0; j < qk/2; ++j) {
+            const uint8_t xh_0 = ((qh >> (j +  0)) << 4) & 0x10;
+            const uint8_t xh_1 = ((qh >> (j + 12))     ) & 0x10;
+
+            const int32_t x0 = (x[i].qs[j] & 0xF) | xh_0;
+            const int32_t x1 = (x[i].qs[j] >>  4) | xh_1;
+
+            sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]);
+        }
+
+        sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
+    }
+
+    *s = sumf;
+#endif
+}
+
+void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+
+    assert(n % qk == 0);
+
+    const block_q8_0 * restrict x = vx;
+    const block_q8_0 * restrict y = vy;
+
+#if defined(__ARM_NEON)
+    float32x4_t sumv0 = vdupq_n_f32(0.0f);
+    float32x4_t sumv1 = vdupq_n_f32(0.0f);
+
+    assert(nb % 2 == 0); // TODO: handle odd nb
+
+    for (int i = 0; i < nb; i += 2) {
+        const block_q8_0 * restrict x0 = &x[i + 0];
+        const block_q8_0 * restrict x1 = &x[i + 1];
+        const block_q8_0 * restrict y0 = &y[i + 0];
+        const block_q8_0 * restrict y1 = &y[i + 1];
+
+        const int8x16_t x0_0 = vld1q_s8(x0->qs);
+        const int8x16_t x0_1 = vld1q_s8(x0->qs + 16);
+        const int8x16_t x1_0 = vld1q_s8(x1->qs);
+        const int8x16_t x1_1 = vld1q_s8(x1->qs + 16);
+
+        // load y
+        const int8x16_t y0_0 = vld1q_s8(y0->qs);
+        const int8x16_t y0_1 = vld1q_s8(y0->qs + 16);
+        const int8x16_t y1_0 = vld1q_s8(y1->qs);
+        const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
+                        vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
+                        vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
+                        vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
+                        vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+
+#else
+        const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0));
+        const int16x8_t p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
+        const int16x8_t p0_2 = vmull_s8(vget_low_s8 (x0_1), vget_low_s8 (y0_1));
+        const int16x8_t p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
+
+        const int16x8_t p1_0 = vmull_s8(vget_low_s8 (x1_0), vget_low_s8 (y1_0));
+        const int16x8_t p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
+        const int16x8_t p1_2 = vmull_s8(vget_low_s8 (x1_1), vget_low_s8 (y1_1));
+        const int16x8_t p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
+
+        const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
+        const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
+        const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
+        const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+#endif
+    }
+
+    *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
+#elif defined(__AVX2__) || defined(__AVX__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+
+    // Main loop
+    for (int i = 0; i < nb; ++i) {
+        // Compute combined scale for the block
+        const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
+        __m256i bx = _mm256_loadu_si256((const __m256i *)x[i].qs);
+        __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
+
+        const __m256 q = mul_sum_i8_pairs_float(bx, by);
+
+        // Multiply q with scale and accumulate
+#if defined(__AVX2__)
+        acc = _mm256_fmadd_ps( d, q, acc );
+#else
+        acc = _mm256_add_ps( _mm256_mul_ps( d, q ), acc );
+#endif
+    }
+
+    *s = hsum_float_8(acc);
+#elif defined(__riscv_v_intrinsic)
+    float sumf = 0.0;
+    size_t vl = __riscv_vsetvl_e8m1(qk);
+
+    for (int i = 0; i < nb; i++) {
+        // load elements
+        vint8m1_t bx = __riscv_vle8_v_i8m1(x[i].qs, vl);
+        vint8m1_t by = __riscv_vle8_v_i8m1(y[i].qs, vl);
+
+        vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx, by, vl);
+
+        vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl);
+        vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl);
+
+        int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum);
+
+        sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d));
+    }
+
+    *s = sumf;
+#else
+    // scalar
+    float sumf = 0.0;
+
+    for (int i = 0; i < nb; i++) {
+        int sumi = 0;
+
+        for (int j = 0; j < qk; j++) {
+            sumi += x[i].qs[j]*y[i].qs[j];
+        }
+
+        sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d));
+    }
+
+    *s = sumf;
+#endif
+}
+
+#if QK_K == 256
+void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+
+    const block_q2_K * restrict x = vx;
+    const block_q8_K * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#ifdef __ARM_NEON
+
+    const uint8x16_t m3 = vdupq_n_u8(0x3);
+    const uint8x16_t m4 = vdupq_n_u8(0xF);
+#if defined(__ARM_FEATURE_DOTPROD)
+    const int32x4_t  vzero = vdupq_n_s32(0);
+#endif
+
+    int8x16x2_t q2bytes;
+    uint8_t aux[16];
+
+    float sum = 0;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        const uint8_t * restrict q2 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * restrict sc = x[i].scales;
+
+        const uint8x16_t mins_and_scales = vld1q_u8(sc);
+        const uint8x16_t scales = vandq_u8(mins_and_scales, m4);
+        vst1q_u8(aux, scales);
+
+        const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4);
+        const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums);
+        const int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))};
+        const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])),
+                                       vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0])));
+        const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])),
+                                       vmull_s16(vget_high_s16(mins16.val[1]), vget_high_s16(q8sums.val[1])));
+        sum += dmin * vaddvq_s32(vaddq_s32(s0, s1));
+
+        int isum = 0;
+        int is = 0;
+
+// We use this macro instead of a function call because for some reason
+// the code runs 2-3% slower, even if the function is declared inline
+#if defined(__ARM_FEATURE_DOTPROD)
+#define MULTIPLY_ACCUM_WITH_SCALE(index)\
+        isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\
+        isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)];
+#else
+#define MULTIPLY_ACCUM_WITH_SCALE(index)\
+        {\
+    const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])),\
+                                   vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));\
+    const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),\
+                                   vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));\
+    isum += vaddvq_s16(p1) * aux[is+(index)] + vaddvq_s16(p2) * aux[is+1+(index)];\
+        }
+#endif
+
+#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\
+        q8bytes = vld1q_s8_x2(q8); q8 += 32;\
+        q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\
+        q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\
+        MULTIPLY_ACCUM_WITH_SCALE((index));
+
+
+        for (int j = 0; j < QK_K/128; ++j) {
+
+            const uint8x16x2_t q2bits = vld1q_u8_x2(q2); q2 += 32;
+
+            int8x16x2_t q8bytes = vld1q_s8_x2(q8); q8 += 32;
+            q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3));
+            q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3));
+            MULTIPLY_ACCUM_WITH_SCALE(0);
+
+            SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2);
+
+            SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4);
+
+            SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6);
+
+            is += 8;
+        }
+        sum += d * isum;
+
+    }
+
+    *s = sum;
+
+#elif defined __AVX2__
+
+    const __m256i m3 = _mm256_set1_epi8(3);
+    const __m128i m4 = _mm_set1_epi8(0xF);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        const uint8_t * restrict q2 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);
+        const __m128i scales8 = _mm_and_si128(mins_and_scales, m4);
+        const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
+        const __m256i mins = _mm256_cvtepi8_epi16(mins8);
+        const __m256i prod = _mm256_madd_epi16(mins, _mm256_loadu_si256((const __m256i*)y[i].bsums));
+
+        acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(prod), acc);
+
+        const __m256i all_scales = _mm256_cvtepi8_epi16(scales8);
+        const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
+        const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
+        const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)};
+
+        __m256i sumi = _mm256_setzero_si256();
+
+        for (int j = 0; j < QK_K/128; ++j) {
+
+            const __m256i q2bits = _mm256_loadu_si256((const __m256i*)q2); q2 += 32;
+
+            const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+
+            const __m256i q2_0 = _mm256_and_si256(q2bits, m3);
+            const __m256i q2_1 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), m3);
+            const __m256i q2_2 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), m3);
+            const __m256i q2_3 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), m3);
+
+            __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0);
+            __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1);
+            __m256i p2 = _mm256_maddubs_epi16(q2_2, q8_2);
+            __m256i p3 = _mm256_maddubs_epi16(q2_3, q8_3);
+
+            p0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(0)), p0);
+            p1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(1)), p1);
+            p2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(2)), p2);
+            p3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(3)), p3);
+
+            p0 = _mm256_add_epi32(p0, p1);
+            p2 = _mm256_add_epi32(p2, p3);
+
+            sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p0, p2));
+        }
+
+        acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
+
+    }
+
+    *s = hsum_float_8(acc);
+
+#elif defined __AVX__
+
+    const __m128i m3 = _mm_set1_epi8(0x3);
+    const __m128i m4 = _mm_set1_epi8(0xF);
+    const __m128i m2 = _mm_set1_epi8(0x2);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        const uint8_t * restrict q2 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        // load mins and scales from block_q2_K.scales[QK_K/16]
+        const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);
+        const __m128i scales16 = _mm_and_si128(mins_and_scales, m4);
+        const __m128i mins16 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
+        const __m128i mins_0 = _mm_cvtepi8_epi16(mins16);
+        const __m128i mins_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(mins16, mins16));
+
+        // summs = y[i].bsums * (x[i].scales >> 4) in 16bits*8*2 to 32bits*4*2
+        const __m128i summs_0 = _mm_madd_epi16(mins_0, _mm_loadu_si128((const __m128i*)&y[i].bsums[0]));
+        const __m128i summs_1 = _mm_madd_epi16(mins_1, _mm_loadu_si128((const __m128i*)&y[i].bsums[8]));
+
+        // sumf += -dmin * summs in 32bits*8
+        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(MM256_SET_M128I(summs_1, summs_0))), acc);
+
+        const __m128i scales_0 = _mm_cvtepi8_epi16(scales16);
+        const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales16, scales16));
+        const __m128i scales[2] = { scales_0, scales_1 };
+
+        __m128i sumi_0 = _mm_setzero_si128();
+        __m128i sumi_1 = _mm_setzero_si128();
+
+        for (int j = 0; j < QK_K/128; ++j) {
+
+            // load Q8 quants int8*16*8 from block_q8_K.qs[QK_K]
+            const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+
+            // load 2bits*16*8 from block_q2_K.qs[QK_K/4]
+            __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16;
+            const __m128i q2_0 = _mm_and_si128(q2bits, m3);
+            const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3);
+            const __m128i q2_4 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3);
+            const __m128i q2_6 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3);
+            q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16;
+            const __m128i q2_1 = _mm_and_si128(q2bits, m3);
+            const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3);
+            const __m128i q2_5 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3);
+            const __m128i q2_7 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3);
+
+            // isuml = q8[l] * ((q2[l] >> shift) & 3) in 8bits*16*8 to 16bits*8*8
+            __m128i p0 = _mm_maddubs_epi16(q2_0, q8_0);
+            __m128i p1 = _mm_maddubs_epi16(q2_1, q8_1);
+            __m128i p2 = _mm_maddubs_epi16(q2_2, q8_2);
+            __m128i p3 = _mm_maddubs_epi16(q2_3, q8_3);
+            __m128i p4 = _mm_maddubs_epi16(q2_4, q8_4);
+            __m128i p5 = _mm_maddubs_epi16(q2_5, q8_5);
+            __m128i p6 = _mm_maddubs_epi16(q2_6, q8_6);
+            __m128i p7 = _mm_maddubs_epi16(q2_7, q8_7);
+
+            // isum += (x[i].scales[is++] & 0xF) * isuml in 16bits*8*8 to 32bits*4*8
+            __m128i shuffle = _mm_set1_epi16(0x0100);
+            p0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p0);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p1);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p2);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p3);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p4);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p5);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p6);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p7);
+
+            p0 = _mm_add_epi32(p0, p1);
+            p2 = _mm_add_epi32(p2, p3);
+            p4 = _mm_add_epi32(p4, p5);
+            p6 = _mm_add_epi32(p6, p7);
+
+            // isum in 32bits*4*2
+            sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p0, p2));
+            sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p4, p6));
+        }
+
+        // sumf += dall * isum - dmin * summs in 32bits
+        __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
+        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dall), _mm256_cvtepi32_ps(sumi)), acc);
+    }
+
+    *s = hsum_float_8(acc);
+
+#elif defined __riscv_v_intrinsic
+
+    float sumf = 0;
+    uint8_t temp_01[32] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                            1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
+
+    for (int i = 0; i < nb; ++i) {
+
+        const uint8_t * q2 = x[i].qs;
+        const  int8_t * q8 = y[i].qs;
+        const uint8_t * sc = x[i].scales;
+
+        const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        size_t vl = 16;
+
+        vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl);
+        vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl);
+
+        vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl);
+
+        vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl);
+        vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl);
+        vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl));
+        vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl);
+        vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
+
+        sumf  += dmin * __riscv_vmv_x_s_i32m1_i32(vsums);
+
+        vl = 32;
+
+        vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
+        vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl);
+
+        uint8_t is=0;
+        int isum=0;
+
+        for (int j = 0; j < QK_K/128; ++j) {
+            // load Q2
+            vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl);
+
+            vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl);
+            vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03 , vl);
+            vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03 , vl);
+            vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03 , vl);
+
+            // duplicate scale elements for product
+            vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0+is, vl), vl);
+            vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2+is, vl), vl);
+            vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4+is, vl), vl);
+            vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6+is, vl), vl);
+
+            vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl));
+            vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl));
+            vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl));
+            vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl));
+
+            // load Q8
+            vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl);
+            vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl);
+            vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8+64, vl);
+            vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8+96, vl);
+
+            vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl);
+            vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl);
+            vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl);
+            vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl);
+
+            vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl);
+            vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl);
+
+            isum += __riscv_vmv_x_s_i32m1_i32(isum1);
+
+            q2+=32;  q8+=128;  is=8;
+
+        }
+
+        sumf += dall * isum;
+
+    }
+
+    *s = sumf;
+
+#else
+
+    float sumf = 0;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const uint8_t * q2 = x[i].qs;
+        const  int8_t * q8 = y[i].qs;
+        const uint8_t * sc = x[i].scales;
+
+        int summs = 0;
+        for (int j = 0; j < 16; ++j) {
+            summs += y[i].bsums[j] * (sc[j] >> 4);
+        }
+
+        const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        int isum = 0;
+        int is = 0;
+        int d;
+        for (int k = 0; k < QK_K/128; ++k) {
+            int shift = 0;
+            for (int j = 0; j < 4; ++j) {
+                d = sc[is++] & 0xF;
+                int isuml = 0;
+                for (int l =  0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
+                isum += d * isuml;
+                d = sc[is++] & 0xF;
+                isuml = 0;
+                for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
+                isum += d * isuml;
+                shift += 2;
+                q8 += 32;
+            }
+            q2 += 32;
+        }
+        sumf += dall * isum - dmin * summs;
+    }
+    *s = sumf;
+#endif
+}
+
+#else
+
+void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+
+    const block_q2_K * restrict x = vx;
+    const block_q8_K * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#ifdef __ARM_NEON
+
+    const uint8x16_t m3 = vdupq_n_u8(0x3);
+#if defined(__ARM_FEATURE_DOTPROD)
+    const int32x4_t  vzero = vdupq_n_s32(0);
+#endif
+
+    int8x16x4_t q2bytes;
+
+    uint32_t aux32[2];
+    const uint8_t * scales = (const uint8_t *)aux32;
+
+    float sum = 0;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * (float)x[i].d;
+        const float dmin = -y[i].d * (float)x[i].dmin;
+
+        const uint8_t * restrict q2 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+        const uint32_t * restrict sc = (const uint32_t *)x[i].scales;
+
+        aux32[0] = sc[0] & 0x0f0f0f0f;
+        aux32[1] = (sc[0] >> 4) & 0x0f0f0f0f;
+
+        sum += dmin * (scales[4] * y[i].bsums[0] + scales[5] * y[i].bsums[1] + scales[6] * y[i].bsums[2] + scales[7] * y[i].bsums[3]);
+
+        int isum1 = 0, isum2 = 0;
+
+        const uint8x16_t q2bits = vld1q_u8(q2);
+
+        const int8x16x4_t q8bytes = vld1q_s8_x4(q8);
+
+        q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits, m3));
+        q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 2), m3));
+        q2bytes.val[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 4), m3));
+        q2bytes.val[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 6), m3));
+
+#if defined(__ARM_FEATURE_DOTPROD)
+        isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * scales[0];
+        isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * scales[1];
+        isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[2], q8bytes.val[2])) * scales[2];
+        isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[3], q8bytes.val[3])) * scales[3];
+#else
+        const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
+                                       vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));
+        const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
+                                       vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));
+        isum1 += vaddvq_s16(p1) * scales[0];
+        isum2 += vaddvq_s16(p2) * scales[1];
+
+        const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
+                                       vmull_s8(vget_high_s8(q2bytes.val[2]), vget_high_s8(q8bytes.val[2])));
+        const int16x8_t p4 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
+                                       vmull_s8(vget_high_s8(q2bytes.val[3]), vget_high_s8(q8bytes.val[3])));
+        isum1 += vaddvq_s16(p3) * scales[2];
+        isum2 += vaddvq_s16(p4) * scales[3];
+#endif
+        sum += d * (isum1 + isum2);
+
+    }
+
+    *s = sum;
+
+#elif defined __AVX2__
+
+    const __m256i m3 = _mm256_set1_epi8(3);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    uint32_t ud, um;
+    const uint8_t * restrict db = (const uint8_t *)&ud;
+    const uint8_t * restrict mb = (const uint8_t *)&um;
+
+    float summs = 0;
+
+    // TODO: optimize this
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        const uint8_t * restrict q2 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const uint32_t * restrict sc = (const uint32_t *)x[i].scales;
+        ud = (sc[0] >> 0) & 0x0f0f0f0f;
+        um = (sc[0] >> 4) & 0x0f0f0f0f;
+
+        int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3];
+        summs += dmin * smin;
+
+        const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2);
+        const __m256i q2_0 = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q2bits, 2), q2bits), m3);
+        const __m256i q2_1 = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q2bits, 6), _mm_srli_epi16(q2bits, 4)), m3);
+
+        const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
+        const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
+
+        const __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0);
+        const __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1);
+
+        const __m256i p_0 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 0));
+        const __m256i p_1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 1));
+        const __m256i p_2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 0));
+        const __m256i p_3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 1));
+
+        acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0), acc);
+        acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1), acc);
+        acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2), acc);
+        acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3), acc);
+    }
+
+    *s = hsum_float_8(acc) + summs;
+
+#elif defined __AVX__
+
+    const __m128i m3 = _mm_set1_epi8(3);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    uint32_t ud, um;
+    const uint8_t * restrict db = (const uint8_t *)&ud;
+    const uint8_t * restrict mb = (const uint8_t *)&um;
+
+    float summs = 0;
+
+    // TODO: optimize this
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        const uint8_t * restrict q2 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const uint32_t * restrict sc = (const uint32_t *)x[i].scales;
+        ud = (sc[0] >> 0) & 0x0f0f0f0f;
+        um = (sc[0] >> 4) & 0x0f0f0f0f;
+
+        int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3];
+        summs += dmin * smin;
+
+        const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2);
+        const __m128i q2_0 = _mm_and_si128(q2bits, m3);
+        const __m128i q2_1 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3);
+        const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3);
+        const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3);
+
+        const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
+        const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
+
+        const __m128i p0 = _mm_maddubs_epi16(q2_0, _mm256_extractf128_si256(q8_0, 0));
+        const __m128i p1 = _mm_maddubs_epi16(q2_1, _mm256_extractf128_si256(q8_0, 1));
+        const __m128i p2 = _mm_maddubs_epi16(q2_2, _mm256_extractf128_si256(q8_1, 0));
+        const __m128i p3 = _mm_maddubs_epi16(q2_3, _mm256_extractf128_si256(q8_1, 1));
+
+        const __m256i p_0 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p0, p0)), _mm_cvtepi16_epi32(p0));
+        const __m256i p_1 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p1, p1)), _mm_cvtepi16_epi32(p1));
+        const __m256i p_2 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p2, p2)), _mm_cvtepi16_epi32(p2));
+        const __m256i p_3 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p3, p3)), _mm_cvtepi16_epi32(p3));
+
+        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0)), acc);
+        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1)), acc);
+        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2)), acc);
+        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3)), acc);
+    }
+
+    *s = hsum_float_8(acc) + summs;
+
+#elif defined __riscv_v_intrinsic
+
+    uint32_t aux32[2];
+    const uint8_t * scales = (const uint8_t *)aux32;
+
+    float sumf = 0;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * (float)x[i].d;
+        const float dmin = -y[i].d * (float)x[i].dmin;
+
+        const uint8_t * restrict q2 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+        const uint32_t * restrict sc = (const uint32_t *)x[i].scales;
+
+        aux32[0] = sc[0] & 0x0f0f0f0f;
+        aux32[1] = (sc[0] >> 4) & 0x0f0f0f0f;
+
+        sumf += dmin * (scales[4] * y[i].bsums[0] + scales[5] * y[i].bsums[1] + scales[6] * y[i].bsums[2] + scales[7] * y[i].bsums[3]);
+
+        int isum1 = 0;
+        int isum2 = 0;
+
+        size_t vl = 16;
+
+        vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1);
+
+        // load Q2
+        vuint8mf2_t q2_x = __riscv_vle8_v_u8mf2(q2, vl);
+
+        vint8mf2_t q2_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q2_x, 0x03, vl));
+        vint8mf2_t q2_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q2_x, 0x2, vl), 0x03 , vl));
+        vint8mf2_t q2_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q2_x, 0x4, vl), 0x03 , vl));
+        vint8mf2_t q2_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q2_x, 0x6, vl), 0x03 , vl));
+
+        // load Q8, and take product with Q2
+        vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q2_0, __riscv_vle8_v_i8mf2(q8, vl), vl);
+        vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q2_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl);
+        vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q2_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl);
+        vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q2_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl);
+
+        vint16m1_t vs_0 = __riscv_vredsum_vs_i16m1_i16m1(p0, vzero, vl);
+        vint16m1_t vs_1 = __riscv_vredsum_vs_i16m1_i16m1(p1, vzero, vl);
+        vint16m1_t vs_2 = __riscv_vredsum_vs_i16m1_i16m1(p2, vzero, vl);
+        vint16m1_t vs_3 = __riscv_vredsum_vs_i16m1_i16m1(p3, vzero, vl);
+
+        isum1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[0];
+        isum2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[1];
+        isum1 += __riscv_vmv_x_s_i16m1_i16(vs_2) * scales[2];
+        isum2 += __riscv_vmv_x_s_i16m1_i16(vs_3) * scales[3];
+
+        sumf += d * (isum1 + isum2);
+
+    }
+
+    *s = sumf;
+
+#else
+
+    float sumf = 0;
+
+    int isum[4];
+
+    for (int i = 0; i < nb; ++i) {
+
+        const uint8_t * q2 = x[i].qs;
+        const  int8_t * q8 = y[i].qs;
+        const uint8_t * sc = x[i].scales;
+
+        int summs = 0;
+        for (int j = 0; j < QK_K/16; ++j) {
+            summs += y[i].bsums[j] * (sc[j] >> 4);
+        }
+
+        const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        isum[0] = isum[1] = isum[2] = isum[3] = 0;
+        for (int l =  0; l < 16; ++l) {
+            isum[0] += q8[l+ 0] * ((q2[l] >> 0) & 3);
+            isum[1] += q8[l+16] * ((q2[l] >> 2) & 3);
+            isum[2] += q8[l+32] * ((q2[l] >> 4) & 3);
+            isum[3] += q8[l+48] * ((q2[l] >> 6) & 3);
+        }
+        for (int l = 0; l < 4; ++l) {
+            isum[l] *= (sc[l] & 0xF);
+        }
+        sumf += dall * (isum[0] + isum[1] + isum[2] + isum[3]) - dmin * summs;
+    }
+    *s = sumf;
+#endif
+}
+#endif
+
+#if QK_K == 256
+void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    assert(n % QK_K == 0);
+
+    const uint32_t kmask1 = 0x03030303;
+    const uint32_t kmask2 = 0x0f0f0f0f;
+
+    const block_q3_K * restrict x = vx;
+    const block_q8_K * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#ifdef __ARM_NEON
+
+    uint32_t aux[3];
+    uint32_t utmp[4];
+
+    const uint8x16_t m3b = vdupq_n_u8(0x3);
+#ifdef __ARM_FEATURE_DOTPROD
+    const int32x4_t  vzero = vdupq_n_s32(0);
+#endif
+
+    const uint8x16_t m0 = vdupq_n_u8(1);
+    const uint8x16_t m1 = vshlq_n_u8(m0, 1);
+    const uint8x16_t m2 = vshlq_n_u8(m0, 2);
+    const uint8x16_t m3 = vshlq_n_u8(m0, 3);
+    const int8_t m32 = 32;
+
+    int8x16x4_t q3bytes;
+
+    float sum = 0;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict q3 = x[i].qs;
+        const uint8_t * restrict qh = x[i].hmask;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        uint8x16x2_t qhbits = vld1q_u8_x2(qh);
+
+        uint8x16x4_t q3h;
+
+        int32_t isum = 0;
+
+        // Set up scales
+        memcpy(aux, x[i].scales, 12);
+        utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
+        utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
+        utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
+        utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
+
+        int8_t * scale = (int8_t *)utmp;
+        for (int j = 0; j < 16; ++j) scale[j] -= m32;
+
+        for (int j = 0; j < QK_K/128; ++j) {
+
+            const uint8x16x2_t q3bits = vld1q_u8_x2(q3); q3 += 32;
+            const int8x16x4_t q8bytes_1 = vld1q_s8_x4(q8); q8 += 64;
+            const int8x16x4_t q8bytes_2 = vld1q_s8_x4(q8); q8 += 64;
+
+            q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2);
+            q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2);
+            q3h.val[2] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[0]), 1);
+            q3h.val[3] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[1]), 1);
+
+            q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[0], m3b)), vreinterpretq_s8_u8(q3h.val[0]));
+            q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[1], m3b)), vreinterpretq_s8_u8(q3h.val[1]));
+            q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
+            q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
+
+#if defined(__ARM_FEATURE_DOTPROD)
+            isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0];
+            isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1];
+            isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2];
+            isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3];
+#else
+            int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_1.val[0])),
+                                     vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_1.val[0])));
+            int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_1.val[1])),
+                                     vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_1.val[1])));
+            int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_1.val[2])),
+                                     vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_1.val[2])));
+            int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_1.val[3])),
+                                     vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_1.val[3])));
+            isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
+#endif
+            scale += 4;
+
+            q3h.val[0] = vbicq_u8(m2, qhbits.val[0]);
+            q3h.val[1] = vbicq_u8(m2, qhbits.val[1]);
+            q3h.val[2] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[0]), 1);
+            q3h.val[3] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[1]), 1);
+
+            q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), vreinterpretq_s8_u8(q3h.val[0]));
+            q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), vreinterpretq_s8_u8(q3h.val[1]));
+            q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
+            q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
+
+#if defined(__ARM_FEATURE_DOTPROD)
+            isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0];
+            isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1];
+            isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2];
+            isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3];
+#else
+            p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_2.val[0])),
+                           vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_2.val[0])));
+            p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_2.val[1])),
+                           vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_2.val[1])));
+            p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_2.val[2])),
+                           vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_2.val[2])));
+            p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_2.val[3])),
+                           vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_2.val[3])));
+            isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
+#endif
+            scale += 4;
+
+            if (j == 0) {
+                qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 4);
+                qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 4);
+            }
+
+        }
+        sum += d * isum;
+
+    }
+
+    *s = sum;
+
+#elif defined __AVX2__
+
+    const __m256i m3 = _mm256_set1_epi8(3);
+    const __m256i mone = _mm256_set1_epi8(1);
+    const __m128i m32 = _mm_set1_epi8(32);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    uint32_t aux[3];
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict q3 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        // Set up scales
+        memcpy(aux, x[i].scales, 12);
+        __m128i scales128 = _mm_set_epi32(
+                ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),
+                ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),
+                (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
+                (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
+        scales128 = _mm_sub_epi8(scales128, m32);
+        const __m256i all_scales = _mm256_cvtepi8_epi16(scales128);
+        const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
+        const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
+        const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)};
+
+        // high bit
+        const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].hmask);
+
+        // integer accumulator
+        __m256i sumi = _mm256_setzero_si256();
+
+        int bit = 0;
+        int is  = 0;
+
+        for (int j = 0; j < QK_K/128; ++j) {
+            // load low 2 bits
+            const __m256i q3bits = _mm256_loadu_si256((const __m256i*)q3); q3 += 32;
+
+            // prepare low and high bits
+            const __m256i q3l_0 = _mm256_and_si256(q3bits, m3);
+            const __m256i q3h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
+            ++bit;
+
+            const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3);
+            const __m256i q3h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
+            ++bit;
+
+            const __m256i q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3);
+            const __m256i q3h_2 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
+            ++bit;
+
+            const __m256i q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3);
+            const __m256i q3h_3 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
+            ++bit;
+
+            // load Q8 quants
+            const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+
+            // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
+            // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
+            // and 2 if the high bit was set)
+            __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0);
+            __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1);
+            __m256i q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2);
+            __m256i q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3);
+
+            __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0);
+            __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1);
+            __m256i p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2);
+            __m256i p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3);
+
+            p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
+            p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
+            p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
+            p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
+
+            // multiply with scales
+            p16_0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0);
+            p16_1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1);
+            p16_2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2);
+            p16_3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3);
+
+            // accumulate
+            p16_0 = _mm256_add_epi32(p16_0, p16_1);
+            p16_2 = _mm256_add_epi32(p16_2, p16_3);
+            sumi  = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2));
+
+        }
+
+        // multiply with block scale and accumulate
+        acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
+
+    }
+
+    *s = hsum_float_8(acc);
+
+#elif defined __AVX__
+
+    const __m128i m3 = _mm_set1_epi8(3);
+    const __m128i mone = _mm_set1_epi8(1);
+    const __m128i m32 = _mm_set1_epi8(32);
+    const __m128i m2 = _mm_set1_epi8(2);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    const uint32_t *aux;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict q3 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        // Set up scales
+        aux = (const uint32_t *)x[i].scales;
+        __m128i scales128 = _mm_set_epi32(
+                ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),
+                ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),
+                (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
+                (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
+        scales128 = _mm_sub_epi8(scales128, m32);
+        const __m128i scales_0 = _mm_cvtepi8_epi16(scales128);
+        const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales128, scales128));
+        const __m128i scales[2] = { scales_0, scales_1 };
+
+        // high bit *128*2 from block_q3_K.hmask[QK_K/8]
+        const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].hmask[0]);
+        const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].hmask[16]);
+
+        // integer accumulator
+        __m128i sumi_0 = _mm_setzero_si128();
+        __m128i sumi_1 = _mm_setzero_si128();
+
+        for (int j = 0; j < QK_K/128; ++j) {
+            // load low 2 bits *64*2 from block_q3_K.qs[QK_K/4]
+            const __m128i q3bits_0 = _mm_loadu_si128((const __m128i*)q3); q3 += 16;
+            const __m128i q3bits_1 = _mm_loadu_si128((const __m128i*)q3); q3 += 16;
+
+            // prepare low and high bits
+            const int bit = j << 2;
+
+            const __m128i q3l_0 = _mm_and_si128(q3bits_0, m3);
+            const __m128i q3l_1 = _mm_and_si128(q3bits_1, m3);
+            const __m128i q3h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit)), bit), 2);
+            const __m128i q3h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit)), bit), 2);
+
+            const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 2), m3);
+            const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 2), m3);
+            const __m128i q3h_2 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+1)), bit+1), 2);
+            const __m128i q3h_3 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+1)), bit+1), 2);
+
+            const __m128i q3l_4 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 4), m3);
+            const __m128i q3l_5 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 4), m3);
+            const __m128i q3h_4 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+2)), bit+2), 2);
+            const __m128i q3h_5 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+2)), bit+2), 2);
+
+            const __m128i q3l_6 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 6), m3);
+            const __m128i q3l_7 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 6), m3);
+            const __m128i q3h_6 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+3)), bit+3), 2);
+            const __m128i q3h_7 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+3)), bit+3), 2);
+
+            // load Q8 quants from block_q8_K.qs[QK_K]
+            const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+
+            // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
+            // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
+            // and 2 if the high bit was set)
+            __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, q8_0);
+            __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, q8_1);
+            __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, q8_2);
+            __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, q8_3);
+            __m128i q8s_4 = _mm_maddubs_epi16(q3h_4, q8_4);
+            __m128i q8s_5 = _mm_maddubs_epi16(q3h_5, q8_5);
+            __m128i q8s_6 = _mm_maddubs_epi16(q3h_6, q8_6);
+            __m128i q8s_7 = _mm_maddubs_epi16(q3h_7, q8_7);
+
+            __m128i p16_0 = _mm_maddubs_epi16(q3l_0, q8_0);
+            __m128i p16_1 = _mm_maddubs_epi16(q3l_1, q8_1);
+            __m128i p16_2 = _mm_maddubs_epi16(q3l_2, q8_2);
+            __m128i p16_3 = _mm_maddubs_epi16(q3l_3, q8_3);
+            __m128i p16_4 = _mm_maddubs_epi16(q3l_4, q8_4);
+            __m128i p16_5 = _mm_maddubs_epi16(q3l_5, q8_5);
+            __m128i p16_6 = _mm_maddubs_epi16(q3l_6, q8_6);
+            __m128i p16_7 = _mm_maddubs_epi16(q3l_7, q8_7);
+
+            p16_0 = _mm_sub_epi16(p16_0, q8s_0);
+            p16_1 = _mm_sub_epi16(p16_1, q8s_1);
+            p16_2 = _mm_sub_epi16(p16_2, q8s_2);
+            p16_3 = _mm_sub_epi16(p16_3, q8s_3);
+            p16_4 = _mm_sub_epi16(p16_4, q8s_4);
+            p16_5 = _mm_sub_epi16(p16_5, q8s_5);
+            p16_6 = _mm_sub_epi16(p16_6, q8s_6);
+            p16_7 = _mm_sub_epi16(p16_7, q8s_7);
+
+            // multiply with scales
+            __m128i shuffle = _mm_set1_epi16(0x0100);
+            p16_0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_0);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p16_1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_1);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p16_2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_2);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p16_3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_3);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p16_4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_4);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p16_5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_5);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p16_6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_6);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p16_7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_7);
+
+            // accumulate
+            p16_0 = _mm_add_epi32(p16_0, p16_1);
+            p16_2 = _mm_add_epi32(p16_2, p16_3);
+            p16_4 = _mm_add_epi32(p16_4, p16_5);
+            p16_6 = _mm_add_epi32(p16_6, p16_7);
+            sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
+            sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_4, p16_6));
+
+        }
+
+        // multiply with block scale and accumulate
+        __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
+        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
+
+    }
+
+    *s = hsum_float_8(acc);
+
+#elif defined __riscv_v_intrinsic
+
+    uint32_t aux[3];
+    uint32_t utmp[4];
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+
+        const uint8_t * restrict q3 = x[i].qs;
+        const uint8_t * restrict qh = x[i].hmask;
+        const  int8_t * restrict q8 = y[i].qs;
+
+        memcpy(aux, x[i].scales, 12);
+        utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
+        utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
+        utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
+        utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
+
+        int8_t * scale = (int8_t *)utmp;
+        for (int j = 0; j < 16; ++j) scale[j] -= 32;
+
+
+        size_t vl = 32;
+        uint8_t m =  1;
+
+        vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
+        vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl);
+
+        int sum_t = 0;
+
+        for (int j = 0; j < QK_K; j += 128) {
+
+            vl = 32;
+
+            // load Q3
+            vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl);
+
+            vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl));
+            vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl));
+            vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl));
+            vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl));
+
+            // compute mask for subtraction
+            vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl);
+            vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl);
+            vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_m(vmask_0, q3_0, 0x4, vl);
+            m <<= 1;
+
+            vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl);
+            vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl);
+            vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_m(vmask_1, q3_1, 0x4, vl);
+            m <<= 1;
+
+            vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl);
+            vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl);
+            vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_m(vmask_2, q3_2, 0x4, vl);
+            m <<= 1;
+
+            vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl);
+            vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl);
+            vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_m(vmask_3, q3_3, 0x4, vl);
+            m <<= 1;
+
+            // load Q8 and take product with Q3
+            vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl);
+            vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
+            vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
+            vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
+
+            vl = 16;
+
+            // retreive lane to multiply with scale
+            vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl);
+            vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl);
+            vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl);
+            vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl);
+            vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl);
+            vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl);
+            vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl);
+            vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl);
+
+            vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl);
+            vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl);
+            vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl);
+            vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl);
+
+            sum_t +=  __riscv_vmv_x_s_i32m1_i32(isum3);
+
+            q3 += 32;    q8 += 128;   scale += 8;
+
+        }
+
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+
+        sumf += d*sum_t;
+
+    }
+
+    *s = sumf;
+
+#else
+    // scalar version
+    // This function is written like this so the compiler can manage to vectorize most of it
+    // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the
+    // manually vectorized version above. Every other version I tried would run at least 4 times slower.
+    // The ideal situation would be if we could just write the code once, and the compiler would
+    // automatically produce the best possible set of machine instructions, instead of us having to manually
+    // write vectorized versions for AVX, ARM_NEON, etc.
+
+    int8_t  aux8[QK_K];
+    int16_t aux16[8];
+    float   sums [8];
+    int32_t aux32[8];
+    memset(sums, 0, 8*sizeof(float));
+
+    uint32_t auxs[4];
+    const int8_t * scales = (const int8_t*)auxs;
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+        const uint8_t * restrict q3 = x[i].qs;
+        const uint8_t * restrict hm = x[i].hmask;
+        const  int8_t * restrict q8 = y[i].qs;
+        memset(aux32, 0, 8*sizeof(int32_t));
+        int8_t * restrict a = aux8;
+        uint8_t m = 1;
+        for (int j = 0; j < QK_K; j += 128) {
+            for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3;
+            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
+            a += 32; m <<= 1;
+            for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3;
+            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
+            a += 32; m <<= 1;
+            for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3;
+            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
+            a += 32; m <<= 1;
+            for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3;
+            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
+            a += 32; m <<= 1;
+            q3 += 32;
+        }
+        a = aux8;
+
+        memcpy(auxs, x[i].scales, 12);
+        uint32_t tmp = auxs[2];
+        auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
+        auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
+        auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
+        auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
+        for (int j = 0; j < QK_K/16; ++j) {
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
+            q8 += 8; a += 8;
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
+            q8 += 8; a += 8;
+        }
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
+    }
+    for (int l = 0; l < 8; ++l) sumf += sums[l];
+    *s = sumf;
+
+#endif
+
+}
+
+#else
+
+void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    assert(n % QK_K == 0);
+
+    const block_q3_K * restrict x = vx;
+    const block_q8_K * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#ifdef __ARM_NEON
+
+#ifdef __ARM_FEATURE_DOTPROD
+    const int32x4_t  vzero = vdupq_n_s32(0);
+#endif
+
+    const uint8x16_t m3b = vdupq_n_u8(0x3);
+    const uint8x16_t mh  = vdupq_n_u8(4);
+
+    int8x16x4_t q3bytes;
+
+    uint16_t aux16[2];
+    int8_t * scales = (int8_t *)aux16;
+
+    float sum = 0;
+
+    for (int i = 0; i < nb; ++i) {
+
+        uint8x16x4_t q3h;
+
+        const uint8x8_t  hbits    = vld1_u8(x[i].hmask);
+        const uint8x16_t q3bits   = vld1q_u8(x[i].qs);
+        const int8x16x4_t q8bytes = vld1q_s8_x4(y[i].qs);
+
+        const uint16_t a = *(const uint16_t *)x[i].scales;
+        aux16[0] = a & 0x0f0f;
+        aux16[1] = (a >> 4) & 0x0f0f;
+
+        for (int j = 0; j < 4; ++j) scales[j] -= 8;
+
+        int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]);
+
+        const float d = y[i].d * (float)x[i].d;
+
+        const uint8x16_t htmp = vcombine_u8(hbits, vshr_n_u8(hbits, 1));
+        q3h.val[0] = vandq_u8(mh, vshlq_n_u8(htmp, 2));
+        q3h.val[1] = vandq_u8(mh, htmp);
+        q3h.val[2] = vandq_u8(mh, vshrq_n_u8(htmp, 2));
+        q3h.val[3] = vandq_u8(mh, vshrq_n_u8(htmp, 4));
+
+        q3bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q3bits, m3b),                q3h.val[0]));
+        q3bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 2), m3b), q3h.val[1]));
+        q3bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 4), m3b), q3h.val[2]));
+        q3bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q3bits, 6),                q3h.val[3]));
+
+#if defined(__ARM_FEATURE_DOTPROD)
+        isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0];
+        isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[2];
+        isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[1];
+        isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3];
+#else
+        const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
+                                       vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes.val[0])));
+        const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
+                                       vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes.val[1])));
+        const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
+                                       vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes.val[2])));
+        const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
+                                       vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes.val[3])));
+        isum += vaddvq_s16(p0) * scales[0] + vaddvq_s16(p1) * scales[2] + vaddvq_s16(p2) * scales[1] + vaddvq_s16(p3) * scales[3];
+#endif
+
+        sum += d * isum;
+
+    }
+
+    *s = sum;
+
+#elif defined __AVX2__
+
+    const __m256i m3 = _mm256_set1_epi8(3);
+    const __m256i m1 = _mm256_set1_epi8(1);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    uint64_t aux64;
+
+    uint16_t aux16[2];
+    const int8_t * aux8 = (const int8_t *)aux16;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict q3 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const uint16_t a = *(const uint16_t *)x[i].scales;
+        aux16[0] = a & 0x0f0f;
+        aux16[1] = (a >> 4) & 0x0f0f;
+
+        const __m256i scale_0 = MM256_SET_M128I(_mm_set1_epi16(aux8[2] - 8), _mm_set1_epi16(aux8[0] - 8));
+        const __m256i scale_1 = MM256_SET_M128I(_mm_set1_epi16(aux8[3] - 8), _mm_set1_epi16(aux8[1] - 8));
+
+        memcpy(&aux64, x[i].hmask, 8);
+
+        const __m128i haux = _mm_set_epi64x(aux64 >> 1, aux64 >> 0);
+        __m256i q3h_0 = MM256_SET_M128I(_mm_srli_epi16(haux, 2), haux);
+        __m256i q3h_1 = _mm256_srli_epi16(q3h_0, 4);
+        q3h_0 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_0, m1), 2);
+        q3h_1 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_1, m1), 2);
+
+        // load low 2 bits
+        const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3);
+
+        // prepare low and high bits
+        const __m256i q3aux  = MM256_SET_M128I(_mm_srli_epi16(q3bits, 2), q3bits);
+        const __m256i q3l_0 = _mm256_and_si256(q3aux, m3);
+        const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3aux, 4), m3);
+
+        // load Q8 quants
+        const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
+        const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
+
+        // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
+        // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
+        // and 2 if the high bit was set)
+        const __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0);
+        const __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1);
+
+        __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0);
+        __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1);
+
+        p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
+        p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
+
+        // multiply with scales
+        p16_0 = _mm256_madd_epi16(scale_0, p16_0);
+        p16_1 = _mm256_madd_epi16(scale_1, p16_1);
+
+        p16_0 = _mm256_add_epi32(p16_0, p16_1);
+
+        // multiply with block scale and accumulate
+        acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16_0), acc);
+
+    }
+
+    *s = hsum_float_8(acc);
+
+#elif defined __AVX__
+
+    const __m128i m3 = _mm_set1_epi8(3);
+    const __m128i m1 = _mm_set1_epi8(1);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    uint64_t aux64;
+
+    uint16_t aux16[2];
+    const int8_t * aux8 = (const int8_t *)aux16;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict q3 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const uint16_t a = *(const uint16_t *)x[i].scales;
+        aux16[0] = a & 0x0f0f;
+        aux16[1] = (a >> 4) & 0x0f0f;
+
+        const __m128i scale_0 = _mm_set1_epi16(aux8[0] - 8);
+        const __m128i scale_1 = _mm_set1_epi16(aux8[2] - 8);
+        const __m128i scale_2 = _mm_set1_epi16(aux8[1] - 8);
+        const __m128i scale_3 = _mm_set1_epi16(aux8[3] - 8);
+
+        memcpy(&aux64, x[i].hmask, 8);
+
+        __m128i q3h_0 = _mm_set_epi64x(aux64 >> 1, aux64 >> 0);
+        __m128i q3h_1 = _mm_srli_epi16(q3h_0, 2);
+        __m128i q3h_2 = _mm_srli_epi16(q3h_0, 4);
+        __m128i q3h_3 = _mm_srli_epi16(q3h_0, 6);
+        q3h_0 = _mm_slli_epi16(_mm_andnot_si128(q3h_0, m1), 2);
+        q3h_1 = _mm_slli_epi16(_mm_andnot_si128(q3h_1, m1), 2);
+        q3h_2 = _mm_slli_epi16(_mm_andnot_si128(q3h_2, m1), 2);
+        q3h_3 = _mm_slli_epi16(_mm_andnot_si128(q3h_3, m1), 2);
+
+        // load low 2 bits
+        const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3);
+
+        // prepare low and high bits
+        const __m128i q3l_0 = _mm_and_si128(q3bits, m3);
+        const __m128i q3l_1 = _mm_and_si128(_mm_srli_epi16(q3bits, 2), m3);
+        const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits, 4), m3);
+        const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits, 6), m3);
+
+        // load Q8 quants
+        const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
+        const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
+
+        // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm_maddubs_epi16,
+        // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
+        // and 2 if the high bit was set)
+        const __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, _mm256_extractf128_si256(q8_0, 0));
+        const __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, _mm256_extractf128_si256(q8_0, 1));
+        const __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, _mm256_extractf128_si256(q8_1, 0));
+        const __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, _mm256_extractf128_si256(q8_1, 1));
+
+        __m128i p16_0 = _mm_maddubs_epi16(q3l_0, _mm256_extractf128_si256(q8_0, 0));
+        __m128i p16_1 = _mm_maddubs_epi16(q3l_1, _mm256_extractf128_si256(q8_0, 1));
+        __m128i p16_2 = _mm_maddubs_epi16(q3l_2, _mm256_extractf128_si256(q8_1, 0));
+        __m128i p16_3 = _mm_maddubs_epi16(q3l_3, _mm256_extractf128_si256(q8_1, 1));
+
+        p16_0 = _mm_sub_epi16(p16_0, q8s_0);
+        p16_1 = _mm_sub_epi16(p16_1, q8s_1);
+        p16_2 = _mm_sub_epi16(p16_2, q8s_2);
+        p16_3 = _mm_sub_epi16(p16_3, q8s_3);
+
+        // multiply with scales
+        p16_0 = _mm_madd_epi16(scale_0, p16_0);
+        p16_1 = _mm_madd_epi16(scale_1, p16_1);
+        p16_2 = _mm_madd_epi16(scale_2, p16_2);
+        p16_3 = _mm_madd_epi16(scale_3, p16_3);
+
+        p16_0 = _mm_add_epi32(p16_0, p16_2);
+        p16_1 = _mm_add_epi32(p16_1, p16_3);
+        __m256i p16 = MM256_SET_M128I(p16_1, p16_0);
+
+        // multiply with block scale and accumulate
+        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16)), acc);
+
+    }
+
+    *s = hsum_float_8(acc);
+
+#elif defined __riscv_v_intrinsic
+
+    uint16_t aux16[2];
+    int8_t * scales = (int8_t *)aux16;
+
+    float sumf = 0;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const uint8_t * restrict q3 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const uint16_t a = *(const uint16_t *)x[i].scales;
+        aux16[0] = a & 0x0f0f;
+        aux16[1] = (a >> 4) & 0x0f0f;
+
+        for (int j = 0; j < 4; ++j) scales[j] -= 8;
+
+        int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]);
+
+        const float d = y[i].d * (float)x[i].d;
+
+        vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
+
+        // load qh
+        vuint8mf4_t qh_x1   = __riscv_vle8_v_u8mf4(x[i].hmask, 8);
+        vuint8mf2_t qh_x2   = __riscv_vlmul_ext_v_u8mf4_u8mf2(__riscv_vsrl_vx_u8mf4(qh_x1, 1, 8));
+
+        size_t vl = 16;
+
+        // extend and combine both qh_x1 and qh_x2
+        vuint8mf2_t qh_x = __riscv_vslideup_vx_u8mf2(__riscv_vlmul_ext_v_u8mf4_u8mf2(qh_x1), qh_x2, vl/2, vl);
+
+        vuint8mf2_t qh_0 = __riscv_vand_vx_u8mf2(__riscv_vsll_vx_u8mf2(qh_x, 0x2, vl), 0x4, vl);
+        vuint8mf2_t qh_1 = __riscv_vand_vx_u8mf2(qh_x, 0x4, vl);
+        vuint8mf2_t qh_2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl), 0x4, vl);
+        vuint8mf2_t qh_3 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x4, vl), 0x4, vl);
+
+        // load Q3
+        vuint8mf2_t q3_x  = __riscv_vle8_v_u8mf2(q3, vl);
+
+        vuint8mf2_t q3h_0 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(q3_x, 0x3, vl), qh_0, vl);
+        vuint8mf2_t q3h_1 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 2, vl), 0x3, vl), qh_1, vl);
+        vuint8mf2_t q3h_2 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 4, vl), 0x3, vl), qh_2, vl);
+        vuint8mf2_t q3h_3 = __riscv_vor_vv_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 0x6, vl), qh_3, vl);
+
+        vint8mf2_t q3_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_0);
+        vint8mf2_t q3_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_1);
+        vint8mf2_t q3_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_2);
+        vint8mf2_t q3_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_3);
+
+        // load Q8 and take product with Q3
+        vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q3_0, __riscv_vle8_v_i8mf2(q8, vl), vl);
+        vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q3_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl);
+        vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q3_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl);
+        vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q3_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl);
+
+        vint32m1_t vs_0 = __riscv_vwredsum_vs_i16m1_i32m1(p0, vzero, vl);
+        vint32m1_t vs_1 = __riscv_vwredsum_vs_i16m1_i32m1(p1, vzero, vl);
+        vint32m1_t vs_2 = __riscv_vwredsum_vs_i16m1_i32m1(p2, vzero, vl);
+        vint32m1_t vs_3 = __riscv_vwredsum_vs_i16m1_i32m1(p3, vzero, vl);
+
+        isum += __riscv_vmv_x_s_i32m1_i32(vs_0) * scales[0];
+        isum += __riscv_vmv_x_s_i32m1_i32(vs_1) * scales[2];
+        isum += __riscv_vmv_x_s_i32m1_i32(vs_2) * scales[1];
+        isum += __riscv_vmv_x_s_i32m1_i32(vs_3) * scales[3];
+
+        sumf += d * isum;
+
+    }
+
+    *s = sumf;
+
+#else
+
+    int8_t  aux8[QK_K];
+    int16_t aux16[8];
+    float   sums [8];
+    int32_t aux32[8];
+    int32_t scales[4];
+    memset(sums, 0, 8*sizeof(float));
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+        const uint8_t * restrict q3 = x[i].qs;
+        const uint8_t * restrict hm = x[i].hmask;
+        const  int8_t * restrict q8 = y[i].qs;
+        int8_t * restrict a = aux8;
+        for (int l = 0; l < 8; ++l) {
+            a[l+ 0] = (int8_t)((q3[l+0] >> 0) & 3) - (hm[l] & 0x01 ? 0 : 4);
+            a[l+ 8] = (int8_t)((q3[l+8] >> 0) & 3) - (hm[l] & 0x02 ? 0 : 4);
+            a[l+16] = (int8_t)((q3[l+0] >> 2) & 3) - (hm[l] & 0x04 ? 0 : 4);
+            a[l+24] = (int8_t)((q3[l+8] >> 2) & 3) - (hm[l] & 0x08 ? 0 : 4);
+            a[l+32] = (int8_t)((q3[l+0] >> 4) & 3) - (hm[l] & 0x10 ? 0 : 4);
+            a[l+40] = (int8_t)((q3[l+8] >> 4) & 3) - (hm[l] & 0x20 ? 0 : 4);
+            a[l+48] = (int8_t)((q3[l+0] >> 6) & 3) - (hm[l] & 0x40 ? 0 : 4);
+            a[l+56] = (int8_t)((q3[l+8] >> 6) & 3) - (hm[l] & 0x80 ? 0 : 4);
+        }
+
+        scales[0] = (x[i].scales[0] & 0xF) - 8;
+        scales[1] = (x[i].scales[0] >>  4) - 8;
+        scales[2] = (x[i].scales[1] & 0xF) - 8;
+        scales[3] = (x[i].scales[1] >>  4) - 8;
+
+        memset(aux32, 0, 8*sizeof(int32_t));
+        for (int j = 0; j < QK_K/16; ++j) {
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            q8 += 8; a += 8;
+            for (int l = 0; l < 8; ++l) aux16[l] += q8[l] * a[l];
+            q8 += 8; a += 8;
+            for (int l = 0; l < 8; ++l) aux32[l] += scales[j] * aux16[l];
+        }
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
+    }
+    for (int l = 0; l < 8; ++l) sumf += sums[l];
+    *s = sumf;
+
+#endif
+
+}
+#endif
+
+#if QK_K == 256
+void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    assert(n % QK_K == 0);
+
+    const block_q4_K * restrict x = vx;
+    const block_q8_K * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+    static const uint32_t kmask1 = 0x3f3f3f3f;
+    static const uint32_t kmask2 = 0x0f0f0f0f;
+    static const uint32_t kmask3 = 0x03030303;
+
+    uint32_t utmp[4];
+
+#ifdef __ARM_NEON
+
+    const uint8x16_t m4b = vdupq_n_u8(0xf);
+#ifdef __ARM_FEATURE_DOTPROD
+    const int32x4_t mzero = vdupq_n_s32(0);
+#endif
+
+    int8x16x2_t q4bytes;
+    int8x16x2_t q8bytes;
+
+    float sumf = 0;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
+
+        memcpy(utmp, x[i].scales, 12);
+
+        uint32x2_t mins8 = { 0 };
+        mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0);
+        mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1);
+
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[0] &= kmask1;
+
+        const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
+        const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
+                                         vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
+        sumf -= dmin * vaddvq_s32(prod);
+
+        const uint8_t * scales = (const uint8_t *)utmp;
+
+        const uint8_t * restrict q4 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        int32_t sumi1 = 0;
+        int32_t sumi2 = 0;
+
+        for (int j = 0; j < QK_K/64; ++j) {
+
+            const uint8x16x2_t q4bits = vld1q_u8_x2(q4); q4 += 32;
+
+#ifdef __ARM_FEATURE_DOTPROD
+            q8bytes = vld1q_s8_x2(q8); q8 += 32;
+            q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[0], m4b));
+            q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[1], m4b));
+
+            const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
+            sumi1 += vaddvq_s32(p1) * scales[2*j+0];
+
+            q8bytes = vld1q_s8_x2(q8); q8 += 32;
+            q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
+            q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
+
+            const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
+
+            sumi2 += vaddvq_s32(p2) * scales[2*j+1];
+#else
+            q8bytes = vld1q_s8_x2(q8); q8 += 32;
+            q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[0], m4b));
+            q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[1], m4b));
+            const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
+                                           vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
+            const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
+                                           vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
+            sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) * scales[2*j+0];
+
+            q8bytes = vld1q_s8_x2(q8); q8 += 32;
+            q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
+            q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
+            const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
+                                           vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
+            const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
+                                           vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
+            sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) * scales[2*j+1];
+
+#endif
+        }
+
+        sumf += d * (sumi1 + sumi2);
+
+    }
+
+    *s = sumf;
+
+#elif defined __AVX2__
+
+    const __m256i m4 = _mm256_set1_epi8(0xF);
+
+    __m256 acc = _mm256_setzero_ps();
+    __m128 acc_m = _mm_setzero_ps();
+
+   for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        const uint8_t * restrict q4 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
+
+        const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums);
+        const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
+        const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
+        acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m);
+
+        const __m128i sc128  = _mm256_extracti128_si256(mins_and_scales, 0);
+        const __m256i scales = MM256_SET_M128I(sc128, sc128);
+
+        __m256i sumi = _mm256_setzero_si256();
+
+        for (int j = 0; j < QK_K/64; ++j) {
+
+            const __m256i scale_l = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0));
+            const __m256i scale_h = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1));
+
+            const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
+            const __m256i q4l = _mm256_and_si256(q4bits, m4);
+            const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4);
+
+            const __m256i q8l = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            __m256i p16l = _mm256_maddubs_epi16(q4l, q8l);
+            p16l = _mm256_madd_epi16(scale_l, p16l);
+
+            const __m256i q8h = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            __m256i p16h = _mm256_maddubs_epi16(q4h, q8h);
+            p16h = _mm256_madd_epi16(scale_h, p16h);
+            const __m256i sumj = _mm256_add_epi32(p16l, p16h);
+
+            sumi = _mm256_add_epi32(sumi, sumj);
+        }
+
+        __m256 vd = _mm256_set1_ps(d);
+        acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
+
+    }
+
+    acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));
+    acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));
+
+    *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m);
+
+#elif defined __AVX__
+
+    const __m128i m4 = _mm_set1_epi8(0xF);
+    const __m128i m2 = _mm_set1_epi8(0x2);
+
+    __m256 acc = _mm256_setzero_ps();
+    __m128 acc_m = _mm_setzero_ps();
+
+   for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        const uint8_t * restrict q4 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]);
+        const __m128i scales = _mm_cvtepu8_epi16(utmps);
+        const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps));
+
+        const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]);
+        const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]);
+        const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1);
+        const __m128i prod = _mm_madd_epi16(mins, q8s);
+        acc_m = _mm_add_ps(_mm_mul_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod)), acc_m);
+
+        __m128i sumi_0 = _mm_setzero_si128();
+        __m128i sumi_1 = _mm_setzero_si128();
+
+        __m128i shuffle = _mm_set1_epi16(0x0100);
+        for (int j = 0; j < QK_K/64; ++j) {
+
+            const __m128i scale_l = _mm_shuffle_epi8(scales, shuffle);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            const __m128i scale_h = _mm_shuffle_epi8(scales, shuffle);
+            shuffle = _mm_add_epi16(shuffle, m2);
+
+            __m128i q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
+            const __m128i q4l_0 = _mm_and_si128(q4bits, m4);
+            const __m128i q4h_0 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4);
+            q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
+            const __m128i q4l_1 = _mm_and_si128(q4bits, m4);
+            const __m128i q4h_1 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4);
+
+            const __m128i q8l_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            __m128i p16l = _mm_maddubs_epi16(q4l_0, q8l_0);
+            p16l = _mm_madd_epi16(scale_l, p16l);
+            sumi_0 = _mm_add_epi32(sumi_0, p16l);
+            const __m128i q8l_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            p16l = _mm_maddubs_epi16(q4l_1, q8l_1);
+            p16l = _mm_madd_epi16(scale_l, p16l);
+            sumi_1 = _mm_add_epi32(sumi_1, p16l);
+
+            const __m128i q8h_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            __m128i p16h = _mm_maddubs_epi16(q4h_0, q8h_0);
+            p16h = _mm_madd_epi16(scale_h, p16h);
+            sumi_0 = _mm_add_epi32(sumi_0, p16h);
+            const __m128i q8h_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            p16h = _mm_maddubs_epi16(q4h_1, q8h_1);
+            p16h = _mm_madd_epi16(scale_h, p16h);
+            sumi_1 = _mm_add_epi32(sumi_1, p16h);
+
+        }
+
+        __m256 vd = _mm256_set1_ps(d);
+        __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
+        acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
+
+    }
+
+    acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));
+    acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));
+
+    *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m);
+
+#elif defined __riscv_v_intrinsic
+
+    const uint8_t * scales = (const uint8_t*)&utmp[0];
+    const uint8_t * mins   = (const uint8_t*)&utmp[2];
+
+    float sumf = 0;
+
+    for (int i = 0; i < nb; ++i) {
+
+        size_t vl = 8;
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl);
+        vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl);
+        vint16mf2_t q8sums   = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl);
+
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        vuint8mf4_t mins8  = __riscv_vle8_v_u8mf4(mins, vl);
+        vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl));
+        vint32m1_t  prod   = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl);
+
+        vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
+        sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi);
+
+        const uint8_t * restrict q4 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        vl = 32;
+
+        int32_t sum_1 = 0;
+        int32_t sum_2 = 0;
+
+        vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1);
+
+        for (int j = 0; j < QK_K/64; ++j) {
+            // load Q4
+            vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl);
+
+            // load Q8 and multiply it with lower Q4 nibble
+            vint8m1_t  q8_0 = __riscv_vle8_v_i8m1(q8, vl);
+            vint8m1_t  q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl));
+            vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl);
+            vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl);
+
+            sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0];
+
+            // load Q8 and multiply it with upper Q4 nibble
+            vint8m1_t  q8_1 = __riscv_vle8_v_i8m1(q8+32, vl);
+            vint8m1_t  q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl));
+            vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl);
+            vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl);
+
+            sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1];
+
+            q4 += 32;    q8 += 64;
+
+        }
+
+        sumf += d*(sum_1 + sum_2);
+
+    }
+
+    *s = sumf;
+
+#else
+
+
+    const uint8_t * scales = (const uint8_t*)&utmp[0];
+    const uint8_t * mins   = (const uint8_t*)&utmp[2];
+
+    int8_t  aux8[QK_K];
+    int16_t aux16[8];
+    float   sums [8];
+    int32_t aux32[8];
+    memset(sums, 0, 8*sizeof(float));
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+        const uint8_t * restrict q4 = x[i].qs;
+        const  int8_t * restrict q8 = y[i].qs;
+        memset(aux32, 0, 8*sizeof(int32_t));
+        int8_t * restrict a = aux8;
+        for (int j = 0; j < QK_K/64; ++j) {
+            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
+            a += 32;
+            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l]  >> 4);
+            a += 32; q4 += 32;
+        }
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        int sumi = 0;
+        for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
+        a = aux8;
+        int is = 0;
+        for (int j = 0; j < QK_K/32; ++j) {
+            int32_t scale = scales[is++];
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+        }
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
+        const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
+        sumf -= dmin * sumi;
+    }
+    for (int l = 0; l < 8; ++l) sumf += sums[l];
+    *s = sumf;
+#endif
+}
+#else
+void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    assert(n % QK_K == 0);
+
+    const block_q4_K * restrict x = vx;
+    const block_q8_K * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#ifdef __ARM_NEON
+
+    const uint8x16_t m4b = vdupq_n_u8(0xf);
+
+#ifdef __ARM_FEATURE_DOTPROD
+    const int32x4_t mzero = vdupq_n_s32(0);
+#endif
+
+    float sumf = 0;
+
+    int8x16x2_t q4bytes;
+    int8x16x4_t q8bytes;
+
+    float sum_mins = 0.f;
+
+    uint16_t aux16[2];
+    const uint8_t * restrict scales = (const uint8_t *)aux16;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const uint8_t * restrict q4 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const uint16_t * restrict a = (const uint16_t *)x[i].scales;
+        aux16[0] = a[0] & 0x0f0f;
+        aux16[1] = (a[0] >> 4) & 0x0f0f;
+
+        const int32_t summi = scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]);
+        sum_mins += y[i].d * (float)x[i].d[1] * summi;
+
+        const float d = y[i].d * (float)x[i].d[0];
+
+        const uint8x16x2_t q4bits = vld1q_u8_x2(q4);
+
+#ifdef __ARM_FEATURE_DOTPROD
+        q8bytes = vld1q_s8_x4(q8);
+        q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[0], m4b));
+        q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[1], m4b));
+
+        const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
+        const int32_t sumi1 = vaddvq_s32(p1) * scales[0];
+
+        q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
+        q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
+
+        const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]);
+        const int32_t sumi2 = vaddvq_s32(p2) * scales[1];
+
+#else
+        q8bytes = vld1q_s8_x4(q8);
+        q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[0], m4b));
+        q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[1], m4b));
+        const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
+                                       vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
+        const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
+                                       vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
+        int32_t sumi1 = vaddvq_s16(vaddq_s16(p0, p1)) * scales[0];
+
+        q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
+        q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
+        const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[2])),
+                                       vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[2])));
+        const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[3])),
+                                       vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[3])));
+        int32_t sumi2 = vaddvq_s16(vaddq_s16(p2, p3)) * scales[1];
+
+#endif
+        sumf += d * (sumi1 + sumi2);
+
+    }
+
+    *s = sumf - sum_mins;
+
+#elif defined __AVX2__
+
+    const __m256i m4 = _mm256_set1_epi8(0xF);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    float summs = 0;
+
+    uint16_t aux16[2];
+    const uint8_t * scales = (const uint8_t *)aux16;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = GGML_FP16_TO_FP32(x[i].d[0]) * y[i].d;
+        const float m = GGML_FP16_TO_FP32(x[i].d[1]) * y[i].d;
+        const __m256 vd = _mm256_set1_ps(d);
+
+        const uint16_t * a = (const uint16_t *)x[i].scales;
+        aux16[0] = a[0] & 0x0f0f;
+        aux16[1] = (a[0] >> 4) & 0x0f0f;
+
+        summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
+
+        const uint8_t * restrict q4 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4);
+        const __m256i q4l = _mm256_and_si256(q4bits, m4);
+        const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4);
+
+        const __m256i q8l = _mm256_loadu_si256((const __m256i*)(q8+ 0));
+        const __m256i q8h = _mm256_loadu_si256((const __m256i*)(q8+32));
+
+        const __m256i p16l = _mm256_maddubs_epi16(q4l, q8l);
+        const __m256i p16h = _mm256_maddubs_epi16(q4h, q8h);
+
+        const __m256i p32l = _mm256_madd_epi16(_mm256_set1_epi16(scales[0]), p16l);
+        acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32l), acc);
+
+        const __m256i p32h = _mm256_madd_epi16(_mm256_set1_epi16(scales[1]), p16h);
+        acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32h), acc);
+
+    }
+
+    *s = hsum_float_8(acc) - summs;
+
+#elif defined __AVX__
+
+    const __m128i m4 = _mm_set1_epi8(0xF);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    float summs = 0;
+
+    uint16_t aux16[2];
+    const uint8_t * scales = (const uint8_t *)aux16;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = GGML_FP16_TO_FP32(x[i].d[0]) * y[i].d;
+        const float m = GGML_FP16_TO_FP32(x[i].d[1]) * y[i].d;
+        const __m256 vd = _mm256_set1_ps(d);
+
+        const uint16_t * a = (const uint16_t *)x[i].scales;
+        aux16[0] = a[0] & 0x0f0f;
+        aux16[1] = (a[0] >> 4) & 0x0f0f;
+
+        summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
+
+        const uint8_t * restrict q4 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4);
+        const __m128i q4bits_0 = _mm256_extractf128_si256(q4bits, 0);
+        const __m128i q4bits_1 = _mm256_extractf128_si256(q4bits, 1);
+        const __m128i q4_0 = _mm_and_si128(q4bits_0, m4);
+        const __m128i q4_1 = _mm_and_si128(q4bits_1, m4);
+        const __m128i q4_2 = _mm_and_si128(_mm_srli_epi16(q4bits_0, 4), m4);
+        const __m128i q4_3 = _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4);
+
+        const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
+        const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
+
+        const __m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0));
+        const __m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1));
+        const __m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0));
+        const __m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1));
+
+        const __m128i p32_0 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_0);
+        const __m128i p32_1 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_1);
+        acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(MM256_SET_M128I(p32_1, p32_0))), acc);
+
+        const __m128i p32_2 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_2);
+        const __m128i p32_3 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_3);
+        acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(MM256_SET_M128I(p32_3, p32_2))), acc);
+
+    }
+
+    *s = hsum_float_8(acc) - summs;
+
+#elif defined __riscv_v_intrinsic
+
+    uint16_t s16[2];
+    const uint8_t * restrict scales = (const uint8_t *)s16;
+
+    float sumf = 0;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const uint8_t * restrict q4 = x[i].qs;
+        const  int8_t * restrict q8 = y[i].qs;
+
+        const uint16_t * restrict b = (const uint16_t *)x[i].scales;
+        s16[0] = b[0] & 0x0f0f;
+        s16[1] = (b[0] >> 4) & 0x0f0f;
+
+        sumf -= y[i].d * GGML_FP16_TO_FP32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d[0]);
+
+        size_t vl = 32;
+
+        vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1);
+
+        // load Q4
+        vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl);
+
+        // load Q8 and multiply it with lower Q4 nibble
+        vint8m1_t  q4_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl));
+        vint16m2_t va_0 = __riscv_vwmul_vv_i16m2(q4_a, __riscv_vle8_v_i8m1(q8, vl), vl);
+        vint16m1_t aux1 = __riscv_vredsum_vs_i16m2_i16m1(va_0, vzero, vl);
+
+        sumf += d*scales[0]*__riscv_vmv_x_s_i16m1_i16(aux1);
+
+        // load Q8 and multiply it with upper Q4 nibble
+        vint8m1_t  q4_s = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl));
+        vint16m2_t va_1 = __riscv_vwmul_vv_i16m2(q4_s, __riscv_vle8_v_i8m1(q8+32, vl), vl);
+        vint16m1_t aux2 = __riscv_vredsum_vs_i16m2_i16m1(va_1, vzero, vl);
+
+        sumf += d*scales[1]*__riscv_vmv_x_s_i16m1_i16(aux2);
+
+    }
+
+    *s = sumf;
+
+#else
+
+    uint8_t aux8[QK_K];
+    int16_t aux16[16];
+    float   sums [8];
+    memset(sums, 0, 8*sizeof(float));
+
+    uint16_t s16[2];
+    const uint8_t * restrict scales = (const uint8_t *)s16;
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+        const uint8_t * restrict q4 = x[i].qs;
+        const  int8_t * restrict q8 = y[i].qs;
+        uint8_t * restrict a = aux8;
+        for (int l = 0; l < 32; ++l) a[l+ 0] = q4[l] & 0xF;
+        for (int l = 0; l < 32; ++l) a[l+32] = q4[l]  >> 4;
+
+        const uint16_t * restrict b = (const uint16_t *)x[i].scales;
+        s16[0] = b[0] & 0x0f0f;
+        s16[1] = (b[0] >> 4) & 0x0f0f;
+
+        sumf -= y[i].d * GGML_FP16_TO_FP32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d[0]);
+
+        for (int j = 0; j < QK_K/32; ++j) {
+            for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l];
+            q8 += 16; a += 16;
+            for (int l = 0; l < 16; ++l) aux16[l] += q8[l] * a[l];
+            q8 += 16; a += 16;
+            const float dl = d * scales[j];
+            for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[l+8]);
+        }
+    }
+    for (int l = 0; l < 8; ++l) sumf += sums[l];
+    *s = sumf;
+#endif
+}
+#endif
+
+#if QK_K == 256
+void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    assert(n % QK_K == 0);
+
+    const block_q5_K * restrict x = vx;
+    const block_q8_K * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+    static const uint32_t kmask1 = 0x3f3f3f3f;
+    static const uint32_t kmask2 = 0x0f0f0f0f;
+    static const uint32_t kmask3 = 0x03030303;
+
+    uint32_t utmp[4];
+
+
+#ifdef __ARM_NEON
+
+    const uint8x16_t m4b = vdupq_n_u8(0xf);
+    const uint8x16_t mone = vdupq_n_u8(1);
+    const uint8x16_t mtwo = vdupq_n_u8(2);
+#if defined(__ARM_FEATURE_DOTPROD)
+    const int32x4_t mzero = vdupq_n_s32(0);
+#endif
+
+    int8x16x4_t q5bytes;
+
+    float sumf = 0;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
+
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        const uint8x8_t mins8 = vld1_u8((const uint8_t*)utmp + 8);
+        const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(mins8));
+        const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
+                                         vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
+        int32_t sumi_mins = vaddvq_s32(prod);
+
+        const uint8_t * scales = (const uint8_t *)utmp;
+
+        const uint8_t * restrict q5 = x[i].qs;
+        const uint8_t * restrict qh = x[i].qh;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        uint8x16x2_t qhbits = vld1q_u8_x2(qh);
+
+        uint8x16x4_t q5h;
+
+        int32_t sumi = 0;
+
+        for (int j = 0; j < QK_K/64; ++j) {
+
+            const uint8x16x2_t q5bits = vld1q_u8_x2(q5); q5 += 32;
+            const int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64;
+
+            q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
+            q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
+            q5h.val[2] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[0]), 3);
+            q5h.val[3] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[1]), 3);
+            qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 2);
+            qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 2);
+
+            q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0]));
+            q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1]));
+            q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2]));
+            q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3]));
+
+#if defined(__ARM_FEATURE_DOTPROD)
+
+            sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++;
+            sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++;
+#else
+
+            const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
+                                           vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0])));
+            const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
+                                           vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1])));
+            sumi += vaddvq_s16(vaddq_s16(p0, p1)) * *scales++;
+
+            const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
+                                           vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2])));
+            const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
+                                           vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3])));
+            sumi += vaddvq_s16(vaddq_s16(p2, p3)) * *scales++;
+#endif
+        }
+
+        sumf += d * sumi - dmin * sumi_mins;
+
+    }
+
+    *s = sumf;
+
+#elif defined __AVX2__
+
+    const __m256i m4 = _mm256_set1_epi8(0xF);
+    const __m128i mzero = _mm_setzero_si128();
+    const __m256i mone  = _mm256_set1_epi8(1);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    float summs = 0.f;
+
+   for (int i = 0; i < nb; ++i) {
+
+        const uint8_t * restrict q5 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+#if QK_K == 256
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+#else
+        // TODO
+        const float d = 0, dmin = 0;
+#endif
+
+        const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
+
+        const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums);
+        const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
+        const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
+        const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero);
+        summs += dmin * _mm_extract_epi32(hsum, 0);
+
+        const __m128i sc128  = _mm256_extracti128_si256(mins_and_scales, 0);
+        const __m256i scales = MM256_SET_M128I(sc128, sc128);
+
+        const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].qh);
+        __m256i hmask = mone;
+
+        __m256i sumi = _mm256_setzero_si256();
+
+        int bit = 0;
+
+        for (int j = 0; j < QK_K/64; ++j) {
+
+            const __m256i scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0));
+            const __m256i scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1));
+
+            const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); q5 += 32;
+
+            const __m256i q5l_0 = _mm256_and_si256(q5bits, m4);
+            const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4);
+            const __m256i q5_0  = _mm256_add_epi8(q5l_0, q5h_0);
+            hmask = _mm256_slli_epi16(hmask, 1);
+
+            const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4);
+            const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4);
+            const __m256i q5_1  = _mm256_add_epi8(q5l_1, q5h_1);
+            hmask = _mm256_slli_epi16(hmask, 1);
+
+            const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+
+            __m256i p16_0 = _mm256_maddubs_epi16(q5_0, q8_0);
+            __m256i p16_1 = _mm256_maddubs_epi16(q5_1, q8_1);
+
+            p16_0 = _mm256_madd_epi16(scale_0, p16_0);
+            p16_1 = _mm256_madd_epi16(scale_1, p16_1);
+
+            sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
+
+        }
+
+        __m256 vd = _mm256_set1_ps(d);
+        acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
+
+    }
+
+    *s = hsum_float_8(acc) + summs;
+
+#elif defined __AVX__
+
+    const __m128i m4 = _mm_set1_epi8(0xF);
+    const __m128i mzero = _mm_setzero_si128();
+    const __m128i mone  = _mm_set1_epi8(1);
+    const __m128i m2 = _mm_set1_epi8(2);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    float summs = 0.f;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        const uint8_t * restrict q5 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]);
+        const __m128i scales = _mm_cvtepu8_epi16(utmps);
+        const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps));
+
+        const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]);
+        const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]);
+        const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1);
+        const __m128i prod = _mm_madd_epi16(mins, q8s);
+        const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero);
+        summs += dmin * _mm_extract_epi32(hsum, 0);
+
+        const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].qh[0]);
+        const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].qh[16]);
+        __m128i hmask = mone;
+
+        __m128i sumi_0 = _mm_setzero_si128();
+        __m128i sumi_1 = _mm_setzero_si128();
+
+        int bit = 0;
+
+        __m128i shuffle = _mm_set1_epi16(0x0100);
+        for (int j = 0; j < QK_K/64; ++j) {
+
+            const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle);
+            shuffle = _mm_add_epi16(shuffle, m2);
+
+            const __m128i q5bits_0 = _mm_loadu_si128((const __m128i*)q5); q5 += 16;
+            const __m128i q5bits_1 = _mm_loadu_si128((const __m128i*)q5); q5 += 16;
+
+            __m128i q5l_0 = _mm_and_si128(q5bits_0, m4);
+            __m128i q5l_1 = _mm_and_si128(q5bits_1, m4);
+            __m128i q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4);
+            __m128i q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4);
+            __m128i q5_0  = _mm_add_epi8(q5l_0, q5h_0);
+            __m128i q5_1  = _mm_add_epi8(q5l_1, q5h_1);
+            hmask = _mm_slli_epi16(hmask, 1);
+
+            __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            __m128i p16_0 = _mm_maddubs_epi16(q5_0, q8_0);
+            __m128i p16_1 = _mm_maddubs_epi16(q5_1, q8_1);
+            p16_0 = _mm_madd_epi16(scale_0, p16_0);
+            p16_1 = _mm_madd_epi16(scale_0, p16_1);
+
+            q5l_0 = _mm_and_si128(_mm_srli_epi16(q5bits_0, 4), m4);
+            q5l_1 = _mm_and_si128(_mm_srli_epi16(q5bits_1, 4), m4);
+            q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4);
+            q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4);
+            q5_0  = _mm_add_epi8(q5l_0, q5h_0);
+            q5_1  = _mm_add_epi8(q5l_1, q5h_1);
+            hmask = _mm_slli_epi16(hmask, 1);
+
+            q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            __m128i p16_2 = _mm_maddubs_epi16(q5_0, q8_0);
+            __m128i p16_3 = _mm_maddubs_epi16(q5_1, q8_1);
+            p16_2 = _mm_madd_epi16(scale_1, p16_2);
+            p16_3 = _mm_madd_epi16(scale_1, p16_3);
+
+            sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
+            sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
+
+        }
+
+        __m256 vd = _mm256_set1_ps(d);
+        __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
+        acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
+
+    }
+
+    *s = hsum_float_8(acc) + summs;
+
+#elif defined __riscv_v_intrinsic
+
+    const uint8_t * scales = (const uint8_t*)&utmp[0];
+    const uint8_t * mins   = (const uint8_t*)&utmp[2];
+
+    float sumf = 0;
+    float sums = 0.0;
+
+    size_t vl;
+
+    for (int i = 0; i < nb; ++i) {
+
+        vl = 8;
+
+        const uint8_t * restrict q5 = x[i].qs;
+        const uint8_t * restrict hm = x[i].qh;
+        const  int8_t * restrict q8 = y[i].qs;
+
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
+
+        vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl);
+        vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl);
+        vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl);
+
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl);
+        vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl));
+        vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl);
+
+        vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
+        sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi);
+
+        vl = 32;
+        int32_t aux32 = 0;
+        int is = 0;
+
+        uint8_t m = 1;
+        vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
+        vuint8m1_t vqh = __riscv_vle8_v_u8m1(hm, vl);
+
+        for (int j = 0; j < QK_K/64; ++j) {
+            // load Q5 and Q8
+            vuint8m1_t q5_x = __riscv_vle8_v_u8m1(q5, vl);
+            vint8m1_t  q8_y1 = __riscv_vle8_v_i8m1(q8, vl);
+            vint8m1_t  q8_y2 = __riscv_vle8_v_i8m1(q8+32, vl);
+
+            // compute mask for addition
+            vint8m1_t q5_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q5_x, 0x0F, vl));
+            vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl);
+            vbool8_t vmask_1 = __riscv_vmsne_vx_u8m1_b8(qh_m1, 0, vl);
+            vint8m1_t q5_m1 = __riscv_vadd_vx_i8m1_m(vmask_1, q5_a, 16, vl);
+            m <<= 1;
+
+            vint8m1_t q5_l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q5_x, 0x04, vl));
+            vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl);
+            vbool8_t vmask_2 = __riscv_vmsne_vx_u8m1_b8(qh_m2, 0, vl);
+            vint8m1_t q5_m2 = __riscv_vadd_vx_i8m1_m(vmask_2, q5_l, 16, vl);
+            m <<= 1;
+
+            vint16m2_t v0 = __riscv_vwmul_vv_i16m2(q5_m1, q8_y1, vl);
+            vint16m2_t v1 = __riscv_vwmul_vv_i16m2(q5_m2, q8_y2, vl);
+
+            vint32m4_t vs1 = __riscv_vwmul_vx_i32m4(v0, scales[is++], vl);
+            vint32m4_t vs2 = __riscv_vwmul_vx_i32m4(v1, scales[is++], vl);
+
+            vint32m1_t vacc1 = __riscv_vredsum_vs_i32m4_i32m1(vs1, vzero, vl);
+            vint32m1_t vacc2 = __riscv_vredsum_vs_i32m4_i32m1(vs2, vzero, vl);
+
+            aux32 += __riscv_vmv_x_s_i32m1_i32(vacc1) + __riscv_vmv_x_s_i32m1_i32(vacc2);
+            q5 += 32;    q8 += 64;
+
+        }
+
+        vfloat32m1_t vaux = __riscv_vfmul_vf_f32m1(__riscv_vfmv_v_f_f32m1(aux32, 1), d, 1);
+        sums += __riscv_vfmv_f_s_f32m1_f32(vaux);
+
+    }
+
+    *s = sumf+sums;
+
+#else
+
+    const uint8_t * scales = (const uint8_t*)&utmp[0];
+    const uint8_t * mins   = (const uint8_t*)&utmp[2];
+
+    int8_t  aux8[QK_K];
+    int16_t aux16[8];
+    float   sums [8];
+    int32_t aux32[8];
+    memset(sums, 0, 8*sizeof(float));
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+        const uint8_t * restrict q4 = x[i].qs;
+        const uint8_t * restrict hm = x[i].qh;
+        const  int8_t * restrict q8 = y[i].qs;
+        memset(aux32, 0, 8*sizeof(int32_t));
+        int8_t * restrict a = aux8;
+        uint8_t m = 1;
+        for (int j = 0; j < QK_K/64; ++j) {
+            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
+            for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
+            a += 32; m <<= 1;
+            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l]  >> 4);
+            for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
+            a += 32; m <<= 1;
+            q4 += 32;
+        }
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        int sumi = 0;
+        for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
+        a = aux8;
+        int is = 0;
+        for (int j = 0; j < QK_K/32; ++j) {
+            int32_t scale = scales[is++];
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+        }
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
+        const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
+        sumf -= dmin * sumi;
+    }
+    for (int l = 0; l < 8; ++l) sumf += sums[l];
+    *s = sumf;
+#endif
+}
+
+#else
+
+void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    assert(n % QK_K == 0);
+
+    const block_q5_K * restrict x = vx;
+    const block_q8_K * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#ifdef __ARM_NEON
+
+    const uint8x16_t m4b = vdupq_n_u8(0xf);
+    const uint8x16_t mh = vdupq_n_u8(16);
+#if defined(__ARM_FEATURE_DOTPROD)
+    const int32x4_t mzero = vdupq_n_s32(0);
+#endif
+
+    int8x16x4_t q5bytes;
+    uint8x16x4_t q5h;
+
+    float sumf = 0;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * (float)x[i].d;
+        const int8_t * sc = x[i].scales;
+
+        const uint8_t * restrict q5 = x[i].qs;
+        const uint8_t * restrict qh = x[i].qh;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const uint8x8_t qhbits = vld1_u8(qh);
+
+        const uint8x16x2_t q5bits = vld1q_u8_x2(q5);
+        const int8x16x4_t q8bytes = vld1q_s8_x4(q8);
+
+        const uint8x16_t htmp = vcombine_u8(qhbits, vshr_n_u8(qhbits, 1));
+        q5h.val[0] = vbicq_u8(mh, vshlq_n_u8(htmp, 4));
+        q5h.val[1] = vbicq_u8(mh, vshlq_n_u8(htmp, 2));
+        q5h.val[2] = vbicq_u8(mh, htmp);
+        q5h.val[3] = vbicq_u8(mh, vshrq_n_u8(htmp, 2));
+
+        q5bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[0], m4b)), vreinterpretq_s8_u8(q5h.val[0]));
+        q5bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[1], m4b)), vreinterpretq_s8_u8(q5h.val[1]));
+        q5bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[0], 4)), vreinterpretq_s8_u8(q5h.val[2]));
+        q5bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[1], 4)), vreinterpretq_s8_u8(q5h.val[3]));
+
+#if defined(__ARM_FEATURE_DOTPROD)
+
+        int32_t sumi1 = sc[0] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]));
+        int32_t sumi2 = sc[1] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[1], q8bytes.val[1]));
+        int32_t sumi3 = sc[2] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]));
+        int32_t sumi4 = sc[3] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3]));
+
+        sumf += d * (sumi1 + sumi2 + sumi3 + sumi4);
+
+#else
+
+        const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
+                                       vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0])));
+        const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
+                                       vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1])));
+        int32_t sumi = sc[0] * vaddvq_s16(p0) + sc[1] * vaddvq_s16(p1);
+
+        const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
+                                       vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2])));
+        const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
+                                       vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3])));
+        sumi += sc[2] * vaddvq_s16(p2) + sc[3] * vaddvq_s16(p3);
+
+        sumf += d*sumi;
+#endif
+
+    }
+
+    *s = sumf;
+
+#elif defined __AVX2__
+
+    const __m256i m4 = _mm256_set1_epi8(0xF);
+    const __m256i mone  = _mm256_set1_epi8(1);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    for (int i = 0; i < nb; ++i) {
+
+        const uint8_t * restrict q5 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+        const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5);
+
+        const __m256i scale_l = MM256_SET_M128I(_mm_set1_epi16(x[i].scales[1]), _mm_set1_epi16(x[i].scales[0]));
+        const __m256i scale_h = MM256_SET_M128I(_mm_set1_epi16(x[i].scales[3]), _mm_set1_epi16(x[i].scales[2]));
+
+        int64_t aux64;
+        memcpy(&aux64, x[i].qh, 8);
+        const __m128i haux128 = _mm_set_epi64x(aux64 >> 1, aux64);
+        const __m256i haux256 = MM256_SET_M128I(_mm_srli_epi16(haux128, 2), haux128);
+
+        const __m256i q5h_0 = _mm256_slli_epi16(_mm256_andnot_si256(haux256, mone), 4);
+        const __m256i q5h_1 = _mm256_slli_epi16(_mm256_andnot_si256(_mm256_srli_epi16(haux256, 4), mone), 4);
+
+        const __m256i q5l_0 = _mm256_and_si256(q5bits, m4);
+        const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4);
+
+        const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
+        const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
+
+        const __m256i p16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5l_0, q8_0));
+        const __m256i p16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5l_1, q8_1));
+        const __m256i s16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5h_0, q8_0));
+        const __m256i s16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5h_1, q8_1));
+
+        const __m256i dot = _mm256_sub_epi32(_mm256_add_epi32(p16_0, p16_1), _mm256_add_epi32(s16_0, s16_1));
+
+        acc = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(dot), acc);
+
+    }
+
+    *s = hsum_float_8(acc);
+
+#elif defined __AVX__
+
+    const __m128i m4 = _mm_set1_epi8(0xF);
+    const __m128i mone  = _mm_set1_epi8(1);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    for (int i = 0; i < nb; ++i) {
+
+        const uint8_t * restrict q5 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+        const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5);
+
+        const __m128i scale_0 = _mm_set1_epi16(x[i].scales[0]);
+        const __m128i scale_1 = _mm_set1_epi16(x[i].scales[1]);
+        const __m128i scale_2 = _mm_set1_epi16(x[i].scales[2]);
+        const __m128i scale_3 = _mm_set1_epi16(x[i].scales[3]);
+
+        int64_t aux64;
+        memcpy(&aux64, x[i].qh, 8);
+        const __m128i haux128_0 = _mm_set_epi64x(aux64 >> 1, aux64);
+        const __m128i haux128_1 = _mm_srli_epi16(haux128_0, 2);
+
+        const __m128i q5h_0 = _mm_slli_epi16(_mm_andnot_si128(haux128_0, mone), 4);
+        const __m128i q5h_1 = _mm_slli_epi16(_mm_andnot_si128(haux128_1, mone), 4);
+        const __m128i q5h_2 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_0, 4), mone), 4);
+        const __m128i q5h_3 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_1, 4), mone), 4);
+
+        const __m128i q5l_0 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 0), m4);
+        const __m128i q5l_1 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 1), m4);
+        const __m128i q5l_2 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 0), 4), m4);
+        const __m128i q5l_3 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 1), 4), m4);
+
+        const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
+        const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
+
+        const __m128i p16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5l_0, _mm256_extractf128_si256(q8_0, 0)));
+        const __m128i p16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5l_1, _mm256_extractf128_si256(q8_0, 1)));
+        const __m128i p16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5l_2, _mm256_extractf128_si256(q8_1, 0)));
+        const __m128i p16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5l_3, _mm256_extractf128_si256(q8_1, 1)));
+        const __m128i s16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5h_0, _mm256_extractf128_si256(q8_0, 0)));
+        const __m128i s16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5h_1, _mm256_extractf128_si256(q8_0, 1)));
+        const __m128i s16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5h_2, _mm256_extractf128_si256(q8_1, 0)));
+        const __m128i s16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5h_3, _mm256_extractf128_si256(q8_1, 1)));
+
+        const __m128i dot_0 = _mm_sub_epi32(_mm_add_epi32(p16_0, p16_2), _mm_add_epi32(s16_0, s16_2));
+        const __m128i dot_1 = _mm_sub_epi32(_mm_add_epi32(p16_1, p16_3), _mm_add_epi32(s16_1, s16_3));
+
+        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(dot_1, dot_0))), acc);
+
+    }
+
+    *s = hsum_float_8(acc);
+
+#elif defined __riscv_v_intrinsic
+
+    float sumf = 0;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * (float)x[i].d;
+        const int8_t * sc = x[i].scales;
+
+        const uint8_t * restrict q5 = x[i].qs;
+        const uint8_t * restrict qh = x[i].qh;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
+
+        // load qh
+        vuint8mf4_t qh_x1   = __riscv_vle8_v_u8mf4(qh, 8);
+        vuint8mf2_t qh_x2   = __riscv_vlmul_ext_v_u8mf4_u8mf2(__riscv_vsrl_vx_u8mf4(qh_x1, 1, 8));
+
+        size_t vl = 16;
+
+        // combine both qh_1 and qh_2
+        vuint8mf2_t qh_x = __riscv_vslideup_vx_u8mf2(__riscv_vlmul_ext_v_u8mf4_u8mf2(qh_x1), qh_x2, vl/2, vl);
+
+        vuint8mf2_t qh_h0 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(__riscv_vsll_vx_u8mf2(qh_x, 0x4, vl), vl), 16, vl);
+        vuint8mf2_t qh_h1 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(__riscv_vsll_vx_u8mf2(qh_x, 0x2, vl), vl), 16, vl);
+        vuint8mf2_t qh_h2 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(qh_x, vl), 16, vl);
+        vuint8mf2_t qh_h3 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x4, vl), vl), 16, vl);
+
+        vint8mf2_t qh_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h0);
+        vint8mf2_t qh_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h1);
+        vint8mf2_t qh_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h2);
+        vint8mf2_t qh_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h3);
+
+        // load q5
+        vuint8mf2_t q5_x1  = __riscv_vle8_v_u8mf2(q5, vl);
+        vuint8mf2_t q5_x2  = __riscv_vle8_v_u8mf2(q5+16, vl);
+
+        vint8mf2_t q5s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q5_x1, 0xF, vl));
+        vint8mf2_t q5s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q5_x2, 0xF, vl));
+        vint8mf2_t q5s_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(q5_x1, 0x4, vl));
+        vint8mf2_t q5s_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(q5_x2, 0x4, vl));
+
+        vint8mf2_t q5_0 = __riscv_vsub_vv_i8mf2(q5s_0, qh_0, vl);
+        vint8mf2_t q5_1 = __riscv_vsub_vv_i8mf2(q5s_1, qh_1, vl);
+        vint8mf2_t q5_2 = __riscv_vsub_vv_i8mf2(q5s_2, qh_2, vl);
+        vint8mf2_t q5_3 = __riscv_vsub_vv_i8mf2(q5s_3, qh_3, vl);
+
+        // load Q8 and multiply it with Q5
+        vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q5_0, __riscv_vle8_v_i8mf2(q8, vl), vl);
+        vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q5_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl);
+        vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q5_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl);
+        vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q5_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl);
+
+        vint32m1_t vs_0 = __riscv_vwredsum_vs_i16m1_i32m1(p0, vzero, vl);
+        vint32m1_t vs_1 = __riscv_vwredsum_vs_i16m1_i32m1(p1, vzero, vl);
+        vint32m1_t vs_2 = __riscv_vwredsum_vs_i16m1_i32m1(p2, vzero, vl);
+        vint32m1_t vs_3 = __riscv_vwredsum_vs_i16m1_i32m1(p3, vzero, vl);
+
+        int32_t sumi1 = sc[0] * __riscv_vmv_x_s_i32m1_i32(vs_0);
+        int32_t sumi2 = sc[1] * __riscv_vmv_x_s_i32m1_i32(vs_1);
+        int32_t sumi3 = sc[2] * __riscv_vmv_x_s_i32m1_i32(vs_2);
+        int32_t sumi4 = sc[3] * __riscv_vmv_x_s_i32m1_i32(vs_3);
+
+        sumf += d * (sumi1 + sumi2 + sumi3 + sumi4);
+
+    }
+
+    *s = sumf;
+
+#else
+
+    int8_t aux8[QK_K];
+    int16_t aux16[16];
+    float   sums [8];
+    memset(sums, 0, 8*sizeof(float));
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+        const uint8_t * restrict q4 = x[i].qs;
+        const uint8_t * restrict hm = x[i].qh;
+        const  int8_t * restrict q8 = y[i].qs;
+        int8_t * restrict a = aux8;
+        for (int l = 0; l < 32; ++l) {
+            a[l+ 0] = q4[l] & 0xF;
+            a[l+32] = q4[l]  >> 4;
+        }
+        for (int is = 0; is < 8; ++is) {
+            uint8_t m = 1 << is;
+            for (int l = 0; l < 8; ++l) a[8*is + l] -= (hm[l] & m ? 0 : 16);
+        }
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const int8_t * restrict sc = x[i].scales;
+
+        for (int j = 0; j < QK_K/16; ++j) {
+            const float dl = d * sc[j];
+            for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l <  8; ++l) sums[l] += dl * (aux16[l] + aux16[8+l]);
+            q8 += 16; a += 16;
+        }
+    }
+    for (int l = 0; l < 8; ++l) sumf += sums[l];
+    *s = sumf;
+#endif
+}
+#endif
+
+
+#if QK_K == 256
+void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    assert(n % QK_K == 0);
+
+    const block_q6_K * restrict x = vx;
+    const block_q8_K * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#ifdef __ARM_NEON
+
+    float sum = 0;
+
+    const uint8x16_t m4b = vdupq_n_u8(0xF);
+#if defined(__ARM_FEATURE_DOTPROD)
+    const int32x4_t  vzero = vdupq_n_s32(0);
+#endif
+    //const int8x16_t  m32s = vdupq_n_s8(32);
+
+    const uint8x16_t mone = vdupq_n_u8(3);
+
+    int8x16x4_t q6bytes;
+    uint8x16x4_t q6h;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d_all = GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict q6 = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const int8_t * restrict scale = x[i].scales;
+
+        const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums);
+        const int8x16_t scales = vld1q_s8(scale);
+        const int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))};
+
+        const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])),
+                                                   vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))),
+                                         vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[1]), vget_low_s16 (q6scales.val[1])),
+                                                   vmull_s16(vget_high_s16(q8sums.val[1]), vget_high_s16(q6scales.val[1]))));
+        int32_t isum_mins = vaddvq_s32(prod);
+
+        int32_t isum = 0;
+
+        for (int j = 0; j < QK_K/128; ++j) {
+
+            uint8x16x2_t qhbits = vld1q_u8_x2(qh); qh += 32;
+            uint8x16x4_t q6bits = vld1q_u8_x4(q6); q6 += 64;
+            int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64;
+
+            q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
+            q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
+            uint8x16_t shifted = vshrq_n_u8(qhbits.val[0], 2);
+            q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+            shifted = vshrq_n_u8(qhbits.val[1], 2);
+            q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+
+            //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s);
+            //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s);
+            //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])), m32s);
+            //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])), m32s);
+            q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0]));
+            q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1]));
+            q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2]));
+            q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3]));
+
+#if defined(__ARM_FEATURE_DOTPROD)
+
+            isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
+                    vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
+                    vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
+                    vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
+            scale += 4;
+
+#else
+
+            int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
+                                     vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
+            int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
+                                     vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
+            isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];
+            scale += 2;
+
+            int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
+                                     vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
+            int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
+                                     vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
+            isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1];
+            scale += 2;
+#endif
+
+            q8bytes = vld1q_s8_x4(q8); q8 += 64;
+
+            shifted = vshrq_n_u8(qhbits.val[0], 4);
+            q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+            shifted = vshrq_n_u8(qhbits.val[1], 4);
+            q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+            shifted = vshrq_n_u8(qhbits.val[0], 6);
+            q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+            shifted = vshrq_n_u8(qhbits.val[1], 6);
+            q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+
+            //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])), m32s);
+            //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])), m32s);
+            //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])), m32s);
+            //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])), m32s);
+            q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0]));
+            q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1]));
+            q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2]));
+            q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3]));
+
+#if defined(__ARM_FEATURE_DOTPROD)
+
+            isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
+                    vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
+                    vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
+                    vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
+            scale += 4;
+
+            //for (int l = 0; l < 4; ++l) {
+            //    const int32x4_t p = vdotq_s32(vzero, q6bytes.val[l], q8bytes.val[l]);
+            //    isum += vaddvq_s32(p) * *scale++;
+            //}
+#else
+            p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
+                                    vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
+            p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
+                                    vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
+            isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];
+            scale += 2;
+
+            p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
+                                    vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
+            p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
+                                    vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
+            isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1];
+            scale += 2;
+#endif
+
+        }
+        //sum += isum * d_all * y[i].d;
+        sum += d_all * y[i].d * (isum - 32 * isum_mins);
+
+    }
+    *s = sum;
+
+#elif defined __AVX2__
+
+    const __m256i m4 = _mm256_set1_epi8(0xF);
+    const __m256i m2 = _mm256_set1_epi8(3);
+    const __m256i m32s = _mm256_set1_epi8(32);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict q4 = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
+
+        __m256i sumi = _mm256_setzero_si256();
+
+        int is = 0;
+
+        for (int j = 0; j < QK_K/128; ++j) {
+
+            const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0));
+            const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));
+            const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2));
+            const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3));
+            is += 4;
+
+            const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
+            const __m256i q4bits2 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
+            const __m256i q4bitsH = _mm256_loadu_si256((const __m256i*)qh); qh += 32;
+
+            const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m2), 4);
+            const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 2), m2), 4);
+            const __m256i q4h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 4), m2), 4);
+            const __m256i q4h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 6), m2), 4);
+
+            const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
+            const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1);
+            const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2);
+            const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3);
+
+            const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+
+            __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0);
+            __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1);
+            __m256i q8s_2 = _mm256_maddubs_epi16(m32s, q8_2);
+            __m256i q8s_3 = _mm256_maddubs_epi16(m32s, q8_3);
+
+            __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0);
+            __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1);
+            __m256i p16_2 = _mm256_maddubs_epi16(q4_2, q8_2);
+            __m256i p16_3 = _mm256_maddubs_epi16(q4_3, q8_3);
+
+            p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
+            p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
+            p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
+            p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
+
+            p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0);
+            p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1);
+            p16_2 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_2), p16_2);
+            p16_3 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_3), p16_3);
+
+            sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
+            sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_2, p16_3));
+
+        }
+
+        acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
+    }
+
+    *s = hsum_float_8(acc);
+
+#elif defined __AVX__
+
+    const __m128i m4 = _mm_set1_epi8(0xF);
+    const __m128i m3 = _mm_set1_epi8(3);
+    const __m128i m32s = _mm_set1_epi8(32);
+    const __m128i m2 = _mm_set1_epi8(2);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict q4 = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
+
+        __m128i sumi_0 = _mm_setzero_si128();
+        __m128i sumi_1 = _mm_setzero_si128();
+
+        __m128i shuffle = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
+        for (int j = 0; j < QK_K/128; ++j) {
+
+            const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
+            const __m128i q4bitsH_1 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
+
+            const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4);
+            const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4);
+            const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 2), m3), 4);
+            const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 2), m3), 4);
+            const __m128i q4h_4 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 4), m3), 4);
+            const __m128i q4h_5 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 4), m3), 4);
+            const __m128i q4h_6 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 6), m3), 4);
+            const __m128i q4h_7 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 6), m3), 4);
+
+            const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
+            const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
+            const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
+            const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
+
+            const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m4), q4h_0);
+            const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m4), q4h_1);
+            const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m4), q4h_2);
+            const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m4), q4h_3);
+            const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m4), q4h_4);
+            const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m4), q4h_5);
+            const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m4), q4h_6);
+            const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m4), q4h_7);
+
+            const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+
+            __m128i q8s_0 = _mm_maddubs_epi16(m32s, q8_0);
+            __m128i q8s_1 = _mm_maddubs_epi16(m32s, q8_1);
+            __m128i q8s_2 = _mm_maddubs_epi16(m32s, q8_2);
+            __m128i q8s_3 = _mm_maddubs_epi16(m32s, q8_3);
+            __m128i q8s_4 = _mm_maddubs_epi16(m32s, q8_4);
+            __m128i q8s_5 = _mm_maddubs_epi16(m32s, q8_5);
+            __m128i q8s_6 = _mm_maddubs_epi16(m32s, q8_6);
+            __m128i q8s_7 = _mm_maddubs_epi16(m32s, q8_7);
+
+            __m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0);
+            __m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1);
+            __m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2);
+            __m128i p16_3 = _mm_maddubs_epi16(q4_3, q8_3);
+            __m128i p16_4 = _mm_maddubs_epi16(q4_4, q8_4);
+            __m128i p16_5 = _mm_maddubs_epi16(q4_5, q8_5);
+            __m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6);
+            __m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7);
+
+            p16_0 = _mm_sub_epi16(p16_0, q8s_0);
+            p16_1 = _mm_sub_epi16(p16_1, q8s_1);
+            p16_2 = _mm_sub_epi16(p16_2, q8s_2);
+            p16_3 = _mm_sub_epi16(p16_3, q8s_3);
+            p16_4 = _mm_sub_epi16(p16_4, q8s_4);
+            p16_5 = _mm_sub_epi16(p16_5, q8s_5);
+            p16_6 = _mm_sub_epi16(p16_6, q8s_6);
+            p16_7 = _mm_sub_epi16(p16_7, q8s_7);
+
+            const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle);
+            shuffle = _mm_add_epi8(shuffle, m2);
+            const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle);
+            shuffle = _mm_add_epi8(shuffle, m2);
+            const __m128i scale_2 = _mm_shuffle_epi8(scales, shuffle);
+            shuffle = _mm_add_epi8(shuffle, m2);
+            const __m128i scale_3 = _mm_shuffle_epi8(scales, shuffle);
+            shuffle = _mm_add_epi8(shuffle, m2);
+
+            p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
+            p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1);
+            p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
+            p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3);
+            p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4);
+            p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_2, scale_2)), p16_5);
+            p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6);
+            p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_3, scale_3)), p16_7);
+
+            sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
+            sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
+            sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_4, p16_6));
+            sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_5, p16_7));
+
+        }
+
+        __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
+        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
+    }
+
+    *s = hsum_float_8(acc);
+
+#elif defined __riscv_v_intrinsic
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+
+        const uint8_t * restrict q6 = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        const  int8_t * restrict q8 = y[i].qs;
+
+        const int8_t * restrict scale = x[i].scales;
+
+        size_t vl;
+
+        vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
+
+        int sum_t = 0;
+        int is = 0;
+
+        for (int j = 0; j < QK_K/128; ++j) {
+
+            vl = 32;
+
+            // load qh
+            vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl);
+
+            // load Q6
+            vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl);
+            vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl);
+
+            vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl);
+            vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl);
+            vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl);
+            vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl);
+
+            vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl);
+            vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl);
+            vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl);
+            vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl);
+
+            vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl);
+            vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl);
+            vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl);
+            vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl);
+
+            vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl);
+            vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl);
+            vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl);
+            vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl);
+
+            // load Q8 and take product
+            vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl);
+            vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
+            vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
+            vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
+
+            vl = 16;
+
+            vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl);
+            vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl);
+            vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl);
+            vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl);
+            vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl);
+            vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl);
+            vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl);
+            vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl);
+
+            vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl);
+            vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl);
+            vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl);
+            vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl);
+
+            sum_t += __riscv_vmv_x_s_i32m1_i32(isum3);
+
+            q6 += 64;   qh += 32;   q8 += 128;   is=8;
+
+        }
+
+        sumf += d * sum_t;
+
+    }
+
+    *s = sumf;
+
+#else
+
+    int8_t  aux8[QK_K];
+    int16_t aux16[8];
+    float   sums [8];
+    int32_t aux32[8];
+    memset(sums, 0, 8*sizeof(float));
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+        const uint8_t * restrict q4 = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        const  int8_t * restrict q8 = y[i].qs;
+        memset(aux32, 0, 8*sizeof(int32_t));
+        int8_t * restrict a = aux8;
+        for (int j = 0; j < QK_K; j += 128) {
+            for (int l = 0; l < 32; ++l) {
+                a[l +  0] = (int8_t)((q4[l +  0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
+                a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
+                a[l + 64] = (int8_t)((q4[l +  0] >>  4) | (((qh[l] >> 4) & 3) << 4)) - 32;
+                a[l + 96] = (int8_t)((q4[l + 32] >>  4) | (((qh[l] >> 6) & 3) << 4)) - 32;
+            }
+            a  += 128;
+            q4 += 64;
+            qh += 32;
+        }
+        a = aux8;
+        int is = 0;
+        for (int j = 0; j < QK_K/16; ++j) {
+            int scale = x[i].scales[is++];
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+        }
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
+    }
+    for (int l = 0; l < 8; ++l) sumf += sums[l];
+    *s = sumf;
+#endif
+}
+
+#else
+
+void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    assert(n % QK_K == 0);
+
+    const block_q6_K * restrict x = vx;
+    const block_q8_K * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#ifdef __ARM_NEON
+
+    float sum = 0;
+
+    const uint8x16_t m4b = vdupq_n_u8(0xF);
+    const int8x16_t  m32s = vdupq_n_s8(32);
+#if defined(__ARM_FEATURE_DOTPROD)
+    const int32x4_t  vzero = vdupq_n_s32(0);
+#endif
+
+    const uint8x16_t mone = vdupq_n_u8(3);
+
+    int8x16x4_t q6bytes;
+    uint8x16x4_t q6h;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d_all = (float)x[i].d;
+
+        const uint8_t * restrict q6 = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const int8_t * restrict scale = x[i].scales;
+
+        int32_t isum = 0;
+
+        uint8x16_t   qhbits = vld1q_u8(qh);
+        uint8x16x2_t q6bits = vld1q_u8_x2(q6);
+        int8x16x4_t q8bytes = vld1q_s8_x4(q8);
+
+        q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits), 4);
+        uint8x16_t shifted = vshrq_n_u8(qhbits, 2);
+        q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+        shifted = vshrq_n_u8(qhbits, 4);
+        q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+        shifted = vshrq_n_u8(qhbits, 6);
+        q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+
+        q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s);
+        q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s);
+        q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[2])), m32s);
+        q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[3])), m32s);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+
+        isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
+                vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
+                vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
+                vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
+#else
+
+        int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
+                                 vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
+        int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
+                                 vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
+        isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];
+
+        int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
+                                 vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
+        int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
+                                 vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
+        isum += vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
+#endif
+
+        sum += isum * d_all * y[i].d;
+
+    }
+    *s = sum;
+
+#elif defined __AVX2__
+
+    const __m256i m4 = _mm256_set1_epi8(0xF);
+    const __m256i m2 = _mm256_set1_epi8(3);
+    const __m256i m32s = _mm256_set1_epi8(32);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict q4 = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]);
+        const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]);
+        const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]);
+        const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]);
+
+        __m256i sumi = _mm256_setzero_si256();
+
+        const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1);
+        const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3);
+
+        const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4);
+        const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh);
+
+        const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q4bitsH, 2), q4bitsH), m2), 4);
+        const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q4bitsH, 6), _mm_srli_epi16(q4bitsH, 4)), m2), 4);
+
+        const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
+        const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_1);
+
+        const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
+        const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
+
+        __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0);
+        __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1);
+
+        __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0);
+        __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1);
+
+        p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
+        p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
+
+        p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0);
+        p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1);
+
+        sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
+
+        acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
+    }
+
+    *s = hsum_float_8(acc);
+
+#elif defined __AVX__
+
+    const __m128i m4 = _mm_set1_epi8(0xF);
+    const __m128i m2 = _mm_set1_epi8(3);
+    const __m128i m32s = _mm_set1_epi8(32);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict q4 = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]);
+        const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]);
+        const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]);
+        const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]);
+
+        __m128i sumi_0 = _mm_setzero_si128();
+        __m128i sumi_1 = _mm_setzero_si128();
+
+        const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1);
+        const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3);
+
+        const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4);
+        const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh);
+
+        const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH, m2), 4);
+        const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 2), m2), 4);
+        const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 4), m2), 4);
+        const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 6), m2), 4);
+
+        const __m128i q4_0 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 0), m4), q4h_0);
+        const __m128i q4_1 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 1), m4), q4h_1);
+        const __m128i q4_2 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 0), 4), m4), q4h_2);
+        const __m128i q4_3 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 1), 4), m4), q4h_3);
+
+        const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
+        const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
+
+        __m128i q8s_0 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 0));
+        __m128i q8s_1 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 1));
+        __m128i q8s_2 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 0));
+        __m128i q8s_3 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 1));
+
+        __m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0));
+        __m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1));
+        __m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0));
+        __m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1));
+
+        p16_0 = _mm_sub_epi16(p16_0, q8s_0);
+        p16_1 = _mm_sub_epi16(p16_1, q8s_1);
+        p16_2 = _mm_sub_epi16(p16_2, q8s_2);
+        p16_3 = _mm_sub_epi16(p16_3, q8s_3);
+
+        p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
+        p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1);
+        p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
+        p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3);
+
+        sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
+        sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
+
+        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(MM256_SET_M128I(sumi_1, sumi_0))), acc);
+    }
+
+    *s = hsum_float_8(acc);
+
+#elif defined __riscv_v_intrinsic
+
+    float sumf = 0;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d_all = (float)x[i].d;
+
+        const uint8_t * restrict q6 = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const int8_t * restrict scale = x[i].scales;
+
+        int32_t isum = 0;
+
+        size_t vl = 16;
+
+        vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
+
+        // load Q6
+        vuint8mf2_t q6_0 = __riscv_vle8_v_u8mf2(q6, vl);
+        vuint8mf2_t q6_1 = __riscv_vle8_v_u8mf2(q6+16, vl);
+
+        // load qh
+        vuint8mf2_t qh_x = __riscv_vle8_v_u8mf2(qh, vl);
+
+        vuint8mf2_t qh0 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl);
+        qh_x = __riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl);
+        vuint8mf2_t qh1 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl);
+        qh_x = __riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl);
+        vuint8mf2_t qh2 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl);
+        qh_x = __riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl);
+        vuint8mf2_t qh3 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl);
+
+        vuint8mf2_t q6h_0 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(q6_0, 0xF, vl), qh0, vl);
+        vuint8mf2_t q6h_1 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(q6_1, 0xF, vl), qh1, vl);
+        vuint8mf2_t q6h_2 = __riscv_vor_vv_u8mf2(__riscv_vsrl_vx_u8mf2(q6_0, 0x4, vl), qh2, vl);
+        vuint8mf2_t q6h_3 = __riscv_vor_vv_u8mf2(__riscv_vsrl_vx_u8mf2(q6_1, 0x4, vl), qh3, vl);
+
+        vint8mf2_t q6v_0 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_0), 32, vl);
+        vint8mf2_t q6v_1 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_1), 32, vl);
+        vint8mf2_t q6v_2 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_2), 32, vl);
+        vint8mf2_t q6v_3 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_3), 32, vl);
+
+        // load Q8 and take product
+        vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q6v_0, __riscv_vle8_v_i8mf2(q8, vl), vl);
+        vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q6v_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl);
+        vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q6v_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl);
+        vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q6v_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl);
+
+        vint32m1_t vs_0 = __riscv_vwredsum_vs_i16m1_i32m1(p0, vzero, vl);
+        vint32m1_t vs_1 = __riscv_vwredsum_vs_i16m1_i32m1(p1, vzero, vl);
+        vint32m1_t vs_2 = __riscv_vwredsum_vs_i16m1_i32m1(p2, vzero, vl);
+        vint32m1_t vs_3 = __riscv_vwredsum_vs_i16m1_i32m1(p3, vzero, vl);
+
+        isum += __riscv_vmv_x_s_i32m1_i32(vs_0) * scale[0];
+        isum += __riscv_vmv_x_s_i32m1_i32(vs_1) * scale[1];
+        isum += __riscv_vmv_x_s_i32m1_i32(vs_2) * scale[2];
+        isum += __riscv_vmv_x_s_i32m1_i32(vs_3) * scale[3];
+
+        sumf += isum * d_all * y[i].d;
+
+    }
+
+    *s = sumf;
+
+#else
+
+    int8_t  aux8[QK_K];
+    int16_t aux16[8];
+    float   sums [8];
+    int32_t aux32[8];
+    memset(sums, 0, 8*sizeof(float));
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+        const uint8_t * restrict q4 = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        const  int8_t * restrict q8 = y[i].qs;
+        memset(aux32, 0, 8*sizeof(int32_t));
+        int8_t * restrict a = aux8;
+        for (int l = 0; l < 16; ++l) {
+            a[l+ 0] = (int8_t)((q4[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
+            a[l+16] = (int8_t)((q4[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
+            a[l+32] = (int8_t)((q4[l+ 0] >>  4) | (((qh[l] >> 4) & 3) << 4)) - 32;
+            a[l+48] = (int8_t)((q4[l+16] >>  4) | (((qh[l] >> 6) & 3) << 4)) - 32;
+        }
+        int is = 0;
+        for (int j = 0; j < QK_K/16; ++j) {
+            int scale = x[i].scales[is++];
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+        }
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
+    }
+    for (int l = 0; l < 8; ++l) sumf += sums[l];
+    *s = sumf;
+#endif
+}
+
+#endif
diff --git a/src/ggml-quants.h b/src/ggml-quants.h
new file mode 100644 (file)
index 0000000..70c12c2
--- /dev/null
@@ -0,0 +1,224 @@
+#pragma once
+
+#include "ggml-impl.h"
+
+// GGML internal header
+
+#include <stdint.h>
+#include <stddef.h>
+
+#define QK4_0 32
+typedef struct {
+    ggml_fp16_t d;          // delta
+    uint8_t qs[QK4_0 / 2];  // nibbles / quants
+} block_q4_0;
+static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
+
+#define QK4_1 32
+typedef struct {
+    ggml_fp16_t d;          // delta
+    ggml_fp16_t m;          // min
+    uint8_t qs[QK4_1 / 2];  // nibbles / quants
+} block_q4_1;
+static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_fp16_t) + QK4_1 / 2, "wrong q4_1 block size/padding");
+
+#define QK5_0 32
+typedef struct {
+    ggml_fp16_t d;         // delta
+    uint8_t qh[4];         // 5-th bit of quants
+    uint8_t qs[QK5_0 / 2]; // nibbles / quants
+} block_q5_0;
+static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
+
+#define QK5_1 32
+typedef struct {
+    ggml_fp16_t d;         // delta
+    ggml_fp16_t m;         // min
+    uint8_t qh[4];         // 5-th bit of quants
+    uint8_t qs[QK5_1 / 2]; // nibbles / quants
+} block_q5_1;
+static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
+
+#define QK8_0 32
+typedef struct {
+    ggml_fp16_t d;         // delta
+    int8_t  qs[QK8_0];     // quants
+} block_q8_0;
+static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
+
+#define QK8_1 32
+typedef struct {
+    float d;               // delta
+    float s;               // d * sum(qs[i])
+    int8_t  qs[QK8_1];     // quants
+} block_q8_1;
+static_assert(sizeof(block_q8_1) == 2*sizeof(float) + QK8_1, "wrong q8_1 block size/padding");
+
+//
+// Super-block quantization structures
+//
+
+// Super-block size
+#ifdef GGML_QKK_64
+#define QK_K 64
+#define K_SCALE_SIZE 4
+#else
+#define QK_K 256
+#define K_SCALE_SIZE 12
+#endif
+
+// 2-bit quantization
+// weight is represented as x = a * q + b
+// 16 blocks of 16 elements each
+// Effectively 2.5625 bits per weight
+typedef struct {
+    uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
+    uint8_t qs[QK_K/4];      // quants
+    ggml_fp16_t d;           // super-block scale for quantized scales
+    ggml_fp16_t dmin;        // super-block scale for quantized mins
+} block_q2_K;
+static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
+
+// 3-bit quantization
+// weight is represented as x = a * q
+// 16 blocks of 16 elements each
+// Effectively 3.4375 bits per weight
+#ifdef GGML_QKK_64
+typedef struct {
+    uint8_t hmask[QK_K/8];     // quants - high bit
+    uint8_t qs[QK_K/4];        // quants - low 2 bits
+    uint8_t scales[2];
+    ggml_fp16_t d;             // super-block scale
+} block_q3_K;
+static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 2, "wrong q3_K block size/padding");
+#else
+typedef struct {
+    uint8_t hmask[QK_K/8];     // quants - high bit
+    uint8_t qs[QK_K/4];        // quants - low 2 bits
+    uint8_t scales[12];        // scales, quantized with 6 bits
+    ggml_fp16_t d;             // super-block scale
+} block_q3_K;
+static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 12, "wrong q3_K block size/padding");
+#endif
+
+// 4-bit quantization
+// 8 blocks of 32 elements each
+// weight is represented as x = a * q + b
+// Effectively 4.5 bits per weight
+#ifdef GGML_QKK_64
+typedef struct {
+    ggml_fp16_t d[2];          // super-block scales/mins
+    uint8_t scales[2];         // 4-bit block scales/mins
+    uint8_t qs[QK_K/2];        // 4--bit quants
+} block_q4_K;
+static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding");
+#else
+typedef struct {
+    ggml_fp16_t d;             // super-block scale for quantized scales
+    ggml_fp16_t dmin;          // super-block scale for quantized mins
+    uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
+    uint8_t qs[QK_K/2];        // 4--bit quants
+} block_q4_K;
+static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding");
+#endif
+
+// 5-bit quantization
+// 8 blocks of 32 elements each
+// weight is represented as x = a * q + b
+// Effectively 5.5 bits per weight
+#ifdef GGML_QKK_64
+typedef struct {
+    ggml_fp16_t d;               // super-block scale
+    int8_t  scales[QK_K/16];     // 8-bit block scales
+    uint8_t qh[QK_K/8];          // quants, high bit
+    uint8_t qs[QK_K/2];          // quants, low 4 bits
+} block_q5_K;
+static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding");
+#else
+typedef struct {
+    ggml_fp16_t d;               // super-block scale for quantized scales
+    ggml_fp16_t dmin;            // super-block scale for quantized mins
+    uint8_t scales[K_SCALE_SIZE];   // scales and mins, quantized with 6 bits
+    uint8_t qh[QK_K/8];          // quants, high bit
+    uint8_t qs[QK_K/2];          // quants, low 4 bits
+} block_q5_K;
+static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
+#endif
+
+// 6-bit quantization
+// weight is represented as x = a * q
+// 16 blocks of 16 elements each
+// Effectively 6.5625 bits per weight
+typedef struct {
+    uint8_t ql[QK_K/2];      // quants, lower 4 bits
+    uint8_t qh[QK_K/4];      // quants, upper 2 bits
+    int8_t  scales[QK_K/16]; // scales, quantized with 8 bits
+    ggml_fp16_t d;           // super-block scale
+} block_q6_K;
+static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + QK_K / 16 + 3*QK_K/4, "wrong q6_K block size/padding");
+
+// This is only used for intermediate quantization and dot products
+typedef struct {
+    float   d;              // delta
+    int8_t  qs[QK_K];       // quants
+    int16_t bsums[QK_K/16]; // sum of quants in groups of 16
+} block_q8_K;
+static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding");
+
+
+// Quantization
+void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k);
+void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k);
+void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k);
+void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k);
+void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k);
+void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k);
+
+void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int k);
+void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k);
+void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k);
+void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k);
+void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k);
+void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k);
+
+void quantize_row_q4_0(const float * restrict x, void * restrict y, int k);
+void quantize_row_q4_1(const float * restrict x, void * restrict y, int k);
+void quantize_row_q5_0(const float * restrict x, void * restrict y, int k);
+void quantize_row_q5_1(const float * restrict x, void * restrict y, int k);
+void quantize_row_q8_0(const float * restrict x, void * restrict y, int k);
+void quantize_row_q8_1(const float * restrict x, void * restrict y, int k);
+
+void quantize_row_q2_K(const float * restrict x, void * restrict y, int k);
+void quantize_row_q3_K(const float * restrict x, void * restrict y, int k);
+void quantize_row_q4_K(const float * restrict x, void * restrict y, int k);
+void quantize_row_q5_K(const float * restrict x, void * restrict y, int k);
+void quantize_row_q6_K(const float * restrict x, void * restrict y, int k);
+void quantize_row_q8_K(const float * restrict x, void * restrict y, int k);
+
+// Dequantization
+void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k);
+void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int k);
+void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int k);
+void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int k);
+void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int k);
+//void dequantize_row_q8_1(const block_q8_1 * restrict x, float * restrict y, int k);
+
+void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int k);
+void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k);
+void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k);
+void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k);
+void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k);
+void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k);
+
+// Dot product
+void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+
+void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
index f01363cefb3b62797215486426708053bd603533..af9f96c3332510abf6378e4f95632b7b02d829f7 100644 (file)
@@ -1,11 +1,8 @@
 #define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows
+#define _USE_MATH_DEFINES // For M_PI on MSVC
 
-#include "ggml.h"
 #include "ggml-impl.h"
-
-#ifdef GGML_USE_K_QUANTS
-#include "k_quants.h"
-#endif
+#include "ggml-quants.h"
 
 #if defined(_MSC_VER) || defined(__MINGW32__)
 #include <malloc.h> // using malloc.h with MSC/MINGW
 #include <unistd.h>
 #endif
 
-// static_assert should be a #define, but if it's not,
-// fall back to the _Static_assert C11 keyword.
-// if C99 - static_assert is noop
-// ref: https://stackoverflow.com/a/53923785/4039976
-#ifndef static_assert
-#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L)
-#define static_assert(cond, msg) _Static_assert(cond, msg)
-#else
-#define static_assert(cond, msg) struct global_scope_noop_trick
-#endif
-#endif
-
 #if defined(_MSC_VER)
 // disable "possible loss of data" to avoid hundreds of casts
 // we should just be careful :)
@@ -110,24 +95,17 @@ typedef void * thread_ret_t;
 #include <unistd.h>
 
 #endif
+
 #ifdef GGML_USE_CPU_HBM
 #include <hbwmalloc.h>
 #endif
 
-// __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
-#if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))
-#ifndef __FMA__
-#define __FMA__
-#endif
-#ifndef __F16C__
-#define __F16C__
-#endif
-#ifndef __SSE3__
-#define __SSE3__
-#endif
+#if defined(__APPLE__)
+#include <TargetConditionals.h>
 #endif
 
-#if defined(__linux__) || defined(__APPLE__) || defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__)
+#if (defined(__linux__) || defined(__APPLE__) || defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__)) && \
+    (!defined(TARGET_OS_TV) && !defined(TARGET_OS_WATCH))
 
 #include <sys/wait.h>
 
@@ -290,228 +268,27 @@ inline static void * ggml_aligned_malloc(size_t size) {
 #include "ggml-opencl.h"
 #endif
 
-#undef MIN
-#undef MAX
-#define MIN(a, b) ((a) < (b) ? (a) : (b))
-#define MAX(a, b) ((a) > (b) ? (a) : (b))
-
 // floating point type used to accumulate sums
 typedef double ggml_float;
 
-// 16-bit float
-// on Arm, we use __fp16
-// on x86, we use uint16_t
-#if defined(__ARM_NEON) && !defined(_MSC_VER)
-
-// if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
-//
-//   $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
-//
-#include <arm_neon.h>
-
-#define GGML_COMPUTE_FP16_TO_FP32(x) ((float) (x))
-#define GGML_COMPUTE_FP32_TO_FP16(x) (x)
-
-#define GGML_FP16_TO_FP32(x) ((float) (x))
-#define GGML_FP32_TO_FP16(x) (x)
-
-#else
-
-#ifdef __wasm_simd128__
-#include <wasm_simd128.h>
-#else
-#ifdef __POWER9_VECTOR__
-#include <altivec.h>
-#undef bool
-#define bool _Bool
-#else
-#if defined(_MSC_VER) || defined(__MINGW32__)
-#include <intrin.h>
-#else
-#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__)
-#if !defined(__riscv)
-#include <immintrin.h>
-#endif
-#endif
-#endif
-#endif
-#endif
-
-#ifdef __riscv_v_intrinsic
-#include <riscv_vector.h>
-#endif
-
-#ifdef __F16C__
-
-#ifdef _MSC_VER
-#define GGML_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x)))
-#define GGML_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0)
-#else
-#define GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x)
-#define GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0)
-#endif
-
-#elif defined(__POWER9_VECTOR__)
-
-#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
-#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
-/* the inline asm below is about 12% faster than the lookup method */
-#define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)
-#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
-
-static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
-    register float f;
-    register double d;
-    __asm__(
-        "mtfprd %0,%2\n"
-        "xscvhpdp %0,%0\n"
-        "frsp %1,%0\n" :
-        /* temp */ "=d"(d),
-        /* out */  "=f"(f):
-        /* in */   "r"(h));
-    return f;
-}
-
-static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
-    register double d;
-    register ggml_fp16_t r;
-    __asm__( /* xscvdphp can work on double or single precision */
-        "xscvdphp %0,%2\n"
-        "mffprd %1,%0\n" :
-        /* temp */ "=d"(d),
-        /* out */  "=r"(r):
-        /* in */   "f"(f));
-    return r;
-}
-
-#else
-
-// FP16 <-> FP32
-// ref: https://github.com/Maratyszcza/FP16
-
-static inline float fp32_from_bits(uint32_t w) {
-    union {
-        uint32_t as_bits;
-        float as_value;
-    } fp32;
-    fp32.as_bits = w;
-    return fp32.as_value;
-}
-
-static inline uint32_t fp32_to_bits(float f) {
-    union {
-        float as_value;
-        uint32_t as_bits;
-    } fp32;
-    fp32.as_value = f;
-    return fp32.as_bits;
-}
-
-static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
-    const uint32_t w = (uint32_t) h << 16;
-    const uint32_t sign = w & UINT32_C(0x80000000);
-    const uint32_t two_w = w + w;
-
-    const uint32_t exp_offset = UINT32_C(0xE0) << 23;
-#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
-    const float exp_scale = 0x1.0p-112f;
-#else
-    const float exp_scale = fp32_from_bits(UINT32_C(0x7800000));
-#endif
-    const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
-
-    const uint32_t magic_mask = UINT32_C(126) << 23;
-    const float magic_bias = 0.5f;
-    const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
-
-    const uint32_t denormalized_cutoff = UINT32_C(1) << 27;
-    const uint32_t result = sign |
-        (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value));
-    return fp32_from_bits(result);
-}
-
-static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
-#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
-    const float scale_to_inf = 0x1.0p+112f;
-    const float scale_to_zero = 0x1.0p-110f;
-#else
-    const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000));
-    const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000));
-#endif
-    float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
-
-    const uint32_t w = fp32_to_bits(f);
-    const uint32_t shl1_w = w + w;
-    const uint32_t sign = w & UINT32_C(0x80000000);
-    uint32_t bias = shl1_w & UINT32_C(0xFF000000);
-    if (bias < UINT32_C(0x71000000)) {
-        bias = UINT32_C(0x71000000);
-    }
-
-    base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
-    const uint32_t bits = fp32_to_bits(base);
-    const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
-    const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
-    const uint32_t nonsign = exp_bits + mantissa_bits;
-    return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign);
-}
-
-#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
-#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
-
-#endif // __F16C__
-
-#endif // __ARM_NEON
-
 //
 // global data
 //
 
 // precomputed gelu table for f16 (128 KB)
-static ggml_fp16_t table_gelu_f16[1 << 16];
+static ggml_fp16_t ggml_table_gelu_f16[1 << 16];
 
 // precomputed quick gelu table for f16 (128 KB)
-static ggml_fp16_t table_gelu_quick_f16[1 << 16];
+static ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16];
 
 // precomputed silu table for f16 (128 KB)
-static ggml_fp16_t table_silu_f16[1 << 16];
+static ggml_fp16_t ggml_table_silu_f16[1 << 16];
 
 // precomputed exp table for f16 (128 KB)
-static ggml_fp16_t table_exp_f16[1 << 16];
-
-// precomputed f32 table for f16 (256 KB)
-static float table_f32_f16[1 << 16];
-
-#if defined(__ARM_NEON) || defined(__wasm_simd128__)
-#define B1(c,s,n)  0x ## n ## c ,  0x ## n ## s
-#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
-#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)
-#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)
-#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)
-#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)
-#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)
-#define B8(c,s  ) B7(c,s,     c), B7(c,s,     s)
-
-// precomputed tables for expanding 8bits to 8 bytes:
-static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4
-static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
-#endif
+static ggml_fp16_t ggml_table_exp_f16[1 << 16];
 
-// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
-// so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
-// This is also true for POWER9.
-#if !defined(GGML_FP16_TO_FP32) || !defined(GGML_FP32_TO_FP16)
-
-inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
-    uint16_t s;
-    memcpy(&s, &f, sizeof(uint16_t));
-    return table_f32_f16[s];
-}
-
-#define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x)
-#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
-
-#endif
+// precomputed f32 table for f16 (256 KB) (ggml-impl.h)
+float ggml_table_f32_f16[1 << 16];
 
 // note: do not use these inside ggml.c
 // these are meant to be used via the ggml.h API
@@ -626,3071 +403,816 @@ int64_t ggml_cycles_per_ms(void) {
 
 static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
 
-//
-// quantization
-//
+static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y);
+static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y);
 
-#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
-
-#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
-// multiply int8_t, add results pairwise twice
-static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
-    // Get absolute values of x vectors
-    const __m128i ax = _mm_sign_epi8(x, x);
-    // Sign the values of the y vectors
-    const __m128i sy = _mm_sign_epi8(y, x);
-    // Perform multiplication and create 16-bit values
-    const __m128i dot = _mm_maddubs_epi16(ax, sy);
-    const __m128i ones = _mm_set1_epi16(1);
-    return _mm_madd_epi16(ones, dot);
-}
-
-#if __AVX__ || __AVX2__ || __AVX512F__
-// horizontally add 8 floats
-static inline float hsum_float_8(const __m256 x) {
-    __m128 res = _mm256_extractf128_ps(x, 1);
-    res = _mm_add_ps(res, _mm256_castps256_ps128(x));
-    res = _mm_add_ps(res, _mm_movehl_ps(res, res));
-    res = _mm_add_ss(res, _mm_movehdup_ps(res));
-    return _mm_cvtss_f32(res);
-}
-
-// horizontally add 8 int32_t
-static inline int hsum_i32_8(const __m256i a) {
-    const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
-    const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
-    const __m128i sum64 = _mm_add_epi32(hi64, sum128);
-    const __m128i hi32  = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
-    return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
-}
-
-// horizontally add 4 int32_t
-static inline int hsum_i32_4(const __m128i a) {
-    const __m128i hi64 = _mm_unpackhi_epi64(a, a);
-    const __m128i sum64 = _mm_add_epi32(hi64, a);
-    const __m128i hi32  = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
-    return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
-}
-
-#if defined(__AVX2__) || defined(__AVX512F__)
-// spread 32 bits to 32 bytes { 0x00, 0xFF }
-static inline __m256i bytes_from_bits_32(const uint8_t * x) {
-    uint32_t x32;
-    memcpy(&x32, x, sizeof(uint32_t));
-    const __m256i shuf_mask = _mm256_set_epi64x(
-            0x0303030303030303, 0x0202020202020202,
-            0x0101010101010101, 0x0000000000000000);
-    __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask);
-    const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);
-    bytes = _mm256_or_si256(bytes, bit_mask);
-    return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1));
-}
-
-// Unpack 32 4-bit fields into 32 bytes
-// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
-static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
-{
-    const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi);
-    const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp);
-    const __m256i lowMask = _mm256_set1_epi8( 0xF );
-    return _mm256_and_si256(lowMask, bytes);
-}
-
-// add int16_t pairwise and return as float vector
-static inline __m256 sum_i16_pairs_float(const __m256i x) {
-    const __m256i ones = _mm256_set1_epi16(1);
-    const __m256i summed_pairs = _mm256_madd_epi16(ones, x);
-    return _mm256_cvtepi32_ps(summed_pairs);
-}
-
-static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
-#if __AVXVNNI__
-    const __m256i zero = _mm256_setzero_si256();
-    const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
-    return _mm256_cvtepi32_ps(summed_pairs);
-#else
-    // Perform multiplication and create 16-bit values
-    const __m256i dot = _mm256_maddubs_epi16(ax, sy);
-    return sum_i16_pairs_float(dot);
-#endif
-}
+static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
+    [GGML_TYPE_I8] = {
+        .type_name                = "i8",
+        .blck_size                = 1,
+        .type_size                = sizeof(int8_t),
+        .is_quantized             = false,
+    },
+    [GGML_TYPE_I16] = {
+        .type_name                = "i16",
+        .blck_size                = 1,
+        .type_size                = sizeof(int16_t),
+        .is_quantized             = false,
+    },
+    [GGML_TYPE_I32] = {
+        .type_name                = "i32",
+        .blck_size                = 1,
+        .type_size                = sizeof(int32_t),
+        .is_quantized             = false,
+    },
+    [GGML_TYPE_F32] = {
+        .type_name                = "f32",
+        .blck_size                = 1,
+        .type_size                = sizeof(float),
+        .is_quantized             = false,
+        .vec_dot                  = (ggml_vec_dot_t) ggml_vec_dot_f32,
+        .vec_dot_type             = GGML_TYPE_F32,
+    },
+    [GGML_TYPE_F16] = {
+        .type_name                = "f16",
+        .blck_size                = 1,
+        .type_size                = sizeof(ggml_fp16_t),
+        .is_quantized             = false,
+        .to_float                 = (ggml_to_float_t) ggml_fp16_to_fp32_row,
+        .from_float               = (ggml_from_float_t) ggml_fp32_to_fp16_row,
+        .from_float_reference     = (ggml_from_float_t) ggml_fp32_to_fp16_row,
+        .vec_dot                  = (ggml_vec_dot_t) ggml_vec_dot_f16,
+        .vec_dot_type             = GGML_TYPE_F16,
+    },
+    [GGML_TYPE_Q4_0] = {
+        .type_name                = "q4_0",
+        .blck_size                = QK4_0,
+        .type_size                = sizeof(block_q4_0),
+        .is_quantized             = true,
+        .to_float                 = (ggml_to_float_t) dequantize_row_q4_0,
+        .from_float               = quantize_row_q4_0,
+        .from_float_reference     = (ggml_from_float_t) quantize_row_q4_0_reference,
+        .vec_dot                  = ggml_vec_dot_q4_0_q8_0,
+        .vec_dot_type             = GGML_TYPE_Q8_0,
+    },
+    [GGML_TYPE_Q4_1] = {
+        .type_name                = "q4_1",
+        .blck_size                = QK4_1,
+        .type_size                = sizeof(block_q4_1),
+        .is_quantized             = true,
+        .to_float                 = (ggml_to_float_t) dequantize_row_q4_1,
+        .from_float               = quantize_row_q4_1,
+        .from_float_reference     = (ggml_from_float_t) quantize_row_q4_1_reference,
+        .vec_dot                  = ggml_vec_dot_q4_1_q8_1,
+        .vec_dot_type             = GGML_TYPE_Q8_1,
+    },
+    [4] = { // GGML_TYPE_Q4_2
+        .type_name                = "DEPRECATED",
+        .blck_size                = 0,
+        .type_size                = 0,
+        .is_quantized             = false,
+        .to_float                 = NULL,
+        .from_float               = NULL,
+        .from_float_reference     = NULL,
+        .vec_dot                  = NULL,
+        .vec_dot_type             = GGML_TYPE_COUNT,
+    },
+    [5] = { // GGML_TYPE_Q4_3
+        .type_name                = "DEPRECATED",
+        .blck_size                = 0,
+        .type_size                = 0,
+        .is_quantized             = false,
+        .to_float                 = NULL,
+        .from_float               = NULL,
+        .from_float_reference     = NULL,
+        .vec_dot                  = NULL,
+        .vec_dot_type             = GGML_TYPE_COUNT,
+    },
+    [GGML_TYPE_Q5_0] = {
+        .type_name                = "q5_0",
+        .blck_size                = QK5_0,
+        .type_size                = sizeof(block_q5_0),
+        .is_quantized             = true,
+        .to_float                 = (ggml_to_float_t) dequantize_row_q5_0,
+        .from_float               = quantize_row_q5_0,
+        .from_float_reference     = (ggml_from_float_t) quantize_row_q5_0_reference,
+        .vec_dot                  = ggml_vec_dot_q5_0_q8_0,
+        .vec_dot_type             = GGML_TYPE_Q8_0,
+    },
+    [GGML_TYPE_Q5_1] = {
+        .type_name                = "q5_1",
+        .blck_size                = QK5_1,
+        .type_size                = sizeof(block_q5_1),
+        .is_quantized             = true,
+        .to_float                 = (ggml_to_float_t) dequantize_row_q5_1,
+        .from_float               = quantize_row_q5_1,
+        .from_float_reference     = (ggml_from_float_t) quantize_row_q5_1_reference,
+        .vec_dot                  = ggml_vec_dot_q5_1_q8_1,
+        .vec_dot_type             = GGML_TYPE_Q8_1,
+    },
+    [GGML_TYPE_Q8_0] = {
+        .type_name                = "q8_0",
+        .blck_size                = QK8_0,
+        .type_size                = sizeof(block_q8_0),
+        .is_quantized             = true,
+        .to_float                 = (ggml_to_float_t) dequantize_row_q8_0,
+        .from_float               = quantize_row_q8_0,
+        .from_float_reference     = (ggml_from_float_t) quantize_row_q8_0_reference,
+        .vec_dot                  = ggml_vec_dot_q8_0_q8_0,
+        .vec_dot_type             = GGML_TYPE_Q8_0,
+    },
+    [GGML_TYPE_Q8_1] = {
+        .type_name                = "q8_1",
+        .blck_size                = QK8_1,
+        .type_size                = sizeof(block_q8_1),
+        .is_quantized             = true,
+        .from_float               = quantize_row_q8_1,
+        .from_float_reference     = (ggml_from_float_t) quantize_row_q8_1_reference,
+        .vec_dot_type             = GGML_TYPE_Q8_1,
+    },
+    [GGML_TYPE_Q2_K] = {
+        .type_name                = "q2_K",
+        .blck_size                = QK_K,
+        .type_size                = sizeof(block_q2_K),
+        .is_quantized             = true,
+        .to_float                 = (ggml_to_float_t) dequantize_row_q2_K,
+        .from_float               = quantize_row_q2_K,
+        .from_float_reference     = (ggml_from_float_t) quantize_row_q2_K_reference,
+        .vec_dot                  = ggml_vec_dot_q2_K_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+    },
+    [GGML_TYPE_Q3_K] = {
+        .type_name                = "q3_K",
+        .blck_size                = QK_K,
+        .type_size                = sizeof(block_q3_K),
+        .is_quantized             = true,
+        .to_float                 = (ggml_to_float_t) dequantize_row_q3_K,
+        .from_float               = quantize_row_q3_K,
+        .from_float_reference     = (ggml_from_float_t) quantize_row_q3_K_reference,
+        .vec_dot                  = ggml_vec_dot_q3_K_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+    },
+    [GGML_TYPE_Q4_K] = {
+        .type_name                = "q4_K",
+        .blck_size                = QK_K,
+        .type_size                = sizeof(block_q4_K),
+        .is_quantized             = true,
+        .to_float                 = (ggml_to_float_t) dequantize_row_q4_K,
+        .from_float               = quantize_row_q4_K,
+        .from_float_reference     = (ggml_from_float_t) quantize_row_q4_K_reference,
+        .vec_dot                  = ggml_vec_dot_q4_K_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+    },
+    [GGML_TYPE_Q5_K] = {
+        .type_name                = "q5_K",
+        .blck_size                = QK_K,
+        .type_size                = sizeof(block_q5_K),
+        .is_quantized             = true,
+        .to_float                 = (ggml_to_float_t) dequantize_row_q5_K,
+        .from_float               = quantize_row_q5_K,
+        .from_float_reference     = (ggml_from_float_t) quantize_row_q5_K_reference,
+        .vec_dot                  = ggml_vec_dot_q5_K_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+    },
+    [GGML_TYPE_Q6_K] = {
+        .type_name                = "q6_K",
+        .blck_size                = QK_K,
+        .type_size                = sizeof(block_q6_K),
+        .is_quantized             = true,
+        .to_float                 = (ggml_to_float_t) dequantize_row_q6_K,
+        .from_float               = quantize_row_q6_K,
+        .from_float_reference     = (ggml_from_float_t) quantize_row_q6_K_reference,
+        .vec_dot                  = ggml_vec_dot_q6_K_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+    },
+    [GGML_TYPE_Q8_K] = {
+        .type_name                = "q8_K",
+        .blck_size                = QK_K,
+        .type_size                = sizeof(block_q8_K),
+        .is_quantized             = true,
+        .from_float               = quantize_row_q8_K,
+    }
+};
 
-// multiply int8_t, add results pairwise twice and return as float vector
-static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
-#if __AVXVNNIINT8__
-    const __m256i zero = _mm256_setzero_si256();
-    const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y);
-    return _mm256_cvtepi32_ps(summed_pairs);
-#else
-    // Get absolute values of x vectors
-    const __m256i ax = _mm256_sign_epi8(x, x);
-    // Sign the values of the y vectors
-    const __m256i sy = _mm256_sign_epi8(y, x);
-    return mul_sum_us8_pairs_float(ax, sy);
-#endif
+// For internal test use
+ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
+    GGML_ASSERT(type < GGML_TYPE_COUNT);
+    return type_traits[type];
 }
 
-static inline __m128i packNibbles( __m256i bytes )
-{
-    // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
-#if __AVX512F__
-    const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4);   // 0000_0000_abcd_0000
-    bytes = _mm256_or_si256(bytes, bytes_srli_4);               // 0000_abcd_abcd_efgh
-    return _mm256_cvtepi16_epi8(bytes);                         // abcd_efgh
-#else
-    const __m256i lowByte = _mm256_set1_epi16( 0xFF );
-    __m256i high = _mm256_andnot_si256( lowByte, bytes );
-    __m256i low = _mm256_and_si256( lowByte, bytes );
-    high = _mm256_srli_epi16( high, 4 );
-    bytes = _mm256_or_si256( low, high );
-
-    // Compress uint16_t lanes into bytes
-    __m128i r0 = _mm256_castsi256_si128( bytes );
-    __m128i r1 = _mm256_extracti128_si256( bytes, 1 );
-    return _mm_packus_epi16( r0, r1 );
-#endif
-}
-#elif defined(__AVX__)
-// spread 32 bits to 32 bytes { 0x00, 0xFF }
-static inline __m256i bytes_from_bits_32(const uint8_t * x) {
-    uint32_t x32;
-    memcpy(&x32, x, sizeof(uint32_t));
-    const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
-    const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202);
-    __m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl);
-    __m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh);
-    const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe);
-    bytesl = _mm_or_si128(bytesl, bit_mask);
-    bytesh = _mm_or_si128(bytesh, bit_mask);
-    bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1));
-    bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1));
-    return MM256_SET_M128I(bytesh, bytesl);
-}
-
-// Unpack 32 4-bit fields into 32 bytes
-// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
-static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
-{
-    // Load 16 bytes from memory
-    __m128i tmpl = _mm_loadu_si128((const __m128i *)rsi);
-    __m128i tmph = _mm_srli_epi16(tmpl, 4);
-    const __m128i lowMask = _mm_set1_epi8(0xF);
-    tmpl = _mm_and_si128(lowMask, tmpl);
-    tmph = _mm_and_si128(lowMask, tmph);
-    return MM256_SET_M128I(tmph, tmpl);
-}
-
-// add int16_t pairwise and return as float vector
-static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
-    const __m128i ones = _mm_set1_epi16(1);
-    const __m128i summed_pairsl = _mm_madd_epi16(ones, xl);
-    const __m128i summed_pairsh = _mm_madd_epi16(ones, xh);
-    const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl);
-    return _mm256_cvtepi32_ps(summed_pairs);
-}
-
-static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
-    const __m128i axl = _mm256_castsi256_si128(ax);
-    const __m128i axh = _mm256_extractf128_si256(ax, 1);
-    const __m128i syl = _mm256_castsi256_si128(sy);
-    const __m128i syh = _mm256_extractf128_si256(sy, 1);
-    // Perform multiplication and create 16-bit values
-    const __m128i dotl = _mm_maddubs_epi16(axl, syl);
-    const __m128i doth = _mm_maddubs_epi16(axh, syh);
-    return sum_i16_pairs_float(doth, dotl);
-}
-
-// multiply int8_t, add results pairwise twice and return as float vector
-static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
-    const __m128i xl = _mm256_castsi256_si128(x);
-    const __m128i xh = _mm256_extractf128_si256(x, 1);
-    const __m128i yl = _mm256_castsi256_si128(y);
-    const __m128i yh = _mm256_extractf128_si256(y, 1);
-    // Get absolute values of x vectors
-    const __m128i axl = _mm_sign_epi8(xl, xl);
-    const __m128i axh = _mm_sign_epi8(xh, xh);
-    // Sign the values of the y vectors
-    const __m128i syl = _mm_sign_epi8(yl, xl);
-    const __m128i syh = _mm_sign_epi8(yh, xh);
-    // Perform multiplication and create 16-bit values
-    const __m128i dotl = _mm_maddubs_epi16(axl, syl);
-    const __m128i doth = _mm_maddubs_epi16(axh, syh);
-    return sum_i16_pairs_float(doth, dotl);
-}
-
-static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
-{
-    // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
-    const __m128i lowByte = _mm_set1_epi16( 0xFF );
-    __m128i high = _mm_andnot_si128( lowByte, bytes1 );
-    __m128i low = _mm_and_si128( lowByte, bytes1 );
-    high = _mm_srli_epi16( high, 4 );
-    bytes1 = _mm_or_si128( low, high );
-    high = _mm_andnot_si128( lowByte, bytes2 );
-    low = _mm_and_si128( lowByte, bytes2 );
-    high = _mm_srli_epi16( high, 4 );
-    bytes2 = _mm_or_si128( low, high );
-
-    return _mm_packus_epi16( bytes1, bytes2);
-}
-#endif
-#elif defined(__SSSE3__)
-// horizontally add 4x4 floats
-static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) {
-    __m128 res_0 =_mm_hadd_ps(a, b);
-    __m128 res_1 =_mm_hadd_ps(c, d);
-    __m128 res =_mm_hadd_ps(res_0, res_1);
-    res =_mm_hadd_ps(res, res);
-    res =_mm_hadd_ps(res, res);
-
-    return _mm_cvtss_f32(res);
-}
-#endif // __AVX__ || __AVX2__ || __AVX512F__
-#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
-
-#if defined(__ARM_NEON)
-
-#if !defined(__aarch64__)
-
-inline static int32_t vaddvq_s32(int32x4_t v) {
-    return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
-}
-
-inline static float vaddvq_f32(float32x4_t v) {
-    return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
-}
-
-inline static float vmaxvq_f32(float32x4_t v) {
-    return
-        MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
-            MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
-}
-
-inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
-    int32x4_t res;
-
-    res[0] = roundf(vgetq_lane_f32(v, 0));
-    res[1] = roundf(vgetq_lane_f32(v, 1));
-    res[2] = roundf(vgetq_lane_f32(v, 2));
-    res[3] = roundf(vgetq_lane_f32(v, 3));
-
-    return res;
-}
-
-#endif
-#endif
-
-#define QK4_0 32
-typedef struct {
-    ggml_fp16_t d;          // delta
-    uint8_t qs[QK4_0 / 2];  // nibbles / quants
-} block_q4_0;
-static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
-
-#define QK4_1 32
-typedef struct {
-    ggml_fp16_t d;          // delta
-    ggml_fp16_t m;          // min
-    uint8_t qs[QK4_1 / 2];  // nibbles / quants
-} block_q4_1;
-static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_fp16_t) + QK4_1 / 2, "wrong q4_1 block size/padding");
-
-#define QK5_0 32
-typedef struct {
-    ggml_fp16_t d;         // delta
-    uint8_t qh[4];         // 5-th bit of quants
-    uint8_t qs[QK5_0 / 2]; // nibbles / quants
-} block_q5_0;
-static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
-
-#define QK5_1 32
-typedef struct {
-    ggml_fp16_t d;         // delta
-    ggml_fp16_t m;         // min
-    uint8_t qh[4];         // 5-th bit of quants
-    uint8_t qs[QK5_1 / 2]; // nibbles / quants
-} block_q5_1;
-static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
-
-#define QK8_0 32
-typedef struct {
-    ggml_fp16_t d;         // delta
-    int8_t  qs[QK8_0];     // quants
-} block_q8_0;
-static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
-
-#define QK8_1 32
-typedef struct {
-    float d;               // delta
-    float s;               // d * sum(qs[i])
-    int8_t  qs[QK8_1];     // quants
-} block_q8_1;
-static_assert(sizeof(block_q8_1) == 2*sizeof(float) + QK8_1, "wrong q8_1 block size/padding");
-
-// reference implementation for deterministic creation of model files
-static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
-    static const int qk = QK4_0;
-
-    assert(k % qk == 0);
-
-    const int nb = k / qk;
-
-    for (int i = 0; i < nb; i++) {
-        float amax = 0.0f; // absolute max
-        float max  = 0.0f;
-
-        for (int j = 0; j < qk; j++) {
-            const float v = x[i*qk + j];
-            if (amax < fabsf(v)) {
-                amax = fabsf(v);
-                max  = v;
-            }
-        }
-
-        const float d  = max / -8;
-        const float id = d ? 1.0f/d : 0.0f;
-
-        y[i].d = GGML_FP32_TO_FP16(d);
-
-        for (int j = 0; j < qk/2; ++j) {
-            const float x0 = x[i*qk + 0    + j]*id;
-            const float x1 = x[i*qk + qk/2 + j]*id;
-
-            const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
-            const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
-
-            y[i].qs[j]  = xi0;
-            y[i].qs[j] |= xi1 << 4;
-        }
-    }
-}
-
-static void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
-    quantize_row_q4_0_reference(x, y, k);
-}
-
-static void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k) {
-    const int qk = QK4_1;
-
-    assert(k % qk == 0);
-
-    const int nb = k / qk;
-
-    for (int i = 0; i < nb; i++) {
-        float min = FLT_MAX;
-        float max = -FLT_MAX;
-
-        for (int j = 0; j < qk; j++) {
-            const float v = x[i*qk + j];
-
-            if (v < min) min = v;
-            if (v > max) max = v;
-        }
-
-        const float d  = (max - min) / ((1 << 4) - 1);
-        const float id = d ? 1.0f/d : 0.0f;
-
-        y[i].d = GGML_FP32_TO_FP16(d);
-        y[i].m = GGML_FP32_TO_FP16(min);
-
-        for (int j = 0; j < qk/2; ++j) {
-            const float x0 = (x[i*qk + 0    + j] - min)*id;
-            const float x1 = (x[i*qk + qk/2 + j] - min)*id;
-
-            const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
-            const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
-
-            y[i].qs[j]  = xi0;
-            y[i].qs[j] |= xi1 << 4;
-        }
-    }
-}
-
-static void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
-    quantize_row_q4_1_reference(x, y, k);
-}
-
-static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) {
-    static const int qk = QK5_0;
-
-    assert(k % qk == 0);
-
-    const int nb = k / qk;
-
-    for (int i = 0; i < nb; i++) {
-        float amax = 0.0f; // absolute max
-        float max  = 0.0f;
-
-        for (int j = 0; j < qk; j++) {
-            const float v = x[i*qk + j];
-            if (amax < fabsf(v)) {
-                amax = fabsf(v);
-                max  = v;
-            }
-        }
-
-        const float d  = max / -16;
-        const float id = d ? 1.0f/d : 0.0f;
-
-        y[i].d = GGML_FP32_TO_FP16(d);
-
-        uint32_t qh = 0;
-
-        for (int j = 0; j < qk/2; ++j) {
-            const float x0 = x[i*qk + 0    + j]*id;
-            const float x1 = x[i*qk + qk/2 + j]*id;
-
-            const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
-            const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
-
-            y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
-
-            // get the 5-th bit and store it in qh at the right position
-            qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
-            qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);
-        }
-
-        memcpy(&y[i].qh, &qh, sizeof(qh));
-    }
-}
-
-static void quantize_row_q5_0(const float * restrict x, void * restrict y, int k) {
-    quantize_row_q5_0_reference(x, y, k);
-}
-
-static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) {
-    const int qk = QK5_1;
-
-    assert(k % qk == 0);
-
-    const int nb = k / qk;
-
-    for (int i = 0; i < nb; i++) {
-        float min = FLT_MAX;
-        float max = -FLT_MAX;
-
-        for (int j = 0; j < qk; j++) {
-            const float v = x[i*qk + j];
-
-            if (v < min) min = v;
-            if (v > max) max = v;
-        }
-
-        const float d  = (max - min) / ((1 << 5) - 1);
-        const float id = d ? 1.0f/d : 0.0f;
-
-        y[i].d = GGML_FP32_TO_FP16(d);
-        y[i].m = GGML_FP32_TO_FP16(min);
-
-        uint32_t qh = 0;
-
-        for (int j = 0; j < qk/2; ++j) {
-            const float x0 = (x[i*qk + 0    + j] - min)*id;
-            const float x1 = (x[i*qk + qk/2 + j] - min)*id;
-
-            const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
-            const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
-
-            y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
-
-            // get the 5-th bit and store it in qh at the right position
-            qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
-            qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);
-        }
-
-        memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
-    }
-}
-
-static void quantize_row_q5_1(const float * restrict x, void * restrict y, int k) {
-    quantize_row_q5_1_reference(x, y, k);
-}
-
-// reference implementation for deterministic creation of model files
-static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) {
-    assert(k % QK8_0 == 0);
-    const int nb = k / QK8_0;
-
-    for (int i = 0; i < nb; i++) {
-        float amax = 0.0f; // absolute max
-
-        for (int j = 0; j < QK8_0; j++) {
-            const float v = x[i*QK8_0 + j];
-            amax = MAX(amax, fabsf(v));
-        }
-
-        const float d = amax / ((1 << 7) - 1);
-        const float id = d ? 1.0f/d : 0.0f;
-
-        y[i].d = GGML_FP32_TO_FP16(d);
-
-        for (int j = 0; j < QK8_0; ++j) {
-            const float x0 = x[i*QK8_0 + j]*id;
-
-            y[i].qs[j] = roundf(x0);
-        }
-    }
-}
-
-static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
-    assert(QK8_0 == 32);
-    assert(k % QK8_0 == 0);
-    const int nb = k / QK8_0;
-
-    block_q8_0 * restrict y = vy;
-
-#if defined(__ARM_NEON)
-    for (int i = 0; i < nb; i++) {
-        float32x4_t srcv [8];
-        float32x4_t asrcv[8];
-        float32x4_t amaxv[8];
-
-        for (int j = 0; j < 8; j++) srcv[j]  = vld1q_f32(x + i*32 + 4*j);
-        for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
-
-        for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
-        for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
-        for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
-
-        const float amax = vmaxvq_f32(amaxv[0]);
-
-        const float d = amax / ((1 << 7) - 1);
-        const float id = d ? 1.0f/d : 0.0f;
-
-        y[i].d = GGML_FP32_TO_FP16(d);
-
-        for (int j = 0; j < 8; j++) {
-            const float32x4_t v  = vmulq_n_f32(srcv[j], id);
-            const int32x4_t   vi = vcvtnq_s32_f32(v);
-
-            y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
-            y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
-            y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
-            y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
-        }
-    }
-#elif defined(__wasm_simd128__)
-    for (int i = 0; i < nb; i++) {
-        v128_t srcv [8];
-        v128_t asrcv[8];
-        v128_t amaxv[8];
-
-        for (int j = 0; j < 8; j++) srcv[j]  = wasm_v128_load(x + i*32 + 4*j);
-        for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);
-
-        for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);
-        for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);
-        for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);
-
-        const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),
-                                   wasm_f32x4_extract_lane(amaxv[0], 1)),
-                               MAX(wasm_f32x4_extract_lane(amaxv[0], 2),
-                                   wasm_f32x4_extract_lane(amaxv[0], 3)));
-
-        const float d = amax / ((1 << 7) - 1);
-        const float id = d ? 1.0f/d : 0.0f;
-
-        y[i].d = GGML_FP32_TO_FP16(d);
-
-        for (int j = 0; j < 8; j++) {
-            const v128_t v  = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));
-            const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);
-
-            y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);
-            y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);
-            y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);
-            y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);
-        }
-    }
-#elif defined(__AVX2__) || defined(__AVX__)
-    for (int i = 0; i < nb; i++) {
-        // Load elements into 4 AVX vectors
-        __m256 v0 = _mm256_loadu_ps( x );
-        __m256 v1 = _mm256_loadu_ps( x + 8 );
-        __m256 v2 = _mm256_loadu_ps( x + 16 );
-        __m256 v3 = _mm256_loadu_ps( x + 24 );
-        x += 32;
-
-        // Compute max(abs(e)) for the block
-        const __m256 signBit = _mm256_set1_ps( -0.0f );
-        __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
-        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
-        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
-        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
-
-        __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
-        max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
-        max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
-        const float maxScalar = _mm_cvtss_f32( max4 );
-
-        // Quantize these floats
-        const float d = maxScalar / 127.f;
-        y[i].d = GGML_FP32_TO_FP16(d);
-        const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
-        const __m256 mul = _mm256_set1_ps( id );
-
-        // Apply the multiplier
-        v0 = _mm256_mul_ps( v0, mul );
-        v1 = _mm256_mul_ps( v1, mul );
-        v2 = _mm256_mul_ps( v2, mul );
-        v3 = _mm256_mul_ps( v3, mul );
-
-        // Round to nearest integer
-        v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
-        v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
-        v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
-        v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
-
-        // Convert floats to integers
-        __m256i i0 = _mm256_cvtps_epi32( v0 );
-        __m256i i1 = _mm256_cvtps_epi32( v1 );
-        __m256i i2 = _mm256_cvtps_epi32( v2 );
-        __m256i i3 = _mm256_cvtps_epi32( v3 );
-
-#if defined(__AVX2__)
-        // Convert int32 to int16
-        i0 = _mm256_packs_epi32( i0, i1 );     // 0, 1, 2, 3,  8, 9, 10, 11,  4, 5, 6, 7, 12, 13, 14, 15
-        i2 = _mm256_packs_epi32( i2, i3 );     // 16, 17, 18, 19,  24, 25, 26, 27,  20, 21, 22, 23, 28, 29, 30, 31
-                                            // Convert int16 to int8
-        i0 = _mm256_packs_epi16( i0, i2 );     // 0, 1, 2, 3,  8, 9, 10, 11,  16, 17, 18, 19,  24, 25, 26, 27,  4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
-
-        // We got our precious signed bytes, but the order is now wrong
-        // These AVX2 pack instructions process 16-byte pieces independently
-        // The following instruction is fixing the order
-        const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
-        i0 = _mm256_permutevar8x32_epi32( i0, perm );
-
-        _mm256_storeu_si256((__m256i *)y[i].qs, i0);
-#else
-        // Since we don't have in AVX some necessary functions,
-        // we split the registers in half and call AVX2 analogs from SSE
-        __m128i ni0 = _mm256_castsi256_si128( i0 );
-        __m128i ni1 = _mm256_extractf128_si256( i0, 1);
-        __m128i ni2 = _mm256_castsi256_si128( i1 );
-        __m128i ni3 = _mm256_extractf128_si256( i1, 1);
-        __m128i ni4 = _mm256_castsi256_si128( i2 );
-        __m128i ni5 = _mm256_extractf128_si256( i2, 1);
-        __m128i ni6 = _mm256_castsi256_si128( i3 );
-        __m128i ni7 = _mm256_extractf128_si256( i3, 1);
-
-        // Convert int32 to int16
-        ni0 = _mm_packs_epi32( ni0, ni1 );
-        ni2 = _mm_packs_epi32( ni2, ni3 );
-        ni4 = _mm_packs_epi32( ni4, ni5 );
-        ni6 = _mm_packs_epi32( ni6, ni7 );
-        // Convert int16 to int8
-        ni0 = _mm_packs_epi16( ni0, ni2 );
-        ni4 = _mm_packs_epi16( ni4, ni6 );
-
-        _mm_storeu_si128((__m128i *)(y[i].qs +  0), ni0);
-        _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
-#endif
-    }
-#elif defined(__riscv_v_intrinsic)
-
-    size_t vl = __riscv_vsetvl_e32m4(QK8_0);
-
-    for (int i = 0; i < nb; i++) {
-        // load elements
-        vfloat32m4_t v_x   = __riscv_vle32_v_f32m4(x+i*QK8_0, vl);
-
-        vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
-        vfloat32m1_t tmp   = __riscv_vfmv_v_f_f32m1(0.0f, vl);
-        vfloat32m1_t vmax  = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
-        float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
-
-        const float d = amax / ((1 << 7) - 1);
-        const float id = d ? 1.0f/d : 0.0f;
-
-        y[i].d = GGML_FP32_TO_FP16(d);
-
-        vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
-
-        // convert to integer
-        vint16m2_t   vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
-        vint8m1_t    vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
-
-        // store result
-        __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
-    }
-#else
-    // scalar
-    quantize_row_q8_0_reference(x, y, k);
-#endif
-}
-
-// reference implementation for deterministic creation of model files
-static void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k) {
-    assert(QK8_1 == 32);
-    assert(k % QK8_1 == 0);
-    const int nb = k / QK8_1;
-
-    for (int i = 0; i < nb; i++) {
-        float amax = 0.0f; // absolute max
-
-        for (int j = 0; j < QK8_1; j++) {
-            const float v = x[i*QK8_1 + j];
-            amax = MAX(amax, fabsf(v));
-        }
-
-        const float d = amax / ((1 << 7) - 1);
-        const float id = d ? 1.0f/d : 0.0f;
-
-        y[i].d = d;
-
-        int sum = 0;
-
-        for (int j = 0; j < QK8_1/2; ++j) {
-            const float v0 = x[i*QK8_1           + j]*id;
-            const float v1 = x[i*QK8_1 + QK8_1/2 + j]*id;
-
-            y[i].qs[          j] = roundf(v0);
-            y[i].qs[QK8_1/2 + j] = roundf(v1);
-
-            sum += y[i].qs[          j];
-            sum += y[i].qs[QK8_1/2 + j];
-        }
-
-        y[i].s = sum*d;
-    }
-}
-
-static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) {
-    assert(k % QK8_1 == 0);
-    const int nb = k / QK8_1;
-
-    block_q8_1 * restrict y = vy;
-
-#if defined(__ARM_NEON)
-    for (int i = 0; i < nb; i++) {
-        float32x4_t srcv [8];
-        float32x4_t asrcv[8];
-        float32x4_t amaxv[8];
-
-        for (int j = 0; j < 8; j++) srcv[j]  = vld1q_f32(x + i*32 + 4*j);
-        for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
-
-        for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
-        for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
-        for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
-
-        const float amax = vmaxvq_f32(amaxv[0]);
-
-        const float d = amax / ((1 << 7) - 1);
-        const float id = d ? 1.0f/d : 0.0f;
-
-        y[i].d = d;
-
-        int32x4_t accv = vdupq_n_s32(0);
-
-        for (int j = 0; j < 8; j++) {
-            const float32x4_t v  = vmulq_n_f32(srcv[j], id);
-            const int32x4_t   vi = vcvtnq_s32_f32(v);
-
-            y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
-            y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
-            y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
-            y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
-
-            accv = vaddq_s32(accv, vi);
-        }
-
-        y[i].s = d * vaddvq_s32(accv);
-    }
-#elif defined(__wasm_simd128__)
-    for (int i = 0; i < nb; i++) {
-        v128_t srcv [8];
-        v128_t asrcv[8];
-        v128_t amaxv[8];
-
-        for (int j = 0; j < 8; j++) srcv[j]  = wasm_v128_load(x + i*32 + 4*j);
-        for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);
-
-        for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);
-        for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);
-        for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);
-
-        const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),
-                                   wasm_f32x4_extract_lane(amaxv[0], 1)),
-                               MAX(wasm_f32x4_extract_lane(amaxv[0], 2),
-                                   wasm_f32x4_extract_lane(amaxv[0], 3)));
-
-        const float d = amax / ((1 << 7) - 1);
-        const float id = d ? 1.0f/d : 0.0f;
-
-        y[i].d = d;
-
-        v128_t accv = wasm_i32x4_splat(0);
-
-        for (int j = 0; j < 8; j++) {
-            const v128_t v  = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));
-            const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);
-
-            y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);
-            y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);
-            y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);
-            y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);
-
-            accv = wasm_i32x4_add(accv, vi);
-        }
-
-        y[i].s = d * (wasm_i32x4_extract_lane(accv, 0) +
-                      wasm_i32x4_extract_lane(accv, 1) +
-                      wasm_i32x4_extract_lane(accv, 2) +
-                      wasm_i32x4_extract_lane(accv, 3));
-    }
-#elif defined(__AVX2__) || defined(__AVX__)
-    for (int i = 0; i < nb; i++) {
-        // Load elements into 4 AVX vectors
-        __m256 v0 = _mm256_loadu_ps( x );
-        __m256 v1 = _mm256_loadu_ps( x + 8 );
-        __m256 v2 = _mm256_loadu_ps( x + 16 );
-        __m256 v3 = _mm256_loadu_ps( x + 24 );
-        x += 32;
-
-        // Compute max(abs(e)) for the block
-        const __m256 signBit = _mm256_set1_ps( -0.0f );
-        __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
-        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
-        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
-        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
-
-        __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
-        max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
-        max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
-        const float maxScalar = _mm_cvtss_f32( max4 );
-
-        // Quantize these floats
-        const float d = maxScalar / 127.f;
-        y[i].d = d;
-        const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
-        const __m256 mul = _mm256_set1_ps( id );
-
-        // Apply the multiplier
-        v0 = _mm256_mul_ps( v0, mul );
-        v1 = _mm256_mul_ps( v1, mul );
-        v2 = _mm256_mul_ps( v2, mul );
-        v3 = _mm256_mul_ps( v3, mul );
-
-        // Round to nearest integer
-        v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
-        v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
-        v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
-        v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
-
-        // Convert floats to integers
-        __m256i i0 = _mm256_cvtps_epi32( v0 );
-        __m256i i1 = _mm256_cvtps_epi32( v1 );
-        __m256i i2 = _mm256_cvtps_epi32( v2 );
-        __m256i i3 = _mm256_cvtps_epi32( v3 );
-
-#if defined(__AVX2__)
-        // Compute the sum of the quants and set y[i].s
-        y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
-
-        // Convert int32 to int16
-        i0 = _mm256_packs_epi32( i0, i1 );     // 0, 1, 2, 3,  8, 9, 10, 11,  4, 5, 6, 7, 12, 13, 14, 15
-        i2 = _mm256_packs_epi32( i2, i3 );     // 16, 17, 18, 19,  24, 25, 26, 27,  20, 21, 22, 23, 28, 29, 30, 31
-                                            // Convert int16 to int8
-        i0 = _mm256_packs_epi16( i0, i2 );     // 0, 1, 2, 3,  8, 9, 10, 11,  16, 17, 18, 19,  24, 25, 26, 27,  4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
-
-        // We got our precious signed bytes, but the order is now wrong
-        // These AVX2 pack instructions process 16-byte pieces independently
-        // The following instruction is fixing the order
-        const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
-        i0 = _mm256_permutevar8x32_epi32( i0, perm );
-
-        _mm256_storeu_si256((__m256i *)y[i].qs, i0);
-#else
-        // Since we don't have in AVX some necessary functions,
-        // we split the registers in half and call AVX2 analogs from SSE
-        __m128i ni0 = _mm256_castsi256_si128( i0 );
-        __m128i ni1 = _mm256_extractf128_si256( i0, 1);
-        __m128i ni2 = _mm256_castsi256_si128( i1 );
-        __m128i ni3 = _mm256_extractf128_si256( i1, 1);
-        __m128i ni4 = _mm256_castsi256_si128( i2 );
-        __m128i ni5 = _mm256_extractf128_si256( i2, 1);
-        __m128i ni6 = _mm256_castsi256_si128( i3 );
-        __m128i ni7 = _mm256_extractf128_si256( i3, 1);
-
-        // Compute the sum of the quants and set y[i].s
-        const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3));
-        const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7));
-        y[i].s = d * hsum_i32_4(_mm_add_epi32(s0, s1));
-
-        // Convert int32 to int16
-        ni0 = _mm_packs_epi32( ni0, ni1 );
-        ni2 = _mm_packs_epi32( ni2, ni3 );
-        ni4 = _mm_packs_epi32( ni4, ni5 );
-        ni6 = _mm_packs_epi32( ni6, ni7 );
-        // Convert int16 to int8
-        ni0 = _mm_packs_epi16( ni0, ni2 );
-        ni4 = _mm_packs_epi16( ni4, ni6 );
-
-        _mm_storeu_si128((__m128i *)(y[i].qs +  0), ni0);
-        _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
-#endif
-    }
-#elif defined(__riscv_v_intrinsic)
-
-    size_t vl = __riscv_vsetvl_e32m4(QK8_1);
-
-    for (int i = 0; i < nb; i++) {
-        // load elements
-        vfloat32m4_t v_x   = __riscv_vle32_v_f32m4(x+i*QK8_1, vl);
-
-        vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
-        vfloat32m1_t tmp   = __riscv_vfmv_v_f_f32m1(0.0, vl);
-        vfloat32m1_t vmax  = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
-        float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
-
-        const float d  = amax / ((1 << 7) - 1);
-        const float id = d ? 1.0f/d : 0.0f;
-
-        y[i].d = d;
-
-        vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
-
-        // convert to integer
-        vint16m2_t   vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
-        vint8m1_t    vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
-
-        // store result
-        __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
-
-        // compute sum for y[i].s
-        vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl);
-        vint16m1_t vwrs = __riscv_vwredsum_vs_i8m1_i16m1(vs, tmp2, vl);
-
-        // set y[i].s
-        int sum = __riscv_vmv_x_s_i16m1_i16(vwrs);
-        y[i].s = sum*d;
-    }
-#else
-    // scalar
-    quantize_row_q8_1_reference(x, y, k);
-#endif
-}
-
-static void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k) {
-    static const int qk = QK4_0;
-
-    assert(k % qk == 0);
-
-    const int nb = k / qk;
-
-    for (int i = 0; i < nb; i++) {
-        const float d = GGML_FP16_TO_FP32(x[i].d);
-
-        for (int j = 0; j < qk/2; ++j) {
-            const int x0 = (x[i].qs[j] & 0x0F) - 8;
-            const int x1 = (x[i].qs[j] >>   4) - 8;
-
-            y[i*qk + j + 0   ] = x0*d;
-            y[i*qk + j + qk/2] = x1*d;
-        }
-    }
-}
-
-static void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int k) {
-    static const int qk = QK4_1;
-
-    assert(k % qk == 0);
-
-    const int nb = k / qk;
-
-    for (int i = 0; i < nb; i++) {
-        const float d = GGML_FP16_TO_FP32(x[i].d);
-        const float m = GGML_FP16_TO_FP32(x[i].m);
-
-        for (int j = 0; j < qk/2; ++j) {
-            const int x0 = (x[i].qs[j] & 0x0F);
-            const int x1 = (x[i].qs[j] >>   4);
-
-            y[i*qk + j + 0   ] = x0*d + m;
-            y[i*qk + j + qk/2] = x1*d + m;
-        }
-    }
-}
-
-static void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int k) {
-    static const int qk = QK5_0;
-
-    assert(k % qk == 0);
-
-    const int nb = k / qk;
-
-    for (int i = 0; i < nb; i++) {
-        const float d = GGML_FP16_TO_FP32(x[i].d);
-
-        uint32_t qh;
-        memcpy(&qh, x[i].qh, sizeof(qh));
-
-        for (int j = 0; j < qk/2; ++j) {
-            const uint8_t xh_0 = ((qh >> (j +  0)) << 4) & 0x10;
-            const uint8_t xh_1 = ((qh >> (j + 12))     ) & 0x10;
-
-            const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16;
-            const int32_t x1 = ((x[i].qs[j] >>   4) | xh_1) - 16;
-
-            y[i*qk + j + 0   ] = x0*d;
-            y[i*qk + j + qk/2] = x1*d;
-        }
-    }
-}
-
-static void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int k) {
-    static const int qk = QK5_1;
-
-    assert(k % qk == 0);
-
-    const int nb = k / qk;
-
-    for (int i = 0; i < nb; i++) {
-        const float d = GGML_FP16_TO_FP32(x[i].d);
-        const float m = GGML_FP16_TO_FP32(x[i].m);
-
-        uint32_t qh;
-        memcpy(&qh, x[i].qh, sizeof(qh));
-
-        for (int j = 0; j < qk/2; ++j) {
-            const uint8_t xh_0 = ((qh >> (j +  0)) << 4) & 0x10;
-            const uint8_t xh_1 = ((qh >> (j + 12))     ) & 0x10;
-
-            const int x0 = (x[i].qs[j] & 0x0F) | xh_0;
-            const int x1 = (x[i].qs[j] >>   4) | xh_1;
-
-            y[i*qk + j + 0   ] = x0*d + m;
-            y[i*qk + j + qk/2] = x1*d + m;
-        }
-    }
-}
-
-static void dequantize_row_q8_0(const void * restrict vx, float * restrict y, int k) {
-    static const int qk = QK8_0;
-
-    assert(k % qk == 0);
-
-    const int nb = k / qk;
-
-    const block_q8_0 * restrict x = vx;
-
-    for (int i = 0; i < nb; i++) {
-        const float d = GGML_FP16_TO_FP32(x[i].d);
-
-        for (int j = 0; j < qk; ++j) {
-            y[i*qk + j] = x[i].qs[j]*d;
-        }
-    }
-}
-
-static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y);
-static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y);
-static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
-static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
-static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
-static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
-static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
-
-static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
-    [GGML_TYPE_I8] = {
-        .type_name                = "i8",
-        .blck_size                = 1,
-        .type_size                = sizeof(int8_t),
-        .is_quantized             = false,
-    },
-    [GGML_TYPE_I16] = {
-        .type_name                = "i16",
-        .blck_size                = 1,
-        .type_size                = sizeof(int16_t),
-        .is_quantized             = false,
-    },
-    [GGML_TYPE_I32] = {
-        .type_name                = "i32",
-        .blck_size                = 1,
-        .type_size                = sizeof(int32_t),
-        .is_quantized             = false,
-    },
-    [GGML_TYPE_F32] = {
-        .type_name                = "f32",
-        .blck_size                = 1,
-        .type_size                = sizeof(float),
-        .is_quantized             = false,
-        .vec_dot                  = (ggml_vec_dot_t) ggml_vec_dot_f32,
-        .vec_dot_type             = GGML_TYPE_F32,
-    },
-    [GGML_TYPE_F16] = {
-        .type_name                = "f16",
-        .blck_size                = 1,
-        .type_size                = sizeof(ggml_fp16_t),
-        .is_quantized             = false,
-        .to_float                 = (ggml_to_float_t) ggml_fp16_to_fp32_row,
-        .from_float               = (ggml_from_float_t) ggml_fp32_to_fp16_row,
-        .from_float_reference     = (ggml_from_float_t) ggml_fp32_to_fp16_row,
-        .vec_dot                  = (ggml_vec_dot_t) ggml_vec_dot_f16,
-        .vec_dot_type             = GGML_TYPE_F16,
-    },
-    [GGML_TYPE_Q4_0] = {
-        .type_name                = "q4_0",
-        .blck_size                = QK4_0,
-        .type_size                = sizeof(block_q4_0),
-        .is_quantized             = true,
-        .to_float                 = (ggml_to_float_t) dequantize_row_q4_0,
-        .from_float               = quantize_row_q4_0,
-        .from_float_reference     = (ggml_from_float_t) quantize_row_q4_0_reference,
-        .vec_dot                  = ggml_vec_dot_q4_0_q8_0,
-        .vec_dot_type             = GGML_TYPE_Q8_0,
-    },
-    [GGML_TYPE_Q4_1] = {
-        .type_name                = "q4_1",
-        .blck_size                = QK4_1,
-        .type_size                = sizeof(block_q4_1),
-        .is_quantized             = true,
-        .to_float                 = (ggml_to_float_t) dequantize_row_q4_1,
-        .from_float               = quantize_row_q4_1,
-        .from_float_reference     = (ggml_from_float_t) quantize_row_q4_1_reference,
-        .vec_dot                  = ggml_vec_dot_q4_1_q8_1,
-        .vec_dot_type             = GGML_TYPE_Q8_1,
-    },
-    [GGML_TYPE_Q5_0] = {
-        .type_name                = "q5_0",
-        .blck_size                = QK5_0,
-        .type_size                = sizeof(block_q5_0),
-        .is_quantized             = true,
-        .to_float                 = (ggml_to_float_t) dequantize_row_q5_0,
-        .from_float               = quantize_row_q5_0,
-        .from_float_reference     = (ggml_from_float_t) quantize_row_q5_0_reference,
-        .vec_dot                  = ggml_vec_dot_q5_0_q8_0,
-        .vec_dot_type             = GGML_TYPE_Q8_0,
-    },
-    [GGML_TYPE_Q5_1] = {
-        .type_name                = "q5_1",
-        .blck_size                = QK5_1,
-        .type_size                = sizeof(block_q5_1),
-        .is_quantized             = true,
-        .to_float                 = (ggml_to_float_t) dequantize_row_q5_1,
-        .from_float               = quantize_row_q5_1,
-        .from_float_reference     = (ggml_from_float_t) quantize_row_q5_1_reference,
-        .vec_dot                  = ggml_vec_dot_q5_1_q8_1,
-        .vec_dot_type             = GGML_TYPE_Q8_1,
-    },
-    [GGML_TYPE_Q8_0] = {
-        .type_name                = "q8_0",
-        .blck_size                = QK8_0,
-        .type_size                = sizeof(block_q8_0),
-        .is_quantized             = true,
-        .to_float                 = dequantize_row_q8_0,
-        .from_float               = quantize_row_q8_0,
-        .from_float_reference     = (ggml_from_float_t) quantize_row_q8_0_reference,
-        .vec_dot                  = ggml_vec_dot_q8_0_q8_0,
-        .vec_dot_type             = GGML_TYPE_Q8_0,
-    },
-    [GGML_TYPE_Q8_1] = {
-        .type_name                = "q8_1",
-        .blck_size                = QK8_1,
-        .type_size                = sizeof(block_q8_1),
-        .is_quantized             = true,
-        .from_float               = quantize_row_q8_1,
-        .from_float_reference     = (ggml_from_float_t) quantize_row_q8_1_reference,
-        .vec_dot_type             = GGML_TYPE_Q8_1,
-    },
-#ifdef GGML_USE_K_QUANTS
-    [GGML_TYPE_Q2_K] = {
-        .type_name                = "q2_K",
-        .blck_size                = QK_K,
-        .type_size                = sizeof(block_q2_K),
-        .is_quantized             = true,
-        .to_float                 = (ggml_to_float_t) dequantize_row_q2_K,
-        .from_float               = quantize_row_q2_K,
-        .from_float_reference     = (ggml_from_float_t) quantize_row_q2_K_reference,
-        .vec_dot                  = ggml_vec_dot_q2_K_q8_K,
-        .vec_dot_type             = GGML_TYPE_Q8_K,
-    },
-    [GGML_TYPE_Q3_K] = {
-        .type_name                = "q3_K",
-        .blck_size                = QK_K,
-        .type_size                = sizeof(block_q3_K),
-        .is_quantized             = true,
-        .to_float                 = (ggml_to_float_t) dequantize_row_q3_K,
-        .from_float               = quantize_row_q3_K,
-        .from_float_reference     = (ggml_from_float_t) quantize_row_q3_K_reference,
-        .vec_dot                  = ggml_vec_dot_q3_K_q8_K,
-        .vec_dot_type             = GGML_TYPE_Q8_K,
-    },
-    [GGML_TYPE_Q4_K] = {
-        .type_name                = "q4_K",
-        .blck_size                = QK_K,
-        .type_size                = sizeof(block_q4_K),
-        .is_quantized             = true,
-        .to_float                 = (ggml_to_float_t) dequantize_row_q4_K,
-        .from_float               = quantize_row_q4_K,
-        .from_float_reference     = (ggml_from_float_t) quantize_row_q4_K_reference,
-        .vec_dot                  = ggml_vec_dot_q4_K_q8_K,
-        .vec_dot_type             = GGML_TYPE_Q8_K,
-    },
-    [GGML_TYPE_Q5_K] = {
-        .type_name                = "q5_K",
-        .blck_size                = QK_K,
-        .type_size                = sizeof(block_q5_K),
-        .is_quantized             = true,
-        .to_float                 = (ggml_to_float_t) dequantize_row_q5_K,
-        .from_float               = quantize_row_q5_K,
-        .from_float_reference     = (ggml_from_float_t) quantize_row_q5_K_reference,
-        .vec_dot                  = ggml_vec_dot_q5_K_q8_K,
-        .vec_dot_type             = GGML_TYPE_Q8_K,
-    },
-    [GGML_TYPE_Q6_K] = {
-        .type_name                = "q6_K",
-        .blck_size                = QK_K,
-        .type_size                = sizeof(block_q6_K),
-        .is_quantized             = true,
-        .to_float                 = (ggml_to_float_t) dequantize_row_q6_K,
-        .from_float               = quantize_row_q6_K,
-        .from_float_reference     = (ggml_from_float_t) quantize_row_q6_K_reference,
-        .vec_dot                  = ggml_vec_dot_q6_K_q8_K,
-        .vec_dot_type             = GGML_TYPE_Q8_K,
-    },
-    [GGML_TYPE_Q8_K] = {
-        .type_name                = "q8_K",
-        .blck_size                = QK_K,
-        .type_size                = sizeof(block_q8_K),
-        .is_quantized             = true,
-        .from_float               = quantize_row_q8_K,
-    }
-#endif
-};
-
-// For internal test use
-ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
-    GGML_ASSERT(type < GGML_TYPE_COUNT);
-    return type_traits[type];
-}
-
-//
-// simd mappings
-//
-
-// we define a common set of C macros which map to specific intrinsics based on the current architecture
-// we then implement the fundamental computation operations below using only these macros
-// adding support for new architectures requires to define the corresponding SIMD macros
-//
-// GGML_F32_STEP / GGML_F16_STEP
-//   number of elements to process in a single step
-//
-// GGML_F32_EPR / GGML_F16_EPR
-//   number of elements to fit in a single register
-//
-
-#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA)
-
-#define GGML_SIMD
-
-// F32 NEON
-
-#define GGML_F32_STEP 16
-#define GGML_F32_EPR  4
-
-#define GGML_F32x4              float32x4_t
-#define GGML_F32x4_ZERO         vdupq_n_f32(0.0f)
-#define GGML_F32x4_SET1(x)      vdupq_n_f32(x)
-#define GGML_F32x4_LOAD         vld1q_f32
-#define GGML_F32x4_STORE        vst1q_f32
-#define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c)
-#define GGML_F32x4_ADD          vaddq_f32
-#define GGML_F32x4_MUL          vmulq_f32
-#define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
-#define GGML_F32x4_REDUCE(res, x)              \
-{                                              \
-    int offset = GGML_F32_ARR >> 1;            \
-    for (int i = 0; i < offset; ++i) {         \
-        x[i] = vaddq_f32(x[i], x[offset+i]);   \
-    }                                          \
-    offset >>= 1;                              \
-    for (int i = 0; i < offset; ++i) {         \
-        x[i] = vaddq_f32(x[i], x[offset+i]);   \
-    }                                          \
-    offset >>= 1;                              \
-    for (int i = 0; i < offset; ++i) {         \
-        x[i] = vaddq_f32(x[i], x[offset+i]);   \
-    }                                          \
-    res = GGML_F32x4_REDUCE_ONE(x[0]);         \
-}
-
-#define GGML_F32_VEC        GGML_F32x4
-#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
-#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
-#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
-#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
-#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
-#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
-#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
-#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
-
-// F16 NEON
-
-#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
-    #define GGML_F16_STEP 32
-    #define GGML_F16_EPR  8
-
-    #define GGML_F16x8              float16x8_t
-    #define GGML_F16x8_ZERO         vdupq_n_f16(0.0f)
-    #define GGML_F16x8_SET1(x)      vdupq_n_f16(x)
-    #define GGML_F16x8_LOAD         vld1q_f16
-    #define GGML_F16x8_STORE        vst1q_f16
-    #define GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c)
-    #define GGML_F16x8_ADD          vaddq_f16
-    #define GGML_F16x8_MUL          vmulq_f16
-    #define GGML_F16x8_REDUCE(res, x)                             \
-    do {                                                          \
-        int offset = GGML_F16_ARR >> 1;                           \
-        for (int i = 0; i < offset; ++i) {                        \
-            x[i] = vaddq_f16(x[i], x[offset+i]);                  \
-        }                                                         \
-        offset >>= 1;                                             \
-        for (int i = 0; i < offset; ++i) {                        \
-            x[i] = vaddq_f16(x[i], x[offset+i]);                  \
-        }                                                         \
-        offset >>= 1;                                             \
-        for (int i = 0; i < offset; ++i) {                        \
-            x[i] = vaddq_f16(x[i], x[offset+i]);                  \
-        }                                                         \
-        const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \
-        const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \
-        res = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1));         \
-    } while (0)
-
-    #define GGML_F16_VEC                GGML_F16x8
-    #define GGML_F16_VEC_ZERO           GGML_F16x8_ZERO
-    #define GGML_F16_VEC_SET1           GGML_F16x8_SET1
-    #define GGML_F16_VEC_LOAD(p, i)     GGML_F16x8_LOAD(p)
-    #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE(p, r[i])
-    #define GGML_F16_VEC_FMA            GGML_F16x8_FMA
-    #define GGML_F16_VEC_ADD            GGML_F16x8_ADD
-    #define GGML_F16_VEC_MUL            GGML_F16x8_MUL
-    #define GGML_F16_VEC_REDUCE         GGML_F16x8_REDUCE
-#else
-    // if FP16 vector arithmetic is not supported, we use FP32 instead
-    // and take advantage of the vcvt_ functions to convert to/from FP16
-
-    #define GGML_F16_STEP 16
-    #define GGML_F16_EPR  4
-
-    #define GGML_F32Cx4              float32x4_t
-    #define GGML_F32Cx4_ZERO         vdupq_n_f32(0.0f)
-    #define GGML_F32Cx4_SET1(x)      vdupq_n_f32(x)
-    #define GGML_F32Cx4_LOAD(x)      vcvt_f32_f16(vld1_f16(x))
-    #define GGML_F32Cx4_STORE(x, y)  vst1_f16(x, vcvt_f16_f32(y))
-    #define GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c)
-    #define GGML_F32Cx4_ADD          vaddq_f32
-    #define GGML_F32Cx4_MUL          vmulq_f32
-    #define GGML_F32Cx4_REDUCE       GGML_F32x4_REDUCE
-
-    #define GGML_F16_VEC                GGML_F32Cx4
-    #define GGML_F16_VEC_ZERO           GGML_F32Cx4_ZERO
-    #define GGML_F16_VEC_SET1           GGML_F32Cx4_SET1
-    #define GGML_F16_VEC_LOAD(p, i)     GGML_F32Cx4_LOAD(p)
-    #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
-    #define GGML_F16_VEC_FMA            GGML_F32Cx4_FMA
-    #define GGML_F16_VEC_ADD            GGML_F32Cx4_ADD
-    #define GGML_F16_VEC_MUL            GGML_F32Cx4_MUL
-    #define GGML_F16_VEC_REDUCE         GGML_F32Cx4_REDUCE
-#endif
-
-#elif defined(__AVX__)
-
-#define GGML_SIMD
-
-// F32 AVX
-
-#define GGML_F32_STEP 32
-#define GGML_F32_EPR  8
-
-#define GGML_F32x8         __m256
-#define GGML_F32x8_ZERO    _mm256_setzero_ps()
-#define GGML_F32x8_SET1(x) _mm256_set1_ps(x)
-#define GGML_F32x8_LOAD    _mm256_loadu_ps
-#define GGML_F32x8_STORE   _mm256_storeu_ps
-#if defined(__FMA__)
-    #define GGML_F32x8_FMA(a, b, c) _mm256_fmadd_ps(b, c, a)
-#else
-    #define GGML_F32x8_FMA(a, b, c) _mm256_add_ps(_mm256_mul_ps(b, c), a)
-#endif
-#define GGML_F32x8_ADD     _mm256_add_ps
-#define GGML_F32x8_MUL     _mm256_mul_ps
-#define GGML_F32x8_REDUCE(res, x)                                 \
-do {                                                              \
-    int offset = GGML_F32_ARR >> 1;                               \
-    for (int i = 0; i < offset; ++i) {                            \
-        x[i] = _mm256_add_ps(x[i], x[offset+i]);                  \
-    }                                                             \
-    offset >>= 1;                                                 \
-    for (int i = 0; i < offset; ++i) {                            \
-        x[i] = _mm256_add_ps(x[i], x[offset+i]);                  \
-    }                                                             \
-    offset >>= 1;                                                 \
-    for (int i = 0; i < offset; ++i) {                            \
-        x[i] = _mm256_add_ps(x[i], x[offset+i]);                  \
-    }                                                             \
-    const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]),    \
-                                 _mm256_extractf128_ps(x[0], 1)); \
-    const __m128 t1 = _mm_hadd_ps(t0, t0);                        \
-    res = _mm_cvtss_f32(_mm_hadd_ps(t1, t1));                     \
-} while (0)
-// TODO: is this optimal ?
-
-#define GGML_F32_VEC        GGML_F32x8
-#define GGML_F32_VEC_ZERO   GGML_F32x8_ZERO
-#define GGML_F32_VEC_SET1   GGML_F32x8_SET1
-#define GGML_F32_VEC_LOAD   GGML_F32x8_LOAD
-#define GGML_F32_VEC_STORE  GGML_F32x8_STORE
-#define GGML_F32_VEC_FMA    GGML_F32x8_FMA
-#define GGML_F32_VEC_ADD    GGML_F32x8_ADD
-#define GGML_F32_VEC_MUL    GGML_F32x8_MUL
-#define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE
-
-// F16 AVX
-
-#define GGML_F16_STEP 32
-#define GGML_F16_EPR  8
-
-// F16 arithmetic is not supported by AVX, so we use F32 instead
-
-#define GGML_F32Cx8             __m256
-#define GGML_F32Cx8_ZERO        _mm256_setzero_ps()
-#define GGML_F32Cx8_SET1(x)     _mm256_set1_ps(x)
-
-#if defined(__F16C__)
-// the  _mm256_cvt intrinsics require F16C
-#define GGML_F32Cx8_LOAD(x)     _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x)))
-#define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
-#else
-static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
-    float tmp[8];
-
-    for (int i = 0; i < 8; i++) {
-        tmp[i] = GGML_FP16_TO_FP32(x[i]);
-    }
-
-    return _mm256_loadu_ps(tmp);
-}
-static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
-    float arr[8];
-
-    _mm256_storeu_ps(arr, y);
-
-    for (int i = 0; i < 8; i++)
-        x[i] = GGML_FP32_TO_FP16(arr[i]);
-}
-#define GGML_F32Cx8_LOAD(x)     __avx_f32cx8_load(x)
-#define GGML_F32Cx8_STORE(x, y) __avx_f32cx8_store(x, y)
-#endif
-
-#define GGML_F32Cx8_FMA         GGML_F32x8_FMA
-#define GGML_F32Cx8_ADD         _mm256_add_ps
-#define GGML_F32Cx8_MUL         _mm256_mul_ps
-#define GGML_F32Cx8_REDUCE      GGML_F32x8_REDUCE
-
-#define GGML_F16_VEC                GGML_F32Cx8
-#define GGML_F16_VEC_ZERO           GGML_F32Cx8_ZERO
-#define GGML_F16_VEC_SET1           GGML_F32Cx8_SET1
-#define GGML_F16_VEC_LOAD(p, i)     GGML_F32Cx8_LOAD(p)
-#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i])
-#define GGML_F16_VEC_FMA            GGML_F32Cx8_FMA
-#define GGML_F16_VEC_ADD            GGML_F32Cx8_ADD
-#define GGML_F16_VEC_MUL            GGML_F32Cx8_MUL
-#define GGML_F16_VEC_REDUCE         GGML_F32Cx8_REDUCE
-
-#elif defined(__POWER9_VECTOR__)
-
-#define GGML_SIMD
-
-// F32 POWER9
-
-#define GGML_F32_STEP 32
-#define GGML_F32_EPR  4
-
-#define GGML_F32x4              vector float
-#define GGML_F32x4_ZERO         0.0f
-#define GGML_F32x4_SET1         vec_splats
-#define GGML_F32x4_LOAD(p)      vec_xl(0, p)
-#define GGML_F32x4_STORE(p, r)  vec_xst(r, 0, p)
-#define GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a)
-#define GGML_F32x4_ADD          vec_add
-#define GGML_F32x4_MUL          vec_mul
-#define GGML_F32x4_REDUCE(res, x)              \
-{                                              \
-    int offset = GGML_F32_ARR >> 1;            \
-    for (int i = 0; i < offset; ++i) {         \
-        x[i] = vec_add(x[i], x[offset+i]);     \
-    }                                          \
-    offset >>= 1;                              \
-    for (int i = 0; i < offset; ++i) {         \
-        x[i] = vec_add(x[i], x[offset+i]);     \
-    }                                          \
-    offset >>= 1;                              \
-    for (int i = 0; i < offset; ++i) {         \
-        x[i] = vec_add(x[i], x[offset+i]);     \
-    }                                          \
-    res = vec_extract(x[0], 0) +               \
-          vec_extract(x[0], 1) +               \
-          vec_extract(x[0], 2) +               \
-          vec_extract(x[0], 3);                \
-}
-
-#define GGML_F32_VEC        GGML_F32x4
-#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
-#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
-#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
-#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
-#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
-#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
-#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
-#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
-
-// F16 POWER9
-#define GGML_F16_STEP       GGML_F32_STEP
-#define GGML_F16_EPR        GGML_F32_EPR
-#define GGML_F16_VEC        GGML_F32x4
-#define GGML_F16_VEC_ZERO   GGML_F32x4_ZERO
-#define GGML_F16_VEC_SET1   GGML_F32x4_SET1
-#define GGML_F16_VEC_FMA    GGML_F32x4_FMA
-#define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE
-// Use vec_xl, not vec_ld, in case the load address is not aligned.
-#define GGML_F16_VEC_LOAD(p, i) (i & 0x1) ?                   \
-  vec_extract_fp32_from_shorth(vec_xl(0, p - GGML_F16_EPR)) : \
-  vec_extract_fp32_from_shortl(vec_xl(0, p))
-#define GGML_ENDIAN_BYTE(i) ((unsigned char *)&(uint16_t){1})[i]
-#define GGML_F16_VEC_STORE(p, r, i)                             \
-  if (i & 0x1)                                                  \
-    vec_xst(vec_pack_to_short_fp32(r[i - GGML_ENDIAN_BYTE(1)],  \
-                                   r[i - GGML_ENDIAN_BYTE(0)]), \
-            0, p - GGML_F16_EPR)
-
-#elif defined(__wasm_simd128__)
-
-#define GGML_SIMD
-
-// F32 WASM
-
-#define GGML_F32_STEP 16
-#define GGML_F32_EPR  4
-
-#define GGML_F32x4              v128_t
-#define GGML_F32x4_ZERO         wasm_f32x4_splat(0.0f)
-#define GGML_F32x4_SET1(x)      wasm_f32x4_splat(x)
-#define GGML_F32x4_LOAD         wasm_v128_load
-#define GGML_F32x4_STORE        wasm_v128_store
-#define GGML_F32x4_FMA(a, b, c) wasm_f32x4_add(wasm_f32x4_mul(b, c), a)
-#define GGML_F32x4_ADD          wasm_f32x4_add
-#define GGML_F32x4_MUL          wasm_f32x4_mul
-#define GGML_F32x4_REDUCE(res, x)                  \
-{                                                  \
-    int offset = GGML_F32_ARR >> 1;                \
-    for (int i = 0; i < offset; ++i) {             \
-        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
-    }                                              \
-    offset >>= 1;                                  \
-    for (int i = 0; i < offset; ++i) {             \
-        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
-    }                                              \
-    offset >>= 1;                                  \
-    for (int i = 0; i < offset; ++i) {             \
-        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
-    }                                              \
-    res = wasm_f32x4_extract_lane(x[0], 0) +       \
-          wasm_f32x4_extract_lane(x[0], 1) +       \
-          wasm_f32x4_extract_lane(x[0], 2) +       \
-          wasm_f32x4_extract_lane(x[0], 3);        \
-}
-
-#define GGML_F32_VEC        GGML_F32x4
-#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
-#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
-#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
-#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
-#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
-#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
-#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
-#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
-
-// F16 WASM
-
-#define GGML_F16_STEP 16
-#define GGML_F16_EPR  4
-
-inline static v128_t __wasm_f16x4_load(const ggml_fp16_t * p) {
-    float tmp[4];
-
-    tmp[0] = GGML_FP16_TO_FP32(p[0]);
-    tmp[1] = GGML_FP16_TO_FP32(p[1]);
-    tmp[2] = GGML_FP16_TO_FP32(p[2]);
-    tmp[3] = GGML_FP16_TO_FP32(p[3]);
-
-    return wasm_v128_load(tmp);
-}
-
-inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) {
-    float tmp[4];
-
-    wasm_v128_store(tmp, x);
-
-    p[0] = GGML_FP32_TO_FP16(tmp[0]);
-    p[1] = GGML_FP32_TO_FP16(tmp[1]);
-    p[2] = GGML_FP32_TO_FP16(tmp[2]);
-    p[3] = GGML_FP32_TO_FP16(tmp[3]);
-}
-
-#define GGML_F16x4             v128_t
-#define GGML_F16x4_ZERO        wasm_f32x4_splat(0.0f)
-#define GGML_F16x4_SET1(x)     wasm_f32x4_splat(x)
-#define GGML_F16x4_LOAD(x)     __wasm_f16x4_load(x)
-#define GGML_F16x4_STORE(x, y) __wasm_f16x4_store(x, y)
-#define GGML_F16x4_FMA         GGML_F32x4_FMA
-#define GGML_F16x4_ADD         wasm_f32x4_add
-#define GGML_F16x4_MUL         wasm_f32x4_mul
-#define GGML_F16x4_REDUCE(res, x)                  \
-{                                                  \
-    int offset = GGML_F16_ARR >> 1;                \
-    for (int i = 0; i < offset; ++i) {             \
-        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
-    }                                              \
-    offset >>= 1;                                  \
-    for (int i = 0; i < offset; ++i) {             \
-        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
-    }                                              \
-    offset >>= 1;                                  \
-    for (int i = 0; i < offset; ++i) {             \
-        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
-    }                                              \
-    res = wasm_f32x4_extract_lane(x[0], 0) +       \
-          wasm_f32x4_extract_lane(x[0], 1) +       \
-          wasm_f32x4_extract_lane(x[0], 2) +       \
-          wasm_f32x4_extract_lane(x[0], 3);        \
-}
-
-#define GGML_F16_VEC                GGML_F16x4
-#define GGML_F16_VEC_ZERO           GGML_F16x4_ZERO
-#define GGML_F16_VEC_SET1           GGML_F16x4_SET1
-#define GGML_F16_VEC_LOAD(p, i)     GGML_F16x4_LOAD(p)
-#define GGML_F16_VEC_STORE(p, r, i) GGML_F16x4_STORE(p, r[i])
-#define GGML_F16_VEC_FMA            GGML_F16x4_FMA
-#define GGML_F16_VEC_ADD            GGML_F16x4_ADD
-#define GGML_F16_VEC_MUL            GGML_F16x4_MUL
-#define GGML_F16_VEC_REDUCE         GGML_F16x4_REDUCE
-
-#elif defined(__SSE3__)
-
-#define GGML_SIMD
-
-// F32 SSE
-
-#define GGML_F32_STEP 32
-#define GGML_F32_EPR  4
-
-#define GGML_F32x4         __m128
-#define GGML_F32x4_ZERO    _mm_setzero_ps()
-#define GGML_F32x4_SET1(x) _mm_set1_ps(x)
-#define GGML_F32x4_LOAD    _mm_loadu_ps
-#define GGML_F32x4_STORE   _mm_storeu_ps
-#if defined(__FMA__)
-    // TODO: Does this work?
-    #define GGML_F32x4_FMA(a, b, c) _mm_fmadd_ps(b, c, a)
-#else
-    #define GGML_F32x4_FMA(a, b, c) _mm_add_ps(_mm_mul_ps(b, c), a)
-#endif
-#define GGML_F32x4_ADD     _mm_add_ps
-#define GGML_F32x4_MUL     _mm_mul_ps
-#define GGML_F32x4_REDUCE(res, x)                                 \
-{                                                                 \
-    int offset = GGML_F32_ARR >> 1;                               \
-    for (int i = 0; i < offset; ++i) {                            \
-        x[i] = _mm_add_ps(x[i], x[offset+i]);                     \
-    }                                                             \
-    offset >>= 1;                                                 \
-    for (int i = 0; i < offset; ++i) {                            \
-        x[i] = _mm_add_ps(x[i], x[offset+i]);                     \
-    }                                                             \
-    offset >>= 1;                                                 \
-    for (int i = 0; i < offset; ++i) {                            \
-        x[i] = _mm_add_ps(x[i], x[offset+i]);                     \
-    }                                                             \
-    const __m128 t0 = _mm_hadd_ps(x[0], x[0]);                    \
-    res = _mm_cvtss_f32(_mm_hadd_ps(t0, t0));                     \
-}
-// TODO: is this optimal ?
-
-#define GGML_F32_VEC        GGML_F32x4
-#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
-#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
-#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
-#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
-#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
-#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
-#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
-#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
-
-// F16 SSE
-
-#define GGML_F16_STEP 32
-#define GGML_F16_EPR  4
-
-static inline __m128 __sse_f16x4_load(ggml_fp16_t *x) {
-    float tmp[4];
-
-    tmp[0] = GGML_FP16_TO_FP32(x[0]);
-    tmp[1] = GGML_FP16_TO_FP32(x[1]);
-    tmp[2] = GGML_FP16_TO_FP32(x[2]);
-    tmp[3] = GGML_FP16_TO_FP32(x[3]);
-
-    return _mm_loadu_ps(tmp);
-}
-
-static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) {
-    float arr[4];
-
-    _mm_storeu_ps(arr, y);
-
-    x[0] = GGML_FP32_TO_FP16(arr[0]);
-    x[1] = GGML_FP32_TO_FP16(arr[1]);
-    x[2] = GGML_FP32_TO_FP16(arr[2]);
-    x[3] = GGML_FP32_TO_FP16(arr[3]);
-}
-
-#define GGML_F32Cx4             __m128
-#define GGML_F32Cx4_ZERO        _mm_setzero_ps()
-#define GGML_F32Cx4_SET1(x)     _mm_set1_ps(x)
-#define GGML_F32Cx4_LOAD(x)     __sse_f16x4_load(x)
-#define GGML_F32Cx4_STORE(x, y) __sse_f16x4_store(x, y)
-#define GGML_F32Cx4_FMA         GGML_F32x4_FMA
-#define GGML_F32Cx4_ADD         _mm_add_ps
-#define GGML_F32Cx4_MUL         _mm_mul_ps
-#define GGML_F32Cx4_REDUCE      GGML_F32x4_REDUCE
-
-#define GGML_F16_VEC                 GGML_F32Cx4
-#define GGML_F16_VEC_ZERO            GGML_F32Cx4_ZERO
-#define GGML_F16_VEC_SET1            GGML_F32Cx4_SET1
-#define GGML_F16_VEC_LOAD(p, i)      GGML_F32Cx4_LOAD(p)
-#define GGML_F16_VEC_STORE(p, r, i)  GGML_F32Cx4_STORE(p, r[i])
-#define GGML_F16_VEC_FMA             GGML_F32Cx4_FMA
-#define GGML_F16_VEC_ADD             GGML_F32Cx4_ADD
-#define GGML_F16_VEC_MUL             GGML_F32Cx4_MUL
-#define GGML_F16_VEC_REDUCE          GGML_F32Cx4_REDUCE
-
-#endif
-
-// GGML_F32_ARR / GGML_F16_ARR
-//   number of registers to use per step
-#ifdef GGML_SIMD
-#define GGML_F32_ARR (GGML_F32_STEP/GGML_F32_EPR)
-#define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR)
-#endif
-
-//
-// fundamental operations
-//
-
-inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
-
-inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
-
-inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
-
-inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
-
-inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i] + y[i]; }
-inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float   v) { for (int i = 0; i < n; ++i) z[i]  = x[i] + v;    }
-inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i] += x[i];        }
-inline static void ggml_vec_acc1_f32(const int n, float * y, const float   v)                  { for (int i = 0; i < n; ++i) y[i] += v;           }
-inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i] - y[i]; }
-inline static void ggml_vec_set_f32 (const int n, float * x, const float   v)                  { for (int i = 0; i < n; ++i) x[i]  = v;           }
-inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i]  = x[i];        }
-inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i]  = -x[i];       }
-inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i]*y[i];   }
-inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i]/y[i];   }
-
-static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) {
-#ifdef GGML_SIMD
-    float sumf = 0.0f;
-    const int np = (n & ~(GGML_F32_STEP - 1));
-
-    GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
-
-    GGML_F32_VEC ax[GGML_F32_ARR];
-    GGML_F32_VEC ay[GGML_F32_ARR];
-
-    for (int i = 0; i < np; i += GGML_F32_STEP) {
-        for (int j = 0; j < GGML_F32_ARR; j++) {
-            ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
-            ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
-
-            sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]);
-        }
-    }
-
-    // reduce sum0..sum3 to sum0
-    GGML_F32_VEC_REDUCE(sumf, sum);
-
-    // leftovers
-    for (int i = np; i < n; ++i) {
-        sumf += x[i]*y[i];
-    }
-#else
-    // scalar
-    ggml_float sumf = 0.0;
-    for (int i = 0; i < n; ++i) {
-        sumf += (ggml_float)(x[i]*y[i]);
-    }
-#endif
-
-    *s = sumf;
-}
-
-static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
-    ggml_float sumf = 0.0;
-
-#if defined(GGML_SIMD)
-    const int np = (n & ~(GGML_F16_STEP - 1));
-
-    GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
-
-    GGML_F16_VEC ax[GGML_F16_ARR];
-    GGML_F16_VEC ay[GGML_F16_ARR];
-
-    for (int i = 0; i < np; i += GGML_F16_STEP) {
-        for (int j = 0; j < GGML_F16_ARR; j++) {
-            ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
-            ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
-
-            sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]);
-        }
-    }
-
-    // reduce sum0..sum3 to sum0
-    GGML_F16_VEC_REDUCE(sumf, sum);
-
-    // leftovers
-    for (int i = np; i < n; ++i) {
-        sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
-    }
-#else
-    for (int i = 0; i < n; ++i) {
-        sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
-    }
-#endif
-
-    *s = sumf;
-}
-
-static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
-    const int qk = QK8_0;
-    const int nb = n / qk;
-
-    assert(n % qk == 0);
-
-    const block_q4_0 * restrict x = vx;
-    const block_q8_0 * restrict y = vy;
-
-#if defined(__ARM_NEON)
-    float32x4_t sumv0 = vdupq_n_f32(0.0f);
-    float32x4_t sumv1 = vdupq_n_f32(0.0f);
-
-    GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb
-    for (int i = 0; i < nb; i += 2) {
-        const block_q4_0 * restrict x0 = &x[i + 0];
-        const block_q4_0 * restrict x1 = &x[i + 1];
-        const block_q8_0 * restrict y0 = &y[i + 0];
-        const block_q8_0 * restrict y1 = &y[i + 1];
-
-        const uint8x16_t m4b = vdupq_n_u8(0x0F);
-        const int8x16_t  s8b = vdupq_n_s8(0x8);
-
-        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
-        const uint8x16_t v0_1 = vld1q_u8(x1->qs);
-
-        // 4-bit -> 8-bit
-        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
-        const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
-        const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
-        const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
-
-        // sub 8
-        const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
-        const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
-        const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
-        const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
-
-        // load y
-        const int8x16_t v1_0l = vld1q_s8(y0->qs);
-        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
-        const int8x16_t v1_1l = vld1q_s8(y1->qs);
-        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
-
-#if defined(__ARM_FEATURE_DOTPROD)
-        // dot product into int32x4_t
-        const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
-        const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
-
-        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
-#else
-        const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l));
-        const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
-        const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0h));
-        const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
-
-        const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1l));
-        const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l));
-        const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1h));
-        const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h));
-
-        const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
-        const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
-        const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
-        const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
-
-        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
-#endif
-    }
-
-    *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
-#elif defined(__AVX2__)
-    // Initialize accumulator with zeros
-    __m256 acc = _mm256_setzero_ps();
-
-    // Main loop
-    for (int i = 0; i < nb; ++i) {
-        /* Compute combined scale for the block */
-        const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
-
-        __m256i bx = bytes_from_nibbles_32(x[i].qs);
-
-        // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
-        const __m256i off = _mm256_set1_epi8( 8 );
-        bx = _mm256_sub_epi8( bx, off );
-
-        __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
-
-        const __m256 q = mul_sum_i8_pairs_float(bx, by);
-
-        /* Multiply q with scale and accumulate */
-        acc = _mm256_fmadd_ps( d, q, acc );
-    }
-
-    *s = hsum_float_8(acc);
-#elif defined(__AVX__)
-    // Initialize accumulator with zeros
-    __m256 acc = _mm256_setzero_ps();
-
-    // Main loop
-    for (int i = 0; i < nb; ++i) {
-        // Compute combined scale for the block
-        const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
-
-        const __m128i lowMask = _mm_set1_epi8(0xF);
-        const __m128i off = _mm_set1_epi8(8);
-
-        const __m128i tmp = _mm_loadu_si128((const __m128i *)x[i].qs);
-
-        __m128i bx = _mm_and_si128(lowMask, tmp);
-        __m128i by = _mm_loadu_si128((const __m128i *)y[i].qs);
-        bx = _mm_sub_epi8(bx, off);
-        const __m128i i32_0 = mul_sum_i8_pairs(bx, by);
-
-        bx = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4));
-        by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
-        bx = _mm_sub_epi8(bx, off);
-        const __m128i i32_1 = mul_sum_i8_pairs(bx, by);
-
-        // Convert int32_t to float
-        __m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1));
-
-        // Apply the scale, and accumulate
-        acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
-    }
-
-    *s = hsum_float_8(acc);
-#elif defined(__SSSE3__)
-    // set constants
-    const __m128i lowMask = _mm_set1_epi8(0xF);
-    const __m128i off = _mm_set1_epi8(8);
-
-    // Initialize accumulator with zeros
-    __m128 acc_0 = _mm_setzero_ps();
-    __m128 acc_1 = _mm_setzero_ps();
-    __m128 acc_2 = _mm_setzero_ps();
-    __m128 acc_3 = _mm_setzero_ps();
-
-    // First round without accumulation
-    {
-        _mm_prefetch(&x[0] + sizeof(block_q4_0), _MM_HINT_T0);
-        _mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0);
-
-        // Compute combined scale for the block 0 and 1
-        const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[0].d) * GGML_FP16_TO_FP32(y[0].d) );
-
-        const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[0].qs);
-
-        __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
-        __m128i by_0 = _mm_loadu_si128((const __m128i *)y[0].qs);
-        bx_0 = _mm_sub_epi8(bx_0, off);
-        const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
-
-        __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
-        __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[0].qs + 16));
-        bx_1 = _mm_sub_epi8(bx_1, off);
-        const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
-
-        _mm_prefetch(&x[1] + sizeof(block_q4_0), _MM_HINT_T0);
-        _mm_prefetch(&y[1] + sizeof(block_q8_0), _MM_HINT_T0);
-
-        // Compute combined scale for the block 2 and 3
-        const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[1].d) * GGML_FP16_TO_FP32(y[1].d) );
-
-        const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[1].qs);
-
-        __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
-        __m128i by_2 = _mm_loadu_si128((const __m128i *)y[1].qs);
-        bx_2 = _mm_sub_epi8(bx_2, off);
-        const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
-
-        __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
-        __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[1].qs + 16));
-        bx_3 = _mm_sub_epi8(bx_3, off);
-        const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
-
-        // Convert int32_t to float
-        __m128 p0 = _mm_cvtepi32_ps(i32_0);
-        __m128 p1 = _mm_cvtepi32_ps(i32_1);
-        __m128 p2 = _mm_cvtepi32_ps(i32_2);
-        __m128 p3 = _mm_cvtepi32_ps(i32_3);
-
-        // Apply the scale
-        acc_0 = _mm_mul_ps( d_0_1, p0 );
-        acc_1 = _mm_mul_ps( d_0_1, p1 );
-        acc_2 = _mm_mul_ps( d_2_3, p2 );
-        acc_3 = _mm_mul_ps( d_2_3, p3 );
-    }
-
-    // Main loop
-    GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb
-    for (int i = 2; i < nb; i+=2) {
-        _mm_prefetch(&x[i] + sizeof(block_q4_0), _MM_HINT_T0);
-        _mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0);
-
-        // Compute combined scale for the block 0 and 1
-        const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
-
-        const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[i].qs);
-
-        __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
-        __m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs);
-        bx_0 = _mm_sub_epi8(bx_0, off);
-        const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
-
-        __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
-        __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
-        bx_1 = _mm_sub_epi8(bx_1, off);
-        const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
-
-        _mm_prefetch(&x[i] + 2 * sizeof(block_q4_0), _MM_HINT_T0);
-        _mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
-
-        // Compute combined scale for the block 2 and 3
-        const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i + 1].d) * GGML_FP16_TO_FP32(y[i + 1].d) );
-
-        const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[i + 1].qs);
-
-        __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
-        __m128i by_2 = _mm_loadu_si128((const __m128i *)y[i + 1].qs);
-        bx_2 = _mm_sub_epi8(bx_2, off);
-        const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
-
-        __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
-        __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[i + 1].qs + 16));
-        bx_3 = _mm_sub_epi8(bx_3, off);
-        const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
-
-        // Convert int32_t to float
-        __m128 p0 = _mm_cvtepi32_ps(i32_0);
-        __m128 p1 = _mm_cvtepi32_ps(i32_1);
-        __m128 p2 = _mm_cvtepi32_ps(i32_2);
-        __m128 p3 = _mm_cvtepi32_ps(i32_3);
-
-        // Apply the scale
-        __m128 p0_d = _mm_mul_ps( d_0_1, p0 );
-        __m128 p1_d = _mm_mul_ps( d_0_1, p1 );
-        __m128 p2_d = _mm_mul_ps( d_2_3, p2 );
-        __m128 p3_d = _mm_mul_ps( d_2_3, p3 );
-
-        // Acummulate
-        acc_0 = _mm_add_ps(p0_d, acc_0);
-        acc_1 = _mm_add_ps(p1_d, acc_1);
-        acc_2 = _mm_add_ps(p2_d, acc_2);
-        acc_3 = _mm_add_ps(p3_d, acc_3);
-    }
-
-    *s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
-#elif defined(__riscv_v_intrinsic)
-    float sumf = 0.0;
-
-    size_t vl = __riscv_vsetvl_e8m1(qk/2);
-
-    for (int i = 0; i < nb; i++) {
-        // load elements
-        vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
-
-        vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
-        vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
-
-        // mask and store lower part of x, and then upper part
-        vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
-        vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
-
-        vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
-        vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
-
-        // subtract offset
-        vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 8, vl);
-        vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 8, vl);
-
-        vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
-        vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
-
-        vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
-
-        vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
-        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
-
-        int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
-
-        sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d);
-    }
+//
+// simd mappings
+//
 
-    *s = sumf;
-#else
-    // scalar
-    float sumf = 0.0;
+// we define a common set of C macros which map to specific intrinsics based on the current architecture
+// we then implement the fundamental computation operations below using only these macros
+// adding support for new architectures requires to define the corresponding SIMD macros
+//
+// GGML_F32_STEP / GGML_F16_STEP
+//   number of elements to process in a single step
+//
+// GGML_F32_EPR / GGML_F16_EPR
+//   number of elements to fit in a single register
+//
 
-    for (int i = 0; i < nb; i++) {
-        int sumi = 0;
+#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA)
 
-        for (int j = 0; j < qk/2; ++j) {
-            const int v0 = (x[i].qs[j] & 0x0F) - 8;
-            const int v1 = (x[i].qs[j] >>   4) - 8;
+#define GGML_SIMD
 
-            sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
-        }
+// F32 NEON
 
-        sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d);
-    }
+#define GGML_F32_STEP 16
+#define GGML_F32_EPR  4
 
-    *s = sumf;
-#endif
+#define GGML_F32x4              float32x4_t
+#define GGML_F32x4_ZERO         vdupq_n_f32(0.0f)
+#define GGML_F32x4_SET1(x)      vdupq_n_f32(x)
+#define GGML_F32x4_LOAD         vld1q_f32
+#define GGML_F32x4_STORE        vst1q_f32
+#define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c)
+#define GGML_F32x4_ADD          vaddq_f32
+#define GGML_F32x4_MUL          vmulq_f32
+#define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
+#define GGML_F32x4_REDUCE(res, x)              \
+{                                              \
+    int offset = GGML_F32_ARR >> 1;            \
+    for (int i = 0; i < offset; ++i) {         \
+        x[i] = vaddq_f32(x[i], x[offset+i]);   \
+    }                                          \
+    offset >>= 1;                              \
+    for (int i = 0; i < offset; ++i) {         \
+        x[i] = vaddq_f32(x[i], x[offset+i]);   \
+    }                                          \
+    offset >>= 1;                              \
+    for (int i = 0; i < offset; ++i) {         \
+        x[i] = vaddq_f32(x[i], x[offset+i]);   \
+    }                                          \
+    res = GGML_F32x4_REDUCE_ONE(x[0]);         \
 }
 
-static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
-    const int qk = QK8_1;
-    const int nb = n / qk;
-
-    assert(n % qk == 0);
-
-    const block_q4_1 * restrict x = vx;
-    const block_q8_1 * restrict y = vy;
-
-    // TODO: add WASM SIMD
-#if defined(__ARM_NEON)
-    float32x4_t sumv0 = vdupq_n_f32(0.0f);
-    float32x4_t sumv1 = vdupq_n_f32(0.0f);
-
-    float summs = 0;
-
-    GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb
-    for (int i = 0; i < nb; i += 2) {
-        const block_q4_1 * restrict x0 = &x[i + 0];
-        const block_q4_1 * restrict x1 = &x[i + 1];
-        const block_q8_1 * restrict y0 = &y[i + 0];
-        const block_q8_1 * restrict y1 = &y[i + 1];
-
-        summs += GGML_FP16_TO_FP32(x0->m) * y0->s + GGML_FP16_TO_FP32(x1->m) * y1->s;
-
-        const uint8x16_t m4b = vdupq_n_u8(0x0F);
-
-        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
-        const uint8x16_t v0_1 = vld1q_u8(x1->qs);
+#define GGML_F32_VEC        GGML_F32x4
+#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
+#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
+#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
+#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
+#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
+#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
+#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
 
-        // 4-bit -> 8-bit
-        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
-        const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
-        const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
-        const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+// F16 NEON
 
-        // load y
-        const int8x16_t v1_0l = vld1q_s8(y0->qs);
-        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
-        const int8x16_t v1_1l = vld1q_s8(y1->qs);
-        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
+    #define GGML_F16_STEP 32
+    #define GGML_F16_EPR  8
 
-#if defined(__ARM_FEATURE_DOTPROD)
-        // dot product into int32x4_t
-        const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
-        const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
+    #define GGML_F16x8              float16x8_t
+    #define GGML_F16x8_ZERO         vdupq_n_f16(0.0f)
+    #define GGML_F16x8_SET1(x)      vdupq_n_f16(x)
+    #define GGML_F16x8_LOAD         vld1q_f16
+    #define GGML_F16x8_STORE        vst1q_f16
+    #define GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c)
+    #define GGML_F16x8_ADD          vaddq_f16
+    #define GGML_F16x8_MUL          vmulq_f16
+    #define GGML_F16x8_REDUCE(res, x)                             \
+    do {                                                          \
+        int offset = GGML_F16_ARR >> 1;                           \
+        for (int i = 0; i < offset; ++i) {                        \
+            x[i] = vaddq_f16(x[i], x[offset+i]);                  \
+        }                                                         \
+        offset >>= 1;                                             \
+        for (int i = 0; i < offset; ++i) {                        \
+            x[i] = vaddq_f16(x[i], x[offset+i]);                  \
+        }                                                         \
+        offset >>= 1;                                             \
+        for (int i = 0; i < offset; ++i) {                        \
+            x[i] = vaddq_f16(x[i], x[offset+i]);                  \
+        }                                                         \
+        const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \
+        const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \
+        res = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1));         \
+    } while (0)
 
-        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*y0->d);
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*y1->d);
+    #define GGML_F16_VEC                GGML_F16x8
+    #define GGML_F16_VEC_ZERO           GGML_F16x8_ZERO
+    #define GGML_F16_VEC_SET1           GGML_F16x8_SET1
+    #define GGML_F16_VEC_LOAD(p, i)     GGML_F16x8_LOAD(p)
+    #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE(p, r[i])
+    #define GGML_F16_VEC_FMA            GGML_F16x8_FMA
+    #define GGML_F16_VEC_ADD            GGML_F16x8_ADD
+    #define GGML_F16_VEC_MUL            GGML_F16x8_MUL
+    #define GGML_F16_VEC_REDUCE         GGML_F16x8_REDUCE
 #else
-        const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0l));
-        const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0l));
-        const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0h));
-        const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0h));
-
-        const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1l));
-        const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1l));
-        const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1h));
-        const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1h));
-
-        const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
-        const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
-        const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
-        const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
-
-        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d);
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d);
-#endif
-    }
-
-    *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
-#elif defined(__AVX2__) || defined(__AVX__)
-    // Initialize accumulator with zeros
-    __m256 acc = _mm256_setzero_ps();
-
-    float summs = 0;
-
-    // Main loop
-    for (int i = 0; i < nb; ++i) {
-        const float d0 = GGML_FP16_TO_FP32(x[i].d);
-        const float d1 = y[i].d;
-
-        summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
-
-        const __m256 d0v = _mm256_set1_ps( d0 );
-        const __m256 d1v = _mm256_set1_ps( d1 );
-
-        // Compute combined scales
-        const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
+    // if FP16 vector arithmetic is not supported, we use FP32 instead
+    // and take advantage of the vcvt_ functions to convert to/from FP16
 
-        // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
-        const __m256i bx = bytes_from_nibbles_32(x[i].qs);
-        const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
+    #define GGML_F16_STEP 16
+    #define GGML_F16_EPR  4
 
-        const __m256 xy = mul_sum_us8_pairs_float(bx, by);
+    #define GGML_F32Cx4              float32x4_t
+    #define GGML_F32Cx4_ZERO         vdupq_n_f32(0.0f)
+    #define GGML_F32Cx4_SET1(x)      vdupq_n_f32(x)
+    #define GGML_F32Cx4_LOAD(x)      vcvt_f32_f16(vld1_f16(x))
+    #define GGML_F32Cx4_STORE(x, y)  vst1_f16(x, vcvt_f16_f32(y))
+    #define GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c)
+    #define GGML_F32Cx4_ADD          vaddq_f32
+    #define GGML_F32Cx4_MUL          vmulq_f32
+    #define GGML_F32Cx4_REDUCE       GGML_F32x4_REDUCE
 
-        // Accumulate d0*d1*x*y
-#if defined(__AVX2__)
-        acc = _mm256_fmadd_ps( d0d1, xy, acc );
-#else
-        acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc );
+    #define GGML_F16_VEC                GGML_F32Cx4
+    #define GGML_F16_VEC_ZERO           GGML_F32Cx4_ZERO
+    #define GGML_F16_VEC_SET1           GGML_F32Cx4_SET1
+    #define GGML_F16_VEC_LOAD(p, i)     GGML_F32Cx4_LOAD(p)
+    #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
+    #define GGML_F16_VEC_FMA            GGML_F32Cx4_FMA
+    #define GGML_F16_VEC_ADD            GGML_F32Cx4_ADD
+    #define GGML_F16_VEC_MUL            GGML_F32Cx4_MUL
+    #define GGML_F16_VEC_REDUCE         GGML_F32Cx4_REDUCE
 #endif
-    }
-
-    *s = hsum_float_8(acc) + summs;
-#elif defined(__riscv_v_intrinsic)
-    float sumf = 0.0;
-
-    size_t vl = __riscv_vsetvl_e8m1(qk/2);
-
-    for (int i = 0; i < nb; i++) {
-        // load elements
-        vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
-
-        vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
-        vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
-
-        // mask and store lower part of x, and then upper part
-        vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
-        vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
-
-        vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
-        vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
-
-        vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
-        vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
-
-        vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
-
-        vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
-        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
-
-        int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
-
-        sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
-    }
-
-    *s = sumf;
-#else
-    // scalar
-    float sumf = 0.0;
-
-    for (int i = 0; i < nb; i++) {
-        int sumi = 0;
-
-        for (int j = 0; j < qk/2; ++j) {
-            const int v0 = (x[i].qs[j] & 0x0F);
-            const int v1 = (x[i].qs[j] >>   4);
-
-            sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
-        }
-
-        sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
-    }
 
-    *s = sumf;
-#endif
-}
+#elif defined(__AVX__)
 
-static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
-    const int qk = QK8_0;
-    const int nb = n / qk;
+#define GGML_SIMD
 
-    assert(n % qk == 0);
-    assert(qk == QK5_0);
+// F32 AVX
 
-    const block_q5_0 * restrict x = vx;
-    const block_q8_0 * restrict y = vy;
+#define GGML_F32_STEP 32
+#define GGML_F32_EPR  8
 
-#if defined(__ARM_NEON)
-    float32x4_t sumv0 = vdupq_n_f32(0.0f);
-    float32x4_t sumv1 = vdupq_n_f32(0.0f);
-
-    uint32_t qh0;
-    uint32_t qh1;
-
-    uint64_t tmp0[4];
-    uint64_t tmp1[4];
-
-    GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb
-    for (int i = 0; i < nb; i += 2) {
-        const block_q5_0 * restrict x0 = &x[i];
-        const block_q5_0 * restrict x1 = &x[i + 1];
-        const block_q8_0 * restrict y0 = &y[i];
-        const block_q8_0 * restrict y1 = &y[i + 1];
-
-        const uint8x16_t m4b = vdupq_n_u8(0x0F);
-
-        // extract the 5th bit via lookup table ((!b) << 4)
-        memcpy(&qh0, x0->qh, sizeof(qh0));
-        memcpy(&qh1, x1->qh, sizeof(qh1));
-
-        tmp0[0] = table_b2b_1[(qh0 >>  0) & 0xFF];
-        tmp0[1] = table_b2b_1[(qh0 >>  8) & 0xFF];
-        tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF];
-        tmp0[3] = table_b2b_1[(qh0 >> 24)       ];
-
-        tmp1[0] = table_b2b_1[(qh1 >>  0) & 0xFF];
-        tmp1[1] = table_b2b_1[(qh1 >>  8) & 0xFF];
-        tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF];
-        tmp1[3] = table_b2b_1[(qh1 >> 24)       ];
-
-        const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0));
-        const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2));
-        const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0));
-        const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2));
-
-        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
-        const uint8x16_t v0_1 = vld1q_u8(x1->qs);
-
-        // 4-bit -> 8-bit
-        int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
-        int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
-        int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
-        int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
-
-        // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero)
-        const int8x16_t v0_0lf = vsubq_s8(v0_0l, qhl0);
-        const int8x16_t v0_0hf = vsubq_s8(v0_0h, qhh0);
-        const int8x16_t v0_1lf = vsubq_s8(v0_1l, qhl1);
-        const int8x16_t v0_1hf = vsubq_s8(v0_1h, qhh1);
-
-        // load y
-        const int8x16_t v1_0l = vld1q_s8(y0->qs);
-        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
-        const int8x16_t v1_1l = vld1q_s8(y1->qs);
-        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
-
-#if defined(__ARM_FEATURE_DOTPROD)
-        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
-                        vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
-                        vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
-                        vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
-                        vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+#define GGML_F32x8         __m256
+#define GGML_F32x8_ZERO    _mm256_setzero_ps()
+#define GGML_F32x8_SET1(x) _mm256_set1_ps(x)
+#define GGML_F32x8_LOAD    _mm256_loadu_ps
+#define GGML_F32x8_STORE   _mm256_storeu_ps
+#if defined(__FMA__)
+    #define GGML_F32x8_FMA(a, b, c) _mm256_fmadd_ps(b, c, a)
 #else
-        const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
-        const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
-        const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h));
-        const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h));
-
-        const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l));
-        const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l));
-        const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h));
-        const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h));
-
-        const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
-        const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
-        const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
-        const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
-
-        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+    #define GGML_F32x8_FMA(a, b, c) _mm256_add_ps(_mm256_mul_ps(b, c), a)
 #endif
-    }
-
-    *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
-#elif defined(__wasm_simd128__)
-    v128_t sumv = wasm_f32x4_splat(0.0f);
-
-    uint32_t qh;
-    uint64_t tmp[4];
-
-    // TODO: check if unrolling this is better
-    for (int i = 0; i < nb; ++i) {
-        const block_q5_0 * restrict x0 = &x[i];
-        const block_q8_0 * restrict y0 = &y[i];
-
-        const v128_t m4b  = wasm_i8x16_splat(0x0F);
-
-        // extract the 5th bit
-        memcpy(&qh, x0->qh, sizeof(qh));
-
-        tmp[0] = table_b2b_1[(qh >>  0) & 0xFF];
-        tmp[1] = table_b2b_1[(qh >>  8) & 0xFF];
-        tmp[2] = table_b2b_1[(qh >> 16) & 0xFF];
-        tmp[3] = table_b2b_1[(qh >> 24)       ];
-
-        const v128_t qhl = wasm_v128_load(tmp + 0);
-        const v128_t qhh = wasm_v128_load(tmp + 2);
-
-        const v128_t v0 = wasm_v128_load(x0->qs);
-
-        // 4-bit -> 8-bit
-        const v128_t v0l = wasm_v128_and (v0, m4b);
-        const v128_t v0h = wasm_u8x16_shr(v0, 4);
-
-        // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero)
-        const v128_t v0lf = wasm_i8x16_sub(v0l, qhl);
-        const v128_t v0hf = wasm_i8x16_sub(v0h, qhh);
-
-        // load y
-        const v128_t v1l = wasm_v128_load(y0->qs);
-        const v128_t v1h = wasm_v128_load(y0->qs + 16);
-
-        // int8x16 -> int16x8
-        const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
-        const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
-        const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
-        const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
-
-        const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
-        const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
-        const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
-        const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
-
-        // dot product
-        sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(
-                        wasm_i32x4_add(
-                            wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
-                                           wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
-                            wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
-                                           wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
-                    wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d))));
-    }
-
-    *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
-         wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
-#elif defined(__AVX2__)
-    // Initialize accumulator with zeros
-    __m256 acc = _mm256_setzero_ps();
-
-    // Main loop
-    for (int i = 0; i < nb; i++) {
-        /* Compute combined scale for the block */
-        const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
-
-        __m256i bx = bytes_from_nibbles_32(x[i].qs);
-        __m256i bxhi = bytes_from_bits_32(x[i].qh);
-        bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0));
-        bx = _mm256_or_si256(bx, bxhi);
-
-        __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
-
-        const __m256 q = mul_sum_i8_pairs_float(bx, by);
+#define GGML_F32x8_ADD     _mm256_add_ps
+#define GGML_F32x8_MUL     _mm256_mul_ps
+#define GGML_F32x8_REDUCE(res, x)                                 \
+do {                                                              \
+    int offset = GGML_F32_ARR >> 1;                               \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = _mm256_add_ps(x[i], x[offset+i]);                  \
+    }                                                             \
+    offset >>= 1;                                                 \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = _mm256_add_ps(x[i], x[offset+i]);                  \
+    }                                                             \
+    offset >>= 1;                                                 \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = _mm256_add_ps(x[i], x[offset+i]);                  \
+    }                                                             \
+    const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]),    \
+                                 _mm256_extractf128_ps(x[0], 1)); \
+    const __m128 t1 = _mm_hadd_ps(t0, t0);                        \
+    res = _mm_cvtss_f32(_mm_hadd_ps(t1, t1));                     \
+} while (0)
+// TODO: is this optimal ?
 
-        /* Multiply q with scale and accumulate */
-        acc = _mm256_fmadd_ps(d, q, acc);
-    }
+#define GGML_F32_VEC        GGML_F32x8
+#define GGML_F32_VEC_ZERO   GGML_F32x8_ZERO
+#define GGML_F32_VEC_SET1   GGML_F32x8_SET1
+#define GGML_F32_VEC_LOAD   GGML_F32x8_LOAD
+#define GGML_F32_VEC_STORE  GGML_F32x8_STORE
+#define GGML_F32_VEC_FMA    GGML_F32x8_FMA
+#define GGML_F32_VEC_ADD    GGML_F32x8_ADD
+#define GGML_F32_VEC_MUL    GGML_F32x8_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE
 
-    *s = hsum_float_8(acc);
-#elif defined(__AVX__)
-    // Initialize accumulator with zeros
-    __m256 acc = _mm256_setzero_ps();
-    __m128i mask = _mm_set1_epi8((char)0xF0);
+// F16 AVX
 
-    // Main loop
-    for (int i = 0; i < nb; i++) {
-        /* Compute combined scale for the block */
-        const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
+#define GGML_F16_STEP 32
+#define GGML_F16_EPR  8
 
-        __m256i bx = bytes_from_nibbles_32(x[i].qs);
-        const __m256i bxhi = bytes_from_bits_32(x[i].qh);
-        __m128i bxhil = _mm256_castsi256_si128(bxhi);
-        __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
-        bxhil = _mm_andnot_si128(bxhil, mask);
-        bxhih = _mm_andnot_si128(bxhih, mask);
-        __m128i bxl = _mm256_castsi256_si128(bx);
-        __m128i bxh = _mm256_extractf128_si256(bx, 1);
-        bxl = _mm_or_si128(bxl, bxhil);
-        bxh = _mm_or_si128(bxh, bxhih);
-        bx = MM256_SET_M128I(bxh, bxl);
+// F16 arithmetic is not supported by AVX, so we use F32 instead
 
-        const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
+#define GGML_F32Cx8             __m256
+#define GGML_F32Cx8_ZERO        _mm256_setzero_ps()
+#define GGML_F32Cx8_SET1(x)     _mm256_set1_ps(x)
 
-        const __m256 q = mul_sum_i8_pairs_float(bx, by);
+#if defined(__F16C__)
+// the  _mm256_cvt intrinsics require F16C
+#define GGML_F32Cx8_LOAD(x)     _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x)))
+#define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
+#else
+static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
+    float tmp[8];
 
-        /* Multiply q with scale and accumulate */
-        acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc);
+    for (int i = 0; i < 8; i++) {
+        tmp[i] = GGML_FP16_TO_FP32(x[i]);
     }
 
-    *s = hsum_float_8(acc);
-#elif defined(__riscv_v_intrinsic)
-    float sumf = 0.0;
+    return _mm256_loadu_ps(tmp);
+}
+static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
+    float arr[8];
 
-    uint32_t qh;
+    _mm256_storeu_ps(arr, y);
 
-    size_t vl = __riscv_vsetvl_e8m1(qk/2);
+    for (int i = 0; i < 8; i++)
+        x[i] = GGML_FP32_TO_FP16(arr[i]);
+}
+#define GGML_F32Cx8_LOAD(x)     __avx_f32cx8_load(x)
+#define GGML_F32Cx8_STORE(x, y) __avx_f32cx8_store(x, y)
+#endif
 
-    // These tempory registers are for masking and shift operations
-    vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
-    vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl);
+#define GGML_F32Cx8_FMA         GGML_F32x8_FMA
+#define GGML_F32Cx8_ADD         _mm256_add_ps
+#define GGML_F32Cx8_MUL         _mm256_mul_ps
+#define GGML_F32Cx8_REDUCE      GGML_F32x8_REDUCE
 
-    vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl);
-    vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
+#define GGML_F16_VEC                GGML_F32Cx8
+#define GGML_F16_VEC_ZERO           GGML_F32Cx8_ZERO
+#define GGML_F16_VEC_SET1           GGML_F32Cx8_SET1
+#define GGML_F16_VEC_LOAD(p, i)     GGML_F32Cx8_LOAD(p)
+#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i])
+#define GGML_F16_VEC_FMA            GGML_F32Cx8_FMA
+#define GGML_F16_VEC_ADD            GGML_F32Cx8_ADD
+#define GGML_F16_VEC_MUL            GGML_F32Cx8_MUL
+#define GGML_F16_VEC_REDUCE         GGML_F32Cx8_REDUCE
 
-    for (int i = 0; i < nb; i++) {
-        memcpy(&qh, x[i].qh, sizeof(uint32_t));
+#elif defined(__POWER9_VECTOR__)
 
-        // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
-        vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(vt_2, qh, vl);
-        vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(xha_0, vt_1, vl);
-        vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
+#define GGML_SIMD
 
-        // ((qh & (1u << (j + 16))) >> (j + 12));
-        vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(vt_3, qh, vl);
-        vuint32m2_t xhl_1 = __riscv_vsrl_vv_u32m2(xha_1, vt_4, vl);
+// F32 POWER9
 
-        // narrowing
-        vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xhl_0, vl);
-        vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
+#define GGML_F32_STEP 32
+#define GGML_F32_EPR  4
 
-        vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xhl_1, vl);
-        vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
+#define GGML_F32x4              vector float
+#define GGML_F32x4_ZERO         0.0f
+#define GGML_F32x4_SET1         vec_splats
+#define GGML_F32x4_LOAD(p)      vec_xl(0, p)
+#define GGML_F32x4_STORE(p, r)  vec_xst(r, 0, p)
+#define GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a)
+#define GGML_F32x4_ADD          vec_add
+#define GGML_F32x4_MUL          vec_mul
+#define GGML_F32x4_REDUCE(res, x)              \
+{                                              \
+    int offset = GGML_F32_ARR >> 1;            \
+    for (int i = 0; i < offset; ++i) {         \
+        x[i] = vec_add(x[i], x[offset+i]);     \
+    }                                          \
+    offset >>= 1;                              \
+    for (int i = 0; i < offset; ++i) {         \
+        x[i] = vec_add(x[i], x[offset+i]);     \
+    }                                          \
+    offset >>= 1;                              \
+    for (int i = 0; i < offset; ++i) {         \
+        x[i] = vec_add(x[i], x[offset+i]);     \
+    }                                          \
+    res = vec_extract(x[0], 0) +               \
+          vec_extract(x[0], 1) +               \
+          vec_extract(x[0], 2) +               \
+          vec_extract(x[0], 3);                \
+}
 
-        // load
-        vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
+#define GGML_F32_VEC        GGML_F32x4
+#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
+#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
+#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
+#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
+#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
+#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
+#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
 
-        vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
-        vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
+// F16 POWER9
+#define GGML_F16_STEP       GGML_F32_STEP
+#define GGML_F16_EPR        GGML_F32_EPR
+#define GGML_F16_VEC        GGML_F32x4
+#define GGML_F16_VEC_ZERO   GGML_F32x4_ZERO
+#define GGML_F16_VEC_SET1   GGML_F32x4_SET1
+#define GGML_F16_VEC_FMA    GGML_F32x4_FMA
+#define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE
+// Use vec_xl, not vec_ld, in case the load address is not aligned.
+#define GGML_F16_VEC_LOAD(p, i) (i & 0x1) ?                   \
+  vec_extract_fp32_from_shorth(vec_xl(0, p - GGML_F16_EPR)) : \
+  vec_extract_fp32_from_shortl(vec_xl(0, p))
+#define GGML_ENDIAN_BYTE(i) ((unsigned char *)&(uint16_t){1})[i]
+#define GGML_F16_VEC_STORE(p, r, i)                             \
+  if (i & 0x1)                                                  \
+    vec_xst(vec_pack_to_short_fp32(r[i - GGML_ENDIAN_BYTE(1)],  \
+                                   r[i - GGML_ENDIAN_BYTE(0)]), \
+            0, p - GGML_F16_EPR)
 
-        vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
-        vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
+#elif defined(__wasm_simd128__)
 
-        vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
-        vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
+#define GGML_SIMD
 
-        vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
-        vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
+// F32 WASM
 
-        vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 16, vl);
-        vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 16, vl);
+#define GGML_F32_STEP 16
+#define GGML_F32_EPR  4
 
-        vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
-        vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
+#define GGML_F32x4              v128_t
+#define GGML_F32x4_ZERO         wasm_f32x4_splat(0.0f)
+#define GGML_F32x4_SET1(x)      wasm_f32x4_splat(x)
+#define GGML_F32x4_LOAD         wasm_v128_load
+#define GGML_F32x4_STORE        wasm_v128_store
+#define GGML_F32x4_FMA(a, b, c) wasm_f32x4_add(wasm_f32x4_mul(b, c), a)
+#define GGML_F32x4_ADD          wasm_f32x4_add
+#define GGML_F32x4_MUL          wasm_f32x4_mul
+#define GGML_F32x4_REDUCE(res, x)                  \
+{                                                  \
+    int offset = GGML_F32_ARR >> 1;                \
+    for (int i = 0; i < offset; ++i) {             \
+        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
+    }                                              \
+    offset >>= 1;                                  \
+    for (int i = 0; i < offset; ++i) {             \
+        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
+    }                                              \
+    offset >>= 1;                                  \
+    for (int i = 0; i < offset; ++i) {             \
+        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
+    }                                              \
+    res = wasm_f32x4_extract_lane(x[0], 0) +       \
+          wasm_f32x4_extract_lane(x[0], 1) +       \
+          wasm_f32x4_extract_lane(x[0], 2) +       \
+          wasm_f32x4_extract_lane(x[0], 3);        \
+}
 
-        vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
+#define GGML_F32_VEC        GGML_F32x4
+#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
+#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
+#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
+#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
+#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
+#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
+#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
 
-        vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
-        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
+// F16 WASM
 
-        int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
+#define GGML_F16_STEP 16
+#define GGML_F16_EPR  4
 
-        sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi;
-    }
+inline static v128_t __wasm_f16x4_load(const ggml_fp16_t * p) {
+    float tmp[4];
 
-    *s = sumf;
-#else
-    // scalar
-    float sumf = 0.0;
+    tmp[0] = GGML_FP16_TO_FP32(p[0]);
+    tmp[1] = GGML_FP16_TO_FP32(p[1]);
+    tmp[2] = GGML_FP16_TO_FP32(p[2]);
+    tmp[3] = GGML_FP16_TO_FP32(p[3]);
 
-    for (int i = 0; i < nb; i++) {
-        uint32_t qh;
-        memcpy(&qh, x[i].qh, sizeof(qh));
+    return wasm_v128_load(tmp);
+}
 
-        int sumi = 0;
+inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) {
+    float tmp[4];
 
-        for (int j = 0; j < qk/2; ++j) {
-            const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
-            const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
+    wasm_v128_store(tmp, x);
 
-            const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16;
-            const int32_t x1 = ((x[i].qs[j] >>   4) | xh_1) - 16;
+    p[0] = GGML_FP32_TO_FP16(tmp[0]);
+    p[1] = GGML_FP32_TO_FP16(tmp[1]);
+    p[2] = GGML_FP32_TO_FP16(tmp[2]);
+    p[3] = GGML_FP32_TO_FP16(tmp[3]);
+}
 
-            sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]);
-        }
+#define GGML_F16x4             v128_t
+#define GGML_F16x4_ZERO        wasm_f32x4_splat(0.0f)
+#define GGML_F16x4_SET1(x)     wasm_f32x4_splat(x)
+#define GGML_F16x4_LOAD(x)     __wasm_f16x4_load(x)
+#define GGML_F16x4_STORE(x, y) __wasm_f16x4_store(x, y)
+#define GGML_F16x4_FMA         GGML_F32x4_FMA
+#define GGML_F16x4_ADD         wasm_f32x4_add
+#define GGML_F16x4_MUL         wasm_f32x4_mul
+#define GGML_F16x4_REDUCE(res, x)                  \
+{                                                  \
+    int offset = GGML_F16_ARR >> 1;                \
+    for (int i = 0; i < offset; ++i) {             \
+        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
+    }                                              \
+    offset >>= 1;                                  \
+    for (int i = 0; i < offset; ++i) {             \
+        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
+    }                                              \
+    offset >>= 1;                                  \
+    for (int i = 0; i < offset; ++i) {             \
+        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
+    }                                              \
+    res = wasm_f32x4_extract_lane(x[0], 0) +       \
+          wasm_f32x4_extract_lane(x[0], 1) +       \
+          wasm_f32x4_extract_lane(x[0], 2) +       \
+          wasm_f32x4_extract_lane(x[0], 3);        \
+}
 
-        sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi;
-    }
+#define GGML_F16_VEC                GGML_F16x4
+#define GGML_F16_VEC_ZERO           GGML_F16x4_ZERO
+#define GGML_F16_VEC_SET1           GGML_F16x4_SET1
+#define GGML_F16_VEC_LOAD(p, i)     GGML_F16x4_LOAD(p)
+#define GGML_F16_VEC_STORE(p, r, i) GGML_F16x4_STORE(p, r[i])
+#define GGML_F16_VEC_FMA            GGML_F16x4_FMA
+#define GGML_F16_VEC_ADD            GGML_F16x4_ADD
+#define GGML_F16_VEC_MUL            GGML_F16x4_MUL
+#define GGML_F16_VEC_REDUCE         GGML_F16x4_REDUCE
 
-    *s = sumf;
-#endif
-}
+#elif defined(__SSE3__)
 
-static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
-    const int qk = QK8_1;
-    const int nb = n / qk;
+#define GGML_SIMD
 
-    assert(n % qk == 0);
-    assert(qk == QK5_1);
+// F32 SSE
 
-    const block_q5_1 * restrict x = vx;
-    const block_q8_1 * restrict y = vy;
+#define GGML_F32_STEP 32
+#define GGML_F32_EPR  4
 
-#if defined(__ARM_NEON)
-    float32x4_t sumv0 = vdupq_n_f32(0.0f);
-    float32x4_t sumv1 = vdupq_n_f32(0.0f);
-
-    float summs0 = 0.0f;
-    float summs1 = 0.0f;
-
-    uint32_t qh0;
-    uint32_t qh1;
-
-    uint64_t tmp0[4];
-    uint64_t tmp1[4];
-
-    GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb
-    for (int i = 0; i < nb; i += 2) {
-        const block_q5_1 * restrict x0 = &x[i];
-        const block_q5_1 * restrict x1 = &x[i + 1];
-        const block_q8_1 * restrict y0 = &y[i];
-        const block_q8_1 * restrict y1 = &y[i + 1];
-
-        const uint8x16_t m4b = vdupq_n_u8(0x0F);
-
-        summs0 += GGML_FP16_TO_FP32(x0->m) * y0->s;
-        summs1 += GGML_FP16_TO_FP32(x1->m) * y1->s;
-
-        // extract the 5th bit via lookup table ((b) << 4)
-        memcpy(&qh0, x0->qh, sizeof(qh0));
-        memcpy(&qh1, x1->qh, sizeof(qh1));
-
-        tmp0[0] = table_b2b_0[(qh0 >>  0) & 0xFF];
-        tmp0[1] = table_b2b_0[(qh0 >>  8) & 0xFF];
-        tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF];
-        tmp0[3] = table_b2b_0[(qh0 >> 24)       ];
-
-        tmp1[0] = table_b2b_0[(qh1 >>  0) & 0xFF];
-        tmp1[1] = table_b2b_0[(qh1 >>  8) & 0xFF];
-        tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF];
-        tmp1[3] = table_b2b_0[(qh1 >> 24)       ];
-
-        const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0));
-        const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2));
-        const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0));
-        const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2));
-
-        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
-        const uint8x16_t v0_1 = vld1q_u8(x1->qs);
-
-        // 4-bit -> 8-bit
-        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
-        const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
-        const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
-        const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
-
-        // add high bit
-        const int8x16_t v0_0lf = vorrq_s8(v0_0l, qhl0);
-        const int8x16_t v0_0hf = vorrq_s8(v0_0h, qhh0);
-        const int8x16_t v0_1lf = vorrq_s8(v0_1l, qhl1);
-        const int8x16_t v0_1hf = vorrq_s8(v0_1h, qhh1);
-
-        // load y
-        const int8x16_t v1_0l = vld1q_s8(y0->qs);
-        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
-        const int8x16_t v1_1l = vld1q_s8(y1->qs);
-        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
-
-#if defined(__ARM_FEATURE_DOTPROD)
-        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
-                        vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
-                        vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*y0->d);
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
-                        vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
-                        vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*y1->d);
+#define GGML_F32x4         __m128
+#define GGML_F32x4_ZERO    _mm_setzero_ps()
+#define GGML_F32x4_SET1(x) _mm_set1_ps(x)
+#define GGML_F32x4_LOAD    _mm_loadu_ps
+#define GGML_F32x4_STORE   _mm_storeu_ps
+#if defined(__FMA__)
+    // TODO: Does this work?
+    #define GGML_F32x4_FMA(a, b, c) _mm_fmadd_ps(b, c, a)
 #else
-        const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
-        const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
-        const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h));
-        const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h));
-
-        const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l));
-        const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l));
-        const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h));
-        const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h));
-
-        const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
-        const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
-        const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
-        const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
-
-        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d);
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d);
+    #define GGML_F32x4_FMA(a, b, c) _mm_add_ps(_mm_mul_ps(b, c), a)
 #endif
-    }
-
-    *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
-#elif defined(__wasm_simd128__)
-    v128_t sumv = wasm_f32x4_splat(0.0f);
-
-    float summs = 0.0f;
-
-    uint32_t qh;
-    uint64_t tmp[4];
-
-    // TODO: check if unrolling this is better
-    for (int i = 0; i < nb; ++i) {
-        const block_q5_1 * restrict x0 = &x[i];
-        const block_q8_1 * restrict y0 = &y[i];
-
-        summs += GGML_FP16_TO_FP32(x0->m) * y0->s;
-
-        const v128_t m4b = wasm_i8x16_splat(0x0F);
-
-        // extract the 5th bit
-        memcpy(&qh, x0->qh, sizeof(qh));
-
-        tmp[0] = table_b2b_0[(qh >>  0) & 0xFF];
-        tmp[1] = table_b2b_0[(qh >>  8) & 0xFF];
-        tmp[2] = table_b2b_0[(qh >> 16) & 0xFF];
-        tmp[3] = table_b2b_0[(qh >> 24)       ];
+#define GGML_F32x4_ADD     _mm_add_ps
+#define GGML_F32x4_MUL     _mm_mul_ps
+#define GGML_F32x4_REDUCE(res, x)                                 \
+{                                                                 \
+    int offset = GGML_F32_ARR >> 1;                               \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = _mm_add_ps(x[i], x[offset+i]);                     \
+    }                                                             \
+    offset >>= 1;                                                 \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = _mm_add_ps(x[i], x[offset+i]);                     \
+    }                                                             \
+    offset >>= 1;                                                 \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = _mm_add_ps(x[i], x[offset+i]);                     \
+    }                                                             \
+    const __m128 t0 = _mm_hadd_ps(x[0], x[0]);                    \
+    res = _mm_cvtss_f32(_mm_hadd_ps(t0, t0));                     \
+}
+// TODO: is this optimal ?
 
-        const v128_t qhl = wasm_v128_load(tmp + 0);
-        const v128_t qhh = wasm_v128_load(tmp + 2);
+#define GGML_F32_VEC        GGML_F32x4
+#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
+#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
+#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
+#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
+#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
+#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
+#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
 
-        const v128_t v0 = wasm_v128_load(x0->qs);
+// F16 SSE
 
-        // 4-bit -> 8-bit
-        const v128_t v0l = wasm_v128_and (v0, m4b);
-        const v128_t v0h = wasm_u8x16_shr(v0, 4);
+#define GGML_F16_STEP 32
+#define GGML_F16_EPR  4
 
-        // add high bit
-        const v128_t v0lf = wasm_v128_or(v0l, qhl);
-        const v128_t v0hf = wasm_v128_or(v0h, qhh);
+static inline __m128 __sse_f16x4_load(ggml_fp16_t *x) {
+    float tmp[4];
 
-        // load y
-        const v128_t v1l = wasm_v128_load(y0->qs);
-        const v128_t v1h = wasm_v128_load(y0->qs + 16);
+    tmp[0] = GGML_FP16_TO_FP32(x[0]);
+    tmp[1] = GGML_FP16_TO_FP32(x[1]);
+    tmp[2] = GGML_FP16_TO_FP32(x[2]);
+    tmp[3] = GGML_FP16_TO_FP32(x[3]);
 
-        // int8x16 -> int16x8
-        const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
-        const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
-        const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
-        const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
+    return _mm_loadu_ps(tmp);
+}
 
-        const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
-        const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
-        const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
-        const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
+static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) {
+    float arr[4];
 
-        // dot product
-        sumv = wasm_f32x4_add(sumv,
-                wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add(
-                            wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
-                                           wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
-                            wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
-                                           wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
-                    wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * y0->d)));
-    }
+    _mm_storeu_ps(arr, y);
 
-    *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
-         wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs;
-#elif defined(__AVX2__)
-    // Initialize accumulator with zeros
-    __m256 acc = _mm256_setzero_ps();
+    x[0] = GGML_FP32_TO_FP16(arr[0]);
+    x[1] = GGML_FP32_TO_FP16(arr[1]);
+    x[2] = GGML_FP32_TO_FP16(arr[2]);
+    x[3] = GGML_FP32_TO_FP16(arr[3]);
+}
 
-    float summs = 0.0f;
+#define GGML_F32Cx4             __m128
+#define GGML_F32Cx4_ZERO        _mm_setzero_ps()
+#define GGML_F32Cx4_SET1(x)     _mm_set1_ps(x)
+#define GGML_F32Cx4_LOAD(x)     __sse_f16x4_load(x)
+#define GGML_F32Cx4_STORE(x, y) __sse_f16x4_store(x, y)
+#define GGML_F32Cx4_FMA         GGML_F32x4_FMA
+#define GGML_F32Cx4_ADD         _mm_add_ps
+#define GGML_F32Cx4_MUL         _mm_mul_ps
+#define GGML_F32Cx4_REDUCE      GGML_F32x4_REDUCE
 
-    // Main loop
-    for (int i = 0; i < nb; i++) {
-        const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
+#define GGML_F16_VEC                 GGML_F32Cx4
+#define GGML_F16_VEC_ZERO            GGML_F32Cx4_ZERO
+#define GGML_F16_VEC_SET1            GGML_F32Cx4_SET1
+#define GGML_F16_VEC_LOAD(p, i)      GGML_F32Cx4_LOAD(p)
+#define GGML_F16_VEC_STORE(p, r, i)  GGML_F32Cx4_STORE(p, r[i])
+#define GGML_F16_VEC_FMA             GGML_F32Cx4_FMA
+#define GGML_F16_VEC_ADD             GGML_F32Cx4_ADD
+#define GGML_F16_VEC_MUL             GGML_F32Cx4_MUL
+#define GGML_F16_VEC_REDUCE          GGML_F32Cx4_REDUCE
 
-        summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
+#endif
 
-        __m256i bx = bytes_from_nibbles_32(x[i].qs);
-        __m256i bxhi = bytes_from_bits_32(x[i].qh);
-        bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10));
-        bx = _mm256_or_si256(bx, bxhi);
+// GGML_F32_ARR / GGML_F16_ARR
+//   number of registers to use per step
+#ifdef GGML_SIMD
+#define GGML_F32_ARR (GGML_F32_STEP/GGML_F32_EPR)
+#define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR)
+#endif
 
-        const __m256 dy = _mm256_set1_ps(y[i].d);
-        const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
+//
+// fundamental operations
+//
 
-        const __m256 q = mul_sum_us8_pairs_float(bx, by);
+inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
 
-        acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
-    }
+inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
 
-    *s = hsum_float_8(acc) + summs;
-#elif defined(__AVX__)
-    // Initialize accumulator with zeros
-    __m256 acc = _mm256_setzero_ps();
-    __m128i mask = _mm_set1_epi8(0x10);
+inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
 
-    float summs = 0.0f;
+inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
 
-    // Main loop
-    for (int i = 0; i < nb; i++) {
-        const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
+inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i] + y[i]; }
+inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float   v) { for (int i = 0; i < n; ++i) z[i]  = x[i] + v;    }
+inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i] += x[i];        }
+inline static void ggml_vec_acc1_f32(const int n, float * y, const float   v)                  { for (int i = 0; i < n; ++i) y[i] += v;           }
+inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i] - y[i]; }
+inline static void ggml_vec_set_f32 (const int n, float * x, const float   v)                  { for (int i = 0; i < n; ++i) x[i]  = v;           }
+inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i]  = x[i];        }
+inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i]  = -x[i];       }
+inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i]*y[i];   }
+inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i]/y[i];   }
 
-        summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
+static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) {
+#ifdef GGML_SIMD
+    float sumf = 0.0f;
+    const int np = (n & ~(GGML_F32_STEP - 1));
 
-        __m256i bx = bytes_from_nibbles_32(x[i].qs);
-        const __m256i bxhi = bytes_from_bits_32(x[i].qh);
-        __m128i bxhil = _mm256_castsi256_si128(bxhi);
-        __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
-        bxhil = _mm_and_si128(bxhil, mask);
-        bxhih = _mm_and_si128(bxhih, mask);
-        __m128i bxl = _mm256_castsi256_si128(bx);
-        __m128i bxh = _mm256_extractf128_si256(bx, 1);
-        bxl = _mm_or_si128(bxl, bxhil);
-        bxh = _mm_or_si128(bxh, bxhih);
-        bx = MM256_SET_M128I(bxh, bxl);
+    GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
 
-        const __m256 dy = _mm256_set1_ps(y[i].d);
-        const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
+    GGML_F32_VEC ax[GGML_F32_ARR];
+    GGML_F32_VEC ay[GGML_F32_ARR];
 
-        const __m256 q = mul_sum_us8_pairs_float(bx, by);
+    for (int i = 0; i < np; i += GGML_F32_STEP) {
+        for (int j = 0; j < GGML_F32_ARR; j++) {
+            ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
+            ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
 
-        acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);
+            sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]);
+        }
     }
 
-    *s = hsum_float_8(acc) + summs;
-#elif defined(__riscv_v_intrinsic)
-    float sumf = 0.0;
-
-    uint32_t qh;
-
-    size_t vl = __riscv_vsetvl_e8m1(qk/2);
-
-    // temporary registers for shift operations
-    vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
-    vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
-
-    for (int i = 0; i < nb; i++) {
-        memcpy(&qh, x[i].qh, sizeof(uint32_t));
-
-        // load qh
-        vuint32m2_t vqh = __riscv_vmv_v_x_u32m2(qh, vl);
-
-        // ((qh >> (j +  0)) << 4) & 0x10;
-        vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(vqh, vt_1, vl);
-        vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
-        vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(xhl_0, 0x10, vl);
-
-        // ((qh >> (j + 12))     ) & 0x10;
-        vuint32m2_t xhr_1 = __riscv_vsrl_vv_u32m2(vqh, vt_2, vl);
-        vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(xhr_1, 0x10, vl);
-
-        // narrowing
-        vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xha_0, vl);
-        vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
-
-        vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xha_1, vl);
-        vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
-
-        // load
-        vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
-
-        vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
-        vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
-
-        vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
-        vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
-
-        vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
-        vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
-
-        vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
-        vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
-
-        vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
-        vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
-
-        vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
-
-        vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
-        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
-
-        int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
+    // reduce sum0..sum3 to sum0
+    GGML_F32_VEC_REDUCE(sumf, sum);
 
-        sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
+    // leftovers
+    for (int i = np; i < n; ++i) {
+        sumf += x[i]*y[i];
     }
-
-    *s = sumf;
 #else
     // scalar
-    float sumf = 0.0;
-
-    for (int i = 0; i < nb; i++) {
-        uint32_t qh;
-        memcpy(&qh, x[i].qh, sizeof(qh));
-
-        int sumi = 0;
-
-        for (int j = 0; j < qk/2; ++j) {
-            const uint8_t xh_0 = ((qh >> (j +  0)) << 4) & 0x10;
-            const uint8_t xh_1 = ((qh >> (j + 12))     ) & 0x10;
-
-            const int32_t x0 = (x[i].qs[j] & 0xF) | xh_0;
-            const int32_t x1 = (x[i].qs[j] >>  4) | xh_1;
-
-            sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]);
-        }
-
-        sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
+    ggml_float sumf = 0.0;
+    for (int i = 0; i < n; ++i) {
+        sumf += (ggml_float)(x[i]*y[i]);
     }
+#endif
 
     *s = sumf;
-#endif
 }
 
-static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
-    const int qk = QK8_0;
-    const int nb = n / qk;
-
-    assert(n % qk == 0);
-
-    const block_q8_0 * restrict x = vx;
-    const block_q8_0 * restrict y = vy;
-
-#if defined(__ARM_NEON)
-    float32x4_t sumv0 = vdupq_n_f32(0.0f);
-    float32x4_t sumv1 = vdupq_n_f32(0.0f);
-
-    GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb
-    for (int i = 0; i < nb; i += 2) {
-        const block_q8_0 * restrict x0 = &x[i + 0];
-        const block_q8_0 * restrict x1 = &x[i + 1];
-        const block_q8_0 * restrict y0 = &y[i + 0];
-        const block_q8_0 * restrict y1 = &y[i + 1];
-
-        const int8x16_t x0_0 = vld1q_s8(x0->qs);
-        const int8x16_t x0_1 = vld1q_s8(x0->qs + 16);
-        const int8x16_t x1_0 = vld1q_s8(x1->qs);
-        const int8x16_t x1_1 = vld1q_s8(x1->qs + 16);
-
-        // load y
-        const int8x16_t y0_0 = vld1q_s8(y0->qs);
-        const int8x16_t y0_1 = vld1q_s8(y0->qs + 16);
-        const int8x16_t y1_0 = vld1q_s8(y1->qs);
-        const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
-
-#if defined(__ARM_FEATURE_DOTPROD)
-        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
-                        vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
-                        vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
-
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
-                        vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
-                        vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
+    ggml_float sumf = 0.0;
 
-#else
-        const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0));
-        const int16x8_t p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
-        const int16x8_t p0_2 = vmull_s8(vget_low_s8 (x0_1), vget_low_s8 (y0_1));
-        const int16x8_t p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
-
-        const int16x8_t p1_0 = vmull_s8(vget_low_s8 (x1_0), vget_low_s8 (y1_0));
-        const int16x8_t p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
-        const int16x8_t p1_2 = vmull_s8(vget_low_s8 (x1_1), vget_low_s8 (y1_1));
-        const int16x8_t p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
-
-        const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
-        const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
-        const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
-        const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
-
-        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
-#endif
-    }
+#if defined(GGML_SIMD)
+    const int np = (n & ~(GGML_F16_STEP - 1));
 
-    *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
-#elif defined(__AVX2__) || defined(__AVX__)
-    // Initialize accumulator with zeros
-    __m256 acc = _mm256_setzero_ps();
+    GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
 
-    // Main loop
-    for (int i = 0; i < nb; ++i) {
-        // Compute combined scale for the block
-        const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
-        __m256i bx = _mm256_loadu_si256((const __m256i *)x[i].qs);
-        __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
+    GGML_F16_VEC ax[GGML_F16_ARR];
+    GGML_F16_VEC ay[GGML_F16_ARR];
 
-        const __m256 q = mul_sum_i8_pairs_float(bx, by);
+    for (int i = 0; i < np; i += GGML_F16_STEP) {
+        for (int j = 0; j < GGML_F16_ARR; j++) {
+            ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
+            ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
 
-        // Multiply q with scale and accumulate
-#if defined(__AVX2__)
-        acc = _mm256_fmadd_ps( d, q, acc );
-#else
-        acc = _mm256_add_ps( _mm256_mul_ps( d, q ), acc );
-#endif
+            sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]);
+        }
     }
 
-    *s = hsum_float_8(acc);
-#elif defined(__riscv_v_intrinsic)
-    float sumf = 0.0;
-    size_t vl = __riscv_vsetvl_e8m1(qk);
-
-    for (int i = 0; i < nb; i++) {
-        // load elements
-        vint8m1_t bx = __riscv_vle8_v_i8m1(x[i].qs, vl);
-        vint8m1_t by = __riscv_vle8_v_i8m1(y[i].qs, vl);
-
-        vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx, by, vl);
-
-        vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl);
-        vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl);
-
-        int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum);
+    // reduce sum0..sum3 to sum0
+    GGML_F16_VEC_REDUCE(sumf, sum);
 
-        sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d));
+    // leftovers
+    for (int i = np; i < n; ++i) {
+        sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
     }
-
-    *s = sumf;
 #else
-    // scalar
-    float sumf = 0.0;
-
-    for (int i = 0; i < nb; i++) {
-        int sumi = 0;
-
-        for (int j = 0; j < qk; j++) {
-            sumi += x[i].qs[j]*y[i].qs[j];
-        }
-
-        sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d));
+    for (int i = 0; i < n; ++i) {
+        sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
     }
+#endif
 
     *s = sumf;
-#endif
 }
 
 // compute GGML_VEC_DOT_UNROLL dot products at once
@@ -3886,7 +1408,7 @@ inline static float ggml_gelu_f32(float x) {
 inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
     const uint16_t * i16 = (const uint16_t *) x;
     for (int i = 0; i < n; ++i) {
-        y[i] = table_gelu_f16[i16[i]];
+        y[i] = ggml_table_gelu_f16[i16[i]];
     }
 }
 
@@ -3896,7 +1418,7 @@ inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
     for (int i = 0; i < n; ++i) {
         ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
         memcpy(&t, &fp16, sizeof(uint16_t));
-        y[i] = GGML_FP16_TO_FP32(table_gelu_f16[t]);
+        y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]);
     }
 }
 #else
@@ -3914,7 +1436,7 @@ inline static float ggml_gelu_quick_f32(float x) {
 //inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
 //    const uint16_t * i16 = (const uint16_t *) x;
 //    for (int i = 0; i < n; ++i) {
-//        y[i] = table_gelu_quick_f16[i16[i]];
+//        y[i] = ggml_table_gelu_quick_f16[i16[i]];
 //    }
 //}
 
@@ -3924,7 +1446,7 @@ inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float *
     for (int i = 0; i < n; ++i) {
         ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
         memcpy(&t, &fp16, sizeof(uint16_t));
-        y[i] = GGML_FP16_TO_FP32(table_gelu_quick_f16[t]);
+        y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_quick_f16[t]);
     }
 }
 #else
@@ -3943,7 +1465,7 @@ inline static float ggml_silu_f32(float x) {
 //inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
 //    const uint16_t * i16 = (const uint16_t *) x;
 //    for (int i = 0; i < n; ++i) {
-//        y[i] = table_silu_f16[i16[i]];
+//        y[i] = ggml_table_silu_f16[i16[i]];
 //    }
 //}
 
@@ -3953,7 +1475,7 @@ inline static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
     for (int i = 0; i < n; ++i) {
         ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
         memcpy(&t, &fp16, sizeof(uint16_t));
-        y[i] = GGML_FP16_TO_FP32(table_silu_f16[t]);
+        y[i] = GGML_FP16_TO_FP32(ggml_table_silu_f16[t]);
     }
 }
 #else
@@ -4669,11 +2191,11 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
             for (int i = 0; i < (1 << 16); ++i) {
                 uint16_t ui = i;
                 memcpy(&ii, &ui, sizeof(ii));
-                const float f = table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii);
-                table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
-                table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
-                table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f));
-                table_exp_f16[i]  = GGML_FP32_TO_FP16(expf(f));
+                const float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii);
+                ggml_table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
+                ggml_table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
+                ggml_table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f));
+                ggml_table_exp_f16[i]  = GGML_FP32_TO_FP16(expf(f));
             }
 
             const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
@@ -5676,7 +3198,7 @@ static struct ggml_tensor * ggml_add_cast_impl(
     // TODO: support less-strict constraint
     //       GGML_ASSERT(ggml_can_repeat(b, a));
     GGML_ASSERT(ggml_can_repeat_rows(b, a));
-    GGML_ASSERT(ggml_is_quantized(a->type)); // currently only supported for quantized input
+    GGML_ASSERT(ggml_is_quantized(a->type) || a->type == GGML_TYPE_F16); // currently only supported for quantized input and f16
 
     bool is_node = false;
 
@@ -7376,8 +4898,13 @@ static struct ggml_tensor * ggml_rope_impl(
         int                   n_dims,
         int                   mode,
         int                   n_ctx,
+        int                   n_orig_ctx,
         float                 freq_base,
         float                 freq_scale,
+        float                 ext_factor,
+        float                 attn_factor,
+        float                 beta_fast,
+        float                 beta_slow,
         float                 xpos_base,
         bool                  xpos_down,
         bool                  inplace) {
@@ -7393,11 +4920,15 @@ static struct ggml_tensor * ggml_rope_impl(
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx };
-    memcpy(params + 4, &freq_base,  sizeof(float));
-    memcpy(params + 5, &freq_scale, sizeof(float));
-    memcpy(params + 6, &xpos_base,  sizeof(float));
-    memcpy(params + 7, &xpos_down,  sizeof(bool));
+    int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx };
+    memcpy(params +  5, &freq_base,    sizeof(float));
+    memcpy(params +  6, &freq_scale,   sizeof(float));
+    memcpy(params +  7, &ext_factor,   sizeof(float));
+    memcpy(params +  8, &attn_factor,  sizeof(float));
+    memcpy(params +  9, &beta_fast,    sizeof(float));
+    memcpy(params + 10, &beta_slow,    sizeof(float));
+    memcpy(params + 11, &xpos_base,    sizeof(float));
+    memcpy(params + 12, &xpos_down,    sizeof(bool));
     ggml_set_op_params(result, params, sizeof(params));
 
     result->op   = GGML_OP_ROPE;
@@ -7415,7 +4946,9 @@ struct ggml_tensor * ggml_rope(
         int                   n_dims,
         int                   mode,
         int                   n_ctx) {
-    return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false);
+    return ggml_rope_impl(
+        ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
+    );
 }
 
 struct ggml_tensor * ggml_rope_inplace(
@@ -7425,7 +4958,9 @@ struct ggml_tensor * ggml_rope_inplace(
         int                   n_dims,
         int                   mode,
         int                   n_ctx) {
-    return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true);
+    return ggml_rope_impl(
+        ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
+    );
 }
 
 struct ggml_tensor * ggml_rope_custom(
@@ -7435,9 +4970,17 @@ struct ggml_tensor * ggml_rope_custom(
         int                   n_dims,
         int                   mode,
         int                   n_ctx,
+        int                   n_orig_ctx,
         float                 freq_base,
-        float                 freq_scale) {
-    return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false);
+        float                 freq_scale,
+        float                 ext_factor,
+        float                 attn_factor,
+        float                 beta_fast,
+        float                 beta_slow) {
+    return ggml_rope_impl(
+        ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
+        ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
+    );
 }
 
 struct ggml_tensor * ggml_rope_custom_inplace(
@@ -7447,9 +4990,17 @@ struct ggml_tensor * ggml_rope_custom_inplace(
         int                   n_dims,
         int                   mode,
         int                   n_ctx,
+        int                   n_orig_ctx,
         float                 freq_base,
-        float                 freq_scale) {
-    return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true);
+        float                 freq_scale,
+        float                 ext_factor,
+        float                 attn_factor,
+        float                 beta_fast,
+        float                 beta_slow) {
+    return ggml_rope_impl(
+        ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
+        ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
+    );
 }
 
 struct ggml_tensor * ggml_rope_xpos_inplace(
@@ -7459,7 +5010,7 @@ struct ggml_tensor * ggml_rope_xpos_inplace(
         int                   n_dims,
         float                 base,
         bool                  down) {
-    return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 10000.0f, 1.0f, base, down, true);
+    return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
 }
 
 // ggml_rope_back
@@ -7959,6 +5510,7 @@ struct ggml_tensor * ggml_pool_2d(
         GGML_ASSERT(false); // TODO: implement backward
         is_node = true;
     }
+
     const int64_t ne[3] = {
         ggml_calc_pool_output_size(a->ne[0], k0, s0, p0),
         ggml_calc_pool_output_size(a->ne[1], k1, s1, p1),
@@ -9457,9 +7009,15 @@ static void ggml_compute_forward_add_f16_f32(
 
     GGML_ASSERT(src0->type == GGML_TYPE_F16);
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->type  == GGML_TYPE_F16);
 
-    GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
+    if (dst->type == GGML_TYPE_F32) {
+        GGML_ASSERT( nb0 == sizeof(float));
+    }
+    else {
+        GGML_ASSERT(dst->type  == GGML_TYPE_F16);
+        GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
+    }
+
     GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
 
     // rows per thread
@@ -9470,18 +7028,35 @@ static void ggml_compute_forward_add_f16_f32(
     const int ir1 = MIN(ir0 + dr, nr);
 
     if (nb10 == sizeof(float)) {
-        for (int ir = ir0; ir < ir1; ++ir) {
-            // src0, src1 and dst are same shape => same indices
-            const int i3 = ir/(ne2*ne1);
-            const int i2 = (ir - i3*ne2*ne1)/ne1;
-            const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
-            ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1);
-            ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
-            float *       src1_ptr = (float *)       ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
-
-            for (int i = 0; i < ne0; i++) {
-                dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]);
+        if (dst->type == GGML_TYPE_F16) {
+            for (int ir = ir0; ir < ir1; ++ir) {
+                // src0, src1 and dst are same shape => same indices
+                const int i3 = ir/(ne2*ne1);
+                const int i2 = (ir - i3*ne2*ne1)/ne1;
+                const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+                ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1);
+                ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+                float *       src1_ptr = (float *)       ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
+
+                for (int i = 0; i < ne0; i++) {
+                    dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]);
+                }
+            }
+        } else {
+            for (int ir = ir0; ir < ir1; ++ir) {
+                // src0, src1 and dst are same shape => same indices
+                const int i3 = ir/(ne2*ne1);
+                const int i2 = (ir - i3*ne2*ne1)/ne1;
+                const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+                float *       dst_ptr  = (float *)       ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1);
+                ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+                float *       src1_ptr = (float *)       ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
+
+                for (int i = 0; i < ne0; i++) {
+                    dst_ptr[i] = GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i];
+                }
             }
         }
     }
@@ -13085,7 +10660,7 @@ static void ggml_compute_forward_soft_max_f32(
                 // const float val = (sp[i] == -INFINITY) ? 0.0 : exp(sp[i] - max);
                 ggml_fp16_t s = GGML_FP32_TO_FP16(sp[i] - max);
                 memcpy(&scvt, &s, sizeof(scvt));
-                const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
+                const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
                 sum += (ggml_float)val;
                 dp[i] = val;
             }
@@ -13450,6 +11025,45 @@ static void ggml_compute_forward_clamp(
 
 // ggml_compute_forward_rope
 
+static float rope_yarn_ramp(const float low, const float high, const int i0) {
+    const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
+    return 1 - MIN(1, MAX(0, y));
+}
+
+// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
+// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
+static void rope_yarn(
+    float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
+    float * cos_theta, float * sin_theta
+) {
+    // Get n-d rotational scaling corrected for extrapolation
+    float theta_interp = freq_scale * theta_extrap;
+    float theta = theta_interp;
+    if (ext_factor != 0.0f) {
+        float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
+        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
+
+        // Get n-d magnitude scaling corrected for interpolation
+        mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
+    }
+    *cos_theta = cosf(theta) * mscale;
+    *sin_theta = sinf(theta) * mscale;
+}
+
+// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
+// `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
+static float ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, float base) {
+    return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
+}
+
+void ggml_rope_yarn_corr_dims(
+    int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
+) {
+    // start and end correction dims
+    dims[0] = MAX(0,         floorf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_fast, freq_base)));
+    dims[1] = MIN(n_dims - 1, ceilf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_slow, freq_base)));
+}
+
 static void ggml_compute_forward_rope_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
@@ -13459,21 +11073,26 @@ static void ggml_compute_forward_rope_f32(
         return;
     }
 
-    float freq_base;
-    float freq_scale;
+    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
 
     // these two only relevant for xPos RoPE:
     float xpos_base;
     bool  xpos_down;
 
-    //const int n_past = ((int32_t *) dst->op_params)[0];
-    const int n_dims = ((int32_t *) dst->op_params)[1];
-    const int mode   = ((int32_t *) dst->op_params)[2];
-    const int n_ctx  = ((int32_t *) dst->op_params)[3];
-    memcpy(&freq_base,  (int32_t *) dst->op_params + 4, sizeof(float));
-    memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
-    memcpy(&xpos_base,  (int32_t *) dst->op_params + 6, sizeof(float));
-    memcpy(&xpos_down,  (int32_t *) dst->op_params + 7, sizeof(bool));
+    //const int n_past     = ((int32_t *) dst->op_params)[0];
+    const int n_dims     = ((int32_t *) dst->op_params)[1];
+    const int mode       = ((int32_t *) dst->op_params)[2];
+    const int n_ctx      = ((int32_t *) dst->op_params)[3];
+    const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
+
+    memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
+    memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
+    memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
+    memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
+    memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
+    memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
+    memcpy(&xpos_base,   (int32_t *) dst->op_params + 11, sizeof(float));
+    memcpy(&xpos_down,   (int32_t *) dst->op_params + 12, sizeof(bool));
 
     GGML_TENSOR_UNARY_OP_LOCALS
 
@@ -13501,6 +11120,9 @@ static void ggml_compute_forward_rope_f32(
     int ir = 0;
 
     const float theta_scale = powf(freq_base, -2.0f/n_dims);
+    const float inv_ndims = -1.f/n_dims;
+    float corr_dims[2];
+    ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
 
     const bool is_neox = mode & 2;
     const bool is_glm  = mode & 4;
@@ -13514,18 +11136,18 @@ static void ggml_compute_forward_rope_f32(
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
 
-                float theta = freq_scale * (float)p;
+                float theta_base = (float)p;
 
                 if (is_glm) {
-                    theta = MIN(p, n_ctx - 2);
+                    theta_base = MIN(p, n_ctx - 2);
                     float block_theta = MAX(p - (n_ctx - 2), 0);
                     for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
-                        const float cos_theta = cosf(theta);
-                        const float sin_theta = sinf(theta);
+                        const float cos_theta = cosf(theta_base);
+                        const float sin_theta = sinf(theta_base);
                         const float cos_block_theta = cosf(block_theta);
                         const float sin_block_theta = sinf(block_theta);
 
-                        theta *= theta_scale;
+                        theta_base *= theta_scale;
                         block_theta *= theta_scale;
 
                         const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
@@ -13543,13 +11165,16 @@ static void ggml_compute_forward_rope_f32(
                     }
                 } else if (!is_neox) {
                     for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
-                        const float cos_theta = cosf(theta);
-                        const float sin_theta = sinf(theta);
+                        float cos_theta, sin_theta;
+                        rope_yarn(
+                            theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
+                        );
+
                         // zeta scaling for xPos only:
                         float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
                         if (xpos_down) zeta = 1.0f / zeta;
 
-                        theta *= theta_scale;
+                        theta_base *= theta_scale;
 
                         const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
                               float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
@@ -13563,12 +11188,19 @@ static void ggml_compute_forward_rope_f32(
                 } else {
                     // TODO: this might be wrong for ne0 != n_dims - need double check
                     // ref:  https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
+                    theta_base *= freq_scale;
                     for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
                         for (int64_t ic = 0; ic < n_dims; ic += 2) {
-                            const float cos_theta = cosf(theta);
-                            const float sin_theta = sinf(theta);
+                            // simplified from `(ib * n_dims + ic) * inv_ndims`
+                            float cur_rot = inv_ndims * ic - ib;
+
+                            float cos_theta, sin_theta;
+                            rope_yarn(
+                                theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
+                                &cos_theta, &sin_theta
+                            );
 
-                            theta *= theta_scale;
+                            theta_base *= theta_scale;
 
                             const int64_t i0 = ib*n_dims + ic/2;
 
@@ -13597,15 +11229,19 @@ static void ggml_compute_forward_rope_f16(
         return;
     }
 
-    float freq_base;
-    float freq_scale;
+    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
 
-    //const int n_past = ((int32_t *) dst->op_params)[0];
-    const int n_dims = ((int32_t *) dst->op_params)[1];
-    const int mode   = ((int32_t *) dst->op_params)[2];
-    const int n_ctx  = ((int32_t *) dst->op_params)[3];
-    memcpy(&freq_base,  (int32_t *) dst->op_params + 4, sizeof(float));
-    memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
+    //const int n_past     = ((int32_t *) dst->op_params)[0];
+    const int n_dims     = ((int32_t *) dst->op_params)[1];
+    const int mode       = ((int32_t *) dst->op_params)[2];
+    const int n_ctx      = ((int32_t *) dst->op_params)[3];
+    const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
+    memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
+    memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
+    memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
+    memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
+    memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
+    memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
 
     GGML_TENSOR_UNARY_OP_LOCALS
 
@@ -13633,6 +11269,9 @@ static void ggml_compute_forward_rope_f16(
     int ir = 0;
 
     const float theta_scale = powf(freq_base, -2.0f/n_dims);
+    const float inv_ndims = -1.f/n_dims;
+    float corr_dims[2];
+    ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
 
     const bool is_neox = mode & 2;
     const bool is_glm  = mode & 4;
@@ -13646,18 +11285,18 @@ static void ggml_compute_forward_rope_f16(
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
 
-                float theta = freq_scale * (float)p;
+                float theta_base = (float)p;
 
                 if (is_glm) {
-                    theta = MIN(p, n_ctx - 2);
+                    theta_base = MIN(p, n_ctx - 2);
                     float block_theta = MAX(p - (n_ctx - 2), 0);
                     for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
-                        const float cos_theta = cosf(theta);
-                        const float sin_theta = sinf(theta);
+                        const float cos_theta = cosf(theta_base);
+                        const float sin_theta = sinf(theta_base);
                         const float cos_block_theta = cosf(block_theta);
                         const float sin_block_theta = sinf(block_theta);
 
-                        theta *= theta_scale;
+                        theta_base *= theta_scale;
                         block_theta *= theta_scale;
 
                         const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
@@ -13675,10 +11314,12 @@ static void ggml_compute_forward_rope_f16(
                     }
                 } else if (!is_neox) {
                     for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
-                        const float cos_theta = cosf(theta);
-                        const float sin_theta = sinf(theta);
+                        float cos_theta, sin_theta;
+                        rope_yarn(
+                            theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
+                        );
 
-                        theta *= theta_scale;
+                        theta_base *= theta_scale;
 
                         const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
                               ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
@@ -13692,12 +11333,19 @@ static void ggml_compute_forward_rope_f16(
                 } else {
                     // TODO: this might be wrong for ne0 != n_dims - need double check
                     // ref:  https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
+                    theta_base *= freq_scale;
                     for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
                         for (int64_t ic = 0; ic < n_dims; ic += 2) {
-                            const float cos_theta = cosf(theta);
-                            const float sin_theta = sinf(theta);
+                            // simplified from `(ib * n_dims + ic) * inv_ndims`
+                            float cur_rot = inv_ndims * ic - ib;
+
+                            float cos_theta, sin_theta;
+                            rope_yarn(
+                                theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
+                                &cos_theta, &sin_theta
+                            );
 
-                            theta *= theta_scale;
+                            theta_base *= theta_scale;
 
                             const int64_t i0 = ib*n_dims + ic/2;
 
@@ -13805,17 +11453,18 @@ static void ggml_compute_forward_rope_back_f32(
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
 
-                float theta = freq_scale * (float)p;
+                float theta_base = freq_scale * (float)p;
 
                 if (!is_neox) {
                     for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
-                        const float cos_theta = cosf(theta);
-                        const float sin_theta = sinf(theta);
+                        const float cos_theta = cosf(theta_base);
+                        const float sin_theta = sinf(theta_base);
+
                         // zeta scaling for xPos only:
                         float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
                         if (xpos_down) zeta = 1.0f / zeta;
 
-                        theta *= theta_scale;
+                        theta_base *= theta_scale;
 
                         const float * const dy  = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
                               float *       dx  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
@@ -13829,10 +11478,10 @@ static void ggml_compute_forward_rope_back_f32(
                 } else {
                     for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
                         for (int64_t ic = 0; ic < n_dims; ic += 2) {
-                            const float cos_theta = cosf(theta);
-                            const float sin_theta = sinf(theta);
+                            const float cos_theta = cosf(theta_base);
+                            const float sin_theta = sinf(theta_base);
 
-                            theta *= theta_scale;
+                            theta_base *= theta_scale;
 
                             const int64_t i0 = ib*n_dims + ic/2;
 
@@ -13905,14 +11554,14 @@ static void ggml_compute_forward_rope_back_f16(
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
 
-                float theta = (float)p;
+                float theta_base = (float)p;
 
                 if (!is_neox) {
                     for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
-                        const float cos_theta = cosf(theta);
-                        const float sin_theta = sinf(theta);
+                        const float cos_theta = cosf(theta_base);
+                        const float sin_theta = sinf(theta_base);
 
-                        theta *= theta_scale;
+                        theta_base *= theta_scale;
 
                         const ggml_fp16_t * const dy  = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
                               ggml_fp16_t *       dx  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
@@ -13926,10 +11575,10 @@ static void ggml_compute_forward_rope_back_f16(
                 } else {
                     for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
                         for (int64_t ic = 0; ic < n_dims; ic += 2) {
-                            const float cos_theta = cosf(theta);
-                            const float sin_theta = sinf(theta);
+                            const float cos_theta = cosf(theta_base);
+                            const float sin_theta = sinf(theta_base);
 
-                            theta *= theta_scale;
+                            theta_base *= theta_scale;
 
                             const int64_t i0 = ib*n_dims + ic/2;
 
@@ -15360,7 +13009,7 @@ static void ggml_compute_forward_flash_attn_f32(
 #else
                             ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
                             memcpy(&scvt[j], &s, sizeof(uint16_t));
-                            const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
+                            const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]);
 #endif
                             sump[j] += (ggml_float)val;
                             SS[j] = val;
@@ -15562,7 +13211,7 @@ static void ggml_compute_forward_flash_attn_f16(
                         } else {
                             ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
                             memcpy(&scvt[j], &s, sizeof(uint16_t));
-                            const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
+                            const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]);
                             sump[j] += (ggml_float)val;
                             SS[j] = val;
                         }
@@ -16013,7 +13662,7 @@ static void ggml_compute_forward_flash_attn_back_f32(
 #else
                                     ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
                                     memcpy(&scvt[j], &s, sizeof(uint16_t));
-                                    const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
+                                    const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]);
 #endif
                                     sump[j] += (ggml_float)val;
                                     SW[j] = val;
@@ -16767,7 +14416,7 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
 #else
                     ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
                     memcpy(&scvt, &s, sizeof(scvt));
-                    const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
+                    const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
 #endif
                     sum += (ggml_float)val;
                     st[i] = val;
@@ -16881,7 +14530,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
 #else
                     ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
                     memcpy(&scvt, &s, sizeof(scvt));
-                    const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
+                    const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
 #endif
                     sum += (ggml_float)val;
                     ds0[i] = val;
@@ -18091,9 +15740,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 src1,
                                 n_dims,
                                 mode,
+                                0,
                                 n_ctx,
                                 freq_base,
                                 freq_scale,
+                                0.0f,
+                                1.0f,
+                                0.0f,
+                                0.0f,
                                 xpos_base,
                                 xpos_down,
                                 false),
@@ -21207,7 +18861,6 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
                 block_q8_0 * block = (block_q8_0*)dst + start / QK8_0;
                 result = ggml_quantize_q8_0(src + start, block, n, n, hist);
             } break;
-#ifdef GGML_USE_K_QUANTS
         case GGML_TYPE_Q2_K:
             {
                 GGML_ASSERT(start % QK_K == 0);
@@ -21238,7 +18891,6 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
                 block_q6_K * block = (block_q6_K*)dst + start / QK_K;
                 result = ggml_quantize_q6_K(src + start, block, n, n, hist);
             } break;
-#endif
         case GGML_TYPE_F16:
             {
                 int elemsize = sizeof(ggml_fp16_t);
@@ -21370,8 +19022,7 @@ static bool gguf_fread_el(FILE * file, void * dst, size_t size, size_t * offset)
     return n == size;
 }
 
-// NOTE: temporary handling of GGUFv1 >> remove after Oct 2023
-static bool gguf_fread_str_cur(FILE * file, struct gguf_str * p, size_t * offset) {
+static bool gguf_fread_str(FILE * file, struct gguf_str * p, size_t * offset) {
     p->n    = 0;
     p->data = NULL;
 
@@ -21383,19 +19034,6 @@ static bool gguf_fread_str_cur(FILE * file, struct gguf_str * p, size_t * offset
     return ok;
 }
 
-static bool gguf_fread_str_v1(FILE * file, struct gguf_str * p, size_t * offset) {
-    p->n    = 0;
-    p->data = NULL;
-
-    bool ok = true;
-
-    uint32_t n = 0;
-    ok = ok && gguf_fread_el(file, &n,       sizeof(n), offset); p->data = calloc(n + 1, 1); p->n = n;
-    ok = ok && gguf_fread_el(file,  p->data, p->n,      offset);
-
-    return ok;
-}
-
 struct gguf_context * gguf_init_empty(void) {
     struct gguf_context * ctx = GGML_ALIGNED_MALLOC(sizeof(struct gguf_context));
 
@@ -21454,20 +19092,14 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
         ctx->data  = NULL;
 
         ok = ok && gguf_fread_el(file, &ctx->header.version,   sizeof(ctx->header.version),   &offset);
+        ok = ok && gguf_fread_el(file, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors), &offset);
+        ok = ok && gguf_fread_el(file, &ctx->header.n_kv,      sizeof(ctx->header.n_kv),      &offset);
 
         if (ctx->header.version == 1) {
-            // NOTE: temporary handling of GGUFv1 >> remove after Oct 2023
-            uint32_t n_tensors = 0;
-            uint32_t n_kv      = 0;
-
-            ok = ok && gguf_fread_el(file, &n_tensors, sizeof(n_tensors), &offset);
-            ok = ok && gguf_fread_el(file, &n_kv,      sizeof(n_kv),      &offset);
-
-            ctx->header.n_tensors = n_tensors;
-            ctx->header.n_kv      = n_kv;
-        } else {
-            ok = ok && gguf_fread_el(file, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors), &offset);
-            ok = ok && gguf_fread_el(file, &ctx->header.n_kv,      sizeof(ctx->header.n_kv),      &offset);
+            fprintf(stderr, "%s: GGUFv1 is no longer supported. please use a more up-to-date version\n", __func__);
+            fclose(file);
+            gguf_free(ctx);
+            return NULL;
         }
 
         if (!ok) {
@@ -21478,12 +19110,6 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
         }
     }
 
-    // NOTE: temporary handling of GGUFv1 >> remove after Oct 2023
-    bool (* gguf_fread_str)(FILE *, struct gguf_str *, size_t *) = gguf_fread_str_cur;
-    if (ctx->header.version == 1) {
-        gguf_fread_str = gguf_fread_str_v1;
-    }
-
     // read the kv pairs
     {
         ctx->kv = malloc(ctx->header.n_kv * sizeof(struct gguf_kv));
@@ -21514,15 +19140,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
                 case GGUF_TYPE_ARRAY:
                     {
                         ok = ok && gguf_fread_el(file, &kv->value.arr.type, sizeof(kv->value.arr.type), &offset);
-
-                        if (ctx->header.version == 1) {
-                            // NOTE: temporary handling of GGUFv1 >> remove after Oct 2023
-                            uint32_t n = 0;
-                            ok = ok && gguf_fread_el(file, &n, sizeof(n), &offset);
-                            kv->value.arr.n = n;
-                        } else {
-                            ok = ok && gguf_fread_el(file, &kv->value.arr.n, sizeof(kv->value.arr.n), &offset);
-                        }
+                        ok = ok && gguf_fread_el(file, &kv->value.arr.n,    sizeof(kv->value.arr.n), &offset);
 
                         switch (kv->value.arr.type) {
                             case GGUF_TYPE_UINT8:
@@ -21581,14 +19199,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
             ok = ok && gguf_fread_str(file, &info->name,                          &offset);
             ok = ok && gguf_fread_el (file, &info->n_dims, sizeof(info->n_dims),  &offset);
             for (uint32_t j = 0; j < info->n_dims; ++j) {
-                if (ctx->header.version == 1) {
-                    // NOTE: temporary handling of GGUFv1 >> remove after Oct 2023
-                    uint32_t t = 0;
-                    ok = ok && gguf_fread_el(file, &t, sizeof(t), &offset);
-                    info->ne[j] = t;
-                } else {
-                    ok = ok && gguf_fread_el(file, &info->ne[j], sizeof(info->ne[j]), &offset);
-                }
+                ok = ok && gguf_fread_el(file, &info->ne[j], sizeof(info->ne[j]), &offset);
             }
             ok = ok && gguf_fread_el (file, &info->type,   sizeof(info->type),    &offset);
             ok = ok && gguf_fread_el (file, &info->offset, sizeof(info->offset),  &offset);
index 884af40548fb7912cd2e80c3c7e503bba938c06b..a2459a2867c5c087d27f2a390f96e852b7493207 100644 (file)
@@ -129,6 +129,13 @@ int main(int argc, char * argv[]) {
         ggml_type type = (ggml_type) i;
         ggml_type_traits_t qfns = ggml_internal_get_type_traits(type);
 
+        // deprecated - skip
+        if (qfns.blck_size == 0) {
+            continue;
+        }
+
+        printf("Testing %s\n", ggml_type_name((ggml_type) i));
+
         if (qfns.from_float && qfns.to_float) {
             const float total_error = total_quantization_error(qfns, test_size, test_data.data());
             const float max_quantization_error =