#include <atomic>
#include <assert.h>
+#if defined(GGML_USE_HIPBLAS)
+#include <hip/hip_runtime.h>
+#include <hipblas/hipblas.h>
+#include <hip/hip_fp16.h>
+#ifdef __HIP_PLATFORM_AMD__
+// for rocblas_initialize()
+#include "rocblas/rocblas.h"
+#endif
+#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
+#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
+#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
+#define CUBLAS_OP_N HIPBLAS_OP_N
+#define CUBLAS_OP_T HIPBLAS_OP_T
+#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
+#define CUBLAS_TF32_TENSOR_OP_MATH 0
+#define CUDA_R_16F HIPBLAS_R_16F
+#define CUDA_R_32F HIPBLAS_R_32F
+#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
+#define cublasCreate hipblasCreate
+#define cublasGemmEx hipblasGemmEx
+#define cublasHandle_t hipblasHandle_t
+#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
+#define cublasSetStream hipblasSetStream
+#define cublasSgemm hipblasSgemm
+#define cublasStatus_t hipblasStatus_t
+#define cudaDeviceProp hipDeviceProp_t
+#define cudaDeviceSynchronize hipDeviceSynchronize
+#define cudaError_t hipError_t
+#define cudaEventCreateWithFlags hipEventCreateWithFlags
+#define cudaEventDisableTiming hipEventDisableTiming
+#define cudaEventRecord hipEventRecord
+#define cudaEvent_t hipEvent_t
+#define cudaEventDestroy hipEventDestroy
+#define cudaFree hipFree
+#define cudaFreeHost hipHostFree
+#define cudaGetDevice hipGetDevice
+#define cudaGetDeviceCount hipGetDeviceCount
+#define cudaGetDeviceProperties hipGetDeviceProperties
+#define cudaGetErrorString hipGetErrorString
+#define cudaGetLastError hipGetLastError
+#define cudaMalloc hipMalloc
+#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
+#define cudaMemcpy hipMemcpy
+#define cudaMemcpy2DAsync hipMemcpy2DAsync
+#define cudaMemcpyAsync hipMemcpyAsync
+#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
+#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
+#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
+#define cudaMemcpyKind hipMemcpyKind
+#define cudaMemset hipMemset
+#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
+#define cudaSetDevice hipSetDevice
+#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
+#define cudaStreamNonBlocking hipStreamNonBlocking
+#define cudaStreamSynchronize hipStreamSynchronize
+#define cudaStreamWaitEvent(stream, event) hipStreamWaitEvent(stream, event, 0)
+#define cudaStream_t hipStream_t
+#define cudaSuccess hipSuccess
+#else
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <cuda_fp16.h>
+#endif
#include "ggml-cuda.h"
#include "ggml.h"
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
+#ifndef CC_TURING
#define CC_TURING 700
+#endif
+
+#if defined(GGML_USE_HIPBLAS)
+#define __CUDA_ARCH__ 1300
+
+typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
+static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
+ const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
+ const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
+ const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
+ return reinterpret_cast<const int&>(c);
+}
+
+static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
+#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__)
+ c = __builtin_amdgcn_sdot4(a, b, c, false);
+#elif defined(__gfx1100__)
+ c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
+#elif defined(__gfx1010__) || defined(__gfx900__)
+ int tmp1;
+ int tmp2;
+ asm("\n \
+ v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \
+ v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \
+ v_add3_u32 %0, %1, %2, %0 \n \
+ v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \
+ v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \
+ v_add3_u32 %0, %1, %2, %0 \n \
+ "
+ : "+v"(c), "=&v"(tmp1), "=&v"(tmp2)
+ : "v"(a), "v"(b)
+ );
+#else
+ const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
+ const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
+ c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3];
+#endif
+ return c;
+}
+#endif
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
const block_q4_1 * x = (const block_q4_1 *) vx;
- const dfloat d = x[ib].dm.x;
- const dfloat m = x[ib].dm.y;
+ const dfloat d = __low2half(x[ib].dm);
+ const dfloat m = __high2half(x[ib].dm);
const int vui = x[ib].qs[iqs];
static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
const block_q5_1 * x = (const block_q5_1 *) vx;
- const dfloat d = x[ib].dm.x;
- const dfloat m = x[ib].dm.y;
+ const dfloat d = __low2half(x[ib].dm);
+ const dfloat m = __high2half(x[ib].dm);
uint32_t qh;
memcpy(&qh, x[ib].qh, sizeof(qh));
const uint8_t q = x[i].qs[32*n + l];
float * y = yy + i*QK_K + 128*n;
- float dall = x[i].dm.x;
- float dmin = x[i].dm.y;
+ float dall = __low2half(x[i].dm);
+ float dmin = __high2half(x[i].dm);
y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
const int il = tid%16; // 0...15
const uint8_t q = x[i].qs[il] >> (2*is);
float * y = yy + i*QK_K + 16*is + il;
- float dall = x[i].dm.x;
- float dmin = x[i].dm.y;
+ float dall = __low2half(x[i].dm);
+ float dmin = __high2half(x[i].dm);
y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4);
#endif
float * y = yy + i*QK_K + 64*il + n*ir;
- const float dall = x[i].dm.x;
- const float dmin = x[i].dm.y;
+ const float dall = __low2half(x[i].dm);
+ const float dmin = __high2half(x[i].dm);
const uint8_t * q = x[i].qs + 32*il + n*ir;
float * y = yy + i*QK_K + 64*il + 2*ir;
- const float dall = x[i].dm.x;
- const float dmin = x[i].dm.y;
+ const float dall = __low2half(x[i].dm);
+ const float dmin = __high2half(x[i].dm);
const uint8_t * ql = x[i].qs + 32*il + 2*ir;
const uint8_t * qh = x[i].qh + 2*ir;
const float * y = yy + i * QK_K + y_offset;
const uint8_t * q = x[i].qs + q_offset;
- const float dall = x[i].dm.x;
- const float dmin = x[i].dm.y;
+ const float dall = __low2half(x[i].dm);
+ const float dmin = __high2half(x[i].dm);
const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset);
aux[0] = a[0] & 0x0f0f0f0f;
const float * y1 = yy + i*QK_K + y_offset;
const float * y2 = y1 + 128;
- const float dall = x[i].dm.x;
- const float dmin = x[i].dm.y;
+ const float dall = __low2half(x[i].dm);
+ const float dmin = __high2half(x[i].dm);
const uint16_t * a = (const uint16_t *)x[i].scales;
aux[0] = a[im+0] & kmask1;
const float * y1 = yy + i*QK_K + y_offset;
const float * y2 = y1 + 128;
- const float dall = x[i].dm.x;
- const float dmin = x[i].dm.y;
+ const float dall = __low2half(x[i].dm);
+ const float dmin = __high2half(x[i].dm);
const uint16_t * a = (const uint16_t *)x[i].scales;
aux[0] = a[im+0] & kmask1;
return;
}
- y[ib].ds.x = d;
- y[ib].ds.y = sum;
+ reinterpret_cast<half&>(y[ib].ds.x) = d;
+ reinterpret_cast<half&>(y[ib].ds.y) = sum;
}
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
u[i] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
}
- return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMVQ>(v, u, bq8_0->d, bq8_1->ds.x);
+ return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMVQ>(v, u, bq8_0->d, __low2half(bq8_1->ds));
}
template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
#pragma unroll
for (int i = 0; i < QR2_K; ++ i) {
u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
- d8[i] = bq8_1[bq8_offset + i].ds.x;
+ d8[i] = __low2half(bq8_1[bq8_offset + i].ds);
}
return vec_dot_q2_K_q8_1_impl_mmvq(v, u, scales, bq2_K->dm, d8);
#pragma unroll
for (int i = 0; i < QR3_K; ++i) {
u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
- d8[i] = bq8_1[bq8_offset + i].ds.x;
+ d8[i] = __low2half(bq8_1[bq8_offset + i].ds);
}
return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8);
for (int i = 0; i < QR4_K; ++i) {
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
- d8[i] = bq8i->ds.x;
+ d8[i] = __low2half(bq8i->ds);
const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
u[2*i+0] = q8[0];
const float dall = bq4_K->d[0];
const float dmin = bq4_K->d[1];
- const float d8_1 = bq8_1[0].ds.x;
- const float d8_2 = bq8_1[1].ds.x;
+ const float d8_1 = __low2float(bq8_1[0].ds);
+ const float d8_2 = __low2float(bq8_1[1].ds);
const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2));
const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4);
#pragma unroll
for (int i = 0; i < QR5_K; ++i) {
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
- d8[i] = bq8i->ds.x;
+ d8[i] = __low2float(bq8i->ds);
const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
u[2*i+0] = q8[0];
const float d = bq5_K->d;
- const float d8_1 = bq8_1[0].ds.x;
- const float d8_2 = bq8_1[1].ds.x;
+ const float d8_1 = __low2half(bq8_1[0].ds);
+ const float d8_2 = __low2half(bq8_1[1].ds);
const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2));
const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4);
#pragma unroll
for (int i = 0; i < QR6_K; ++i) {
u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1);
- d8[i] = bq8_1[bq8_offset + 2*i].ds.x;
+ d8[i] = __low2half(bq8_1[bq8_offset + 2*i].ds);
}
return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8);
*dsi_dst = *dsi_src;
} else {
float * dfi_dst = (float *) dsi_dst;
- *dfi_dst = (*dsi_src).x;
+ *dfi_dst = __low2half(*dsi_src);
}
}
static bool initialized = false;
if (!initialized) {
+
+#ifdef __HIP_PLATFORM_AMD__
+ // Workaround for a rocBLAS bug when using multiple graphics cards:
+ // https://github.com/ROCmSoftwarePlatform/rocBLAS/issues/1346
+ rocblas_initialize();
+ CUDA_CHECK(cudaDeviceSynchronize());
+#endif
+
CUDA_CHECK(cudaGetDeviceCount(&g_device_count));
GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES);
int64_t total_vram = 0;
- fprintf(stderr, "%s: found %d CUDA devices:\n", __func__, g_device_count);
+ fprintf(stderr, "%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, g_device_count);
for (int id = 0; id < g_device_count; ++id) {
cudaDeviceProp prop;
CUDA_CHECK(cudaGetDeviceProperties(&prop, id));