]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
cuda : add set rows for bf16 (llama/14664)
authorSigbjørn Skjæret <redacted>
Sun, 13 Jul 2025 13:01:24 +0000 (15:01 +0200)
committerGeorgi Gerganov <redacted>
Sat, 19 Jul 2025 14:47:23 +0000 (17:47 +0300)
src/ggml-cuda/ggml-cuda.cu
src/ggml-cuda/set-rows.cu

index c7222207efed626ea0c9ebba9953204e74e092e1..8015b0d4e8d9245b439b5bd0a354d303411f7937 100644 (file)
@@ -3226,8 +3226,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
             } break;
         case GGML_OP_SET_ROWS:
             {
-#pragma message("TODO: implement BF16, Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)")
-                return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
+#pragma message("TODO: implement Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)")
+                return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16) &&
                        op->src[0]->type == GGML_TYPE_F32 &&
                        op->src[1]->type == GGML_TYPE_I64;
             } break;
index d8b3e63e1aa573d11ac6a1bd62f114aea71e72db..3fade72b84eca383de576e65e7280c4a48a39f4b 100644 (file)
@@ -10,6 +10,11 @@ __device__ __forceinline__ void set_rows_1<float, half>(const float * src_f, hal
     *dst_h = __float2half(*src_f);
 }
 
+template<>
+__device__ __forceinline__ void set_rows_1<float, nv_bfloat16>(const float * src_f, nv_bfloat16 * dst_b) {
+    *dst_b = *src_f;
+}
+
 template<>
 __device__ __forceinline__ void set_rows_1<float, float>(const float * src_f, float * dst_f) {
     *dst_f = *src_f;
@@ -124,6 +129,16 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
             nb1, nb2, nb3,
             stream
         );
+    } else if (dst->type == GGML_TYPE_BF16) {
+        set_rows_cuda(
+            src0_d, src1_d, (nv_bfloat16*)dst->data,
+            ne00, ne01, ne02, ne03,
+            ne10, ne11, ne12, ne13,
+            nb01, nb02, nb03,
+            nb10, nb11, nb12,
+            nb1, nb2, nb3,
+            stream
+        );
     } else {
         GGML_ABORT("unsupported type");
     }