} break;
case GGML_OP_SET_ROWS:
{
-#pragma message("TODO: implement BF16, Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)")
- return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
+#pragma message("TODO: implement Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)")
+ return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16) &&
op->src[0]->type == GGML_TYPE_F32 &&
op->src[1]->type == GGML_TYPE_I64;
} break;
*dst_h = __float2half(*src_f);
}
+template<>
+__device__ __forceinline__ void set_rows_1<float, nv_bfloat16>(const float * src_f, nv_bfloat16 * dst_b) {
+ *dst_b = *src_f;
+}
+
template<>
__device__ __forceinline__ void set_rows_1<float, float>(const float * src_f, float * dst_f) {
*dst_f = *src_f;
nb1, nb2, nb3,
stream
);
+ } else if (dst->type == GGML_TYPE_BF16) {
+ set_rows_cuda(
+ src0_d, src1_d, (nv_bfloat16*)dst->data,
+ ne00, ne01, ne02, ne03,
+ ne10, ne11, ne12, ne13,
+ nb01, nb02, nb03,
+ nb10, nb11, nb12,
+ nb1, nb2, nb3,
+ stream
+ );
} else {
GGML_ABORT("unsupported type");
}