]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CUDA: fix FP16 cuBLAS GEMM (llama/11396)
authorJohannes Gäßler <redacted>
Fri, 24 Jan 2025 20:02:43 +0000 (21:02 +0100)
committerGeorgi Gerganov <redacted>
Mon, 3 Feb 2025 20:00:57 +0000 (22:00 +0200)
ggml/src/ggml-cuda/ggml-cuda.cu

index fb3d9e2d92e093bb5dd0f43b9bd2b637cc529ea6..fbe889a01221b01d3172ae93de34f0cf0ca5703f 100644 (file)
@@ -1114,8 +1114,8 @@ static void ggml_cuda_op_mul_mat_cublas(
             CUBLAS_CHECK(
                 cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
                         row_diff, src1_ncols, ne10,
-                        &alpha, src0_ptr,       CUDA_R_16F, ne00,
-                                    src1_ptr,       CUDA_R_16F, ne10,
+                        &alpha, src0_ptr,  CUDA_R_16F, ne00,
+                                src1_ptr,  CUDA_R_16F, ne10,
                         &beta,   dst_dd_i, CUDA_R_32F, ldc,
                         CUBLAS_COMPUTE_32F,
                         CUBLAS_GEMM_DEFAULT_TENSOR_OP));
@@ -1128,9 +1128,9 @@ static void ggml_cuda_op_mul_mat_cublas(
             CUBLAS_CHECK(
                 cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
                         row_diff, src1_ncols, ne10,
-                        &alpha_f16, src0_ptr,       CUDA_R_16F, ne00,
-                                    src1_ptr,       CUDA_R_16F, ne10,
-                        &beta_f16,   dst_dd_i, CUDA_R_16F, ldc,
+                        &alpha_f16, src0_ptr,      CUDA_R_16F, ne00,
+                                    src1_ptr,      CUDA_R_16F, ne10,
+                        &beta_f16,  dst_f16.get(), CUDA_R_16F, ldc,
                         CUBLAS_COMPUTE_16F,
                         CUBLAS_GEMM_DEFAULT_TENSOR_OP));