]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
HIP : Add HIP 7.0+ compatibility for hipBLAS compute types (#14634)
authorSlobodan Josic <redacted>
Fri, 11 Jul 2025 16:55:00 +0000 (18:55 +0200)
committerGitHub <redacted>
Fri, 11 Jul 2025 16:55:00 +0000 (18:55 +0200)
ggml/src/ggml-cuda/vendors/hip.h

index 1a28831b7a96b74f880363a05450819cb2d66a5a..184d445f5c067a617c01408a0ced35d54afd1442 100644 (file)
@@ -10,9 +10,6 @@
 #include "rocblas/rocblas.h"
 #endif // __HIP_PLATFORM_AMD__
 
-#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
-#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
-#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
 #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
 #define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
 #define CUBLAS_OP_N HIPBLAS_OP_N
@@ -30,7 +27,6 @@
 #define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
 #define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
 #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
-#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
 #define cublasCreate hipblasCreate
 #define cublasDestroy hipblasDestroy
 #define cublasGemmEx hipblasGemmEx
@@ -42,7 +38,6 @@
 #define cublasSgemm hipblasSgemm
 #define cublasStatus_t hipblasStatus_t
 #define cublasOperation_t hipblasOperation_t
-#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
 #define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
 #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
 #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
 #define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
 #define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
 
+#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION >= 70000000
+#define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F
+#define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F
+#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_COMPUTE_32F_FAST_16F
+#define cublasComputeType_t hipblasComputeType_t
+#define cudaDataType_t hipDataType
+#else
+#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
+#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
+#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
+#define cublasComputeType_t hipblasDatatype_t
+#define cudaDataType_t hipblasDatatype_t
+#endif
+
 #define __CUDA_ARCH__ 1300
 
 #if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)