]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : initial hipBLAS support (#1209)
authorardfork <redacted>
Sun, 27 Aug 2023 17:03:58 +0000 (17:03 +0000)
committerGitHub <redacted>
Sun, 27 Aug 2023 17:03:58 +0000 (20:03 +0300)
CMakeLists.txt
Makefile
ggml-cuda.cu

index 91385cb3f81e6e2f6a4e5cb8d5953cd962d1f50d..407d9800ca7dd00643b46a94483e2fd10d2b9d91 100644 (file)
@@ -65,6 +65,7 @@ else()
     option(WHISPER_BLAS_VENDOR           "whisper: BLAS library vendor" Generic)
     option(WHISPER_OPENBLAS              "whisper: prefer OpenBLAS"     OFF)
     option(WHISPER_CUBLAS                "whisper: support for cuBLAS"  OFF)
+    option(WHISPER_HIPBLAS               "whisper: support for hipBLAS" OFF)
     option(WHISPER_CLBLAST               "whisper: use CLBlast"         OFF)
 endif()
 
@@ -191,6 +192,37 @@ if (WHISPER_CUBLAS)
     endif()
 endif()
 
+
+if (WHISPER_HIPBLAS)
+    list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
+    if (NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang")
+        message(WARNING "Only LLVM is supported for HIP, hint: CC=/opt/rocm/llvm/bin/clang")
+    endif()
+    if (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang")
+        message(WARNING "Only LLVM is supported for HIP, hint: CXX=/opt/rocm/llvm/bin/clang++")
+    endif()
+
+    find_package(hip)
+    find_package(hipblas)
+    find_package(rocblas)
+
+    if (${hipblas_FOUND} AND ${hip_FOUND})
+        message(STATUS "HIP and hipBLAS found")
+        add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUBLAS)
+        add_library(ggml-rocm OBJECT ggml-cuda.cu ggml-cuda.h)
+        set_property(TARGET ggml-rocm PROPERTY POSITION_INDEPENDENT_CODE ON)
+        set_source_files_properties(ggml-cuda.cu PROPERTIES LANGUAGE CXX)
+        target_link_libraries(ggml-rocm PRIVATE hip::device PUBLIC hip::host roc::rocblas roc::hipblas)
+
+        if (WHISPER_STATIC)
+            message(FATAL_ERROR "Static linking not supported for HIP/ROCm")
+        endif()
+        set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ggml-rocm)
+    else()
+        message(WARNING "hipBLAS or HIP not found. Try setting CMAKE_PREFIX_PATH=/opt/rocm")
+    endif()
+endif()
+
 if (WHISPER_CLBLAST)
     find_package(CLBlast)
     if (CLBlast_FOUND)
index 49530031ddb76ea232139bfe9a7a1adf1b031d75..ee8017b0b16163b29bfa592855856ebdb00f8975 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -161,6 +161,21 @@ ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
        $(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
 endif
 
+ifdef WHISPER_HIPBLAS
+       ROCM_PATH   ?= /opt/rocm
+       HIPCC       ?= $(ROCM_PATH)/bin/hipcc
+       GPU_TARGETS ?= $(shell $(ROCM_PATH)/llvm/bin/amdgpu-arch)
+       CFLAGS      += -DGGML_USE_HIPBLAS -DGGML_USE_CUBLAS
+       CXXFLAGS    += -DGGML_USE_HIPBLAS -DGGML_USE_CUBLAS
+       LDFLAGS     += -L$(ROCM_PATH)/lib -Wl,-rpath=$(ROCM_PATH)/lib
+       LDFLAGS     += -lhipblas -lamdhip64 -lrocblas
+       HIPFLAGS    += $(addprefix --offload-arch=,$(GPU_TARGETS))
+       WHISPER_OBJ += ggml-cuda.o
+
+ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
+       $(HIPCC) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $<
+endif
+
 ifdef WHISPER_CLBLAST
        CFLAGS          += -DGGML_USE_CLBLAST
        CXXFLAGS        += -DGGML_USE_CLBLAST
index 50df20edd7a7b211324e0af1c18e970bc250251d..694e0bcb83704153ff91ab2667b8b54902d03365 100644 (file)
@@ -6,9 +6,60 @@
 #include <atomic>
 #include <assert.h>
 
+#if defined(GGML_USE_HIPBLAS)
+#include <hip/hip_runtime.h>
+#include <hipblas/hipblas.h>
+#include <hip/hip_fp16.h>
+#include <rocblas/rocblas.h>
+#define CUBLAS_OP_N HIPBLAS_OP_N
+#define CUBLAS_OP_T HIPBLAS_OP_T
+#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
+#define CUBLAS_TF32_TENSOR_OP_MATH 0
+#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
+#define cublasCreate hipblasCreate
+#define cublasGetStatusString rocblas_status_to_string
+#define cublasHandle_t hipblasHandle_t
+#define cublasLoggerConfigure(logIsOn, logToStdOut, logToStdErr, logFileName) CUBLAS_STATUS_SUCCESS
+#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
+#define cublasSetStream hipblasSetStream
+#define cublasSgemm hipblasSgemm
+#define cublasStatus_t hipblasStatus_t
+#define cudaDeviceProp hipDeviceProp_t
+#define cudaDeviceSynchronize hipDeviceSynchronize
+#define cudaError_t hipError_t
+#define cudaEventCreateWithFlags hipEventCreateWithFlags
+#define cudaEventDestroy hipEventDestroy
+#define cudaEventDisableTiming hipEventDisableTiming
+#define cudaEventRecord hipEventRecord
+#define cudaEvent_t hipEvent_t
+#define cudaFree hipFree
+#define cudaFreeHost hipHostFree
+#define cudaGetDevice hipGetDevice
+#define cudaGetDeviceCount hipGetDeviceCount
+#define cudaGetDeviceProperties hipGetDeviceProperties
+#define cudaGetErrorString hipGetErrorString
+#define cudaGetLastError hipGetLastError
+#define cudaMalloc hipMalloc
+#define cudaMallocHost hipHostMalloc
+#define cudaMemcpy hipMemcpy
+#define cudaMemcpy2DAsync hipMemcpy2DAsync
+#define cudaMemcpyAsync hipMemcpyAsync
+#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
+#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
+#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
+#define cudaMemcpyKind hipMemcpyKind
+#define cudaMemset hipMemset
+#define cudaSetDevice hipSetDevice
+#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
+#define cudaStreamNonBlocking hipStreamNonBlocking
+#define cudaStreamWaitEvent(stream, event) hipStreamWaitEvent(stream, event, 0)
+#define cudaStream_t hipStream_t
+#define cudaSuccess hipSuccess
+#else
 #include <cuda_runtime.h>
 #include <cublas_v2.h>
 #include <cuda_fp16.h>
+#endif
 
 #include "ggml-cuda.h"
 #include "ggml.h"