]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
cuda: add q8_0->f32 cpy operation (#9571)
authorIvan <redacted>
Tue, 24 Sep 2024 00:14:24 +0000 (03:14 +0300)
committerGitHub <redacted>
Tue, 24 Sep 2024 00:14:24 +0000 (02:14 +0200)
llama: enable K-shift for quantized KV cache
It will fail on unsupported backends or quant types.

ggml/src/ggml-cuda.cu
ggml/src/ggml-cuda/cpy.cu
src/llama.cpp

index a0d2561009f5837261d29779639ccc82c7b7d8b5..0bb7f2d997543a1ba29de4a3a91c6c8833368808 100644 (file)
@@ -2899,6 +2899,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
                 if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
                     return true;
                 }
+                if (src0_type == GGML_TYPE_Q8_0 && src1_type == GGML_TYPE_F32) {
+                    return true;
+                }
                 if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
                     return true;
                 }
index 51deb75fd5f81ea2195f84c72dae76ad85f11e9b..54c0f66d2dfeddb17ef39352935aae0028d58192 100644 (file)
@@ -81,6 +81,17 @@ static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
     }
 }
 
+static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
+    const block_q8_0 * xi = (const block_q8_0 *) cxi;
+    float * dsti = (float *) cdsti;
+
+    const float d = (float)xi->d;
+
+    for (int j = 0; j < QK8_0; j++) {
+       dsti[j] = xi->qs[j] * d;
+    }
+}
+
 static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
     const float * xi = (const float *) cxi;
     block_q4_0 * dsti = (block_q4_0 *) cdsti;
@@ -288,6 +299,32 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
     cpy_blck(cx + x_offset, cdst + dst_offset);
 }
 
+template <cpy_kernel_t cpy_blck, int qk>
+static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
+                                 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+                                 const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
+                                 const int nb12, const int nb13) {
+    const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
+
+    if (i >= ne) {
+        return;
+    }
+
+    const int i03 = i/(ne00 * ne01 * ne02);
+    const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
+    const int i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
+    const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
+    const int x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
+
+    const int i13 = i/(ne10 * ne11 * ne12);
+    const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
+    const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
+    const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
+    const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
+
+    cpy_blck(cx + x_offset, cdst + dst_offset);
+}
+
 static void ggml_cpy_f16_f32_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -329,6 +366,16 @@ static void ggml_cpy_f32_q8_0_cuda(
         (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
+static void ggml_cpy_q8_0_f32_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+
+    const int num_blocks = ne;
+    cpy_q_f32<cpy_blck_q8_0_f32, QK8_0><<<num_blocks, 1, 0, stream>>>
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
 static void ggml_cpy_f32_q4_0_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -437,6 +484,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
         ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
         ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
+        ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
         ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
@@ -471,6 +520,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
         return (void*) cpy_f32_f16<cpy_1_f32_f16>;
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
         return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
+    } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
+        return (void*) cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>;
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
         return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
index bc4e408e010d09d87cff97f9a7066b87d897db44..e5e0d1a66e51c617c61966733d7a7df8680f01d9 100644 (file)
@@ -9930,17 +9930,36 @@ struct llm_build_context {
             const int64_t n_head_kv = hparams.n_head_kv(il);
             const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
             struct ggml_tensor * rope_factors = build_rope_factors(il);
-            struct ggml_tensor * tmp =
+            struct ggml_tensor * k =
+                ggml_view_3d(ctx0, kv_self.k_l[il],
+                    n_embd_head_k, n_head_kv, n_ctx,
+                    ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
+                    ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
+                    0);
+
+            struct ggml_tensor * tmp;
+            if (ggml_is_quantized(k->type)) {
+                // dequantize to f32 -> RoPE -> quantize back
+                tmp = ggml_cast(ctx0, k, GGML_TYPE_F32);
+                cb(tmp, "K_f32", il);
+                for (auto * backend : lctx.backends) {
+                    // Figure out which backend KV cache belongs to
+                    if (ggml_backend_supports_buft(backend, lctx.model.buft_layer[il].buft)) {
+                        ggml_backend_sched_set_tensor_backend(lctx.sched, tmp, backend);
+                        break;
+                    }
+                }
+                tmp = ggml_rope_ext_inplace(ctx0, tmp,
+                        lctx.inp_K_shift, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow);
+                cb(tmp, "K_shifted_f32", il);
+                tmp = ggml_cpy(ctx0, tmp, k);
+            } else {
                 // we rotate only the first n_rot dimensions
-                ggml_rope_ext_inplace(ctx0,
-                        ggml_view_3d(ctx0, kv_self.k_l[il],
-                            n_embd_head_k, n_head_kv, n_ctx,
-                            ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
-                            ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
-                            0),
+                tmp = ggml_rope_ext_inplace(ctx0, k,
                         lctx.inp_K_shift, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow);
-
+            }
             cb(tmp, "K_shifted", il);
             ggml_build_forward_expand(gf, tmp);
         }