}
}
+struct cublas_force_compute_type {
+ bool fp32 = false;
+ bool fp16 = false;
+};
+
+static const cublas_force_compute_type & ggml_cuda_cublas_get_force_compute_type() {
+ static const cublas_force_compute_type compute_type = [] {
+ cublas_force_compute_type result;
+
+ const bool ggml_cuda_force_cublas_compute_32f_env = getenv("GGML_CUDA_FORCE_CUBLAS_COMPUTE_32F") != nullptr;
+ const bool ggml_cuda_force_cublas_compute_16f_env = getenv("GGML_CUDA_FORCE_CUBLAS_COMPUTE_16F") != nullptr;
+
+ GGML_ASSERT(ggml_cuda_force_cublas_compute_16f_env == false || ggml_cuda_force_cublas_compute_32f_env == false);
+
+ if (ggml_cuda_force_cublas_compute_32f_env) {
+ GGML_LOG_INFO("Detected GGML_CUDA_FORCE_CUBLAS_COMPUTE_32F\n");
+ result.fp32 = true;
+ } else if (ggml_cuda_force_cublas_compute_16f_env) {
+ GGML_LOG_INFO("Detected GGML_CUDA_FORCE_CUBLAS_COMPUTE_16F\n");
+ result.fp16 = true;
+ }
+
+ return result;
+ }();
+
+ return compute_type;
+}
+
static void ggml_cuda_op_mul_mat_cublas(
ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
- if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
+ const auto & force_compute_type = ggml_cuda_cublas_get_force_compute_type();
+
+ if (!force_compute_type.fp16 && (GGML_CUDA_CC_IS_CDNA(cc)
+ || GGML_CUDA_CC_IS_RDNA4(cc)
+ || cc == GGML_CUDA_CC_VOLTA
+ || force_compute_type.fp32))
+ {
const float alpha = 1.0f;
const float beta = 0.0f;
CUBLAS_CHECK(
cudaDataType_t cu_data_type_b = traits::data_type;
const void * alpha = traits::get_alpha();
const void * beta = traits::get_beta();
- const float alpha_f32 = 1.0f;
- const float beta_f32 = 0.0f;
- if (dst->op_params[0] == GGML_PREC_DEFAULT) {
+ const auto & force_compute_type = ggml_cuda_cublas_get_force_compute_type();
+
+ int id = ggml_cuda_get_device();
+ const int cc = ggml_cuda_info().devices[id].cc;
+ static constexpr bool is_src0_type_f16 = src0_type == GGML_TYPE_F16;
+
+ // bf16 and fp32 are already being computed in fp32 (ensure it using static_assert),
+ // so checking necessity of forced fp32 only for fp16 src0_type
+ static_assert(is_src0_type_f16 || traits::compute_type == CUBLAS_COMPUTE_32F);
+
+ const bool need_compute_32f = is_src0_type_f16 && !force_compute_type.fp16 && (GGML_CUDA_CC_IS_CDNA(cc)
+ || GGML_CUDA_CC_IS_RDNA4(cc)
+ || cc == GGML_CUDA_CC_VOLTA
+ || force_compute_type.fp32);
+
+ if (dst->op_params[0] == GGML_PREC_DEFAULT && !need_compute_32f) {
if constexpr (src0_type == GGML_TYPE_F32) {
dst_t = (char *) dst_ddf; // Direct F32 output
} else {
}
} else {
dst_t = (char *) dst_ddf;
- cu_compute_type = CUBLAS_COMPUTE_32F;
- cu_data_type = CUDA_R_32F;
- alpha = &alpha_f32;
- beta = &beta_f32;
- }
-
- int id = ggml_cuda_get_device();
- const int cc = ggml_cuda_info().devices[id].cc;
- if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
- cu_compute_type = CUBLAS_COMPUTE_32F;
- alpha = &alpha_f32;
- beta = &beta_f32;
+ cu_compute_type = batched_mul_mat_traits<GGML_TYPE_F32>::compute_type;
+ cu_data_type = batched_mul_mat_traits<GGML_TYPE_F32>::data_type;
+ alpha = batched_mul_mat_traits<GGML_TYPE_F32>::get_alpha();
+ beta = batched_mul_mat_traits<GGML_TYPE_F32>::get_beta();
}
GGML_ASSERT(ne12 % ne02 == 0);