]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml-cuda : perform cublas mat mul of quantized types as f16 (#3412)
authorslaren <redacted>
Sat, 30 Sep 2023 16:12:57 +0000 (18:12 +0200)
committerGitHub <redacted>
Sat, 30 Sep 2023 16:12:57 +0000 (18:12 +0200)
* ggml-cuda : perform cublas matrix multiplication of quantized types as fp16

* rename CC_TURING to CC_VOLTA

* disable fp16 mat mul completely with multi GPU

ggml-cuda.cu

index 86d1fe203a4653d7c5ea1697f4cf7de85ea4b240..989c419cd0ea44a6db666f3b63958f9ae6bb642f 100644 (file)
@@ -80,9 +80,9 @@
 #include "ggml.h"
 
 #define MIN_CC_DP4A   610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
-#define CC_TURING     700
+#define CC_VOLTA      700
 #define CC_OFFSET_AMD 1000000
-#define CC_RDNA2      CC_OFFSET_AMD + 1030
+#define CC_RDNA2      (CC_OFFSET_AMD + 1030)
 
 #if defined(GGML_USE_HIPBLAS)
 #define __CUDA_ARCH__ 1300
@@ -715,7 +715,8 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in
 
 //================================== k-quants
 
-static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float * __restrict__ yy) {
+template<typename dst_t>
+static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
 
     const int i   = blockIdx.x;
     const block_q2_K * x = (const block_q2_K *) vx;
@@ -727,7 +728,7 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float
     const int is  = 8*n + l/16;
 
     const uint8_t q = x[i].qs[32*n + l];
-    float * y = yy + i*QK_K + 128*n;
+    dst_t * y = yy + i*QK_K + 128*n;
 
     float dall = __low2half(x[i].dm);
     float dmin = __high2half(x[i].dm);
@@ -739,7 +740,7 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float
     const int is = tid/16;  // 0 or 1
     const int il = tid%16;  // 0...15
     const uint8_t q = x[i].qs[il] >> (2*is);
-    float * y = yy + i*QK_K + 16*is + il;
+    dst_t * y = yy + i*QK_K + 16*is + il;
     float dall = __low2half(x[i].dm);
     float dmin = __high2half(x[i].dm);
     y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
@@ -748,7 +749,8 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float
 
 }
 
-static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, float * __restrict__ yy) {
+template<typename dst_t>
+static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
 
     const int i = blockIdx.x;
     const block_q3_K * x = (const block_q3_K *) vx;
@@ -772,7 +774,7 @@ static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, float
     float d_all = x[i].d;
     float dl = d_all * (us - 32);
 
-    float * y = yy + i*QK_K + 128*n + 32*j;
+    dst_t * y = yy + i*QK_K + 128*n + 32*j;
     const uint8_t * q = x[i].qs + 32*n;
     const uint8_t * hm = x[i].hmask;
 
@@ -784,7 +786,7 @@ static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, float
     const int im  = il/8;    // 0...1
     const int in  = il%8;    // 0...7
 
-    float * y = yy + i*QK_K + 16*is + il;
+    dst_t * y = yy + i*QK_K + 16*is + il;
 
     const uint8_t q = x[i].qs[il] >> (2*is);
     const uint8_t h = x[i].hmask[in] >> (2*is + im);
@@ -812,7 +814,8 @@ static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t
 }
 #endif
 
-static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float * __restrict__ yy) {
+template<typename dst_t>
+static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
     const block_q4_K * x = (const block_q4_K *) vx;
 
     const int i = blockIdx.x;
@@ -825,7 +828,7 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float
     const int is  = 2*il;
     const int n   = 4;
 
-    float * y = yy + i*QK_K + 64*il + n*ir;
+    dst_t * y = yy + i*QK_K + 64*il + n*ir;
 
     const float dall = __low2half(x[i].dm);
     const float dmin = __high2half(x[i].dm);
@@ -844,7 +847,7 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float
 #else
     const int tid = threadIdx.x;
     const uint8_t * q = x[i].qs;
-    float * y = yy + i*QK_K;
+    dst_t * y = yy + i*QK_K;
     const float d = (float)x[i].dm[0];
     const float m = (float)x[i].dm[1];
     y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4);
@@ -852,7 +855,8 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float
 #endif
 }
 
-static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float * __restrict__ yy) {
+template<typename dst_t>
+static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
     const block_q5_K * x = (const block_q5_K *) vx;
 
     const int i = blockIdx.x;
@@ -864,7 +868,7 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float
     const int ir  = tid%16;   // ir is in 0...15
     const int is  = 2*il;     // is is in 0...6
 
-    float * y = yy + i*QK_K + 64*il + 2*ir;
+    dst_t * y = yy + i*QK_K + 64*il + 2*ir;
 
     const float dall = __low2half(x[i].dm);
     const float dmin = __high2half(x[i].dm);
@@ -892,13 +896,14 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float
     const int is = tid/16; // 0 or 1
     const uint8_t h = x[i].qh[in] >> im;
     const float d = x[i].d;
-    float * y = yy + i*QK_K + tid;
+    dst_t * y = yy + i*QK_K + tid;
     y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16));
     y[32] = d * x[i].scales[is+2] * ((q >>  4) - ((h >> 4) & 1 ? 0 : 16));
 #endif
 }
 
-static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, float * __restrict__ yy) {
+template<typename dst_t>
+static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
     const block_q6_K * x = (const block_q6_K *) vx;
 
     const int i = blockIdx.x;
@@ -910,7 +915,7 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, float
     const int il  = tid - 32*ip; // 0...32
     const int is  = 8*ip + il/16;
 
-    float * y = yy + i*QK_K + 128*ip + il;
+    dst_t * y = yy + i*QK_K + 128*ip + il;
 
     const float d = x[i].d;
 
@@ -929,7 +934,7 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, float
     const int ip  = tid/16;         // 0 or 1
     const int il  = tid - 16*ip;    // 0...15
 
-    float * y = yy + i*QK_K + 16*ip + il;
+    dst_t * y = yy + i*QK_K + 16*ip + il;
 
     const float d = x[i].d;
 
@@ -3548,7 +3553,7 @@ template <bool need_check> static __global__ void
         load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
         (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 
-#elif __CUDA_ARCH__ >= CC_TURING
+#elif __CUDA_ARCH__ >= CC_VOLTA
     const int mmq_x  =  MMQ_X_Q4_0_AMPERE;
     const int mmq_y  =  MMQ_Y_Q4_0_AMPERE;
     const int nwarps = NWARPS_Q4_0_AMPERE;
@@ -3568,7 +3573,7 @@ template <bool need_check> static __global__ void
 #else
     (void) vec_dot_q4_0_q8_1_mul_mat;
     assert(false);
-#endif // __CUDA_ARCH__ >= CC_TURING
+#endif // __CUDA_ARCH__ >= CC_VOLTA
 }
 
 #define  MMQ_X_Q4_1_RDNA2  64
@@ -3589,9 +3594,9 @@ template <bool need_check> static __global__ void
 #if defined(RDNA3) || defined(RDNA2)
     __launch_bounds__(WARP_SIZE*NWARPS_Q4_1_RDNA2, 2)
 #endif // defined(RDNA3) || defined(RDNA2)
-#elif __CUDA_ARCH__ < CC_TURING
+#elif __CUDA_ARCH__ < CC_VOLTA
     __launch_bounds__(WARP_SIZE*NWARPS_Q4_1_PASCAL, 2)
-#endif // __CUDA_ARCH__ < CC_TURING
+#endif // __CUDA_ARCH__ < CC_VOLTA
     mul_mat_q4_1(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
@@ -3611,7 +3616,7 @@ template <bool need_check> static __global__ void
         load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
         (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 
-#elif __CUDA_ARCH__ >= CC_TURING
+#elif __CUDA_ARCH__ >= CC_VOLTA
     const int mmq_x  =  MMQ_X_Q4_1_AMPERE;
     const int mmq_y  =  MMQ_Y_Q4_1_AMPERE;
     const int nwarps = NWARPS_Q4_1_AMPERE;
@@ -3631,7 +3636,7 @@ template <bool need_check> static __global__ void
 #else
     (void) vec_dot_q4_1_q8_1_mul_mat;
     assert(false);
-#endif // __CUDA_ARCH__ >= CC_TURING
+#endif // __CUDA_ARCH__ >= CC_VOLTA
 }
 
 #define  MMQ_X_Q5_0_RDNA2  64
@@ -3672,7 +3677,7 @@ template <bool need_check> static __global__ void
         load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
         (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 
-#elif __CUDA_ARCH__ >= CC_TURING
+#elif __CUDA_ARCH__ >= CC_VOLTA
     const int mmq_x  =  MMQ_X_Q5_0_AMPERE;
     const int mmq_y  =  MMQ_Y_Q5_0_AMPERE;
     const int nwarps = NWARPS_Q5_0_AMPERE;
@@ -3692,7 +3697,7 @@ template <bool need_check> static __global__ void
 #else
     (void) vec_dot_q5_0_q8_1_mul_mat;
     assert(false);
-#endif // __CUDA_ARCH__ >= CC_TURING
+#endif // __CUDA_ARCH__ >= CC_VOLTA
 }
 
 #define  MMQ_X_Q5_1_RDNA2  64
@@ -3733,7 +3738,7 @@ mul_mat_q5_1(
         load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
         (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 
-#elif __CUDA_ARCH__ >= CC_TURING
+#elif __CUDA_ARCH__ >= CC_VOLTA
     const int mmq_x  =  MMQ_X_Q5_1_AMPERE;
     const int mmq_y  =  MMQ_Y_Q5_1_AMPERE;
     const int nwarps = NWARPS_Q5_1_AMPERE;
@@ -3753,7 +3758,7 @@ mul_mat_q5_1(
 #else
     (void) vec_dot_q5_1_q8_1_mul_mat;
     assert(false);
-#endif // __CUDA_ARCH__ >= CC_TURING
+#endif // __CUDA_ARCH__ >= CC_VOLTA
 }
 
 #define  MMQ_X_Q8_0_RDNA2  64
@@ -3794,7 +3799,7 @@ template <bool need_check> static __global__ void
         load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
         (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 
-#elif __CUDA_ARCH__ >= CC_TURING
+#elif __CUDA_ARCH__ >= CC_VOLTA
     const int mmq_x  =  MMQ_X_Q8_0_AMPERE;
     const int mmq_y  =  MMQ_Y_Q8_0_AMPERE;
     const int nwarps = NWARPS_Q8_0_AMPERE;
@@ -3814,7 +3819,7 @@ template <bool need_check> static __global__ void
 #else
     (void) vec_dot_q8_0_q8_1_mul_mat;
     assert(false);
-#endif // __CUDA_ARCH__ >= CC_TURING
+#endif // __CUDA_ARCH__ >= CC_VOLTA
 }
 
 #define  MMQ_X_Q2_K_RDNA2  64
@@ -3855,7 +3860,7 @@ mul_mat_q2_K(
         load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
         (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 
-#elif __CUDA_ARCH__ >= CC_TURING
+#elif __CUDA_ARCH__ >= CC_VOLTA
     const int mmq_x  =  MMQ_X_Q2_K_AMPERE;
     const int mmq_y  =  MMQ_Y_Q2_K_AMPERE;
     const int nwarps = NWARPS_Q2_K_AMPERE;
@@ -3875,7 +3880,7 @@ mul_mat_q2_K(
 #else
     (void) vec_dot_q2_K_q8_1_mul_mat;
     assert(false);
-#endif // __CUDA_ARCH__ >= CC_TURING
+#endif // __CUDA_ARCH__ >= CC_VOLTA
 }
 
 #define  MMQ_X_Q3_K_RDNA2  128
@@ -3896,9 +3901,9 @@ template <bool need_check> static __global__ void
 #if defined(RDNA3) || defined(RDNA2)
     __launch_bounds__(WARP_SIZE*NWARPS_Q3_K_RDNA2, 2)
 #endif // defined(RDNA3) || defined(RDNA2)
-#elif __CUDA_ARCH__ < CC_TURING
+#elif __CUDA_ARCH__ < CC_VOLTA
     __launch_bounds__(WARP_SIZE*NWARPS_Q3_K_PASCAL, 2)
-#endif // __CUDA_ARCH__ < CC_TURING
+#endif // __CUDA_ARCH__ < CC_VOLTA
     mul_mat_q3_K(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
@@ -3918,7 +3923,7 @@ template <bool need_check> static __global__ void
         load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
         (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 
-#elif __CUDA_ARCH__ >= CC_TURING
+#elif __CUDA_ARCH__ >= CC_VOLTA
     const int mmq_x  =  MMQ_X_Q3_K_AMPERE;
     const int mmq_y  =  MMQ_Y_Q3_K_AMPERE;
     const int nwarps = NWARPS_Q3_K_AMPERE;
@@ -3938,7 +3943,7 @@ template <bool need_check> static __global__ void
 #else
     (void) vec_dot_q3_K_q8_1_mul_mat;
     assert(false);
-#endif // __CUDA_ARCH__ >= CC_TURING
+#endif // __CUDA_ARCH__ >= CC_VOLTA
 }
 
 #define  MMQ_X_Q4_K_RDNA2  64
@@ -3959,9 +3964,9 @@ template <bool need_check> static __global__ void
 #if defined(RDNA3) || defined(RDNA2)
     __launch_bounds__(WARP_SIZE*NWARPS_Q4_K_RDNA2, 2)
 #endif // defined(RDNA3) || defined(RDNA2)
-#elif __CUDA_ARCH__ < CC_TURING
+#elif __CUDA_ARCH__ < CC_VOLTA
     __launch_bounds__(WARP_SIZE*NWARPS_Q4_K_PASCAL, 2)
-#endif // __CUDA_ARCH__ < CC_TURING
+#endif // __CUDA_ARCH__ < CC_VOLTA
     mul_mat_q4_K(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
@@ -3981,7 +3986,7 @@ template <bool need_check> static __global__ void
         load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
         (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 
-#elif __CUDA_ARCH__ >= CC_TURING
+#elif __CUDA_ARCH__ >= CC_VOLTA
     const int mmq_x  =  MMQ_X_Q4_K_AMPERE;
     const int mmq_y  =  MMQ_Y_Q4_K_AMPERE;
     const int nwarps = NWARPS_Q4_K_AMPERE;
@@ -4001,7 +4006,7 @@ template <bool need_check> static __global__ void
 #else
     (void) vec_dot_q4_K_q8_1_mul_mat;
     assert(false);
-#endif // __CUDA_ARCH__ >= CC_TURING
+#endif // __CUDA_ARCH__ >= CC_VOLTA
 }
 
 #define  MMQ_X_Q5_K_RDNA2  64
@@ -4042,7 +4047,7 @@ mul_mat_q5_K(
         load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
         (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 
-#elif __CUDA_ARCH__ >= CC_TURING
+#elif __CUDA_ARCH__ >= CC_VOLTA
     const int mmq_x  =  MMQ_X_Q5_K_AMPERE;
     const int mmq_y  =  MMQ_Y_Q5_K_AMPERE;
     const int nwarps = NWARPS_Q5_K_AMPERE;
@@ -4062,7 +4067,7 @@ mul_mat_q5_K(
 #else
     (void) vec_dot_q5_K_q8_1_mul_mat;
     assert(false);
-#endif // __CUDA_ARCH__ >= CC_TURING
+#endif // __CUDA_ARCH__ >= CC_VOLTA
 }
 
 #define  MMQ_X_Q6_K_RDNA2  64
@@ -4083,9 +4088,9 @@ template <bool need_check> static __global__ void
 #if defined(RDNA3) || defined(RDNA2)
     __launch_bounds__(WARP_SIZE*NWARPS_Q6_K_RDNA2, 2)
 #endif // defined(RDNA3) || defined(RDNA2)
-#elif __CUDA_ARCH__ < CC_TURING
+#elif __CUDA_ARCH__ < CC_VOLTA
     __launch_bounds__(WARP_SIZE*NWARPS_Q6_K_PASCAL, 2)
-#endif // __CUDA_ARCH__ < CC_TURING
+#endif // __CUDA_ARCH__ < CC_VOLTA
     mul_mat_q6_K(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
@@ -4105,7 +4110,7 @@ template <bool need_check> static __global__ void
         load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
         (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 
-#elif __CUDA_ARCH__ >= CC_TURING
+#elif __CUDA_ARCH__ >= CC_VOLTA
     const int mmq_x  =  MMQ_X_Q6_K_AMPERE;
     const int mmq_y  =  MMQ_Y_Q6_K_AMPERE;
     const int nwarps = NWARPS_Q6_K_AMPERE;
@@ -4125,7 +4130,7 @@ template <bool need_check> static __global__ void
 #else
     (void) vec_dot_q6_K_q8_1_mul_mat;
     assert(false);
-#endif // __CUDA_ARCH__ >= CC_TURING
+#endif // __CUDA_ARCH__ >= CC_VOLTA
 }
 
 template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
@@ -4604,32 +4609,38 @@ static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, con
     quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);
 }
 
-static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+template<typename dst_t>
+static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
     dequantize_block<QK4_0, QR4_0, dequantize_q4_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
 }
 
-static void dequantize_row_q4_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+template<typename dst_t>
+static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
     dequantize_block<QK4_1, QR4_1, dequantize_q4_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
 }
 
-static void dequantize_row_q5_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+template<typename dst_t>
+static void dequantize_row_q5_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
     dequantize_block<QK5_0, QR5_0, dequantize_q5_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
 }
 
-static void dequantize_row_q5_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+template<typename dst_t>
+static void dequantize_row_q5_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
     dequantize_block<QK5_1, QR5_1, dequantize_q5_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
 }
 
-static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+template<typename dst_t>
+static void dequantize_row_q8_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
     dequantize_block<QK8_0, QR8_0, dequantize_q8_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
 }
 
-static void dequantize_row_q2_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+template<typename dst_t>
+static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
     const int nb = k / QK_K;
 #if QK_K == 256
     dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -4638,7 +4649,8 @@ static void dequantize_row_q2_K_cuda(const void * vx, float * y, const int k, cu
 #endif
 }
 
-static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+template<typename dst_t>
+static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
     const int nb = k / QK_K;
 #if QK_K == 256
     dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -4647,12 +4659,14 @@ static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cu
 #endif
 }
 
-static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+template<typename dst_t>
+static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
     const int nb = k / QK_K;
     dequantize_block_q4_K<<<nb, 32, 0, stream>>>(vx, y);
 }
 
-static void dequantize_row_q5_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+template<typename dst_t>
+static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
     const int nb = k / QK_K;
 #if QK_K == 256
     dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -4661,7 +4675,8 @@ static void dequantize_row_q5_K_cuda(const void * vx, float * y, const int k, cu
 #endif
 }
 
-static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+template<typename dst_t>
+static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
     const int nb = k / QK_K;
 #if QK_K == 256
     dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -4868,6 +4883,26 @@ static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, floa
 
 static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
     switch (type) {
+        case GGML_TYPE_Q4_0:
+            return dequantize_row_q4_0_cuda;
+        case GGML_TYPE_Q4_1:
+            return dequantize_row_q4_1_cuda;
+        case GGML_TYPE_Q5_0:
+            return dequantize_row_q5_0_cuda;
+        case GGML_TYPE_Q5_1:
+            return dequantize_row_q5_1_cuda;
+        case GGML_TYPE_Q8_0:
+            return dequantize_row_q8_0_cuda;
+        case GGML_TYPE_Q2_K:
+            return dequantize_row_q2_K_cuda;
+        case GGML_TYPE_Q3_K:
+            return dequantize_row_q3_K_cuda;
+        case GGML_TYPE_Q4_K:
+            return dequantize_row_q4_K_cuda;
+        case GGML_TYPE_Q5_K:
+            return dequantize_row_q5_K_cuda;
+        case GGML_TYPE_Q6_K:
+            return dequantize_row_q6_K_cuda;
         case GGML_TYPE_F32:
             return convert_fp32_to_fp16_cuda;
         default:
@@ -4921,7 +4956,7 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(
         mmq_x  =  MMQ_X_Q4_0_RDNA1;
         mmq_y  =  MMQ_Y_Q4_0_RDNA1;
         nwarps = NWARPS_Q4_0_RDNA1;
-    } else if (compute_capability >= CC_TURING) {
+    } else if (compute_capability >= CC_VOLTA) {
         mmq_x  =  MMQ_X_Q4_0_AMPERE;
         mmq_y  =  MMQ_Y_Q4_0_AMPERE;
         nwarps = NWARPS_Q4_0_AMPERE;
@@ -4966,7 +5001,7 @@ static void ggml_mul_mat_q4_1_q8_1_cuda(
         mmq_x  =  MMQ_X_Q4_1_RDNA1;
         mmq_y  =  MMQ_Y_Q4_1_RDNA1;
         nwarps = NWARPS_Q4_1_RDNA1;
-    } else if (compute_capability >= CC_TURING) {
+    } else if (compute_capability >= CC_VOLTA) {
         mmq_x  =  MMQ_X_Q4_1_AMPERE;
         mmq_y  =  MMQ_Y_Q4_1_AMPERE;
         nwarps = NWARPS_Q4_1_AMPERE;
@@ -5011,7 +5046,7 @@ static void ggml_mul_mat_q5_0_q8_1_cuda(
         mmq_x  =  MMQ_X_Q5_0_RDNA1;
         mmq_y  =  MMQ_Y_Q5_0_RDNA1;
         nwarps = NWARPS_Q5_0_RDNA1;
-    } else if (compute_capability >= CC_TURING) {
+    } else if (compute_capability >= CC_VOLTA) {
         mmq_x  =  MMQ_X_Q5_0_AMPERE;
         mmq_y  =  MMQ_Y_Q5_0_AMPERE;
         nwarps = NWARPS_Q5_0_AMPERE;
@@ -5056,7 +5091,7 @@ static void ggml_mul_mat_q5_1_q8_1_cuda(
         mmq_x  =  MMQ_X_Q5_1_RDNA1;
         mmq_y  =  MMQ_Y_Q5_1_RDNA1;
         nwarps = NWARPS_Q5_1_RDNA1;
-    } else if (compute_capability >= CC_TURING) {
+    } else if (compute_capability >= CC_VOLTA) {
         mmq_x  =  MMQ_X_Q5_1_AMPERE;
         mmq_y  =  MMQ_Y_Q5_1_AMPERE;
         nwarps = NWARPS_Q5_1_AMPERE;
@@ -5101,7 +5136,7 @@ static void ggml_mul_mat_q8_0_q8_1_cuda(
         mmq_x  =  MMQ_X_Q8_0_RDNA1;
         mmq_y  =  MMQ_Y_Q8_0_RDNA1;
         nwarps = NWARPS_Q8_0_RDNA1;
-    } else if (compute_capability >= CC_TURING) {
+    } else if (compute_capability >= CC_VOLTA) {
         mmq_x  =  MMQ_X_Q8_0_AMPERE;
         mmq_y  =  MMQ_Y_Q8_0_AMPERE;
         nwarps = NWARPS_Q8_0_AMPERE;
@@ -5146,7 +5181,7 @@ static void ggml_mul_mat_q2_K_q8_1_cuda(
         mmq_x  =  MMQ_X_Q2_K_RDNA1;
         mmq_y  =  MMQ_Y_Q2_K_RDNA1;
         nwarps = NWARPS_Q2_K_RDNA1;
-    } else if (compute_capability >= CC_TURING) {
+    } else if (compute_capability >= CC_VOLTA) {
         mmq_x  =  MMQ_X_Q2_K_AMPERE;
         mmq_y  =  MMQ_Y_Q2_K_AMPERE;
         nwarps = NWARPS_Q2_K_AMPERE;
@@ -5193,7 +5228,7 @@ static void ggml_mul_mat_q3_K_q8_1_cuda(
         mmq_x  =  MMQ_X_Q3_K_RDNA1;
         mmq_y  =  MMQ_Y_Q3_K_RDNA1;
         nwarps = NWARPS_Q3_K_RDNA1;
-    } else if (compute_capability >= CC_TURING) {
+    } else if (compute_capability >= CC_VOLTA) {
         mmq_x  =  MMQ_X_Q3_K_AMPERE;
         mmq_y  =  MMQ_Y_Q3_K_AMPERE;
         nwarps = NWARPS_Q3_K_AMPERE;
@@ -5239,7 +5274,7 @@ static void ggml_mul_mat_q4_K_q8_1_cuda(
         mmq_x  =  MMQ_X_Q4_K_RDNA1;
         mmq_y  =  MMQ_Y_Q4_K_RDNA1;
         nwarps = NWARPS_Q4_K_RDNA1;
-    } else if (compute_capability >= CC_TURING) {
+    } else if (compute_capability >= CC_VOLTA) {
         mmq_x  =  MMQ_X_Q4_K_AMPERE;
         mmq_y  =  MMQ_Y_Q4_K_AMPERE;
         nwarps = NWARPS_Q4_K_AMPERE;
@@ -5284,7 +5319,7 @@ static void ggml_mul_mat_q5_K_q8_1_cuda(
         mmq_x  =  MMQ_X_Q5_K_RDNA1;
         mmq_y  =  MMQ_Y_Q5_K_RDNA1;
         nwarps = NWARPS_Q5_K_RDNA1;
-    } else if (compute_capability >= CC_TURING) {
+    } else if (compute_capability >= CC_VOLTA) {
         mmq_x  =  MMQ_X_Q5_K_AMPERE;
         mmq_y  =  MMQ_Y_Q5_K_AMPERE;
         nwarps = NWARPS_Q5_K_AMPERE;
@@ -5329,7 +5364,7 @@ static void ggml_mul_mat_q6_K_q8_1_cuda(
         mmq_x  =  MMQ_X_Q6_K_RDNA1;
         mmq_y  =  MMQ_Y_Q6_K_RDNA1;
         nwarps = NWARPS_Q6_K_RDNA1;
-    } else if (compute_capability >= CC_TURING) {
+    } else if (compute_capability >= CC_VOLTA) {
         mmq_x  =  MMQ_X_Q6_K_AMPERE;
         mmq_y  =  MMQ_Y_Q6_K_AMPERE;
         nwarps = NWARPS_Q6_K_AMPERE;
@@ -5907,7 +5942,7 @@ static int64_t get_row_rounding(ggml_type type) {
     switch(type) {
         case GGML_TYPE_Q4_0:
         case GGML_TYPE_Q4_1:
-            return max_compute_capability >= CC_TURING ? 128 : 64;
+            return max_compute_capability >= CC_VOLTA ? 128 : 64;
         case GGML_TYPE_Q5_0:
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
@@ -5918,7 +5953,7 @@ static int64_t get_row_rounding(ggml_type type) {
         case GGML_TYPE_Q3_K:
         case GGML_TYPE_Q4_K:
         case GGML_TYPE_Q5_K:
-            return max_compute_capability >= CC_TURING ? 128 : 64;
+            return max_compute_capability >= CC_VOLTA ? 128 : 64;
         case GGML_TYPE_Q6_K:
             return 64;
         default:
@@ -6083,8 +6118,19 @@ inline void ggml_cuda_op_mul_mat_cublas(
 
     const int compute_capability = g_compute_capabilities[id];
 
-    if (compute_capability >= CC_TURING && src0->type == GGML_TYPE_F16 && ggml_is_contiguous(src0) && ldc == row_diff) {
-        // convert src1 to fp16, multiply as fp16, convert dst to fp32
+    if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
+        // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
+        half * src0_as_f16 = nullptr;
+        size_t src0_as = 0;
+        if (src0->type != GGML_TYPE_F16) {
+            const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
+            GGML_ASSERT(to_fp16_cuda != nullptr);
+            size_t ne = row_diff*ne00;
+            src0_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src0_as);
+            to_fp16_cuda(src0_dd_i, src0_as_f16, ne, stream);
+        }
+        const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16;
+
         half * src1_as_f16 = nullptr;
         size_t src1_as = 0;
         if (src1->type != GGML_TYPE_F16) {
@@ -6106,9 +6152,9 @@ inline void ggml_cuda_op_mul_mat_cublas(
         CUBLAS_CHECK(
             cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
                     row_diff, src1_ncols, ne10,
-                    &alpha_f16, src0_dd_i, CUDA_R_16F, ne00,
-                                src1_ptr,  CUDA_R_16F, ne10,
-                    &beta_f16,   dst_f16,  CUDA_R_16F, ldc,
+                    &alpha_f16, src0_ptr, CUDA_R_16F, ne00,
+                                src1_ptr, CUDA_R_16F, ne10,
+                    &beta_f16,   dst_f16, CUDA_R_16F, ldc,
                     CUBLAS_COMPUTE_16F,
                     CUBLAS_GEMM_DEFAULT_TENSOR_OP));
 
@@ -6117,6 +6163,10 @@ inline void ggml_cuda_op_mul_mat_cublas(
 
         ggml_cuda_pool_free(dst_f16, dst_as);
 
+        if (src0_as != 0) {
+            ggml_cuda_pool_free(src0_as_f16, src0_as);
+        }
+
         if (src1_as != 0) {
             ggml_cuda_pool_free(src1_as_f16, src1_as);
         }