]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
cuda : add f32 to bf16 copy op (#12806)
authorSigbjørn Skjæret <redacted>
Tue, 8 Apr 2025 21:21:31 +0000 (23:21 +0200)
committerGitHub <redacted>
Tue, 8 Apr 2025 21:21:31 +0000 (23:21 +0200)
This allows BF16 KV-cache on CUDA.

ggml/src/ggml-cuda/cpy.cu
ggml/src/ggml-cuda/ggml-cuda.cu

index ed853ee6c15a23abd6a55ff31a24a54947ddd8fe..4f4faa3e63ae7c1245eeeedd6f940d576ddc3f85 100644 (file)
@@ -10,6 +10,13 @@ static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) {
     *dsti = *xi;
 }
 
+static __device__ void cpy_1_f32_bf16(const char * cxi, char * cdsti) {
+    const float * xi = (const float *) cxi;
+    nv_bfloat16 * dsti = (nv_bfloat16 *) cdsti;
+
+    *dsti = *xi;
+}
+
 static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
     const float * xi = (const float *) cxi;
     half * dsti = (half *) cdsti;
@@ -386,6 +393,16 @@ static void ggml_cpy_f32_f32_cuda(
         (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
 }
 
+static void ggml_cpy_f32_bf16_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
+
+    const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+    cpy_f32_f16<cpy_1_f32_bf16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
+}
+
 static void ggml_cpy_f32_f16_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -581,6 +598,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
         CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
         ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
+        ggml_cpy_f32_bf16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
         ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
@@ -634,6 +653,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
         return nullptr;
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
         return (void*) cpy_f32_f16<cpy_1_f32_f32>;
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
+        return (void*) cpy_f32_f16<cpy_1_f32_bf16>;
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
         return (void*) cpy_f32_f16<cpy_1_f32_f16>;
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
index 78717df1a6095165aef17e5f49ac98f4916728ae..633456a92d0defe5b4d25b2f02f31419b215fc8c 100644 (file)
@@ -3079,6 +3079,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                 if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
                     return true;
                 }
+                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_BF16) {
+                    return true;
+                }
                 if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
                     return true;
                 }