const src_t * x = (const src_t *) vx;
- y[i] = x[i];
+ if constexpr (std::is_same_v<src_t, nv_bfloat16>) {
+ y[i] = __bfloat162float(x[i]);
+ } else if constexpr (std::is_same_v<dst_t, nv_bfloat16> && std::is_same_v<src_t, half>) {
+ y[i] = (float)x[i];
+ } else {
+ y[i] = x[i];
+ }
}
template <typename src_t, typename dst_t>
convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
}
+to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) {
+ switch (type) {
+ case GGML_TYPE_F32:
+ return convert_unary_cuda<float>;
+ case GGML_TYPE_F16:
+ return convert_unary_cuda<half>;
+ default:
+ return nullptr;
+ }
+}
+
to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
switch (type) {
case GGML_TYPE_Q4_0:
return dequantize_row_iq3_s_cuda;
case GGML_TYPE_F32:
return convert_unary_cuda<float>;
+ case GGML_TYPE_BF16:
+ return convert_unary_cuda<nv_bfloat16>;
default:
return nullptr;
}
const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
- if (((GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) || GGML_CUDA_CC_IS_AMD(cc)) && use_fp16) {
+ if (src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
+ ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16(ctx.pool(id));
+ if (src1->type != GGML_TYPE_BF16) {
+ const to_bf16_cuda_t to_bf16_cuda = ggml_get_to_bf16_cuda(src1->type);
+ GGML_ASSERT(to_bf16_cuda != nullptr);
+ size_t ne = src1_ncols*ne10;
+ src1_as_bf16.alloc(ne);
+ to_bf16_cuda(src1_ddf_i, src1_as_bf16.get(), ne, stream);
+ }
+ const nv_bfloat16 * src1_ptr = src1->type == GGML_TYPE_BF16 ? (const nv_bfloat16 *) src1_ddf_i : src1_as_bf16.get();
+ const nv_bfloat16 * src0_ptr = (const nv_bfloat16 *)src0_dd_i;
+ ggml_cuda_pool_alloc<nv_bfloat16> dst_bf16(ctx.pool(id), row_diff*src1_ncols);
+
+ const float alpha_f32 = 1.0f;
+ const float beta_f32 = 0.0f;
+
+ CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
+ CUBLAS_CHECK(
+ cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
+ row_diff, src1_ncols, ne10,
+ &alpha_f32, src0_ptr, CUDA_R_16BF, ne00,
+ src1_ptr, CUDA_R_16BF, ne10,
+ &beta_f32, dst_bf16.get(), CUDA_R_16BF, ldc,
+ CUBLAS_COMPUTE_32F,
+ CUBLAS_GEMM_DEFAULT_TENSOR_OP));
+
+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_BF16);
+ to_fp32_cuda(dst_bf16.get(), dst_dd_i, row_diff*src1_ncols, stream);
+ } else if (((GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) || GGML_CUDA_CC_IS_AMD(cc)) && use_fp16) {
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
if (src0->type != GGML_TYPE_F16) {