]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: fix quantized KV cache + multiple sequences (#14822)
authorJohannes Gäßler <redacted>
Wed, 23 Jul 2025 10:35:53 +0000 (12:35 +0200)
committerGeorgi Gerganov <redacted>
Wed, 23 Jul 2025 11:08:09 +0000 (14:08 +0300)
* CUDA: fix quantized KV cache + multiple sequences

* Update ggml/src/ggml-cuda/fattn-common.cuh

Co-authored-by: Georgi Gerganov <redacted>
---------

Co-authored-by: Georgi Gerganov <redacted>
ggml/src/ggml-cuda/convert.cu
ggml/src/ggml-cuda/fattn-common.cuh

index eeaa14bf579500d55c1193b72dc41bbfcdd5e1b0..1b4a71bab074c9f8f2e941909becdd9392fc3e60 100644 (file)
@@ -6,24 +6,33 @@
 #define CUDA_Q8_0_NE_ALIGN 2048
 
 template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
-static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
-    const int64_t i = (int64_t)2*(blockDim.x*blockIdx.x + threadIdx.x);
+static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y,
+        const int64_t ne00, const int64_t ne01, const int64_t ne02,
+        const int64_t s01, const int64_t s02, const int64_t s03) {
+    const int64_t i00 = 2 * (int64_t(blockDim.x)*blockIdx.x + threadIdx.x);
 
-    if (i >= k) {
+    if (i00 >= ne00) {
         return;
     }
 
-    const int64_t ib = i/qk; // block index
-    const int64_t iqs = (i%qk)/qr; // quant index
-    const int64_t iybs = i - i%qk; // y block start index
+    const int64_t i01 = blockIdx.y;
+    const int64_t i02 = blockIdx.z % ne02;
+    const int64_t i03 = blockIdx.z / ne02;
+
+    const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
+
+    const int64_t ib = ibx0 + i00/qk; // block index
+    const int64_t iqs = (i00%qk)/qr; // quant index
+    const int64_t iybs = i00 - i00%qk; // y block start index
     const int64_t y_offset = qr == 1 ? 1 : qk/2;
 
     // dequantize
     dfloat2 v;
     dequantize_kernel(vx, ib, iqs, v);
 
-    y[iybs + iqs + 0]        = v.x;
-    y[iybs + iqs + y_offset] = v.y;
+    const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs;
+    y[iy0 + 0]        = v.x;
+    y[iy0 + y_offset] = v.y;
 }
 
 template <bool need_check>
@@ -457,9 +466,17 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst
 }
 
 template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
-static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
-    const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
-    dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
+static void dequantize_block_cuda(const void * vx, dst_t * y,
+        const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
+        const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
+    const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, ne02*ne03);
+    dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>
+        (vx, y, ne00, ne01, ne02, s01, s02, s03);
+}
+
+template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
+static void dequantize_block_cont_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
+    dequantize_block_cuda<qk, qr, dequantize_kernel, dst_t>(vx, y, k, 1, 1, 1, k/qk, k/qk, k/qk, stream);
 }
 
 static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int64_t k, cudaStream_t stream) {
@@ -624,14 +641,14 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
         case GGML_TYPE_Q4_1:
             return dequantize_row_q4_1_cuda;
         case GGML_TYPE_Q5_0:
-            return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
+            return dequantize_block_cont_cuda<QK5_0, QR5_0, dequantize_q5_0>;
         case GGML_TYPE_Q5_1:
-            return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
+            return dequantize_block_cont_cuda<QK5_1, QR5_1, dequantize_q5_1>;
         case GGML_TYPE_Q8_0:
             if (fp16_available(ggml_cuda_info().devices[ggml_cuda_get_device()].cc)) {
                 return dequantize_block_q8_0_f16_cuda;
             }
-            return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
+            return dequantize_block_cont_cuda<QK8_0, QR8_0, dequantize_q8_0>;
         case GGML_TYPE_Q2_K:
             return dequantize_row_q2_K_cuda;
         case GGML_TYPE_Q3_K:
@@ -676,11 +693,11 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
         case GGML_TYPE_Q4_1:
             return dequantize_row_q4_1_cuda;
         case GGML_TYPE_Q5_0:
-            return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
+            return dequantize_block_cont_cuda<QK5_0, QR5_0, dequantize_q5_0>;
         case GGML_TYPE_Q5_1:
-            return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
+            return dequantize_block_cont_cuda<QK5_1, QR5_1, dequantize_q5_1>;
         case GGML_TYPE_Q8_0:
-            return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
+            return dequantize_block_cont_cuda<QK8_0, QR8_0, dequantize_q8_0>;
         case GGML_TYPE_Q2_K:
             return dequantize_row_q2_K_cuda;
         case GGML_TYPE_Q3_K:
@@ -722,6 +739,16 @@ to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) {
     switch (type) {
         case GGML_TYPE_F32:
             return convert_unary_cuda<float>;
+        case GGML_TYPE_Q4_0:
+            return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
+        case GGML_TYPE_Q4_1:
+            return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
+        case GGML_TYPE_Q5_0:
+            return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
+        case GGML_TYPE_Q5_1:
+            return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
+        case GGML_TYPE_Q8_0:
+            return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
         case GGML_TYPE_BF16:
             return convert_unary_cuda<nv_bfloat16>;
         default:
@@ -733,6 +760,16 @@ to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) {
     switch (type) {
         case GGML_TYPE_F32:
             return convert_unary_cuda<float, nv_bfloat16>;
+        case GGML_TYPE_Q4_0:
+            return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
+        case GGML_TYPE_Q4_1:
+            return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
+        case GGML_TYPE_Q5_0:
+            return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
+        case GGML_TYPE_Q5_1:
+            return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
+        case GGML_TYPE_Q8_0:
+            return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
         case GGML_TYPE_F16:
             return convert_unary_cuda<half, nv_bfloat16>;
         default:
@@ -744,6 +781,16 @@ to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) {
     switch (type) {
         case GGML_TYPE_F16:
             return convert_unary_cuda<half, float>;
+        case GGML_TYPE_Q4_0:
+            return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
+        case GGML_TYPE_Q4_1:
+            return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
+        case GGML_TYPE_Q5_0:
+            return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
+        case GGML_TYPE_Q5_1:
+            return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
+        case GGML_TYPE_Q8_0:
+            return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
         case GGML_TYPE_BF16:
             return convert_unary_cuda<nv_bfloat16, float>;
         default:
index 9122fca6cf99f3ed475167de211716a260bf70bb..3644ddf2fdf360f852be265161b65f07431639c7 100644 (file)
@@ -745,33 +745,58 @@ void launch_fattn(
     size_t nb23 = V ? V->nb[3] : nb13;
 
     if (need_f16_K && K->type != GGML_TYPE_F16) {
-        GGML_ASSERT(ggml_is_contiguously_allocated(K));
-        K_f16.alloc(ggml_nelements(K));
-        to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
-        to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
-        K_data = (char *) K_f16.ptr;
-
         const size_t bs = ggml_blck_size(K->type);
         const size_t ts = ggml_type_size(K->type);
 
-        nb11 = nb11*bs*sizeof(half)/ts;
-        nb12 = nb12*bs*sizeof(half)/ts;
-        nb13 = nb13*bs*sizeof(half)/ts;
+        K_f16.alloc(ggml_nelements(K));
+        if (ggml_is_contiguously_allocated(K)) {
+            to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
+            to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
+
+            nb11 = nb11*bs*sizeof(half)/ts;
+            nb12 = nb12*bs*sizeof(half)/ts;
+            nb13 = nb13*bs*sizeof(half)/ts;
+        } else {
+            GGML_ASSERT(K->nb[0] == ts);
+            to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(K->type);
+            const int64_t s01 = nb11 / ts;
+            const int64_t s02 = nb12 / ts;
+            const int64_t s03 = nb13 / ts;
+            to_fp16(K_data, K_f16.ptr, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream);
+
+            nb11 = K->ne[0] * sizeof(half);
+            nb12 = K->ne[1] * nb11;
+            nb13 = K->ne[2] * nb12;
+        }
+        K_data = (char *) K_f16.ptr;
     }
 
     if (V && need_f16_V && V->type != GGML_TYPE_F16) {
-        GGML_ASSERT(ggml_is_contiguously_allocated(V));
-        V_f16.alloc(ggml_nelements(V));
-        to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
-        to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
-        V_data = (char *) V_f16.ptr;
-
         const size_t bs = ggml_blck_size(V->type);
         const size_t ts = ggml_type_size(V->type);
 
-        nb21 = nb21*bs*sizeof(half)/ts;
-        nb22 = nb22*bs*sizeof(half)/ts;
-        nb23 = nb23*bs*sizeof(half)/ts;
+        V_f16.alloc(ggml_nelements(V));
+        if (ggml_is_contiguously_allocated(V)) {
+            to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
+            to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
+            V_data = (char *) V_f16.ptr;
+
+            nb21 = nb21*bs*sizeof(half)/ts;
+            nb22 = nb22*bs*sizeof(half)/ts;
+            nb23 = nb23*bs*sizeof(half)/ts;
+        } else {
+            GGML_ASSERT(V->nb[0] == ts);
+            to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type);
+            const int64_t s01 = nb21 / ts;
+            const int64_t s02 = nb22 / ts;
+            const int64_t s03 = nb23 / ts;
+            to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);
+
+            nb21 = V->ne[0] * sizeof(half);
+            nb22 = V->ne[1] * nb21;
+            nb23 = V->ne[2] * nb22;
+        }
+        V_data = (char *) V_f16.ptr;
     }
 
     int parallel_blocks = 1;