} \
} while (0)
+typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1);
typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
+typedef void (*dequantize_mul_mat_vec_cuda_t)(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream);
+
+// QK = number of values after dequantization
+// QR = QK / number of values before dequantization
#define QK4_0 32
+#define QR4_0 2
typedef struct {
float d; // delta
uint8_t qs[QK4_0 / 2]; // nibbles / quants
static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding");
#define QK4_1 32
+#define QR4_1 2
typedef struct {
float d; // delta
float m; // min
static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
#define QK5_0 32
+#define QR5_0 2
typedef struct {
half d; // delta
uint8_t qh[4]; // 5-th bit of quants
static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
#define QK5_1 32
+#define QR5_1 2
typedef struct {
half d; // delta
half m; // min
static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
#define QK8_0 32
+#define QR8_0 1
typedef struct {
float d; // delta
int8_t qs[QK8_0]; // quants
} block_q8_0;
static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
+#define CUDA_DMMV_BLOCK_SIZE 32
+
+static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
+ const block_q4_0 * x = (const block_q4_0 *) vx;
+
+ const float d = x[ib].d;
+
+ const uint8_t vui = x[ib].qs[iqs];
+
+ const int8_t vi0 = vui & 0xF;
+ const int8_t vi1 = vui >> 4;
+
+ v0 = (vi0 - 8)*d;
+ v1 = (vi1 - 8)*d;
+}
+
+static __device__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){
+ const block_q4_1 * x = (const block_q4_1 *) vx;
+
+ const float d = x[ib].d;
+ const float m = x[ib].m;
+
+ const uint8_t vui = x[ib].qs[iqs];
+
+ const int8_t vi0 = vui & 0xF;
+ const int8_t vi1 = vui >> 4;
+
+ v0 = vi0*d + m;
+ v1 = vi1*d + m;
+}
+
+static __device__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
+ const block_q5_0 * x = (const block_q5_0 *) vx;
+
+ const float d = x[ib].d;
+
+ uint32_t qh;
+ memcpy(&qh, x[ib].qh, sizeof(qh));
+
+ const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
+ const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
+
+ const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0) - 16;
+ const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1) - 16;
+
+ v0 = x0*d;
+ v1 = x1*d;
+}
+
+static __device__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){
+ const block_q5_1 * x = (const block_q5_1 *) vx;
+
+ const float d = x[ib].d;
+ const float m = x[ib].m;
+
+ uint32_t qh;
+ memcpy(&qh, x[ib].qh, sizeof(qh));
+
+ const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
+ const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
+
+ const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0);
+ const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1);
+
+ v0 = x0*d + m;
+ v1 = x1*d + m;
+}
+
+static __device__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
+ const block_q8_0 * x = (const block_q8_0 *) vx;
+
+ const float d = x[ib].d;
+
+ const int8_t vi0 = x[ib].qs[iqs + 0];
+ const int8_t vi1 = x[ib].qs[iqs + 1];
+
+ v0 = vi0*d;
+ v1 = vi1*d;
+}
+
+static __device__ void convert_f16(const void * vx, const int ib, const int iqs, float & v0, float & v1){
+ const half * x = (const half *) vx;
+
+ v0 = __half2float(x[ib + 0]);
+ v1 = __half2float(x[ib + 1]);
+}
+
static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
static const int qk = QK4_0;
}
}
+template <int block_size, int qk, int qr, dequantize_kernel_t dequantize_kernel>
+static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) {
+ const int row = blockIdx.x;
+ const int tid = threadIdx.x;
+
+ const int y_offset = qr == 1 ? 1 : qk/2;
+
+ __shared__ float tmp[block_size]; // separate sum for each thread
+ tmp[tid] = 0;
+
+ for (int i = 0; i < ncols/block_size; i += 2) {
+ const int col = i*block_size + 2*tid;
+ const int ib = (row*ncols + col)/qk; // block index
+ const int iqs = (col%qk)/qr; // quant index
+ const int iybs = col - col%qk; // y block start index
+
+ // dequantize
+ float v0, v1;
+ dequantize_kernel(vx, ib, iqs, v0, v1);
+
+ // matrix multiplication
+ tmp[tid] += v0 * y[iybs + iqs + 0];
+ tmp[tid] += v1 * y[iybs + iqs + y_offset];
+ }
+
+ // sum up partial sums and write back result
+ __syncthreads();
+ for (int s=block_size/2; s>0; s>>=1) {
+ if (tid < s) {
+ tmp[tid] += tmp[tid + s];
+ }
+ __syncthreads();
+ }
+ if (tid == 0) {
+ dst[row] = tmp[0];
+ }
+}
+
static void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK4_0;
dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y);
}
+static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+ GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
+ dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_0, QR4_0, dequantize_q4_0>
+ <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
+}
+
+static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+ GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
+ dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_1, QR4_1, dequantize_q4_1>
+ <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
+}
+
+static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+ GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
+ dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_0, QR5_0, dequantize_q5_0>
+ <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
+}
+
+static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+ GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
+ dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_1, QR5_1, dequantize_q5_1>
+ <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
+}
+
+static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+ GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
+ dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK8_0, QR8_0, dequantize_q8_0>
+ <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
+}
+
// TODO: optimize
static __global__ void convert_fp16_to_fp32(const void * vx, float * y) {
const half * x = (const half *) vx;
convert_fp16_to_fp32<<<k, 1, 0, stream>>>(x, y);
}
+static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+ GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
+ dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, 32, 1, convert_f16>
+ <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
+}
+
static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
switch (type) {
case GGML_TYPE_Q4_0:
}
}
+static dequantize_mul_mat_vec_cuda_t ggml_get_dequantize_mul_mat_vec_cuda(ggml_type type) {
+ switch (type) {
+ case GGML_TYPE_Q4_0:
+ return dequantize_mul_mat_vec_q4_0_cuda;
+ case GGML_TYPE_Q4_1:
+ return dequantize_mul_mat_vec_q4_1_cuda;
+ case GGML_TYPE_Q5_0:
+ return dequantize_mul_mat_vec_q5_0_cuda;
+ case GGML_TYPE_Q5_1:
+ return dequantize_mul_mat_vec_q5_1_cuda;
+ case GGML_TYPE_Q8_0:
+ return dequantize_mul_mat_vec_q8_0_cuda;
+ case GGML_TYPE_F16:
+ return dequantize_mul_mat_vec_q8_0_cuda;
+ default:
+ return nullptr;
+ }
+}
+
// buffer pool for cuda
-#define MAX_CUDA_BUFFERS 16
+#define MAX_CUDA_BUFFERS 256
struct scoped_spin_lock {
std::atomic_flag& lock;
const int nb2 = dst->nb[2];
const int nb3 = dst->nb[3];
const ggml_type type = src0->type;
+ const bool mul_mat_vec = ne11 == 1;
const float alpha = 1.0f;
const float beta = 0.0f;
const size_t q_sz = ggml_type_size(type) * x_ne / ggml_blck_size(type);
size_t x_size, y_size, d_size, q_size;
- float * d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
+ float * d_X = nullptr;
+ if (!mul_mat_vec) {
+ d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
+ }
float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size);
float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
char * d_Q = (char *) ggml_cuda_pool_malloc(n_mm * q_sz, &q_size);
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(type);
+ dequantize_mul_mat_vec_cuda_t dmmv = ggml_get_dequantize_mul_mat_vec_cuda(type);
GGML_ASSERT(to_fp32_cuda != nullptr);
for (int64_t i03 = 0; i03 < ne03; i03++) {
cudaStream_t cudaStream2 = g_cudaStreams2[i % GGML_CUDA_MAX_STREAMS];
cudaEvent_t cudaEvent = g_cudaEvents[i % GGML_CUDA_MAX_EVENTS];
- float * c_X = d_X + i * x_ne;
float * c_Y = d_Y + i * y_ne;
float * c_D = d_D + i * d_ne;
char * c_Q = d_Q + i * q_sz;
- // copy src0 and convert to fp32 on device
- CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2));
- to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2);
- CUDA_CHECK(cudaGetLastError());
- CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
+ // copy src0 to device if necessary
+ if (src0->backend == GGML_BACKEND_CPU) {
+ CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2));
+ } else if (src0->backend == GGML_BACKEND_CUDA) {
+ c_Q = ((char *) src0->data) + i * q_sz;
+ } else {
+ GGML_ASSERT(false);
+ }
+ if (mul_mat_vec) { // specialized dequantize_mul_mat_vec kernel
+ CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
- // copy src1 to device
- CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
+ // copy src1 to device
+ CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
- // wait for conversion
- CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
+ // wait for data
+ CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
- // compute
- CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
- CUBLAS_CHECK(
- cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
- ne01, ne11, ne10,
- &alpha, c_X, ne00,
- c_Y, ne10,
- &beta, c_D, ne01));
+ // compute
+ dmmv(c_Q, c_Y, c_D, ne00, ne01, cudaStream);
+ CUDA_CHECK(cudaGetLastError());
+
+ } else { // general dequantization kernel + cuBLAS matrix matrix multiplication
+ float * c_X = d_X + i * x_ne;
+
+ // convert src0 to fp32 on device
+ to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2);
+ CUDA_CHECK(cudaGetLastError());
+ CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
+
+ // copy src1 to device
+ CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
+
+ // wait for conversion
+ CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
+
+ // compute
+ CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
+ CUBLAS_CHECK(
+ cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
+ ne01, ne11, ne10,
+ &alpha, c_X, ne00,
+ c_Y, ne10,
+ &beta, c_D, ne01));
+ }
// copy dst to host
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
}
CUDA_CHECK(cudaDeviceSynchronize());
- ggml_cuda_pool_free(d_X, x_size);
+ if (!mul_mat_vec) {
+ ggml_cuda_pool_free(d_X, x_size);
+ }
ggml_cuda_pool_free(d_Y, y_size);
ggml_cuda_pool_free(d_D, d_size);
ggml_cuda_pool_free(d_Q, q_size);
if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
src1->type == GGML_TYPE_F32 &&
dst->type == GGML_TYPE_F32 &&
- (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
-
+ ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_CUDA)) {
return true;
}
return 0;
}
}
+
+void ggml_cuda_transform_tensor(ggml_tensor * tensor) {
+ const int64_t ne0 = tensor->ne[0];
+ const int64_t ne1 = tensor->ne[1];
+ const int64_t ne2 = tensor->ne[2];
+ const int64_t ne3 = tensor->ne[3];
+
+ const ggml_type type = tensor->type;
+ const size_t q_sz = ggml_type_size(type) * ne0 * ne1 * ne2 * ne3 / ggml_blck_size(type);
+
+ size_t q_size;
+ char * d_Q = (char *) ggml_cuda_pool_malloc(q_sz, &q_size);
+
+ cudaStream_t cudaStream2 = g_cudaStreams2[0];
+
+ // copy tensor to device
+ CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, tensor, 0, 0, cudaStream2));
+ CUDA_CHECK(cudaDeviceSynchronize());
+
+ tensor->data = d_Q;
+ tensor->backend = GGML_BACKEND_CUDA;
+}