]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : sync llama.cpp (generalize quantize_fns + CUDA improvements)
authorGeorgi Gerganov <redacted>
Wed, 5 Jul 2023 17:14:13 +0000 (20:14 +0300)
committerGeorgi Gerganov <redacted>
Wed, 5 Jul 2023 17:14:13 +0000 (20:14 +0300)
include/ggml/ggml.h
src/ggml-cuda.cu
src/ggml-opencl.cpp
src/ggml.c

index 0af96c76b6d0ecd4951c6dc6b8b5e3959bfc3d0b..24ca8ae221c75fe453a58b4c016042921bf6dacd 100644 (file)
@@ -250,8 +250,8 @@ extern "C" {
     GGML_API float       ggml_fp16_to_fp32(ggml_fp16_t x);
     GGML_API ggml_fp16_t ggml_fp32_to_fp16(float x);
 
-    GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, size_t n);
-    GGML_API void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, size_t n);
+    GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int n);
+    GGML_API void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int n);
 
     struct ggml_object;
     struct ggml_context;
@@ -1514,26 +1514,19 @@ extern "C" {
     // Internal types and functions exposed for tests and benchmarks
     //
 
-#ifdef  __cplusplus
-    // restrict not standard in C++
-#define GGML_RESTRICT
-#else
-#define GGML_RESTRICT restrict
-#endif
-    typedef void (*dequantize_row_q_t)(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-    typedef void (*quantize_row_q_t)  (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
-    typedef void (*vec_dot_q_t)       (const int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT x, const void * GGML_RESTRICT y);
+    typedef void (*ggml_to_float_t)(const void * x, float * y, int k);
+    typedef void (*ggml_from_float_t)(const float * x, void * y, int k);
+    typedef void (*ggml_vec_dot_t)(const int n, float * s, const void * x, const void * y);
 
     typedef struct {
-        dequantize_row_q_t dequantize_row_q;
-        quantize_row_q_t   quantize_row_q;
-        quantize_row_q_t   quantize_row_q_reference;
-        quantize_row_q_t   quantize_row_q_dot;
-        vec_dot_q_t        vec_dot_q;
-        enum ggml_type     vec_dot_type;
-    } quantize_fns_t;
-
-    quantize_fns_t ggml_internal_get_quantize_fn(size_t i);
+        ggml_to_float_t   to_float;
+        ggml_from_float_t from_float;
+        ggml_from_float_t from_float_reference;
+        ggml_vec_dot_t    vec_dot;
+        enum ggml_type    vec_dot_type;
+    } ggml_type_traits_t;
+
+    ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type i);
 
 #ifdef  __cplusplus
 }
index 50df20edd7a7b211324e0af1c18e970bc250251d..7965ff74111f763e33d77dded69c256545cbaf70 100644 (file)
@@ -70,9 +70,11 @@ typedef void (*ggml_cuda_op_t)(
 
 // QK = number of values after dequantization
 // QR = QK / number of values before dequantization
+// QI = number of 32 bit integers before dequantization
 
 #define QK4_0 32
 #define QR4_0 2
+#define QI4_0 4
 typedef struct {
     half    d;              // delta
     uint8_t qs[QK4_0 / 2];  // nibbles / quants
@@ -81,6 +83,7 @@ static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0
 
 #define QK4_1 32
 #define QR4_1 2
+#define QI4_1 4
 typedef struct {
     half    d;              // delta
     half    m;              // min
@@ -90,6 +93,7 @@ static_assert(sizeof(block_q4_1) == sizeof(ggml_fp16_t) * 2 + QK4_1 / 2, "wrong
 
 #define QK5_0 32
 #define QR5_0 2
+#define QI5_0 4
 typedef struct {
     half d;                 // delta
     uint8_t qh[4];          // 5-th bit of quants
@@ -99,6 +103,7 @@ static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5
 
 #define QK5_1 32
 #define QR5_1 2
+#define QI5_1 4
 typedef struct {
     half d;                 // delta
     half m;                 // min
@@ -109,12 +114,25 @@ static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) +
 
 #define QK8_0 32
 #define QR8_0 1
+#define QI8_0 8
 typedef struct {
     half    d;              // delta
     int8_t  qs[QK8_0];      // quants
 } block_q8_0;
 static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
 
+#define QK8_1 32
+#define QR8_1 1
+#define QI8_1 8
+typedef struct {
+    half    d;              // delta
+    half    s;              // unquantized sum
+    int8_t  qs[QK8_0];      // quants
+} block_q8_1;
+static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_fp16_t) + QK8_0, "wrong q8_1 block size/padding");
+
+typedef float (*vec_dot_q_cuda_t)(const void * vbq, const block_q8_1 * bq8_1, const int iqs);
+
 //================================= k-quants
 
 #ifdef GGML_QKK_64
@@ -198,14 +216,15 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
 #define CUDA_SCALE_BLOCK_SIZE 256
 #define CUDA_ROPE_BLOCK_SIZE 256
 #define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
+#define CUDA_QUANTIZE_BLOCK_SIZE 256
 #define CUDA_DEQUANTIZE_BLOCK_SIZE 256
 
 // dmmv = dequantize_mul_mat_vec
 #ifndef GGML_CUDA_DMMV_X
 #define GGML_CUDA_DMMV_X 32
 #endif
-#ifndef GGML_CUDA_DMMV_Y
-#define GGML_CUDA_DMMV_Y 1
+#ifndef GGML_CUDA_MMV_Y
+#define GGML_CUDA_MMV_Y 1
 #endif
 
 #ifndef K_QUANTS_PER_ITERATION
@@ -270,7 +289,6 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
     }
 
     // sum up partial sums
-    __syncthreads();
 #pragma unroll
     for (int mask = 16; mask > 0; mask >>= 1) {
         tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
@@ -714,7 +732,6 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
 #endif
 
     // sum up partial sums and write back result
-    __syncthreads();
 #pragma unroll
     for (int mask = 16; mask > 0; mask >>= 1) {
         tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
@@ -819,7 +836,6 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
 #endif
 
     // sum up partial sums and write back result
-    __syncthreads();
 #pragma unroll
     for (int mask = 16; mask > 0; mask >>= 1) {
         tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
@@ -923,7 +939,6 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
 #endif
 
     // sum up partial sums and write back result
-    __syncthreads();
 #pragma unroll
     for (int mask = 16; mask > 0; mask >>= 1) {
         tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
@@ -1028,7 +1043,6 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
 #endif
 
     // sum up partial sums and write back result
-    __syncthreads();
 #pragma unroll
     for (int mask = 16; mask > 0; mask >>= 1) {
         tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
@@ -1139,7 +1153,6 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float
 #endif
 
     // sum up partial sums and write back result
-    __syncthreads();
 #pragma unroll
     for (int mask = 16; mask > 0; mask >>= 1) {
         tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
@@ -1158,6 +1171,41 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs,
     v.y = x[ib + iqs + 1];
 }
 
+static __global__ void quantize_q8_1(const float * x, void * vy, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+
+    block_q8_1 * y = (block_q8_1 *) vy;
+
+    const int ib = i / QK8_0; // block index
+    const int iqs = i % QK8_0; // quant index
+
+    const float xi = x[i];
+    float amax = fabsf(xi);
+    float sum = xi;
+
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, mask, 32));
+        sum += __shfl_xor_sync(0xffffffff, sum, mask, 32);
+    }
+
+    const float d = amax / 127;
+    const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
+
+    y[ib].qs[iqs] = q;
+
+    if (iqs > 0) {
+        return;
+    }
+
+    y[ib].d = d;
+    y[ib].s = sum;
+}
+
 template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
 static __global__ void dequantize_block(const void * vx, float * y, const int k) {
     const int i = blockDim.x*blockIdx.x + 2*threadIdx.x;
@@ -1179,6 +1227,182 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k)
     y[iybs + iqs + y_offset] = v.y;
 }
 
+static __device__ __forceinline__ float vec_dot_q4_0_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
+#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
+    const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq;
+
+    int vi;
+    memcpy(&vi,  &bq4_0->qs[sizeof(int) * (iqs + 0)], sizeof(int));
+    const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
+    const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI4_0)]);
+
+    const float d = __half2float(bq4_0->d) * __half2float(bq8_1->d);
+
+    // subtract 8 from each quantized value
+    const int vi0 = __vsub4((vi >> 0) & 0x0F0F0F0F, 0x08080808);
+    const int vi1 = __vsub4((vi >> 4) & 0x0F0F0F0F, 0x08080808);
+
+    // SIMD dot product of quantized values
+    int sumi = __dp4a(vi0, ui0, 0);
+    sumi     = __dp4a(vi1, ui1, sumi);
+
+    return sumi*d;
+#else
+    return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= 600
+}
+
+static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
+#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
+    const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq;
+
+    const int vi  = *((int *) &bq4_1->qs[sizeof(int) * (iqs + 0)]);
+    const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
+    const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI4_1)]);
+
+    const float d = __half2float(bq4_1->d) * __half2float(bq8_1->d);
+    const float m = bq4_1->m;
+    const float s = bq8_1->s;
+
+    const int vi0 = (vi >> 0) & 0x0F0F0F0F;
+    const int vi1 = (vi >> 4) & 0x0F0F0F0F;
+
+    // SIMD dot product of quantized values
+    int sumi = __dp4a(vi0, ui0, 0);
+    sumi     = __dp4a(vi1, ui1, sumi);
+
+    return sumi*d + m*s / QI4_1; // scale sum by QI4_1 because there are QI4_1 threads working on this block
+#else
+    return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= 600
+}
+
+static __device__ __forceinline__ float vec_dot_q5_0_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
+#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
+    const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq;
+
+    int qs;
+    memcpy(&qs, &bq5_0->qs[sizeof(int) * (iqs + 0)], sizeof(int));
+    const int qh0 = bq5_0->qh[iqs/2 + 0] >> 4*(iqs%2);
+    const int qh1 = bq5_0->qh[iqs/2 + 2] >> 4*(iqs%2);
+    const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
+    const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI5_0)]);
+
+    const float d = __half2float(bq5_0->d) * __half2float(bq8_1->d);
+
+    int vi0 = (qs  >>  0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh0 as 5th bits
+    vi0    |= (qh0 <<  4) & 0x00000010; // 1 ->  5
+    vi0    |= (qh0 << 11) & 0x00001000; // 2 -> 13
+    vi0    |= (qh0 << 18) & 0x00100000; // 3 -> 21
+    vi0    |= (qh0 << 25) & 0x10000000; // 4 -> 29
+    vi0     = __vsub4(vi0,  0x10101010); // subtract 16 from quantized values
+    int sumi = __dp4a(vi0, ui0, 0); // SIMD dot product of quantized values
+
+    int vi1 = (qs  >>  4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh1 as 5th bits
+    vi1    |= (qh1 <<  4) & 0x00000010; // 1 ->  5
+    vi1    |= (qh1 << 11) & 0x00001000; // 2 -> 13
+    vi1    |= (qh1 << 18) & 0x00100000; // 3 -> 21
+    vi1    |= (qh1 << 25) & 0x10000000; // 4 -> 29
+    vi1     = __vsub4(vi1,  0x10101010); // subtract 16 from quantized values
+    sumi = __dp4a(vi1, ui1, sumi); // SIMD dot product of quantized values
+
+    return sumi*d;
+#else
+    return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= 600
+}
+
+static __device__ __forceinline__ float vec_dot_q5_1_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
+#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
+    const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq;
+
+    const int qs  = *((int *) &bq5_1->qs[sizeof(int) * (iqs + 0)]);
+    const int qh0 = bq5_1->qh[iqs/2 + 0] >> 4*(iqs%2);
+    const int qh1 = bq5_1->qh[iqs/2 + 2] >> 4*(iqs%2);
+    const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
+    const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI5_1)]);
+
+    const float d = __half2float(bq5_1->d) * __half2float(bq8_1->d);
+    const float m = bq5_1->m;
+    const float s = bq8_1->s;
+
+    int vi0 = (qs  >>  0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh0 as 5th bits
+    vi0    |= (qh0 <<  4) & 0x00000010; // 1 ->  5
+    vi0    |= (qh0 << 11) & 0x00001000; // 2 -> 13
+    vi0    |= (qh0 << 18) & 0x00100000; // 3 -> 21
+    vi0    |= (qh0 << 25) & 0x10000000; // 4 -> 29
+    int sumi = __dp4a(vi0, ui0, 0); // SIMD dot product of quantized values
+
+    int vi1 = (qs  >>  4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh1 as 5th bits
+    vi1    |= (qh1 <<  4) & 0x00000010; // 1 ->  5
+    vi1    |= (qh1 << 11) & 0x00001000; // 2 -> 13
+    vi1    |= (qh1 << 18) & 0x00100000; // 3 -> 21
+    vi1    |= (qh1 << 25) & 0x10000000; // 4 -> 29
+    sumi = __dp4a(vi1, ui1, sumi); // SIMD dot product of quantized values
+
+    return sumi*d + m*s / QI5_1; // scale sum by QI5_1 because there are QI5_1 threads working on this block
+#else
+    return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= 600
+}
+
+static __device__ __forceinline__ float vec_dot_q8_0_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
+#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
+    const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq;
+
+    int vi;
+    memcpy(&vi,  &bq8_0->qs[sizeof(int) * (iqs + 0)], sizeof(int));
+    const int ui = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
+
+    const float d = __half2float(bq8_0->d) * __half2float(bq8_1->d);
+
+    // SIMD dot product of quantized values
+    int sumi = __dp4a(vi, ui, 0);
+
+    return sumi*d;
+#else
+    return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= 600
+}
+
+template <int qk, int qi, typename block_q_t, vec_dot_q_cuda_t vec_dot_q_cuda>
+static __global__ void mul_mat_vec_q(const void * vx, const void * vy, float * dst, const int ncols, const int nrows) {
+    const int row = blockIdx.y*blockDim.y + threadIdx.y;
+
+    if (row >= nrows) {
+        return;
+    }
+
+    const int blocks_per_row = ncols / qk;
+    const int blocks_per_warp = WARP_SIZE / qi;
+
+// partial sum for each thread
+    float tmp = 0.0f;
+
+    const block_q_t  * x = (const block_q_t  *) vx;
+    const block_q8_1 * y = (const block_q8_1 *) vy;
+
+    for (int i = 0; i < blocks_per_row; i += blocks_per_warp) {
+        const int ibx = row*blocks_per_row + i + threadIdx.x / qi; // x block index
+
+        const int iby = i + threadIdx.x / qi; // y block index
+
+        const int iqs  = threadIdx.x % qi; // x block quant index when casting the quants to int
+
+        tmp += vec_dot_q_cuda(&x[ibx], &y[iby], iqs);
+    }
+
+    // sum up partial sums and write back result
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+    }
+
+    if (threadIdx.x == 0) {
+        dst[row] = tmp;
+    }
+}
+
 template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
 static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows) {
     // qk = quantized weights per x block
@@ -1233,7 +1457,6 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y,
     }
 
     // sum up partial sums and write back result
-    __syncthreads();
 #pragma unroll
     for (int mask = 16; mask > 0; mask >>= 1) {
         tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
@@ -1284,7 +1507,6 @@ static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, fl
     const int idst = channel*nrows_dst + row_dst;
 
     // sum up partial sums and write back result
-    __syncthreads();
 #pragma unroll
     for (int mask = 16; mask > 0; mask >>= 1) {
         tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
@@ -1330,7 +1552,6 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
     }
 
     // sum up partial sums and write back result
-    __syncthreads();
 #pragma unroll
     for (int mask = 16; mask > 0; mask >>= 1) {
         tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
@@ -1440,7 +1661,6 @@ static __global__ void soft_max_f32(const float * x, float * dst, const int ncol
     }
 
     // sum up partial sums
-    __syncthreads();
 #pragma unroll
     for (int mask = 16; mask > 0; mask >>= 1) {
         tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
@@ -1494,6 +1714,11 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con
     rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
 }
 
+static void quantize_row_q8_1_cuda(const float * x, void * vy, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
+    quantize_q8_1<<<num_blocks, CUDA_QUANTIZE_BLOCK_SIZE, 0, stream>>>(x, vy, k);
+}
+
 static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
     dequantize_block<QK4_0, QR4_0, dequantize_q4_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
@@ -1562,45 +1787,45 @@ static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cu
 
 static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
-    const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(1, block_num_y, 1);
-    const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
         <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 }
 
 static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
-    const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(1, block_num_y, 1);
-    const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 }
 
 static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
-    const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(1, block_num_y, 1);
-    const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
         <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 }
 
 static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
-    const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(1, block_num_y, 1);
-    const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 }
 
 static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
-    const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(1, block_num_y, 1);
-    const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
         <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 }
@@ -1647,6 +1872,51 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f
     dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 }
 
+static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+    mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, vec_dot_q4_0_q8_1>
+        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
+}
+
+static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+    mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, vec_dot_q4_1_q8_1>
+        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
+}
+
+static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+    mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, vec_dot_q5_0_q8_1>
+        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
+}
+
+static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+    mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, vec_dot_q5_1_q8_1>
+        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
+}
+
+static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+    mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, vec_dot_q8_0_q8_1>
+        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
+}
+
 static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
     dequantize_block<1, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
@@ -1654,9 +1924,9 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c
 
 static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
-    const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(1, block_num_y, 1);
-    const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     dequantize_mul_mat_vec<1, 1, convert_f16>
         <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 }
@@ -1822,6 +2092,7 @@ static size_t g_scratch_offset = 0;
 
 static int g_device_count = -1;
 static int g_main_device = 0;
+static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES];
 static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
 
 static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
@@ -1839,9 +2110,12 @@ void ggml_init_cublas() {
         for (int id = 0; id < g_device_count; ++id) {
             cudaDeviceProp prop;
             CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
-            fprintf(stderr, "  Device %d: %s\n", id, prop.name);
+            fprintf(stderr, "  Device %d: %s, compute capability %d.%d\n", id, prop.name, prop.major, prop.minor);
+
             g_tensor_split[id] = total_vram;
             total_vram += prop.totalGlobalMem;
+
+            g_compute_capabilities[id] = 100*prop.major + 10*prop.minor;
         }
         for (int id = 0; id < g_device_count; ++id) {
             g_tensor_split[id] /= total_vram;
@@ -2057,7 +2331,7 @@ inline void ggml_cuda_op_rms_norm(
     (void) i1;
 }
 
-inline void ggml_cuda_op_dequantize_mul_mat_vec(
+inline void ggml_cuda_op_mul_mat_vec(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
     float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
     cudaStream_t & cudaStream_main){
@@ -2069,69 +2343,116 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
     const int64_t ne00 = src0->ne[0];
     const int64_t nrows = i01_high - i01_low;
 
-// on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
-#ifdef GGML_CUDA_DMMV_F16
-    size_t ash;
-    dfloat * src1_dfloat = nullptr; // dfloat == half
+#ifdef GGML_CUDA_FORCE_DMMV
+    const bool use_mul_mat_vec_q = false;
+#else
+    int id;
+    CUDA_CHECK(cudaGetDevice(&id));
 
-    bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
-        src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
-        src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
+    const bool mul_mat_vec_q_implemented = src0->type == GGML_TYPE_Q4_0 ||
+        src0->type == GGML_TYPE_Q4_1 ||
+        src0->type == GGML_TYPE_Q5_0 ||
+        src0->type == GGML_TYPE_Q5_1 ||
+        src0->type == GGML_TYPE_Q8_0;
 
-    if (src1_convert_f16) {
-        src1_dfloat = (half *) ggml_cuda_pool_malloc(ne00*sizeof(half), &ash);
-        ggml_cpy_f32_f16_cuda((char *) src1_ddf_i, (char *) src1_dfloat, ne00,
-                                ne00, 1, sizeof(float), 0, 0,
-                                ne00, 1, sizeof(half),  0, 0, cudaStream_main);
-    }
+    // The integer intrinsics used in mul_mat_vec_q are available with compute capability 6.
+    // However, they have bad performance with Pascal cards.
+    // Therefore, in a multi GPU setting decide at runtime which GPUs should use mul_mat_vec_q.
+    const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= 700 && mul_mat_vec_q_implemented;
+#endif
+
+    if (use_mul_mat_vec_q) {
+        size_t as;
+        void * src1_q8_1 = ggml_cuda_pool_malloc(ne00*sizeof(block_q8_1)/QK8_1, &as);
+        quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne00, cudaStream_main);
+
+        switch (src0->type) {
+            case GGML_TYPE_Q4_0:
+                mul_mat_vec_q4_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
+                break;
+            case GGML_TYPE_Q4_1:
+                mul_mat_vec_q4_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
+                break;
+            case GGML_TYPE_Q5_0:
+                mul_mat_vec_q5_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
+                break;
+            case GGML_TYPE_Q5_1:
+                mul_mat_vec_q5_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
+                break;
+            case GGML_TYPE_Q8_0:
+                mul_mat_vec_q8_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
+                break;
+            default:
+                GGML_ASSERT(false);
+                break;
+        }
+
+        ggml_cuda_pool_free(src1_q8_1, as);
+    } else {
+        // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
+#ifdef GGML_CUDA_DMMV_F16
+        size_t ash;
+        dfloat * src1_dfloat = nullptr; // dfloat == half
+
+        bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
+            src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
+            src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
+
+        if (src1_convert_f16) {
+            src1_dfloat = (half *) ggml_cuda_pool_malloc(ne00*sizeof(half), &ash);
+            ggml_cpy_f32_f16_cuda((char *) src1_ddf_i, (char *) src1_dfloat, ne00,
+                                    ne00, 1, sizeof(float), 0, 0,
+                                    ne00, 1, sizeof(half),  0, 0, cudaStream_main);
+        }
 #else
-    dfloat * src1_dfloat = src1_ddf_i; // dfloat == float, no conversion
+        dfloat * src1_dfloat = src1_ddf_i; // dfloat == float, no conversion
 #endif // GGML_CUDA_DMMV_F16
 
-    switch (src0->type) {
-        case GGML_TYPE_Q4_0:
-            dequantize_mul_mat_vec_q4_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
-            break;
-        case GGML_TYPE_Q4_1:
-            dequantize_mul_mat_vec_q4_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
-            break;
-        case GGML_TYPE_Q5_0:
-            dequantize_mul_mat_vec_q5_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
-            break;
-        case GGML_TYPE_Q5_1:
-            dequantize_mul_mat_vec_q5_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
-            break;
-        case GGML_TYPE_Q8_0:
-            dequantize_mul_mat_vec_q8_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
-            break;
-        case GGML_TYPE_Q2_K:
-            dequantize_mul_mat_vec_q2_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
-            break;
-        case GGML_TYPE_Q3_K:
-            dequantize_mul_mat_vec_q3_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
-            break;
-        case GGML_TYPE_Q4_K:
-            dequantize_mul_mat_vec_q4_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
-            break;
-        case GGML_TYPE_Q5_K:
-            dequantize_mul_mat_vec_q5_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
-            break;
-        case GGML_TYPE_Q6_K:
-            dequantize_mul_mat_vec_q6_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
-            break;
-        case GGML_TYPE_F16:
-            convert_mul_mat_vec_f16_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
-            break;
-        default:
-            GGML_ASSERT(false);
-            break;
-    }
+        switch (src0->type) {
+            case GGML_TYPE_Q4_0:
+                dequantize_mul_mat_vec_q4_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
+                break;
+            case GGML_TYPE_Q4_1:
+                dequantize_mul_mat_vec_q4_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
+                break;
+            case GGML_TYPE_Q5_0:
+                dequantize_mul_mat_vec_q5_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
+                break;
+            case GGML_TYPE_Q5_1:
+                dequantize_mul_mat_vec_q5_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
+                break;
+            case GGML_TYPE_Q8_0:
+                dequantize_mul_mat_vec_q8_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
+                break;
+            case GGML_TYPE_Q2_K:
+                dequantize_mul_mat_vec_q2_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
+                break;
+            case GGML_TYPE_Q3_K:
+                dequantize_mul_mat_vec_q3_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
+                break;
+            case GGML_TYPE_Q4_K:
+                dequantize_mul_mat_vec_q4_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
+                break;
+            case GGML_TYPE_Q5_K:
+                dequantize_mul_mat_vec_q5_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
+                break;
+            case GGML_TYPE_Q6_K:
+                dequantize_mul_mat_vec_q6_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
+                break;
+            case GGML_TYPE_F16:
+                convert_mul_mat_vec_f16_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
+                break;
+            default:
+                GGML_ASSERT(false);
+                break;
+        }
 
 #ifdef GGML_CUDA_DMMV_F16
-    if (src1_convert_f16) {
-        ggml_cuda_pool_free(src1_dfloat, ash);
-    }
+        if (src1_convert_f16) {
+            ggml_cuda_pool_free(src1_dfloat, ash);
+        }
 #endif // GGML_CUDA_DMMV_F16
+    }
 
     (void) src1;
     (void) dst;
@@ -2701,8 +3022,8 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
     }else if (src0->type == GGML_TYPE_F32) {
         ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);
     } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
-        if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src0->ne[1] % GGML_CUDA_DMMV_Y == 0) {
-            ggml_cuda_op(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false, false);
+        if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) {
+            ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_vec, false, false);
         } else {
             ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);
         }
@@ -2835,7 +3156,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
 }
 
 void ggml_cuda_free_data(struct ggml_tensor * tensor) {
-    if (tensor->backend != GGML_BACKEND_GPU && tensor->backend != GGML_BACKEND_GPU_SPLIT) {
+    if (!tensor || (tensor->backend != GGML_BACKEND_GPU && tensor->backend != GGML_BACKEND_GPU_SPLIT) ) {
         return;
     }
 
index fed4ffb0ccd0538b56aa552023581974b3fe2332..fa0bdbefb1de4fba5b64b39707dfca742236e12a 100644 (file)
@@ -1376,7 +1376,7 @@ static void ggml_cl_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1,
     const int64_t ne00 = src0->ne[0];
     const int64_t ne01 = src0->ne[1];
     const int64_t ne02 = src0->ne[2];
-    const int64_t ne03 = src0->ne[2];
+    const int64_t ne03 = src0->ne[3];
     const int64_t ne0 = ne00 * ne01 * ne02 * ne03;
     const int64_t ne10 = src1->ne[0];
     const int64_t ne11 = src1->ne[1];
index 88cbed7d5347c6e8cb07003769385860d4d450c9..635c32eb5e213bacf7b5d50823cffbeaa31d3e6b 100644 (file)
@@ -481,14 +481,14 @@ ggml_fp16_t ggml_fp32_to_fp16(float x) {
     return GGML_FP32_TO_FP16(x);
 }
 
-void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, size_t n) {
-    for (size_t i = 0; i < n; i++) {
+void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int n) {
+    for (int i = 0; i < n; i++) {
         y[i] = GGML_FP16_TO_FP32(x[i]);
     }
 }
 
-void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, size_t n) {
-    size_t i = 0;
+void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int n) {
+    int i = 0;
 #if defined(__F16C__)
     for (; i + 7 < n; i += 8) {
         __m256 x_vec = _mm256_loadu_ps(x + i);
@@ -1627,109 +1627,112 @@ static void dequantize_row_q8_0(const void * restrict vx, float * restrict y, in
     }
 }
 
+static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y);
+static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y);
 static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 
-static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
+static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
+    [GGML_TYPE_F32] = {
+        .vec_dot                  = (ggml_vec_dot_t) ggml_vec_dot_f32,
+        .vec_dot_type             = GGML_TYPE_F32,
+    },
+    [GGML_TYPE_F16] = {
+        .to_float                 = (ggml_to_float_t) ggml_fp16_to_fp32_row,
+        .from_float               = (ggml_from_float_t) ggml_fp32_to_fp16_row,
+        .from_float_reference     = (ggml_from_float_t) ggml_fp32_to_fp16_row,
+        .vec_dot                  = (ggml_vec_dot_t) ggml_vec_dot_f16,
+        .vec_dot_type             = GGML_TYPE_F16,
+    },
     [GGML_TYPE_Q4_0] = {
-        .dequantize_row_q         = (dequantize_row_q_t) dequantize_row_q4_0,
-        .quantize_row_q           = quantize_row_q4_0,
-        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
-        .quantize_row_q_dot       = quantize_row_q8_0,
-        .vec_dot_q                = ggml_vec_dot_q4_0_q8_0,
+        .to_float                 = (ggml_to_float_t) dequantize_row_q4_0,
+        .from_float               = quantize_row_q4_0,
+        .from_float_reference     = (ggml_from_float_t) quantize_row_q4_0_reference,
+        .vec_dot                  = ggml_vec_dot_q4_0_q8_0,
         .vec_dot_type             = GGML_TYPE_Q8_0,
     },
     [GGML_TYPE_Q4_1] = {
-        .dequantize_row_q         = (dequantize_row_q_t)dequantize_row_q4_1,
-        .quantize_row_q           = quantize_row_q4_1,
-        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
-        .quantize_row_q_dot       = quantize_row_q8_1,
-        .vec_dot_q                = ggml_vec_dot_q4_1_q8_1,
+        .to_float                 = (ggml_to_float_t) dequantize_row_q4_1,
+        .from_float               = quantize_row_q4_1,
+        .from_float_reference     = (ggml_from_float_t) quantize_row_q4_1_reference,
+        .vec_dot                  = ggml_vec_dot_q4_1_q8_1,
         .vec_dot_type             = GGML_TYPE_Q8_1,
     },
     [GGML_TYPE_Q5_0] = {
-        .dequantize_row_q         = (dequantize_row_q_t) dequantize_row_q5_0,
-        .quantize_row_q           = quantize_row_q5_0,
-        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_0_reference,
-        .quantize_row_q_dot       = quantize_row_q8_0,
-        .vec_dot_q                = ggml_vec_dot_q5_0_q8_0,
+        .to_float                 = (ggml_to_float_t) dequantize_row_q5_0,
+        .from_float               = quantize_row_q5_0,
+        .from_float_reference     = (ggml_from_float_t) quantize_row_q5_0_reference,
+        .vec_dot                  = ggml_vec_dot_q5_0_q8_0,
         .vec_dot_type             = GGML_TYPE_Q8_0,
     },
     [GGML_TYPE_Q5_1] = {
-        .dequantize_row_q         = (dequantize_row_q_t) dequantize_row_q5_1,
-        .quantize_row_q           = quantize_row_q5_1,
-        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_1_reference,
-        .quantize_row_q_dot       = quantize_row_q8_1,
-        .vec_dot_q                = ggml_vec_dot_q5_1_q8_1,
+        .to_float                 = (ggml_to_float_t) dequantize_row_q5_1,
+        .from_float               = quantize_row_q5_1,
+        .from_float_reference     = (ggml_from_float_t) quantize_row_q5_1_reference,
+        .vec_dot                  = ggml_vec_dot_q5_1_q8_1,
         .vec_dot_type             = GGML_TYPE_Q8_1,
     },
     [GGML_TYPE_Q8_0] = {
-        .dequantize_row_q         = dequantize_row_q8_0,
-        .quantize_row_q           = quantize_row_q8_0,
-        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_0_reference,
-        .quantize_row_q_dot       = quantize_row_q8_0,
-        .vec_dot_q                = ggml_vec_dot_q8_0_q8_0,
+        .to_float                 = dequantize_row_q8_0,
+        .from_float               = quantize_row_q8_0,
+        .from_float_reference     = (ggml_from_float_t) quantize_row_q8_0_reference,
+        .vec_dot                  = ggml_vec_dot_q8_0_q8_0,
         .vec_dot_type             = GGML_TYPE_Q8_0,
     },
     [GGML_TYPE_Q8_1] = {
-        .dequantize_row_q         = NULL,   // TODO
-        .quantize_row_q           = quantize_row_q8_1,
-        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_1_reference,
-        .quantize_row_q_dot       = quantize_row_q8_1,
-        .vec_dot_q                = NULL,   // TODO
+        .from_float               = quantize_row_q8_1,
+        .from_float_reference     = (ggml_from_float_t) quantize_row_q8_1_reference,
         .vec_dot_type             = GGML_TYPE_Q8_1,
     },
 #ifdef GGML_USE_K_QUANTS
     [GGML_TYPE_Q2_K] = {
-        .dequantize_row_q         = (dequantize_row_q_t) dequantize_row_q2_K,
-        .quantize_row_q           = quantize_row_q2_K,
-        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q2_K_reference,
-        .quantize_row_q_dot       = quantize_row_q8_K,
-        .vec_dot_q                = ggml_vec_dot_q2_K_q8_K,
+        .to_float                 = (ggml_to_float_t) dequantize_row_q2_K,
+        .from_float               = quantize_row_q2_K,
+        .from_float_reference     = (ggml_from_float_t) quantize_row_q2_K_reference,
+        .vec_dot                  = ggml_vec_dot_q2_K_q8_K,
         .vec_dot_type             = GGML_TYPE_Q8_K,
     },
     [GGML_TYPE_Q3_K] = {
-        .dequantize_row_q         = (dequantize_row_q_t) dequantize_row_q3_K,
-        .quantize_row_q           = quantize_row_q3_K,
-        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q3_K_reference,
-        .quantize_row_q_dot       = quantize_row_q8_K,
-        .vec_dot_q                = ggml_vec_dot_q3_K_q8_K,
+        .to_float                 = (ggml_to_float_t) dequantize_row_q3_K,
+        .from_float               = quantize_row_q3_K,
+        .from_float_reference     = (ggml_from_float_t) quantize_row_q3_K_reference,
+        .vec_dot                  = ggml_vec_dot_q3_K_q8_K,
         .vec_dot_type             = GGML_TYPE_Q8_K,
     },
     [GGML_TYPE_Q4_K] = {
-        .dequantize_row_q         = (dequantize_row_q_t) dequantize_row_q4_K,
-        .quantize_row_q           = quantize_row_q4_K,
-        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_K_reference,
-        .quantize_row_q_dot       = quantize_row_q8_K,
-        .vec_dot_q                = ggml_vec_dot_q4_K_q8_K,
+        .to_float                 = (ggml_to_float_t) dequantize_row_q4_K,
+        .from_float               = quantize_row_q4_K,
+        .from_float_reference     = (ggml_from_float_t) quantize_row_q4_K_reference,
+        .vec_dot                  = ggml_vec_dot_q4_K_q8_K,
         .vec_dot_type             = GGML_TYPE_Q8_K,
     },
     [GGML_TYPE_Q5_K] = {
-        .dequantize_row_q         = (dequantize_row_q_t) dequantize_row_q5_K,
-        .quantize_row_q           = quantize_row_q5_K,
-        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_K_reference,
-        .quantize_row_q_dot       = quantize_row_q8_K,
-        .vec_dot_q                = ggml_vec_dot_q5_K_q8_K,
+        .to_float                 = (ggml_to_float_t) dequantize_row_q5_K,
+        .from_float               = quantize_row_q5_K,
+        .from_float_reference     = (ggml_from_float_t) quantize_row_q5_K_reference,
+        .vec_dot                  = ggml_vec_dot_q5_K_q8_K,
         .vec_dot_type             = GGML_TYPE_Q8_K,
     },
     [GGML_TYPE_Q6_K] = {
-        .dequantize_row_q         = (dequantize_row_q_t) dequantize_row_q6_K,
-        .quantize_row_q           = quantize_row_q6_K,
-        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q6_K_reference,
-        .quantize_row_q_dot       = quantize_row_q8_K,
-        .vec_dot_q                = ggml_vec_dot_q6_K_q8_K,
+        .to_float                 = (ggml_to_float_t) dequantize_row_q6_K,
+        .from_float               = quantize_row_q6_K,
+        .from_float_reference     = (ggml_from_float_t) quantize_row_q6_K_reference,
+        .vec_dot                  = ggml_vec_dot_q6_K_q8_K,
         .vec_dot_type             = GGML_TYPE_Q8_K,
     },
+    [GGML_TYPE_Q8_K] = {
+        .from_float               = quantize_row_q8_K,
+    }
 #endif
 };
 
 // For internal test use
-quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
+ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type i) {
     GGML_ASSERT(i < GGML_TYPE_COUNT);
-    return quantize_fns[i];
+    return type_traits[i];
 }
 
 
@@ -2275,7 +2278,7 @@ inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x)
 inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i]*y[i];   }
 inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i]/y[i];   }
 
-inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) {
+static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) {
 #ifdef GGML_SIMD
     float sumf = 0.0f;
     const int np = (n & ~(GGML_F32_STEP - 1));
@@ -2312,7 +2315,7 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
     *s = sumf;
 }
 
-inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
+static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
     ggml_float sumf = 0.0;
 
 #if defined(GGML_SIMD)
@@ -7825,8 +7828,8 @@ static void ggml_compute_forward_dup_f16(
                         id += ne00 * (ne01 - ir1);
                     }
                 }
-            } else if (ggml_is_quantized(dst->type)) {
-                quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
+            } else if (type_traits[dst->type].from_float) {
+                ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float;
                 float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
 
                 size_t id = 0;
@@ -8078,26 +8081,8 @@ static void ggml_compute_forward_dup_f32(
                         id += rs * (ne01 - ir1);
                     }
                 }
-            } else if (dst->type == GGML_TYPE_F16) {
-                size_t id = 0;
-                ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += ne00 * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            for (int i00 = 0; i00 < ne00; i00++) {
-                                const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-
-                                dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
-                                id++;
-                            }
-                        }
-                        id += ne00 * (ne01 - ir1);
-                    }
-                }
-            } else if (ggml_is_quantized(dst->type)) {
-                quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
+            } else if (type_traits[dst->type].from_float) {
+                ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float;
 
                 size_t id = 0;
                 size_t rs = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
@@ -8503,8 +8488,8 @@ static void ggml_compute_forward_add_q_f32(
     const int nth = params->nth;
 
     const enum ggml_type type = src0->type;
-    dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
-    quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q;
+    ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
+    ggml_from_float_t const quantize_row_q = type_traits[type].from_float;
 
     // we don't support permuted src0 or src1
     GGML_ASSERT(nb00 == GGML_TYPE_SIZE[type]);
@@ -8777,8 +8762,8 @@ static void ggml_compute_forward_add1_q_f32(
     GGML_TENSOR_UNARY_OP_LOCALS;
 
     const enum ggml_type type = src0->type;
-    dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
-    quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q;
+    ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
+    ggml_from_float_t const quantize_row_q = type_traits[type].from_float;
 
     // we don't support permuted src0
     GGML_ASSERT(nb00 == GGML_TYPE_SIZE[type]);
@@ -10578,317 +10563,7 @@ static bool ggml_compute_forward_mul_mat_use_blas(
 }
 #endif
 
-static void ggml_compute_forward_mul_mat_f32(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
-              struct ggml_tensor * dst) {
-    int64_t t0 = ggml_perf_time_us();
-    UNUSED(t0);
-
-    GGML_TENSOR_BINARY_OP_LOCALS;
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    assert(ne02 == ne12);
-    assert(ne03 == ne13);
-    assert(ne2  == ne12);
-    assert(ne3  == ne13);
-
-    // we don't support permuted src0 or src1
-    assert(nb00 == sizeof(float));
-    assert(nb10 == sizeof(float));
-
-    // dst cannot be transposed or permuted
-    assert(nb0 == sizeof(float));
-    assert(nb0 <= nb1);
-    assert(nb1 <= nb2);
-    assert(nb2 <= nb3);
-
-    assert(ne0 == ne01);
-    assert(ne1 == ne11);
-    assert(ne2 == ne02);
-    assert(ne3 == ne03);
-
-    // nb01 >= nb00 - src0 is not transposed
-    //   compute by src0 rows
-
-#if defined(GGML_USE_CLBLAST)
-    if (ggml_cl_can_mul_mat(src0, src1, dst)) {
-        if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
-            ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
-        }
-        return;
-    }
-#endif
-
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
-    if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
-        if (params->ith != 0) {
-            return;
-        }
-
-        if (params->type == GGML_TASK_INIT) {
-            return;
-        }
-
-        if (params->type == GGML_TASK_FINALIZE) {
-            return;
-        }
-
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
-                const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
-                float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
-
-                cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
-                        ne11, ne01, ne10,
-                        1.0f,    y, ne10,
-                                 x, ne00,
-                        0.0f,    d, ne01);
-            }
-        }
-        //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
-
-        return;
-    }
-#endif
-
-    if (params->type == GGML_TASK_INIT) {
-        return;
-    }
-
-    if (params->type == GGML_TASK_FINALIZE) {
-        return;
-    }
-
-    // parallelize by src0 rows using ggml_vec_dot_f32
-
-    // total rows in src0
-    const int nr = ne01*ne02*ne03;
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    for (int ir = ir0; ir < ir1; ++ir) {
-        // src0 indices
-        const int i03 = ir/(ne02*ne01);
-        const int i02 = (ir - i03*ne02*ne01)/ne01;
-        const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
-        for (int64_t ic = 0; ic < ne11; ++ic) {
-            // src1 indices
-            const int i13 = i03;
-            const int i12 = i02;
-            const int i11 = ic;
-
-            // dst indices
-            const int i0 = i01;
-            const int i1 = i11;
-            const int i2 = i02;
-            const int i3 = i03;
-
-            ggml_vec_dot_f32(ne00,
-                    (float *) ((char *)  dst->data + (i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
-                    (float *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)),
-                    (float *) ((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13)));
-        }
-    }
-
-    //int64_t t1 = ggml_perf_time_us();
-    //static int64_t acc = 0;
-    //acc += t1 - t0;
-    //if (t1 - t0 > 10) {
-    //    printf("\n");
-    //    printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
-    //    printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
-    //    printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
-    //    printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
-
-    //    printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
-    //}
-}
-
-static void ggml_compute_forward_mul_mat_f16_f32(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
-              struct ggml_tensor * dst) {
-    int64_t t0 = ggml_perf_time_us();
-    UNUSED(t0);
-
-    GGML_TENSOR_BINARY_OP_LOCALS;
-
-    //const int64_t ne   = ne0*ne1*ne2*ne3;
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    GGML_ASSERT(ne02 == ne12);
-    GGML_ASSERT(ne03 == ne13);
-    GGML_ASSERT(ne2  == ne12);
-    GGML_ASSERT(ne3  == ne13);
-
-    // TODO: we don't support permuted src0
-    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
-
-    // dst cannot be transposed or permuted
-    GGML_ASSERT(nb0 == sizeof(float));
-    GGML_ASSERT(nb0 <= nb1);
-    GGML_ASSERT(nb1 <= nb2);
-    GGML_ASSERT(nb2 <= nb3);
-
-    GGML_ASSERT(ne0 == ne01);
-    GGML_ASSERT(ne1 == ne11);
-    GGML_ASSERT(ne2 == ne02);
-    GGML_ASSERT(ne3 == ne03);
-
-    // nb01 >= nb00 - src0 is not transposed
-    //   compute by src0 rows
-
-#if defined(GGML_USE_CLBLAST)
-    if (ggml_cl_can_mul_mat(src0, src1, dst)) {
-        if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
-            ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
-        }
-        return;
-    }
-#endif
-
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
-    if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
-        GGML_ASSERT(nb10 == sizeof(float));
-
-        if (params->ith != 0) {
-            return;
-        }
-
-        if (params->type == GGML_TASK_INIT) {
-            return;
-        }
-
-        if (params->type == GGML_TASK_FINALIZE) {
-            return;
-        }
-
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                float * const wdata = params->wdata;
-                {
-                    size_t id = 0;
-                    for (int64_t i01 = 0; i01 < ne01; ++i01) {
-                        for (int64_t i00 = 0; i00 < ne00; ++i00) {
-                            wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00));
-                        }
-                    }
-
-                    assert(id*sizeof(float) <= params->wsize);
-                }
-
-                const float * x = wdata;
-                const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
-
-                float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
-
-                // zT = y * xT
-                cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
-                        ne11, ne01, ne10,
-                        1.0f,    y, ne10,
-                                 x, ne00,
-                        0.0f,    d, ne01);
-            }
-        }
-
-        /*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
-
-        return;
-    }
-#endif
-
-    if (params->type == GGML_TASK_INIT) {
-        ggml_fp16_t * const wdata = params->wdata;
-
-        size_t id = 0;
-        for (int64_t i13 = 0; i13 < ne13; ++i13) {
-            for (int64_t i12 = 0; i12 < ne12; ++i12) {
-                for (int64_t i11 = 0; i11 < ne11; ++i11) {
-                    for (int64_t i10 = 0; i10 < ne10; ++i10) {
-                        wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10));
-                    }
-                }
-            }
-        }
-
-        GGML_ASSERT(id*sizeof(ggml_fp16_t) <= params->wsize);
-
-        return;
-    }
-
-    if (params->type == GGML_TASK_FINALIZE) {
-        return;
-    }
-
-    // fp16 -> half the size, so divide by 2
-    // TODO: do not support transposed src1
-    assert(nb10/2 == sizeof(ggml_fp16_t));
-
-    // parallelize by src0 rows using ggml_vec_dot_f16
-
-    // total rows in src0
-    const int nr = ne01*ne02*ne03;
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    ggml_fp16_t * wdata = params->wdata;
-
-    for (int ir = ir0; ir < ir1; ++ir) {
-        // src0 indices
-        const int i03 = ir/(ne02*ne01);
-        const int i02 = (ir - i03*ne02*ne01)/ne01;
-        const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
-        const int i13 = i03;
-        const int i12 = i02;
-
-        const int i0 = i01;
-        const int i2 = i02;
-        const int i3 = i03;
-
-        ggml_fp16_t * src0_row = (ggml_fp16_t *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
-        ggml_fp16_t * src1_col =                                wdata + (       0 + i12*ne11 + i13*ne12*ne11)*ne00;
-
-        float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
-
-        for (int64_t ic = 0; ic < ne11; ++ic) {
-            ggml_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00);
-        }
-    }
-
-    //int64_t t1 = ggml_time_us();
-    //static int64_t acc = 0;
-    //acc += t1 - t0;
-    //if (t1 - t0 > 10) {
-    //    printf("\n");
-    //    printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
-    //    printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
-    //    printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
-
-    //    printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
-    //}
-}
-
-static void ggml_compute_forward_mul_mat_q_f32(
+static void ggml_compute_forward_mul_mat(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -10907,9 +10582,10 @@ static void ggml_compute_forward_mul_mat_q_f32(
     GGML_ASSERT(ne3  == ne13);
 
     const enum ggml_type type = src0->type;
-    quantize_row_q_t const quantize_row_q_dot = quantize_fns[type].quantize_row_q_dot;
-    vec_dot_q_t      const vec_dot_q          = quantize_fns[type].vec_dot_q;
-    enum ggml_type   const vec_dot_type       = quantize_fns[type].vec_dot_type;
+
+    ggml_vec_dot_t    const vec_dot               = type_traits[type].vec_dot;
+    enum ggml_type    const vec_dot_type          = type_traits[type].vec_dot_type;
+    ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
 
     // we don't support permuted src0 or src1
     GGML_ASSERT(nb00 == GGML_TYPE_SIZE[type]);
@@ -10952,27 +10628,27 @@ static void ggml_compute_forward_mul_mat_q_f32(
             return;
         }
 
-        float * const wdata = params->wdata;
-        dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
-
         for (int64_t i03 = 0; i03 < ne03; i03++) {
             for (int64_t i02 = 0; i02 < ne02; i02++) {
+                const void * x = (char *) src0->data + i03*nb03 + i02*nb02;
                 const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
 
                 float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
 
-                {
+                if (type != GGML_TYPE_F32) {
+                    float * const wdata = params->wdata;
+                    ggml_to_float_t const to_float = type_traits[type].to_float;
+
                     size_t id = 0;
                     for (int64_t i01 = 0; i01 < ne01; ++i01) {
-                        dequantize_row_q((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00);
+                        to_float((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00);
                         id += ne00;
                     }
 
                     assert(id*sizeof(float) <= params->wsize);
+                    x = wdata;
                 }
 
-                const float * x = wdata;
-
                 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
                         ne11, ne01, ne10,
                         1.0f,    y, ne10,
@@ -10988,14 +10664,16 @@ static void ggml_compute_forward_mul_mat_q_f32(
 #endif
 
     if (params->type == GGML_TASK_INIT) {
-        char * wdata = params->wdata;
-        const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
-
-        for (int64_t i13 = 0; i13 < ne13; ++i13) {
-            for (int64_t i12 = 0; i12 < ne12; ++i12) {
-                for (int64_t i11 = 0; i11 < ne11; ++i11) {
-                    quantize_row_q_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
-                    wdata += row_size;
+        if (src1->type != vec_dot_type) {
+            char * wdata = params->wdata;
+            const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
+
+            for (int64_t i13 = 0; i13 < ne13; ++i13) {
+                for (int64_t i12 = 0; i12 < ne12; ++i12) {
+                    for (int64_t i11 = 0; i11 < ne11; ++i11) {
+                        from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
+                        wdata += row_size;
+                    }
                 }
             }
         }
@@ -11019,7 +10697,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
     const int ir0 = dr*ith;
     const int ir1 = MIN(ir0 + dr, nr);
 
-    void * wdata = params->wdata;
+    void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
     const size_t row_size = ne00*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
 
     for (int ir = ir0; ir < ir1; ++ir) {
@@ -11043,7 +10721,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
         assert(ne00 % 32 == 0);
 
         for (int64_t ic = 0; ic < ne11; ++ic) {
-            vec_dot_q(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size));
+            vec_dot(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size));
         }
     }
 
@@ -11060,40 +10738,6 @@ static void ggml_compute_forward_mul_mat_q_f32(
     //}
 }
 
-static void ggml_compute_forward_mul_mat(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
-        struct ggml_tensor * dst) {
-    switch (src0->type) {
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_Q5_0:
-        case GGML_TYPE_Q5_1:
-        case GGML_TYPE_Q8_0:
-        case GGML_TYPE_Q8_1:
-        case GGML_TYPE_Q2_K:
-        case GGML_TYPE_Q3_K:
-        case GGML_TYPE_Q4_K:
-        case GGML_TYPE_Q5_K:
-        case GGML_TYPE_Q6_K:
-            {
-                ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
-            } break;
-        case GGML_TYPE_F16:
-            {
-                ggml_compute_forward_mul_mat_f16_f32(params, src0, src1, dst);
-            } break;
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_mul_mat_f32(params, src0, src1, dst);
-            } break;
-        default:
-            {
-                GGML_ASSERT(false);
-            } break;
-    }
-}
 
 // ggml_compute_forward_out_prod
 
@@ -11483,7 +11127,7 @@ static void ggml_compute_forward_get_rows_q(
     const int nc = src0->ne[0];
     const int nr = ggml_nelements(src1);
     const enum ggml_type type = src0->type;
-    dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
+    ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
 
     assert( dst->ne[0] == nc);
     assert( dst->ne[1] == nr);
@@ -16529,6 +16173,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                         //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks = %d\n", nr0, nr1, nr0*nr1, node->n_tasks);
 
                         size_t cur = 0;
+                        const enum ggml_type vec_dot_type = type_traits[node->src0->type].vec_dot_type;
 
 #if defined(GGML_USE_CUBLAS)
                         if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) {
@@ -16544,37 +16189,18 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                         }
                         else
 #endif
-                        if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) {
 #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
-                            if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
-                                node->n_tasks = 1; // TODO: this actually is doing nothing
-                                                   //       the threads are still spinning
+                        if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
+                            node->n_tasks = 1; // TODO: this actually is doing nothing
+                                               //       the threads are still spinning
+                            if (node->src0->type != GGML_TYPE_F32) {
                                 // here we need memory just for single 2D matrix from src0
                                 cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
-                            } else {
-                                cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1);
-                            }
-#else
-                            cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1);
-#endif
-                        } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
-                            cur = 0;
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
-                            if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
-                                node->n_tasks = 1;
                             }
+                        } else
 #endif
-                        } else if (ggml_is_quantized(node->src0->type) && node->src1->type == GGML_TYPE_F32) {
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
-                            if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
-                                node->n_tasks = 1;
-                                cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
-                            } else
-#endif
-                            {
-                                const enum ggml_type type_q = quantize_fns[node->src0->type].vec_dot_type;
-                                cur = GGML_TYPE_SIZE[type_q]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[type_q];
-                            }
+                        if (node->src1->type != vec_dot_type) {
+                            cur = GGML_TYPE_SIZE[vec_dot_type]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[vec_dot_type];
                         } else {
                             GGML_ASSERT(false);
                         }