option(WHISPER_BLAS_VENDOR "whisper: BLAS library vendor" Generic)
option(WHISPER_OPENBLAS "whisper: prefer OpenBLAS" OFF)
option(WHISPER_CUBLAS "whisper: support for cuBLAS" OFF)
+ option(WHISPER_HIPBLAS "whisper: support for hipBLAS" OFF)
option(WHISPER_CLBLAST "whisper: use CLBlast" OFF)
endif()
endif()
endif()
+
+if (WHISPER_HIPBLAS)
+ list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
+ if (NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang")
+ message(WARNING "Only LLVM is supported for HIP, hint: CC=/opt/rocm/llvm/bin/clang")
+ endif()
+ if (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang")
+ message(WARNING "Only LLVM is supported for HIP, hint: CXX=/opt/rocm/llvm/bin/clang++")
+ endif()
+
+ find_package(hip)
+ find_package(hipblas)
+ find_package(rocblas)
+
+ if (${hipblas_FOUND} AND ${hip_FOUND})
+ message(STATUS "HIP and hipBLAS found")
+ add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUBLAS)
+ add_library(ggml-rocm OBJECT ggml-cuda.cu ggml-cuda.h)
+ set_property(TARGET ggml-rocm PROPERTY POSITION_INDEPENDENT_CODE ON)
+ set_source_files_properties(ggml-cuda.cu PROPERTIES LANGUAGE CXX)
+ target_link_libraries(ggml-rocm PRIVATE hip::device PUBLIC hip::host roc::rocblas roc::hipblas)
+
+ if (WHISPER_STATIC)
+ message(FATAL_ERROR "Static linking not supported for HIP/ROCm")
+ endif()
+ set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ggml-rocm)
+ else()
+ message(WARNING "hipBLAS or HIP not found. Try setting CMAKE_PREFIX_PATH=/opt/rocm")
+ endif()
+endif()
+
if (WHISPER_CLBLAST)
find_package(CLBlast)
if (CLBlast_FOUND)
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
endif
+ifdef WHISPER_HIPBLAS
+ ROCM_PATH ?= /opt/rocm
+ HIPCC ?= $(ROCM_PATH)/bin/hipcc
+ GPU_TARGETS ?= $(shell $(ROCM_PATH)/llvm/bin/amdgpu-arch)
+ CFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUBLAS
+ CXXFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUBLAS
+ LDFLAGS += -L$(ROCM_PATH)/lib -Wl,-rpath=$(ROCM_PATH)/lib
+ LDFLAGS += -lhipblas -lamdhip64 -lrocblas
+ HIPFLAGS += $(addprefix --offload-arch=,$(GPU_TARGETS))
+ WHISPER_OBJ += ggml-cuda.o
+
+ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
+ $(HIPCC) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $<
+endif
+
ifdef WHISPER_CLBLAST
CFLAGS += -DGGML_USE_CLBLAST
CXXFLAGS += -DGGML_USE_CLBLAST
#include <atomic>
#include <assert.h>
+#if defined(GGML_USE_HIPBLAS)
+#include <hip/hip_runtime.h>
+#include <hipblas/hipblas.h>
+#include <hip/hip_fp16.h>
+#include <rocblas/rocblas.h>
+#define CUBLAS_OP_N HIPBLAS_OP_N
+#define CUBLAS_OP_T HIPBLAS_OP_T
+#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
+#define CUBLAS_TF32_TENSOR_OP_MATH 0
+#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
+#define cublasCreate hipblasCreate
+#define cublasGetStatusString rocblas_status_to_string
+#define cublasHandle_t hipblasHandle_t
+#define cublasLoggerConfigure(logIsOn, logToStdOut, logToStdErr, logFileName) CUBLAS_STATUS_SUCCESS
+#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
+#define cublasSetStream hipblasSetStream
+#define cublasSgemm hipblasSgemm
+#define cublasStatus_t hipblasStatus_t
+#define cudaDeviceProp hipDeviceProp_t
+#define cudaDeviceSynchronize hipDeviceSynchronize
+#define cudaError_t hipError_t
+#define cudaEventCreateWithFlags hipEventCreateWithFlags
+#define cudaEventDestroy hipEventDestroy
+#define cudaEventDisableTiming hipEventDisableTiming
+#define cudaEventRecord hipEventRecord
+#define cudaEvent_t hipEvent_t
+#define cudaFree hipFree
+#define cudaFreeHost hipHostFree
+#define cudaGetDevice hipGetDevice
+#define cudaGetDeviceCount hipGetDeviceCount
+#define cudaGetDeviceProperties hipGetDeviceProperties
+#define cudaGetErrorString hipGetErrorString
+#define cudaGetLastError hipGetLastError
+#define cudaMalloc hipMalloc
+#define cudaMallocHost hipHostMalloc
+#define cudaMemcpy hipMemcpy
+#define cudaMemcpy2DAsync hipMemcpy2DAsync
+#define cudaMemcpyAsync hipMemcpyAsync
+#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
+#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
+#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
+#define cudaMemcpyKind hipMemcpyKind
+#define cudaMemset hipMemset
+#define cudaSetDevice hipSetDevice
+#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
+#define cudaStreamNonBlocking hipStreamNonBlocking
+#define cudaStreamWaitEvent(stream, event) hipStreamWaitEvent(stream, event, 0)
+#define cudaStream_t hipStream_t
+#define cudaSuccess hipSuccess
+#else
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <cuda_fp16.h>
+#endif
#include "ggml-cuda.h"
#include "ggml.h"