]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CUDA: faster dequantize kernels for Q4_0 and Q4_1 (llama/4938)
authorKawrakow <redacted>
Mon, 15 Jan 2024 05:48:06 +0000 (07:48 +0200)
committerGeorgi Gerganov <redacted>
Wed, 17 Jan 2024 19:21:09 +0000 (21:21 +0200)
Co-authored-by: Iwan Kawrakow <redacted>
ggml-cuda.cu

index bd3814c72b407d29c572953fd99f71a894f2f943..a870718a745cbeee66b3df1006337f8969974ceb 100644 (file)
@@ -1105,6 +1105,61 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in
 #endif // GGML_CUDA_F16
 }
 
+template<typename dst_t>
+static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
+
+    const int i = blockIdx.x;
+
+    // assume 32 threads
+    const int tid = threadIdx.x;
+    const int il  = tid/8;
+    const int ir  = tid%8;
+    const int ib = 8*i + ir;
+    if (ib >= nb32) {
+        return;
+    }
+
+    dst_t * y = yy + 256*i + 32*ir + 4*il;
+
+    const block_q4_0 * x = (const block_q4_0 *)vx + ib;
+    const float d = __half2float(x->d);
+    const float dm = -8*d;
+
+    const uint8_t * q = x->qs + 4*il;
+
+    for (int l = 0; l < 4; ++l) {
+        y[l+ 0] = d * (q[l] & 0xF) + dm;
+        y[l+16] = d * (q[l] >>  4) + dm;
+    }
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
+
+    const int i = blockIdx.x;
+
+    // assume 32 threads
+    const int tid = threadIdx.x;
+    const int il  = tid/8;
+    const int ir  = tid%8;
+    const int ib = 8*i + ir;
+    if (ib >= nb32) {
+        return;
+    }
+
+    dst_t * y = yy + 256*i + 32*ir + 4*il;
+
+    const block_q4_1 * x = (const block_q4_1 *)vx + ib;
+    const float2 d = __half22float2(x->dm);
+
+    const uint8_t * q = x->qs + 4*il;
+
+    for (int l = 0; l < 4; ++l) {
+        y[l+ 0] = d.x * (q[l] & 0xF) + d.y;
+        y[l+16] = d.x * (q[l] >>  4) + d.y;
+    }
+}
+
 //================================== k-quants
 
 template<typename dst_t>
@@ -6253,6 +6308,20 @@ static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cu
 #endif
 }
 
+template<typename dst_t>
+static void dequantize_q4_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+    const int nb32 = k / 32;
+    const int nb = (k + 255) / 256;
+    dequantize_block_q4_0<<<nb, 32, 0, stream>>>(vx, y, nb32);
+}
+
+template<typename dst_t>
+static void dequantize_q4_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+    const int nb32 = k / 32;
+    const int nb = (k + 255) / 256;
+    dequantize_block_q4_1<<<nb, 32, 0, stream>>>(vx, y, nb32);
+}
+
 template<typename dst_t>
 static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
     const int nb = k / QK_K;
@@ -6301,9 +6370,9 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
     int id;
     switch (type) {
         case GGML_TYPE_Q4_0:
-            return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
+            return dequantize_q4_0_cuda;
         case GGML_TYPE_Q4_1:
-            return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
+            return dequantize_q4_1_cuda;
         case GGML_TYPE_Q5_0:
             return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
         case GGML_TYPE_Q5_1:
@@ -6338,9 +6407,9 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
 static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
     switch (type) {
         case GGML_TYPE_Q4_0:
-            return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
+            return dequantize_q4_0_cuda;
         case GGML_TYPE_Q4_1:
-            return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
+            return dequantize_q4_1_cuda;
         case GGML_TYPE_Q5_0:
             return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
         case GGML_TYPE_Q5_1: