// for rocblas_initialize()
#include "rocblas/rocblas.h"
#endif // __HIP_PLATFORM_AMD__
+#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
+#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
#define CUBLAS_OP_N HIPBLAS_OP_N
#define CUBLAS_OP_T HIPBLAS_OP_T
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
return *((int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
}
+template<typename T>
+using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int k, cudaStream_t stream);
+typedef to_t_cuda_t<float> to_fp32_cuda_t;
+typedef to_t_cuda_t<half> to_fp16_cuda_t;
+
typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
-typedef void (*to_fp32_cuda_t)(const void * __restrict__ x, float * __restrict__ y, int k, cudaStream_t stream);
typedef void (*dot_kernel_k_t)(const void * __restrict__ vx, const int ib, const int iqs, const float * __restrict__ y, float & v);
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
typedef void (*ggml_cuda_func_t)(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
v.y = x[ib + iqs + 1];
}
+static __device__ void convert_f32(const void * vx, const int ib, const int iqs, dfloat2 & v){
+ const float * x = (const float *) vx;
+
+ // automatic half -> float type cast if dfloat == float
+ v.x = x[ib + iqs + 0];
+ v.y = x[ib + iqs + 1];
+}
+
static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) {
const int ix = blockDim.x*blockIdx.x + threadIdx.x;
reinterpret_cast<half&>(y[ib].ds.y) = sum;
}
-template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
-static __global__ void dequantize_block(const void * __restrict__ vx, float * __restrict__ y, const int k) {
+template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
+static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
const int i = blockDim.x*blockIdx.x + 2*threadIdx.x;
if (i >= k) {
dequantize_block<1, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
}
+static void convert_fp32_to_fp16_cuda(const void * vx, half * y, const int k, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
+ dequantize_block<1, 1, convert_f32><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
+}
+
static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
+static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
+ switch (type) {
+ case GGML_TYPE_F32:
+ return convert_fp32_to_fp16_cuda;
+ default:
+ return nullptr;
+ }
+}
+
static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
switch (type) {
case GGML_TYPE_Q4_0:
GGML_ASSERT(src1_ddf_i != nullptr);
GGML_ASSERT(dst_dd_i != nullptr);
- const float alpha = 1.0f;
- const float beta = 0.0f;
const int64_t ne00 = src0->ne[0];
const int64_t ne0 = dst->ne[0];
const int64_t row_diff = row_high - row_low;
- float * src0_ddq_as_f32;
- size_t src0_as = 0;
-
- if (src0->type != GGML_TYPE_F32) {
- const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
- src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as); // NOLINT
- to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream);
- }
- const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32;
-
int id;
CUDA_CHECK(cudaGetDevice(&id));
// ldc == nrows of the matrix that cuBLAS writes into
int ldc = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff;
- CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream));
- CUBLAS_CHECK(
- cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
- row_diff, src1_ncols, ne10,
- &alpha, src0_ddf_i, ne00,
- src1_ddf_i, ne10,
- &beta, dst_dd_i, ldc));
+ const int compute_capability = g_compute_capabilities[id];
+
+ if (compute_capability >= CC_TURING && src0->type == GGML_TYPE_F16 && ggml_is_contiguous(src0) && ldc == row_diff) {
+ // convert src1 to fp16, multiply as fp16, convert dst to fp32
+ half * src1_as_f16 = nullptr;
+ size_t src1_as = 0;
+ if (src1->type != GGML_TYPE_F16) {
+ const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
+ GGML_ASSERT(to_fp16_cuda != nullptr);
+ size_t ne = src1_ncols*ne10;
+ src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as);
+ to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
+ }
+ const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16;
+
+ size_t dst_as = 0;
+ half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as);
+
+ const half alpha_f16 = 1.0f;
+ const half beta_f16 = 0.0f;
+
+ CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream));
+ CUBLAS_CHECK(
+ cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
+ row_diff, src1_ncols, ne10,
+ &alpha_f16, src0_dd_i, CUDA_R_16F, ne00,
+ src1_ptr, CUDA_R_16F, ne10,
+ &beta_f16, dst_f16, CUDA_R_16F, ldc,
+ CUBLAS_COMPUTE_16F,
+ CUBLAS_GEMM_DEFAULT_TENSOR_OP));
+
+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
+ to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream);
+
+ ggml_cuda_pool_free(dst_f16, dst_as);
- if (src0_as > 0) {
- ggml_cuda_pool_free(src0_ddq_as_f32, src0_as);
+ if (src1_as != 0) {
+ ggml_cuda_pool_free(src1_as_f16, src1_as);
+ }
+ }
+ else {
+ float * src0_ddq_as_f32 = nullptr;
+ size_t src0_as = 0;
+
+ if (src0->type != GGML_TYPE_F32) {
+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
+ GGML_ASSERT(to_fp32_cuda != nullptr);
+ src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as); // NOLINT
+ to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream);
+ }
+ const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32;
+
+ const float alpha = 1.0f;
+ const float beta = 0.0f;
+
+ CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream));
+ CUBLAS_CHECK(
+ cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
+ row_diff, src1_ncols, ne10,
+ &alpha, src0_ddf_i, ne00,
+ src1_ddf_i, ne10,
+ &beta, dst_dd_i, ldc));
+
+ if (src0_as != 0) {
+ ggml_cuda_pool_free(src0_ddq_as_f32, src0_as);
+ }
}
(void) dst;