]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ROCm Port (#1087)
authorHenri Vasserman <redacted>
Fri, 25 Aug 2023 09:09:42 +0000 (12:09 +0300)
committerGitHub <redacted>
Fri, 25 Aug 2023 09:09:42 +0000 (12:09 +0300)
* use hipblas based on cublas
* Update Makefile for the Cuda kernels
* Expand arch list and make it overrideable
* Fix multi GPU on multiple amd architectures with rocblas_initialize() (#5)
* add hipBLAS to README
* new build arg LLAMA_CUDA_MMQ_Y
* fix half2 decomposition
* Add intrinsics polyfills for AMD
* AMD assembly optimized __dp4a
* Allow overriding CC_TURING
* use "ROCm" instead of "CUDA"
* ignore all build dirs
* Add Dockerfiles
* fix llama-bench
* fix -nommq help for non CUDA/HIP

---------

Co-authored-by: YellowRoseCx <redacted>
Co-authored-by: ardfork <redacted>
Co-authored-by: funnbot <redacted>
Co-authored-by: Engininja2 <redacted>
Co-authored-by: Kerfuffle <redacted>
Co-authored-by: jammm <redacted>
Co-authored-by: jdecourval <redacted>
12 files changed:
.devops/full-rocm.Dockerfile [new file with mode: 0644]
.devops/main-rocm.Dockerfile [new file with mode: 0644]
.dockerignore
.gitignore
CMakeLists.txt
Makefile
README.md
common/common.cpp
examples/llama-bench/llama-bench.cpp
ggml-cuda.cu
ggml-cuda.h
llama.cpp

diff --git a/.devops/full-rocm.Dockerfile b/.devops/full-rocm.Dockerfile
new file mode 100644 (file)
index 0000000..6c521e9
--- /dev/null
@@ -0,0 +1,44 @@
+ARG UBUNTU_VERSION=22.04
+
+# This needs to generally match the container host's environment.
+ARG ROCM_VERSION=5.6
+
+# Target the CUDA build image
+ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-complete
+
+FROM ${BASE_ROCM_DEV_CONTAINER} as build
+
+# Unless otherwise specified, we make a fat build.
+# List from https://github.com/ggerganov/llama.cpp/pull/1087#issuecomment-1682807878
+# This is mostly tied to rocBLAS supported archs.
+ARG ROCM_DOCKER_ARCH=\
+    gfx803 \
+    gfx900 \
+    gfx906 \
+    gfx908 \
+    gfx90a \
+    gfx1010 \
+    gfx1030 \
+    gfx1100 \
+    gfx1101 \
+    gfx1102
+
+COPY requirements.txt requirements.txt
+
+RUN pip install --upgrade pip setuptools wheel \
+    && pip install -r requirements.txt
+
+WORKDIR /app
+
+COPY . .
+
+# Set nvcc architecture
+ENV GPU_TARGETS=${ROCM_DOCKER_ARCH}
+# Enable ROCm
+ENV LLAMA_HIPBLAS=1
+ENV CC=/opt/rocm/llvm/bin/clang
+ENV CXX=/opt/rocm/llvm/bin/clang++
+
+RUN make
+
+ENTRYPOINT ["/app/.devops/tools.sh"]
diff --git a/.devops/main-rocm.Dockerfile b/.devops/main-rocm.Dockerfile
new file mode 100644 (file)
index 0000000..789deff
--- /dev/null
@@ -0,0 +1,44 @@
+ARG UBUNTU_VERSION=22.04
+
+# This needs to generally match the container host's environment.
+ARG ROCM_VERSION=5.6
+
+# Target the CUDA build image
+ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-complete
+
+FROM ${BASE_ROCM_DEV_CONTAINER} as build
+
+# Unless otherwise specified, we make a fat build.
+# List from https://github.com/ggerganov/llama.cpp/pull/1087#issuecomment-1682807878
+# This is mostly tied to rocBLAS supported archs.
+ARG ROCM_DOCKER_ARCH=\
+    gfx803 \
+    gfx900 \
+    gfx906 \
+    gfx908 \
+    gfx90a \
+    gfx1010 \
+    gfx1030 \
+    gfx1100 \
+    gfx1101 \
+    gfx1102
+
+COPY requirements.txt requirements.txt
+
+RUN pip install --upgrade pip setuptools wheel \
+    && pip install -r requirements.txt
+
+WORKDIR /app
+
+COPY . .
+
+# Set nvcc architecture
+ENV GPU_TARGETS=${ROCM_DOCKER_ARCH}
+# Enable ROCm
+ENV LLAMA_HIPBLAS=1
+ENV CC=/opt/rocm/llvm/bin/clang
+ENV CXX=/opt/rocm/llvm/bin/clang++
+
+RUN make
+
+ENTRYPOINT [ "/app/main" ]
index 462fac23a69321cd7ba478493dbaeac226155b19..c6ef6c86c9fe141f4c485c40814f763014abb989 100644 (file)
@@ -5,14 +5,7 @@
 .vscode/
 .DS_Store
 
-build/
-build-em/
-build-debug/
-build-release/
-build-static/
-build-no-accel/
-build-sanitize-addr/
-build-sanitize-thread/
+build*/
 
 models/*
 
index 6cb7d9bc64d11c6a684e763f93e9318b59d26d57..e5faab774bed760cb159480324fb769f9875924b 100644 (file)
 .vs/
 .vscode/
 
-build/
-build-em/
-build-debug/
-build-release/
-build-ci-debug/
-build-ci-release/
-build-static/
-build-cublas/
-build-opencl/
-build-metal/
-build-mpi/
-build-no-accel/
-build-sanitize-addr/
-build-sanitize-thread/
+build*/
 out/
 tmp/
 
index bb63ef98e30134aecc5f40cda5a30eefbd5279e2..ba008bcc66da50bc892af1897f80442eefd90c8b 100644 (file)
@@ -74,6 +74,7 @@ set(LLAMA_CUDA_DMMV_X      "32" CACHE STRING "llama: x stride for dmmv CUDA kern
 set(LLAMA_CUDA_MMV_Y        "1" CACHE STRING "llama: y block size for mmv CUDA kernels")
 option(LLAMA_CUDA_F16                        "llama: use 16 bit floats for some calculations"   OFF)
 set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for Q2_K/Q6_K")
+option(LLAMA_HIPBLAS                         "llama: use hipBLAS"                               OFF)
 option(LLAMA_CLBLAST                         "llama: use CLBlast"                               OFF)
 option(LLAMA_METAL                           "llama: use Metal"                                 OFF)
 option(LLAMA_MPI                             "llama: use MPI"                                   OFF)
@@ -352,6 +353,43 @@ if (LLAMA_CLBLAST)
     endif()
 endif()
 
+if (LLAMA_HIPBLAS)
+    list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
+
+    if (NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang")
+        message(WARNING "Only LLVM is supported for HIP, hint: CC=/opt/rocm/llvm/bin/clang")
+    endif()
+    if (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang")
+        message(WARNING "Only LLVM is supported for HIP, hint: CXX=/opt/rocm/llvm/bin/clang++")
+    endif()
+
+    find_package(hip)
+    find_package(hipblas)
+    find_package(rocblas)
+
+    if (${hipblas_FOUND} AND ${hip_FOUND})
+        message(STATUS "HIP and hipBLAS found")
+        add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUBLAS)
+        add_library(ggml-rocm OBJECT ggml-cuda.cu ggml-cuda.h)
+        if (LLAMA_CUDA_FORCE_DMMV)
+            target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_FORCE_DMMV)
+        endif()
+        target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X})
+        target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y})
+        target_compile_definitions(ggml-rocm PRIVATE K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER})
+        target_compile_definitions(ggml-rocm PRIVATE CC_TURING=1000000000)
+        set_source_files_properties(ggml-cuda.cu PROPERTIES LANGUAGE CXX)
+        target_link_libraries(ggml-rocm PRIVATE hip::device PUBLIC hip::host roc::rocblas roc::hipblas)
+
+        if (LLAMA_STATIC)
+            message(FATAL_ERROR "Static linking not supported for HIP/ROCm")
+        endif()
+        set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ggml-rocm)
+    else()
+        message(WARNING "hipBLAS or HIP not found. Try setting CMAKE_PREFIX_PATH=/opt/rocm")
+    endif()
+endif()
+
 if (LLAMA_ALL_WARNINGS)
     if (NOT MSVC)
         set(c_flags
index d31acc450b26154f4ac26a68863b0e6503168407..a3400e491a0c17baf10505d589e0429f7ac24ee0 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -280,6 +280,30 @@ ggml-opencl.o: ggml-opencl.cpp ggml-opencl.h
        $(CXX) $(CXXFLAGS) -c $< -o $@
 endif # LLAMA_CLBLAST
 
+ifdef LLAMA_HIPBLAS
+       ROCM_PATH       ?= /opt/rocm
+       HIPCC       ?= $(ROCM_PATH)/bin/hipcc
+       GPU_TARGETS ?= $(shell $(ROCM_PATH)/llvm/bin/amdgpu-arch)
+       LLAMA_CUDA_DMMV_X       ?= 32
+       LLAMA_CUDA_MMV_Y        ?= 1
+       LLAMA_CUDA_KQUANTS_ITER ?= 2
+       CFLAGS      += -DGGML_USE_HIPBLAS -DGGML_USE_CUBLAS
+       CXXFLAGS    += -DGGML_USE_HIPBLAS -DGGML_USE_CUBLAS
+       LDFLAGS     += -L$(ROCM_PATH)/lib -Wl,-rpath=$(ROCM_PATH)/lib
+       LDFLAGS         += -lhipblas -lamdhip64 -lrocblas
+       HIPFLAGS    += $(addprefix --offload-arch=,$(GPU_TARGETS))
+       HIPFLAGS    += -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X)
+       HIPFLAGS    += -DGGML_CUDA_MMV_Y=$(LLAMA_CUDA_MMV_Y)
+       HIPFLAGS    += -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER)
+       HIPFLAGS    += -DCC_TURING=1000000000
+ifdef LLAMA_CUDA_FORCE_DMMV
+       HIPFLAGS        += -DGGML_CUDA_FORCE_DMMV
+endif # LLAMA_CUDA_FORCE_DMMV
+       OBJS        += ggml-cuda.o
+ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
+       $(HIPCC) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $<
+endif # LLAMA_HIPBLAS
+
 ifdef LLAMA_METAL
        CFLAGS   += -DGGML_USE_METAL -DGGML_METAL_NDEBUG
        CXXFLAGS += -DGGML_USE_METAL
index eebb113929934fc834b448fd26f4f6ceef082dbe..95471fdbb145ba868bb2b86137005aaf6c2bc435 100644 (file)
--- a/README.md
+++ b/README.md
@@ -422,6 +422,35 @@ Building the program with BLAS support may lead to some performance improvements
   | LLAMA_CUDA_F16          | Boolean                |   false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. |
   | LLAMA_CUDA_KQUANTS_ITER | 1 or 2                 |       2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. |
 
+- #### hipBLAS
+
+  This provide BLAS acceleation on HIP supported GPU like AMD GPU.
+  Make sure to have ROCm installed.
+  You can download it from your Linux distro's package manager or from here: [ROCm Quick Start (Linux)](https://rocm.docs.amd.com/en/latest/deploy/linux/quick_start.html).
+  Windows support is coming soon...
+
+  - Using `make`:
+    ```bash
+    make LLAMA_HIPBLAS=1
+    ```
+  - Using `CMake`:
+    ```bash
+    mkdir build
+    cd build
+    CC=/opt/rocm/llvm/bin/clang CXX=/opt/rocm/llvm/bin/clang++ cmake .. -DLLAMA_HIPBLAS=ON
+    cmake --build .
+    ```
+
+  The environment variable [`HIP_VISIBLE_DEVICES`](https://rocm.docs.amd.com/en/latest/understand/gpu_isolation.html#hip-visible-devices) can be used to specify which GPU(s) will be used.
+  If your GPU is not officialy supported you can use the environment variable [`HSA_OVERRIDE_GFX_VERSION`] set to a similar GPU, for example 10.3.0 on RDNA2 or 11.0.0 on RDNA3.
+  The following compilation options are also available to tweak performance (yes, they refer to CUDA, not HIP, because it uses the same code as the cuBLAS version above):
+
+  | Option                  | Legal values           | Default | Description |
+  |-------------------------|------------------------|---------|-------------|
+  | LLAMA_CUDA_DMMV_X       | Positive integer >= 32 |      32 | Number of values in x direction processed by the HIP dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. |
+  | LLAMA_CUDA_MMV_Y        | Positive integer       |       1 | Block size in y direction for the HIP mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. Does not affect k-quants. |
+  | LLAMA_CUDA_KQUANTS_ITER | 1 or 2                 |       2 | Number of values processed per iteration and per HIP thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. |
+
 - #### CLBlast
 
   OpenCL acceleration is provided by the matrix multiplication kernels from the [CLBlast](https://github.com/CNugteren/CLBlast) project and custom kernels for ggml that can generate tokens on the GPU.
index 53002ba306b572ca9d11609f3aceb063f51575a1..ff19ec4e50f60485b66ec940b9d50937c3e105fe 100644 (file)
@@ -613,9 +613,11 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     fprintf(stdout, "                        how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
     fprintf(stdout, "  -mg i, --main-gpu i   the GPU to use for scratch and small tensors\n");
     fprintf(stdout, "  -lv, --low-vram       don't allocate VRAM scratch buffer\n");
+#ifdef GGML_USE_CUBLAS
     fprintf(stdout, "  -nommq, --no-mul-mat-q\n");
-    fprintf(stdout, "                        use cuBLAS instead of custom mul_mat_q CUDA kernels.\n");
+    fprintf(stdout, "                        use " GGML_CUBLAS_NAME " instead of custom mul_mat_q " GGML_CUDA_NAME " kernels.\n");
     fprintf(stdout, "                        Not recommended since this is both slower and uses more VRAM.\n");
+#endif // GGML_USE_CUBLAS
 #endif
     fprintf(stdout, "  --mtest               compute maximum memory usage\n");
     fprintf(stdout, "  --export              export the computation graph to 'llama.ggml'\n");
index 36057bfca56056d466659b580a1490a7af1b2ac4..7a28115841fc3f0fb179d15486e9ace85b084391 100755 (executable)
@@ -18,9 +18,7 @@
 #include "llama.h"
 #include "common.h"
 #include "build-info.h"
-#ifdef GGML_USE_CUBLAS
 #include "ggml-cuda.h"
-#endif
 
 // utils
 static uint64_t get_time_ns() {
@@ -504,7 +502,7 @@ struct test {
 
     static std::string get_backend() {
         if (cuda) {
-            return "CUDA";
+            return GGML_CUDA_NAME;
         }
         if (opencl) {
             return "OpenCL";
index 3bd1caf23f86d674e6e55675164733774583be33..83d53c13c1a54a0ae48a16fef63a838e502e39cf 100644 (file)
 #include <atomic>
 #include <assert.h>
 
+#if defined(GGML_USE_HIPBLAS)
+#include <hip/hip_runtime.h>
+#include <hipblas/hipblas.h>
+#include <hip/hip_fp16.h>
+#ifdef __HIP_PLATFORM_AMD__
+// for rocblas_initialize()
+#include "rocblas/rocblas.h"
+#endif
+#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
+#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
+#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
+#define CUBLAS_OP_N HIPBLAS_OP_N
+#define CUBLAS_OP_T HIPBLAS_OP_T
+#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
+#define CUBLAS_TF32_TENSOR_OP_MATH 0
+#define CUDA_R_16F  HIPBLAS_R_16F
+#define CUDA_R_32F  HIPBLAS_R_32F
+#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
+#define cublasCreate hipblasCreate
+#define cublasGemmEx hipblasGemmEx
+#define cublasHandle_t hipblasHandle_t
+#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
+#define cublasSetStream hipblasSetStream
+#define cublasSgemm hipblasSgemm
+#define cublasStatus_t hipblasStatus_t
+#define cudaDeviceProp hipDeviceProp_t
+#define cudaDeviceSynchronize hipDeviceSynchronize
+#define cudaError_t hipError_t
+#define cudaEventCreateWithFlags hipEventCreateWithFlags
+#define cudaEventDisableTiming hipEventDisableTiming
+#define cudaEventRecord hipEventRecord
+#define cudaEvent_t hipEvent_t
+#define cudaEventDestroy hipEventDestroy
+#define cudaFree hipFree
+#define cudaFreeHost hipHostFree
+#define cudaGetDevice hipGetDevice
+#define cudaGetDeviceCount hipGetDeviceCount
+#define cudaGetDeviceProperties hipGetDeviceProperties
+#define cudaGetErrorString hipGetErrorString
+#define cudaGetLastError hipGetLastError
+#define cudaMalloc hipMalloc
+#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
+#define cudaMemcpy hipMemcpy
+#define cudaMemcpy2DAsync hipMemcpy2DAsync
+#define cudaMemcpyAsync hipMemcpyAsync
+#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
+#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
+#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
+#define cudaMemcpyKind hipMemcpyKind
+#define cudaMemset hipMemset
+#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
+#define cudaSetDevice hipSetDevice
+#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
+#define cudaStreamNonBlocking hipStreamNonBlocking
+#define cudaStreamSynchronize hipStreamSynchronize
+#define cudaStreamWaitEvent(stream, event) hipStreamWaitEvent(stream, event, 0)
+#define cudaStream_t hipStream_t
+#define cudaSuccess hipSuccess
+#else
 #include <cuda_runtime.h>
 #include <cublas_v2.h>
 #include <cuda_fp16.h>
+#endif
 
 #include "ggml-cuda.h"
 #include "ggml.h"
 
 #define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
+#ifndef CC_TURING
 #define CC_TURING   700
+#endif
+
+#if defined(GGML_USE_HIPBLAS)
+#define __CUDA_ARCH__ 1300
+
+typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
+static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
+    const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
+    const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
+    const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
+    return reinterpret_cast<const int&>(c);
+}
+
+static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
+#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__)
+    c = __builtin_amdgcn_sdot4(a, b, c, false);
+#elif defined(__gfx1100__)
+    c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
+#elif defined(__gfx1010__) || defined(__gfx900__)
+    int tmp1;
+    int tmp2;
+    asm("\n \
+        v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \
+        v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \
+        v_add3_u32 %0, %1, %2, %0 \n \
+        v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \
+        v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \
+        v_add3_u32 %0, %1, %2, %0 \n \
+        "
+        : "+v"(c), "=&v"(tmp1), "=&v"(tmp2)
+        : "v"(a), "v"(b)
+    );
+#else
+    const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
+    const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
+    c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3];
+#endif
+    return c;
+}
+#endif
 
 #if defined(_MSC_VER)
 #pragma warning(disable: 4244 4267) // possible loss of data
@@ -424,8 +525,8 @@ static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const in
 static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
     const block_q4_1 * x = (const block_q4_1 *) vx;
 
-    const dfloat d = x[ib].dm.x;
-    const dfloat m = x[ib].dm.y;
+    const dfloat d = __low2half(x[ib].dm);
+    const dfloat m = __high2half(x[ib].dm);
 
     const int vui = x[ib].qs[iqs];
 
@@ -467,8 +568,8 @@ static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const in
 static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
     const block_q5_1 * x = (const block_q5_1 *) vx;
 
-    const dfloat d = x[ib].dm.x;
-    const dfloat m = x[ib].dm.y;
+    const dfloat d = __low2half(x[ib].dm);
+    const dfloat m = __high2half(x[ib].dm);
 
     uint32_t qh;
     memcpy(&qh, x[ib].qh, sizeof(qh));
@@ -520,8 +621,8 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float
     const uint8_t q = x[i].qs[32*n + l];
     float * y = yy + i*QK_K + 128*n;
 
-    float dall = x[i].dm.x;
-    float dmin = x[i].dm.y;
+    float dall = __low2half(x[i].dm);
+    float dmin = __high2half(x[i].dm);
     y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
     y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
     y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
@@ -531,8 +632,8 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float
     const int il = tid%16;  // 0...15
     const uint8_t q = x[i].qs[il] >> (2*is);
     float * y = yy + i*QK_K + 16*is + il;
-    float dall = x[i].dm.x;
-    float dmin = x[i].dm.y;
+    float dall = __low2half(x[i].dm);
+    float dmin = __high2half(x[i].dm);
     y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
     y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4);
 #endif
@@ -618,8 +719,8 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float
 
     float * y = yy + i*QK_K + 64*il + n*ir;
 
-    const float dall = x[i].dm.x;
-    const float dmin = x[i].dm.y;
+    const float dall = __low2half(x[i].dm);
+    const float dmin = __high2half(x[i].dm);
 
     const uint8_t * q = x[i].qs + 32*il + n*ir;
 
@@ -657,8 +758,8 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float
 
     float * y = yy + i*QK_K + 64*il + 2*ir;
 
-    const float dall = x[i].dm.x;
-    const float dmin = x[i].dm.y;
+    const float dall = __low2half(x[i].dm);
+    const float dmin = __high2half(x[i].dm);
 
     const uint8_t * ql = x[i].qs + 32*il + 2*ir;
     const uint8_t * qh = x[i].qh + 2*ir;
@@ -770,8 +871,8 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
         const float   * y = yy + i * QK_K + y_offset;
         const uint8_t * q = x[i].qs + q_offset;
 
-        const float dall = x[i].dm.x;
-        const float dmin = x[i].dm.y;
+        const float dall = __low2half(x[i].dm);
+        const float dmin = __high2half(x[i].dm);
 
         const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset);
         aux[0] = a[0] & 0x0f0f0f0f;
@@ -991,8 +1092,8 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
         const float   * y1 = yy + i*QK_K + y_offset;
         const float   * y2 = y1 + 128;
 
-        const float dall = x[i].dm.x;
-        const float dmin = x[i].dm.y;
+        const float dall = __low2half(x[i].dm);
+        const float dmin = __high2half(x[i].dm);
 
         const uint16_t * a = (const uint16_t *)x[i].scales;
         aux[0] = a[im+0] & kmask1;
@@ -1124,8 +1225,8 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx,
         const float   * y1  = yy + i*QK_K + y_offset;
         const float   * y2  = y1 + 128;
 
-        const float dall = x[i].dm.x;
-        const float dmin = x[i].dm.y;
+        const float dall = __low2half(x[i].dm);
+        const float dmin = __high2half(x[i].dm);
 
         const uint16_t * a = (const uint16_t *)x[i].scales;
         aux[0] = a[im+0] & kmask1;
@@ -1348,8 +1449,8 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
         return;
     }
 
-    y[ib].ds.x = d;
-    y[ib].ds.y = sum;
+    reinterpret_cast<half&>(y[ib].ds.x) = d;
+    reinterpret_cast<half&>(y[ib].ds.y) = sum;
 }
 
 template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
@@ -2346,7 +2447,7 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
         u[i] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
     }
 
-    return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMVQ>(v, u, bq8_0->d, bq8_1->ds.x);
+    return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMVQ>(v, u, bq8_0->d, __low2half(bq8_1->ds));
 }
 
 template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
@@ -2432,7 +2533,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
 #pragma unroll
     for (int i = 0; i < QR2_K; ++ i) {
         u[i]  = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
-        d8[i] = bq8_1[bq8_offset + i].ds.x;
+        d8[i] = __low2half(bq8_1[bq8_offset + i].ds);
     }
 
     return vec_dot_q2_K_q8_1_impl_mmvq(v, u, scales, bq2_K->dm, d8);
@@ -2551,7 +2652,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1(
 #pragma unroll
     for (int i = 0; i < QR3_K; ++i) {
         u[i]  = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
-        d8[i] = bq8_1[bq8_offset + i].ds.x;
+        d8[i] = __low2half(bq8_1[bq8_offset + i].ds);
     }
 
     return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8);
@@ -2720,7 +2821,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
 
     for (int i = 0; i < QR4_K; ++i) {
         const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
-        d8[i] = bq8i->ds.x;
+        d8[i] = __low2half(bq8i->ds);
 
         const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
         u[2*i+0] = q8[0];
@@ -2747,8 +2848,8 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
     const float dall = bq4_K->d[0];
     const float dmin = bq4_K->d[1];
 
-    const float d8_1 = bq8_1[0].ds.x;
-    const float d8_2 = bq8_1[1].ds.x;
+    const float d8_1 = __low2float(bq8_1[0].ds);
+    const float d8_2 = __low2float(bq8_1[1].ds);
 
     const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2));
     const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4);
@@ -2901,7 +3002,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
 #pragma unroll
     for (int i = 0; i < QR5_K; ++i) {
         const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
-        d8[i] = bq8i->ds.x;
+        d8[i] = __low2float(bq8i->ds);
 
         const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
         u[2*i+0] = q8[0];
@@ -2919,8 +3020,8 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
 
     const float d = bq5_K->d;
 
-    const float d8_1 = bq8_1[0].ds.x;
-    const float d8_2 = bq8_1[1].ds.x;
+    const float d8_1 = __low2half(bq8_1[0].ds);
+    const float d8_2 = __low2half(bq8_1[1].ds);
 
     const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2));
     const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4);
@@ -3075,7 +3176,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
 #pragma unroll
     for (int i = 0; i < QR6_K; ++i) {
         u[i]  = get_int_from_int8_aligned(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1);
-        d8[i] = bq8_1[bq8_offset + 2*i].ds.x;
+        d8[i] = __low2half(bq8_1[bq8_offset + 2*i].ds);
     }
 
     return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8);
@@ -3243,7 +3344,7 @@ static __device__ __forceinline__ void mul_mat_q(
                     *dsi_dst = *dsi_src;
                 } else {
                     float * dfi_dst = (float *) dsi_dst;
-                    *dfi_dst = (*dsi_src).x;
+                    *dfi_dst = __low2half(*dsi_src);
                 }
             }
 
@@ -4944,10 +5045,18 @@ void ggml_init_cublas() {
     static bool initialized = false;
 
     if (!initialized) {
+
+#ifdef __HIP_PLATFORM_AMD__
+        // Workaround for a rocBLAS bug when using multiple graphics cards:
+        // https://github.com/ROCmSoftwarePlatform/rocBLAS/issues/1346
+        rocblas_initialize();
+        CUDA_CHECK(cudaDeviceSynchronize());
+#endif
+
         CUDA_CHECK(cudaGetDeviceCount(&g_device_count));
         GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES);
         int64_t total_vram = 0;
-        fprintf(stderr, "%s: found %d CUDA devices:\n", __func__, g_device_count);
+        fprintf(stderr, "%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, g_device_count);
         for (int id = 0; id < g_device_count; ++id) {
             cudaDeviceProp prop;
             CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
index f66bb16786af914867c5a49cd107f9ff85267e4d..a72e82069b9f1430fdbcc007f93ae26ea6c2f924 100644 (file)
@@ -2,6 +2,14 @@
 
 #include "ggml.h"
 
+#ifdef GGML_USE_HIPBLAS
+#define GGML_CUDA_NAME "ROCm"
+#define GGML_CUBLAS_NAME "hipBLAS"
+#else
+#define GGML_CUDA_NAME "CUDA"
+#define GGML_CUBLAS_NAME "cuBLAS"
+#endif
+
 #ifdef  __cplusplus
 extern "C" {
 #endif
index 52ba31d79bc1ea2686010a2bf88c433d1d72237e..d12b6d1cb0713f5a8836f6d77373d50532569e07 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -1836,7 +1836,7 @@ static void llm_load_tensors(
     (void) main_gpu;
     (void) mul_mat_q;
 #if defined(GGML_USE_CUBLAS)
-    LLAMA_LOG_INFO("%s: using CUDA for GPU acceleration\n", __func__);
+    LLAMA_LOG_INFO("%s: using " GGML_CUDA_NAME " for GPU acceleration\n", __func__);
     ggml_cuda_set_main_device(main_gpu);
     ggml_cuda_set_mul_mat_q(mul_mat_q);
 #define LLAMA_BACKEND_OFFLOAD       GGML_BACKEND_GPU