typedef void (*dot_kernel_k_t)(const void * __restrict__ vx, const int ib, const int iqs, const float * __restrict__ y, float & v);
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
typedef void (*ggml_cuda_func_t)(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
-typedef void (*ggml_cuda_op_t)(
- const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, float * src0_ddf_i,
- float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
- cudaStream_t & cudaStream_main);
+typedef void (*ggml_cuda_op_mul_mat_t)(
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+ const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+ const int64_t src1_padded_row_size, const cudaStream_t & stream);
+typedef void (*ggml_cuda_op_flatten_t)(
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream);
// QK = number of values after dequantization
// QR = QK / number of values before dequantization
static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
#endif
+#define MUL_MAT_SRC1_COL_STRIDE 128
+
+#define MAX_STREAMS 8
+static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullptr };
+
struct ggml_tensor_extra_gpu {
void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
- cudaEvent_t events[GGML_CUDA_MAX_DEVICES]; // events for synchronizing multiple GPUs
+ cudaEvent_t events[GGML_CUDA_MAX_DEVICES][MAX_STREAMS]; // events for synchronizing multiple GPUs
};
+// this is faster on Windows
+// probably because the Windows CUDA libraries forget to make this check before invoking the drivers
+inline cudaError_t ggml_cuda_set_device(const int device) {
+ int current_device;
+ CUDA_CHECK(cudaGetDevice(¤t_device));
+
+ if (device == current_device) {
+ return cudaSuccess;
+ }
+
+ return cudaSetDevice(device);
+}
+
static int g_device_count = -1;
static int g_main_device = 0;
static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES];
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
-static cudaStream_t g_cudaStreams_main[GGML_CUDA_MAX_DEVICES] = { nullptr };
-
static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES);
int64_t total_vram = 0;
fprintf(stderr, "%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, g_device_count);
- for (int id = 0; id < g_device_count; ++id) {
+ for (int64_t id = 0; id < g_device_count; ++id) {
cudaDeviceProp prop;
CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
- fprintf(stderr, " Device %d: %s, compute capability %d.%d\n", id, prop.name, prop.major, prop.minor);
+ fprintf(stderr, " Device %ld: %s, compute capability %d.%d\n", id, prop.name, prop.major, prop.minor);
g_tensor_split[id] = total_vram;
total_vram += prop.totalGlobalMem;
g_compute_capabilities[id] = 100*prop.major + 10*prop.minor;
}
- for (int id = 0; id < g_device_count; ++id) {
+ for (int64_t id = 0; id < g_device_count; ++id) {
g_tensor_split[id] /= total_vram;
}
- for (int id = 0; id < g_device_count; ++id) {
- CUDA_CHECK(cudaSetDevice(id));
+ for (int64_t id = 0; id < g_device_count; ++id) {
+ CUDA_CHECK(ggml_cuda_set_device(id));
- // create main stream
- CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_main[id], cudaStreamNonBlocking));
+ // create cuda streams
+ for (int64_t is = 0; is < MAX_STREAMS; ++is) {
+ CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams[id][is], cudaStreamNonBlocking));
+ }
// create cublas handle
CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id]));
}
inline void ggml_cuda_op_add(
- const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
- float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
- cudaStream_t & cudaStream_main){
-
- GGML_ASSERT(src0_ddq_i != nullptr || src0_ddf_i != nullptr);
- GGML_ASSERT(src1_ddf_i != nullptr);
- GGML_ASSERT(dst_ddf_i != nullptr);
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
- const int64_t ne00 = src0->ne[0];
- const int64_t i01_diff = i01_high - i01_low;
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1];
- // compute
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne00*i01_diff, ne10*ne11, cudaStream_main);
+ add_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(src0), ne10*ne11, main_stream);
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
- add_f16_f32_f16_cuda((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne00*i01_diff, cudaStream_main);
+ add_f16_f32_f16_cuda((const half *) src0_dd, src1_dd, (half *) dst_dd, ggml_nelements(src0), main_stream);
} else {
GGML_ASSERT(false);
}
(void) src1;
(void) dst;
- (void) src0_ddq_i;
- (void) i02;
- (void) i1;
}
inline void ggml_cuda_op_mul(
- const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
- float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
- cudaStream_t & cudaStream_main){
-
- GGML_ASSERT(src0_ddf_i != nullptr);
- GGML_ASSERT(src1_ddf_i != nullptr);
- GGML_ASSERT(dst_ddf_i != nullptr);
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
- const int64_t ne00 = src0->ne[0];
- const int64_t i01_diff = i01_high - i01_low;
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1];
- mul_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne00*i01_diff, ne10*ne11, cudaStream_main);
+ mul_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(src0), ne10*ne11, main_stream);
(void) dst;
- (void) src0_ddq_i;
- (void) i02;
- (void) i1;
}
inline void ggml_cuda_op_gelu(
- const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
- float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
- cudaStream_t & cudaStream_main){
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
- GGML_ASSERT(src0_ddf_i != nullptr);
- GGML_ASSERT(dst_ddf_i != nullptr);
-
- const int64_t ne00 = src0->ne[0];
- const int64_t i01_diff = i01_high - i01_low;
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
- // compute
- gelu_f32_cuda(src0_ddf_i, dst_ddf_i, ne00*i01_diff, cudaStream_main);
+ gelu_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
(void) src1;
(void) dst;
- (void) src0_ddq_i;
- (void) src1_ddf_i;
- (void) i02;
- (void) i1;
+ (void) src1_dd;
}
inline void ggml_cuda_op_silu(
- const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
- float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
- cudaStream_t & cudaStream_main){
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
- GGML_ASSERT(src0_ddf_i != nullptr);
- GGML_ASSERT(dst_ddf_i != nullptr);
-
- const int64_t ne00 = src0->ne[0];
- const int64_t i01_diff = i01_high - i01_low;
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
- // compute
- silu_f32_cuda(src0_ddf_i, dst_ddf_i, ne00*i01_diff, cudaStream_main);
+ silu_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
(void) src1;
(void) dst;
- (void) src0_ddq_i;
- (void) src1_ddf_i;
- (void) i02;
- (void) i1;
+ (void) src1_dd;
}
inline void ggml_cuda_op_norm(
- const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
- float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
- cudaStream_t & cudaStream_main){
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
- GGML_ASSERT(src0_ddf_i != nullptr);
- GGML_ASSERT(dst_ddf_i != nullptr);
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
const int64_t ne00 = src0->ne[0];
- const int64_t i01_diff = i01_high - i01_low;
+ const int64_t nrows = ggml_nrows(src0);
- // compute
- norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
+ norm_f32_cuda(src0_dd, dst_dd, ne00, nrows, main_stream);
(void) src1;
(void) dst;
- (void) src0_ddq_i;
- (void) src1_ddf_i;
- (void) i02;
- (void) i1;
+ (void) src1_dd;
}
inline void ggml_cuda_op_rms_norm(
- const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
- float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
- cudaStream_t & cudaStream_main){
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
- GGML_ASSERT(src0_ddf_i != nullptr);
- GGML_ASSERT(dst_ddf_i != nullptr);
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
const int64_t ne00 = src0->ne[0];
- const int64_t i01_diff = i01_high - i01_low;
+ const int64_t nrows = ggml_nrows(src0);
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
- // compute
- rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, eps, cudaStream_main);
+ rms_norm_f32_cuda(src0_dd, dst_dd, ne00, nrows, eps, main_stream);
(void) src1;
(void) dst;
- (void) src0_ddq_i;
- (void) src1_ddf_i;
- (void) i02;
- (void) i1;
+ (void) src1_dd;
}
inline void ggml_cuda_op_mul_mat_q(
- const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
- float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
- cudaStream_t & cudaStream_main){
-
- GGML_ASSERT(src0_ddq_i != nullptr);
- GGML_ASSERT(src1_ddf_i != nullptr);
- GGML_ASSERT(dst_ddf_i != nullptr);
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+ const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+ const int64_t src1_padded_row_size, const cudaStream_t & stream) {
const int64_t ne00 = src0->ne[0];
const int64_t ne10 = src1->ne[0];
- const int64_t ne11 = src1->ne[1];
GGML_ASSERT(ne10 % QK8_1 == 0);
const int64_t ne0 = dst->ne[0];
- const int64_t i01_diff = i01_high - i01_low;
+ const int64_t row_diff = row_high - row_low;
int id;
CUDA_CHECK(cudaGetDevice(&id));
// the main device has a larger memory buffer to hold the results from all GPUs
// nrows_dst == nrows of the matrix that the dequantize_mul_mat kernel writes into
- const int64_t nrows_dst = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : i01_diff;
-
- const int64_t padded_row_size = ne10 % MATRIX_ROW_PADDING == 0 ?
- ne10 : ne10 - ne10 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING;
- size_t as;
- void * src1_q8_1 = ggml_cuda_pool_malloc(padded_row_size*ne11*sizeof(block_q8_1)/QK8_1, &as);
- quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne10, ne11, padded_row_size, cudaStream_main);
+ const int64_t nrows_dst = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff;
switch (src0->type) {
case GGML_TYPE_Q4_0:
- ggml_mul_mat_q4_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+ ggml_mul_mat_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
break;
case GGML_TYPE_Q4_1:
- ggml_mul_mat_q4_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+ ggml_mul_mat_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
break;
case GGML_TYPE_Q5_0:
- ggml_mul_mat_q5_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+ ggml_mul_mat_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
break;
case GGML_TYPE_Q5_1:
- ggml_mul_mat_q5_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+ ggml_mul_mat_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
break;
case GGML_TYPE_Q8_0:
- ggml_mul_mat_q8_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+ ggml_mul_mat_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
break;
case GGML_TYPE_Q2_K:
- ggml_mul_mat_q2_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+ ggml_mul_mat_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
break;
case GGML_TYPE_Q3_K:
- ggml_mul_mat_q3_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+ ggml_mul_mat_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
break;
case GGML_TYPE_Q4_K:
- ggml_mul_mat_q4_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+ ggml_mul_mat_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
break;
case GGML_TYPE_Q5_K:
- ggml_mul_mat_q5_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+ ggml_mul_mat_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
break;
case GGML_TYPE_Q6_K:
- ggml_mul_mat_q6_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+ ggml_mul_mat_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
break;
default:
GGML_ASSERT(false);
break;
}
- ggml_cuda_pool_free(src1_q8_1, as);
-
(void) src1;
(void) dst;
- (void) src0_ddf_i;
- (void) i02;
- (void) i1;
+ (void) src1_ddf_i;
}
static int64_t get_row_rounding(ggml_type type) {
}
}
-inline void ggml_cuda_op_mul_mat_vec(
- const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
- float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
- cudaStream_t & cudaStream_main){
-
- GGML_ASSERT(src0_ddq_i != nullptr);
- GGML_ASSERT(src1_ddf_i != nullptr);
- GGML_ASSERT(dst_ddf_i != nullptr);
+inline void ggml_cuda_op_mul_mat_vec_q(
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+ const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+ const int64_t src1_padded_row_size, const cudaStream_t & stream) {
const int64_t ne00 = src0->ne[0];
- const int64_t nrows = i01_high - i01_low;
+ const int64_t row_diff = row_high - row_low;
-#ifdef GGML_CUDA_FORCE_DMMV
- const bool use_mul_mat_vec_q = false;
- (void) g_compute_capabilities[0];
-#else
- int id;
- CUDA_CHECK(cudaGetDevice(&id));
+ switch (src0->type) {
+ case GGML_TYPE_Q4_0:
+ mul_mat_vec_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q4_1:
+ mul_mat_vec_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q5_0:
+ mul_mat_vec_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q5_1:
+ mul_mat_vec_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q8_0:
+ mul_mat_vec_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q2_K:
+ mul_mat_vec_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q3_K:
+ mul_mat_vec_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q4_K:
+ mul_mat_vec_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q5_K:
+ mul_mat_vec_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q6_K:
+ mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+ break;
+ default:
+ GGML_ASSERT(false);
+ break;
+ }
- bool mul_mat_vec_q_implemented =
- src0->type == GGML_TYPE_Q4_0 ||
- src0->type == GGML_TYPE_Q4_1 ||
- src0->type == GGML_TYPE_Q5_0 ||
- src0->type == GGML_TYPE_Q5_1 ||
- src0->type == GGML_TYPE_Q8_0;
-#if QK_K == 256
- mul_mat_vec_q_implemented = mul_mat_vec_q_implemented ||
- src0->type == GGML_TYPE_Q2_K ||
- src0->type == GGML_TYPE_Q3_K ||
- src0->type == GGML_TYPE_Q4_K ||
- src0->type == GGML_TYPE_Q5_K ||
- src0->type == GGML_TYPE_Q6_K;
-#endif // QK_K == 256
-
- const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= MIN_CC_DP4A && mul_mat_vec_q_implemented;
-#endif
+ (void) src1;
+ (void) dst;
+ (void) src1_ddf_i;
+ (void) src1_ncols;
+ (void) src1_padded_row_size;
+}
- if (use_mul_mat_vec_q) {
- const int64_t padded_row_size = ne00 % MATRIX_ROW_PADDING == 0 ?
- ne00 : ne00 - ne00 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING;
- size_t as;
- void * src1_q8_1 = ggml_cuda_pool_malloc(padded_row_size*sizeof(block_q8_1)/QK8_1, &as);
- quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne00, 1, padded_row_size, cudaStream_main);
-
- switch (src0->type) {
- case GGML_TYPE_Q4_0:
- mul_mat_vec_q4_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
- break;
- case GGML_TYPE_Q4_1:
- mul_mat_vec_q4_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
- break;
- case GGML_TYPE_Q5_0:
- mul_mat_vec_q5_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
- break;
- case GGML_TYPE_Q5_1:
- mul_mat_vec_q5_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
- break;
- case GGML_TYPE_Q8_0:
- mul_mat_vec_q8_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
- break;
- case GGML_TYPE_Q2_K:
- mul_mat_vec_q2_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
- break;
- case GGML_TYPE_Q3_K:
- mul_mat_vec_q3_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
- break;
- case GGML_TYPE_Q4_K:
- mul_mat_vec_q4_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
- break;
- case GGML_TYPE_Q5_K:
- mul_mat_vec_q5_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
- break;
- case GGML_TYPE_Q6_K:
- mul_mat_vec_q6_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
- break;
- default:
- GGML_ASSERT(false);
- break;
- }
+inline void ggml_cuda_op_dequantize_mul_mat_vec(
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+ const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+ const int64_t src1_padded_row_size, const cudaStream_t & stream) {
- ggml_cuda_pool_free(src1_q8_1, as);
- } else {
- // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
+ const int64_t ne00 = src0->ne[0];
+ const int64_t row_diff = row_high - row_low;
+
+ // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
#ifdef GGML_CUDA_F16
- size_t ash;
- dfloat * src1_dfloat = nullptr; // dfloat == half
-
- bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
- src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
- src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
-
- if (src1_convert_f16) {
- src1_dfloat = (half *) ggml_cuda_pool_malloc(ne00*sizeof(half), &ash);
- ggml_cpy_f32_f16_cuda((char *) src1_ddf_i, (char *) src1_dfloat, ne00,
- ne00, 1, sizeof(float), 0, 0,
- ne00, 1, sizeof(half), 0, 0, cudaStream_main);
- }
+ size_t ash;
+ dfloat * src1_dfloat = nullptr; // dfloat == half
+
+ bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
+ src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
+ src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
+
+ if (src1_convert_f16) {
+ src1_dfloat = (half *) ggml_cuda_pool_malloc(ne00*sizeof(half), &ash);
+ ggml_cpy_f32_f16_cuda((const char *) src1_ddf_i, (char *) src1_dfloat, ne00,
+ ne00, 1, sizeof(float), 0, 0,
+ ne00, 1, sizeof(half), 0, 0, stream);
+ }
#else
- dfloat * src1_dfloat = src1_ddf_i; // dfloat == float, no conversion
+ const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion
#endif // GGML_CUDA_F16
- switch (src0->type) {
- case GGML_TYPE_Q4_0:
- dequantize_mul_mat_vec_q4_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
- break;
- case GGML_TYPE_Q4_1:
- dequantize_mul_mat_vec_q4_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
- break;
- case GGML_TYPE_Q5_0:
- dequantize_mul_mat_vec_q5_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
- break;
- case GGML_TYPE_Q5_1:
- dequantize_mul_mat_vec_q5_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
- break;
- case GGML_TYPE_Q8_0:
- dequantize_mul_mat_vec_q8_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
- break;
- case GGML_TYPE_Q2_K:
- dequantize_mul_mat_vec_q2_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
- break;
- case GGML_TYPE_Q3_K:
- dequantize_mul_mat_vec_q3_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
- break;
- case GGML_TYPE_Q4_K:
- dequantize_mul_mat_vec_q4_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
- break;
- case GGML_TYPE_Q5_K:
- dequantize_mul_mat_vec_q5_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
- break;
- case GGML_TYPE_Q6_K:
- dequantize_mul_mat_vec_q6_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
- break;
- case GGML_TYPE_F16:
- convert_mul_mat_vec_f16_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
- break;
- default:
- GGML_ASSERT(false);
- break;
- }
+ switch (src0->type) {
+ case GGML_TYPE_Q4_0:
+ dequantize_mul_mat_vec_q4_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q4_1:
+ dequantize_mul_mat_vec_q4_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q5_0:
+ dequantize_mul_mat_vec_q5_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q5_1:
+ dequantize_mul_mat_vec_q5_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q8_0:
+ dequantize_mul_mat_vec_q8_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q2_K:
+ dequantize_mul_mat_vec_q2_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q3_K:
+ dequantize_mul_mat_vec_q3_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q4_K:
+ dequantize_mul_mat_vec_q4_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q5_K:
+ dequantize_mul_mat_vec_q5_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q6_K:
+ dequantize_mul_mat_vec_q6_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_F16:
+ convert_mul_mat_vec_f16_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+ break;
+ default:
+ GGML_ASSERT(false);
+ break;
+ }
#ifdef GGML_CUDA_F16
- if (src1_convert_f16) {
- ggml_cuda_pool_free(src1_dfloat, ash);
- }
-#endif // GGML_CUDA_F16
+ if (src1_convert_f16) {
+ ggml_cuda_pool_free(src1_dfloat, ash);
}
+#endif // GGML_CUDA_F16
(void) src1;
(void) dst;
- (void) src0_ddf_i;
- (void) i02;
- (void) i1;
+ (void) src1_ddq_i;
+ (void) src1_ncols;
+ (void) src1_padded_row_size;
}
inline void ggml_cuda_op_mul_mat_cublas(
- const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
- float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
- cudaStream_t & cudaStream_main){
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+ const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+ const int64_t src1_padded_row_size, const cudaStream_t & stream) {
- GGML_ASSERT(src0_ddf_i != nullptr);
+ GGML_ASSERT(src0_dd_i != nullptr);
GGML_ASSERT(src1_ddf_i != nullptr);
- GGML_ASSERT(dst_ddf_i != nullptr);
+ GGML_ASSERT(dst_dd_i != nullptr);
const float alpha = 1.0f;
const float beta = 0.0f;
const int64_t ne00 = src0->ne[0];
const int64_t ne10 = src1->ne[0];
- const int64_t ne11 = src1->ne[1];
const int64_t ne0 = dst->ne[0];
- const int64_t i01_diff = i01_high - i01_low;
+ const int64_t row_diff = row_high - row_low;
+
+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
+ size_t src0_as;
+ float * src0_ddf_i = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as);
+ to_fp32_cuda(src0_dd_i, src0_ddf_i, row_diff*ne00, stream);
int id;
CUDA_CHECK(cudaGetDevice(&id));
// the main device has a larger memory buffer to hold the results from all GPUs
// ldc == nrows of the matrix that cuBLAS writes into
- int ldc = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : i01_diff;
+ int ldc = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff;
- CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], cudaStream_main));
+ CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream));
CUBLAS_CHECK(
cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
- i01_diff, ne11, ne10,
+ row_diff, src1_ncols, ne10,
&alpha, src0_ddf_i, ne00,
- src1_ddf_i, ne10,
- &beta, dst_ddf_i, ldc));
+ src1_ddf_i, ne10,
+ &beta, dst_dd_i, ldc));
+
+ ggml_cuda_pool_free(src0_ddf_i, src0_as);
(void) dst;
- (void) src0_ddq_i;
- (void) i02;
- (void) i1;
+ (void) src0_dd_i;
+ (void) src1_ddq_i;
+ (void) src1_padded_row_size;
}
inline void ggml_cuda_op_rope(
- const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
- float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
- cudaStream_t & cudaStream_main){
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
- GGML_ASSERT(src0_ddf_i != nullptr);
- GGML_ASSERT(dst_ddf_i != nullptr);
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
- const int64_t i01_diff = i01_high - i01_low;
+ const int64_t nrows = ggml_nrows(src0);
const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
// compute
if (is_glm) {
- rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, n_ctx, cudaStream_main);
+ rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, n_ctx, main_stream);
} else if (is_neox) {
GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet");
- rope_neox_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main);
+ rope_neox_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, main_stream);
} else {
- rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main);
+ rope_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, main_stream);
}
(void) src1;
(void) dst;
- (void) src0_ddq_i;
- (void) src1_ddf_i;
- (void) i1;
+ (void) src1_dd;
}
inline void ggml_cuda_op_alibi(
- const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
- float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
- cudaStream_t & cudaStream_main){
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
- GGML_ASSERT(src0_ddf_i != nullptr);
- GGML_ASSERT(dst_ddf_i != nullptr);
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
- const int64_t i01_diff = i01_high - i01_low;
+ const int64_t nrows = ggml_nrows(src0);
const int n_past = ((int32_t *) dst->op_params)[0];
const int n_head = ((int32_t *) dst->op_params)[1];
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
- // compute
- alibi_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_heads_log2_floor, m0, m1, cudaStream_main);
+ alibi_f32_cuda(src0_dd, dst_dd, ne00, nrows, ne01, n_heads_log2_floor, m0, m1, main_stream);
(void) src1;
- (void) src0_ddq_i;
- (void) src1_ddf_i;
- (void) i1;
+ (void) src1_dd;
}
inline void ggml_cuda_op_diag_mask_inf(
- const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
- float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
- cudaStream_t & cudaStream_main){
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
- GGML_ASSERT(src0_ddf_i != nullptr);
- GGML_ASSERT(dst_ddf_i != nullptr);
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
- const int64_t i01_diff = i01_high - i01_low;
+ const int nrows0 = ggml_nrows(src0);
const int n_past = ((int32_t *) dst->op_params)[0];
- // compute
- diag_mask_inf_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_past, cudaStream_main);
+ diag_mask_inf_f32_cuda(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream);
(void) src1;
(void) dst;
- (void) src0_ddq_i;
- (void) src1_ddf_i;
- (void) i02;
- (void) i1;
+ (void) src1_dd;
}
inline void ggml_cuda_op_soft_max(
- const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
- float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
- cudaStream_t & cudaStream_main){
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
- GGML_ASSERT(src0_ddf_i != nullptr);
- GGML_ASSERT(dst_ddf_i != nullptr);
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
const int64_t ne00 = src0->ne[0];
- const int64_t i01_diff = i01_high - i01_low;
+ const int64_t nrows = ggml_nrows(src0);
- // compute
- soft_max_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
+ soft_max_f32_cuda(src0_dd, dst_dd, ne00, nrows, main_stream);
(void) src1;
(void) dst;
- (void) src0_ddq_i;
- (void) src1_ddf_i;
- (void) i02;
- (void) i1;
+ (void) src1_dd;
}
inline void ggml_cuda_op_scale(
- const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
- float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
- cudaStream_t & cudaStream_main){
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
- GGML_ASSERT(src0_ddf_i != nullptr);
- GGML_ASSERT(dst_ddf_i != nullptr);
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
const float scale = ((float *) src1->data)[0];
- const int64_t ne00 = src0->ne[0];
- const int64_t i01_diff = i01_high - i01_low;
-
- // compute
- scale_f32_cuda(src0_ddf_i, dst_ddf_i, scale, ne00*i01_diff, cudaStream_main);
+ scale_f32_cuda(src0_dd, dst_dd, scale, ggml_nelements(src0), main_stream);
CUDA_CHECK(cudaGetLastError());
(void) src1;
(void) dst;
- (void) src0_ddq_i;
- (void) src1_ddf_i;
- (void) i02;
- (void) i1;
+ (void) src1_dd;
+}
+
+static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const ggml_cuda_op_flatten_t op) {
+ const int64_t nrows0 = ggml_nrows(src0);
+
+ const bool use_src1 = src1 != nullptr;
+ const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1;
+
+ GGML_ASSERT( src0->backend != GGML_BACKEND_GPU_SPLIT);
+ GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_GPU_SPLIT);
+ GGML_ASSERT( dst->backend != GGML_BACKEND_GPU_SPLIT);
+
+ struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
+ struct ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
+ struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
+
+ const bool src0_on_device = src0->backend == GGML_BACKEND_GPU;
+ const bool src1_on_device = use_src1 && src1->backend == GGML_BACKEND_GPU;
+ const bool dst_on_device = dst->backend == GGML_BACKEND_GPU;
+
+ const bool src1_stays_on_host = use_src1 && dst->op == GGML_OP_SCALE;
+
+ // dd = data device
+ float * src0_ddf = nullptr;
+ float * src1_ddf = nullptr;
+ float * dst_ddf = nullptr;
+
+ // as = actual size
+ size_t src0_asf = 0;
+ size_t src1_asf = 0;
+ size_t dst_asf = 0;
+
+ ggml_cuda_set_device(g_main_device);
+ const cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
+
+ if (src0_on_device) {
+ src0_ddf = (float *) src0_extra->data_device[g_main_device];
+ } else {
+ src0_ddf = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_asf);
+ CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf, src0, 0, 0, 0, nrows0, main_stream));
+ }
+
+ if (use_src1 && !src1_stays_on_host) {
+ if (src1_on_device) {
+ src1_ddf = (float *) src1_extra->data_device[g_main_device];
+ } else {
+ src1_ddf = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src1), &src1_asf);
+ CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf, src1, 0, 0, 0, nrows1, main_stream));
+ }
+ }
+ if (dst_on_device) {
+ dst_ddf = (float *) dst_extra->data_device[g_main_device];
+ } else {
+ dst_ddf = (float *) ggml_cuda_pool_malloc(ggml_nbytes(dst), &dst_asf);
+ }
+
+ // do the computation
+ op(src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream);
+ CUDA_CHECK(cudaGetLastError());
+
+ // copy dst to host if necessary
+ if (!dst_on_device) {
+ CUDA_CHECK(cudaMemcpyAsync(dst->data, dst_ddf, ggml_nbytes(dst), cudaMemcpyDeviceToHost, main_stream));
+ }
+
+ if (src0_asf > 0) {
+ ggml_cuda_pool_free(src0_ddf, src0_asf);
+ }
+ if (src1_asf > 0) {
+ ggml_cuda_pool_free(src1_ddf, src1_asf);
+ }
+ if (dst_asf > 0) {
+ ggml_cuda_pool_free(dst_ddf, dst_asf);
+ }
+
+ if (dst->backend == GGML_BACKEND_CPU) {
+ CUDA_CHECK(cudaDeviceSynchronize());
+ }
}
-static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
- ggml_cuda_op_t op, bool src0_needs_f32, bool flatten_rows) {
+static void ggml_cuda_op_mul_mat(
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_cuda_op_mul_mat_t op,
+ const bool convert_src1_to_q8_1) {
+
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3];
const int64_t nrows0 = ggml_nrows(src0);
- const bool use_src1 = src1 != nullptr;
- const int64_t ne10 = use_src1 ? src1->ne[0] : 1;
- const int64_t ne11 = use_src1 ? src1->ne[1] : 1;
- const int64_t ne12 = use_src1 ? src1->ne[2] : 1;
- const int64_t ne13 = use_src1 ? src1->ne[3] : 1;
- const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1;
+ const int64_t ne10 = src1->ne[0];
+ const int64_t ne11 = src1->ne[1];
+ const int64_t ne12 = src1->ne[2];
+ const int64_t ne13 = src1->ne[3];
+ const int64_t nrows1 = ggml_nrows(src1);
GGML_ASSERT(ne03 == ne13);
const int64_t ne0 = dst->ne[0];
const int64_t ne1 = dst->ne[1];
- const int nb2 = dst->nb[2];
- const int nb3 = dst->nb[3];
+ const int nb2 = dst->nb[2];
+ const int nb3 = dst->nb[3];
GGML_ASSERT(dst->backend != GGML_BACKEND_GPU_SPLIT);
- GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_GPU_SPLIT);
+ GGML_ASSERT(src1->backend != GGML_BACKEND_GPU_SPLIT);
- // strides for iteration over dims 3 and 2
- const int64_t num_iters_0 = ne02 >= ne12 ? ne02*ne03 : ne12*ne13;
- const int64_t num_iters = flatten_rows ? 1 : num_iters_0;
- const int64_t stride_mod = flatten_rows ? num_iters_0 : 1;
- const int64_t src0_stride = ne00 * ne01 * stride_mod;
- const int64_t src1_stride = ne10 * ne11 * stride_mod;
- const int64_t dst_stride = ne0 * ne1 * stride_mod;
+ GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0);
- const int64_t rows_per_iter = flatten_rows ? nrows0 : ne01;
- const int64_t i03_max = flatten_rows ? 1 : ne03;
- const int64_t i02_max = flatten_rows ? 1 : (ne02 >= ne12 ? ne02 : ne12);
- const int64_t i02_divisor = ne02 >= ne12 ? 1 : ne12 / ne02;
- GGML_ASSERT(!(flatten_rows && ne02 < ne12));
+ const int64_t i02_divisor = ne12 / ne02;
const size_t src0_ts = ggml_type_size(src0->type);
const size_t src0_bs = ggml_blck_size(src0->type);
+ const size_t q8_1_ts = sizeof(block_q8_1);
+ const size_t q8_1_bs = QK8_1;
- struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
- struct ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
- struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
+ struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
+ struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
+ struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT;
const bool src0_is_contiguous = ggml_is_contiguous(src0);
- const bool src0_is_f32 = src0->type == GGML_TYPE_F32;
- const bool src1_is_contiguous = use_src1 && ggml_is_contiguous(src1);
- const bool src1_stays_on_host = use_src1 && (
- dst->op == GGML_OP_SCALE || dst->op == GGML_OP_DIAG_MASK_INF || dst->op == GGML_OP_ROPE);
+ const bool src1_is_contiguous = ggml_is_contiguous(src1);
+ const int64_t src1_padded_col_size = ne10 % MATRIX_ROW_PADDING == 0 ?
+ ne10 : ne10 - ne10 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING;
const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT;
+ GGML_ASSERT(!(split && ne02 > 1));
+ GGML_ASSERT(!(split && ne03 > 1));
GGML_ASSERT(!(split && ne02 < ne12));
- const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
-
// dd = data device
- char * src0_ddq[GGML_CUDA_MAX_DEVICES] = {nullptr}; // quantized
- float * src0_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr}; // float
- float * src1_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr};
- float * dst_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr};
-
- // asq = actual size quantized, asf = actual size float
- size_t src0_asq[GGML_CUDA_MAX_DEVICES] = {0};
- size_t src0_asf[GGML_CUDA_MAX_DEVICES] = {0};
- size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0};
- size_t dst_asf[GGML_CUDA_MAX_DEVICES] = {0};
+ char * src0_dd[GGML_CUDA_MAX_DEVICES] = {nullptr};
+ float * src1_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr}; // float
+ char * src1_ddq[GGML_CUDA_MAX_DEVICES] = {nullptr}; // q8_1
+ float * dst_dd[GGML_CUDA_MAX_DEVICES] = {nullptr};
- // if multiple devices are used they need to wait for the main device
- // here an event is recorded that signifies that the main device has finished calculating the input data
- if (split && g_device_count > 1) {
- CUDA_CHECK(cudaSetDevice(g_main_device));
- CUDA_CHECK(cudaEventRecord(src0_extra->events[g_main_device], g_cudaStreams_main[g_main_device]));
- }
+ // as = actual size
+ size_t src0_as[GGML_CUDA_MAX_DEVICES] = {0};
+ size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0};
+ size_t src1_asq[GGML_CUDA_MAX_DEVICES] = {0};
+ size_t dst_as[GGML_CUDA_MAX_DEVICES] = {0};
- for (int id = 0; id < g_device_count; ++id) {
- if (!split && id != g_main_device) {
- continue;
- }
+ int64_t row_low[GGML_CUDA_MAX_DEVICES];
+ int64_t row_high[GGML_CUDA_MAX_DEVICES];
- const bool src1_on_device = use_src1 && src1->backend == GGML_BACKEND_GPU && id == g_main_device;
- const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device;
+ for (int64_t id = 0; id < g_device_count; ++id) {
+ // by default, use all rows
+ row_low[id] = 0;
+ row_high[id] = ne01;
- int64_t row_low, row_high;
+ // for multi GPU, get the row boundaries from tensor split
+ // and round to mul_mat_q tile sizes
if (split) {
const int64_t rounding = get_row_rounding(src0->type);
- row_low = id == 0 ? 0 : nrows0*g_tensor_split[id];
- row_low -= row_low % rounding;
+ if (id != 0) {
+ row_low[id] = ne01*g_tensor_split[id];
+ row_low[id] -= row_low[id] % rounding;
+ }
- if (id == g_device_count - 1) {
- row_high = nrows0;
- } else {
- row_high = nrows0*g_tensor_split[id + 1];
- row_high -= row_high % rounding;
+ if (id != g_device_count - 1) {
+ row_high[id] = ne01*g_tensor_split[id + 1];
+ row_high[id] -= row_high[id] % rounding;
}
- } else {
- row_low = 0;
- row_high = nrows0*i02_divisor;
}
- if (row_low == row_high) {
+ }
+
+ for (int64_t id = 0; id < g_device_count; ++id) {
+ if ((!split && id != g_main_device) || row_low[id] == row_high[id]) {
continue;
}
- int64_t row_diff = row_high - row_low;
+ const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device;
+ const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device;
- cudaSetDevice(id);
- cudaStream_t cudaStream_main = g_cudaStreams_main[id];
-
- // wait for main GPU data if necessary
- if (split && id != g_main_device) {
- CUDA_CHECK(cudaStreamWaitEvent(cudaStream_main, src0_extra->events[g_main_device]));
- }
+ ggml_cuda_set_device(id);
+ const cudaStream_t stream = g_cudaStreams[id][0];
if (src0_on_device && src0_is_contiguous) {
- if (src0_is_f32) {
- src0_ddf[id] = (float *) src0_extra->data_device[id];
- } else {
- src0_ddq[id] = (char *) src0_extra->data_device[id];
- }
+ src0_dd[id] = (char *) src0_extra->data_device[id];
} else {
- if (src0_is_f32) {
- src0_ddf[id] = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_asf[id]);
- } else {
- src0_ddq[id] = (char *) ggml_cuda_pool_malloc(row_diff*ne00 * src0_ts/src0_bs, &src0_asq[id]);
- }
+ const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes(src0);
+ src0_dd[id] = (char *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_as[id]);
}
- if (src0_needs_f32 && !src0_is_f32) {
- src0_ddf[id] = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_asf[id]);
+ if (src1_on_device && src1_is_contiguous) {
+ src1_ddf[id] = (float *) src1_extra->data_device[id];
+ } else {
+ src1_ddf[id] = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src1), &src1_asf[id]);
}
- if (use_src1 && !src1_stays_on_host) {
- if (src1_on_device && src1_is_contiguous) {
- src1_ddf[id] = (float *) src1_extra->data_device[id];
- } else {
- src1_ddf[id] = (float *) ggml_cuda_pool_malloc(num_iters*src1_stride * sizeof(float), &src1_asf[id]);
+ if (convert_src1_to_q8_1) {
+ src1_ddq[id] = (char *) ggml_cuda_pool_malloc(nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs, &src1_asq[id]);
+
+ if (split && src1_on_device && src1_is_contiguous) {
+ quantize_row_q8_1_cuda(src1_ddf[id], src1_ddq[id], ne10, nrows1, src1_padded_col_size, stream);
+ CUDA_CHECK(cudaGetLastError());
}
}
+
if (dst_on_device) {
- dst_ddf[id] = (float *) dst_extra->data_device[id];
+ dst_dd[id] = (float *) dst_extra->data_device[id];
} else {
- size_t size_dst_ddf = split ? row_diff*ne1 * sizeof(float) : num_iters*dst_stride * sizeof(float);
- dst_ddf[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_asf[id]);
+ const size_t size_dst_ddf = split ? (row_high[id]-row_low[id])*ne1*sizeof(float) : ggml_nbytes(dst);
+ dst_dd[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_as[id]);
}
+ }
- for (int64_t i03 = 0; i03 < i03_max; i03++) {
- const int64_t i13 = i03 % ne13;
- for (int64_t i02 = 0; i02 < i02_max; i02++) {
- const int64_t i12 = i02 % ne12;
+ // if multiple devices are used they need to wait for the main device
+ // here an event is recorded that signals that the main device has finished calculating the input data
+ if (split && g_device_count > 1) {
+ CUDA_CHECK(ggml_cuda_set_device(g_main_device));
+ CUDA_CHECK(cudaEventRecord(src0_extra->events[g_main_device][0], g_cudaStreams[g_main_device][0]));
+ }
- const int64_t i0 = i03*i02_max + i02;
+ const int64_t src1_col_stride = split && g_device_count > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11;
+ for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) {
+ const int64_t is = split ? (src1_col_0/src1_col_stride) % MAX_STREAMS : 0;
+ const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
- // i0 values that contain the lower/upper rows for a split tensor when using multiple GPUs
- const int64_t i0_offset_low = row_low/rows_per_iter;
- const int64_t i0_offset_high = row_high/rows_per_iter;
+ for (int64_t id = 0; id < g_device_count; ++id) {
+ if ((!split && id != g_main_device) || row_low[id] == row_high[id]) {
+ continue;
+ }
- int64_t i01_low = 0;
- int64_t i01_high = rows_per_iter;
- if (split) {
- if (i0 < i0_offset_low || i0 > i0_offset_high) {
- continue;
- }
- if (i0 == i0_offset_low) {
- i01_low = row_low % rows_per_iter;
- }
- if (i0 == i0_offset_high) {
- i01_high = row_high % rows_per_iter;
- }
- }
+ const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device;
+ const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device;
+ const int64_t row_diff = row_high[id] - row_low[id];
- // There is possibly a bug in the Windows nvcc compiler regarding instruction reordering or optimizing out local variables.
- // Removing the first assert or changing the order of the arguments causes the second assert to fail.
- // Removing both asserts results in i01_high becoming 0 which in turn results in garbage output.
- // The root cause seems to be a problem with i0_offset_high becoming 0 when it should always be >0 (for single GPU).
- GGML_ASSERT(i01_low == 0 || g_device_count > 1);
- GGML_ASSERT(i01_high == rows_per_iter || g_device_count > 1);
+ ggml_cuda_set_device(id);
+ const cudaStream_t stream = g_cudaStreams[id][is];
- const int64_t i01_diff = i01_high - i01_low;
- if (i01_diff == 0) {
- continue;
- }
- const int64_t i11 = i13*ne12 + i12;
+ // wait for main GPU data if necessary
+ if (split && (id != g_main_device || is != 0)) {
+ CUDA_CHECK(cudaStreamWaitEvent(stream, src0_extra->events[g_main_device][0]));
+ }
+
+ for (int64_t i0 = 0; i0 < ne13*ne12; ++i0) {
+ const int64_t i03 = i0 / ne12;
+ const int64_t i02 = i0 % ne12;
+
+ const size_t src1_ddq_i_offset = (i0*ne11 + src1_col_0) * src1_padded_col_size*q8_1_ts/q8_1_bs;
// for split tensors the data begins at i0 == i0_offset_low
- char * src0_ddq_i = src0_ddq[id] + (i0/i02_divisor - i0_offset_low)*src0_stride*src0_ts/src0_bs;
- float * src0_ddf_i = src0_ddf[id] + (i0/i02_divisor - i0_offset_low)*src0_stride;
- float * src1_ddf_i = src1_ddf[id] + i11*src1_stride;
- float * dst_ddf_i = dst_ddf[id] + (i0 - i0_offset_low)*dst_stride;
-
- // for split tensors the data pointer needs to be rounded down
- // to the bin edge for i03, i02 bins beyond the first
- if (i0 - i0_offset_low > 0) {
- GGML_ASSERT(!flatten_rows);
- src0_ddq_i -= (row_low % ne01)*ne00 * src0_ts/src0_bs;
- src0_ddf_i -= (row_low % ne01)*ne00;
- dst_ddf_i -= (row_low % ne0)*ne1;
- }
+ char * src0_dd_i = src0_dd[id] + (i0/i02_divisor) * ne01*ne00*src0_ts/src0_bs;
+ float * src1_ddf_i = src1_ddf[id] + (i0*ne11 + src1_col_0) * ne10;
+ char * src1_ddq_i = src1_ddq[id] + src1_ddq_i_offset;
+ float * dst_dd_i = dst_dd[id] + (i0*ne1 + src1_col_0) * (dst_on_device ? ne0 : row_diff);
// the main device memory buffer can be on VRAM scratch, with space for all partial results
// in that case an offset on dst_ddf_i is needed
if (dst->backend == GGML_BACKEND_GPU && id == g_main_device) {
- dst_ddf_i += i01_low; // offset is 0 if no tensor split
+ dst_dd_i += row_low[id]; // offset is 0 if no tensor split
}
// copy src0, src1 to device if necessary
- if (use_src1 && !src1_stays_on_host) {
- if (src1->backend == GGML_BACKEND_CPU) {
- GGML_ASSERT(!flatten_rows || nrows0 == ggml_nrows(src1));
- int64_t nrows1 = flatten_rows ? nrows0 : ne11;
- CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, nrows1, cudaStream_main));
- } else if (src1->backend == GGML_BACKEND_GPU && src1_is_contiguous) {
- if (id != g_main_device) {
- GGML_ASSERT(!flatten_rows);
+ if (src1->backend == GGML_BACKEND_GPU && src1_is_contiguous) {
+ if (id != g_main_device) {
+ if (convert_src1_to_q8_1) {
+ char * src1_ddq_i_source = src1_ddq[g_main_device] + src1_ddq_i_offset;
+ CUDA_CHECK(cudaMemcpyAsync(src1_ddq_i, src1_ddq_i_source, src1_ncols*src1_padded_col_size*q8_1_ts/q8_1_bs,
+ cudaMemcpyDeviceToDevice, stream));
+ } else {
float * src1_ddf_i_source = (float *) src1_extra->data_device[g_main_device];
- src1_ddf_i_source += i11*src1_stride;
- CUDA_CHECK(cudaMemcpyAsync(src1_ddf_i, src1_ddf_i_source, src1_stride*sizeof(float),
- cudaMemcpyDeviceToDevice, cudaStream_main));
+ src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10;
+ CUDA_CHECK(cudaMemcpyAsync(src1_ddf_i, src1_ddf_i_source, src1_ncols*ne10*sizeof(float),
+ cudaMemcpyDeviceToDevice, stream));
}
- } else if (src1_on_device && !src1_is_contiguous) {
- GGML_ASSERT(!split);
- CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, ne11, cudaStream_main));
- } else {
- GGML_ASSERT(false);
}
+ } else if (src1->backend == GGML_BACKEND_CPU || (src1_on_device && !src1_is_contiguous)) {
+ CUDA_CHECK(ggml_cuda_cpy_tensor_2d(
+ src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream));
+ } else {
+ GGML_ASSERT(false);
}
- if ((!src0_on_device || !src0_is_contiguous) && i02 % i02_divisor == 0) {
- if (src0_is_f32) {
- CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf_i, src0, i03, i02/i02_divisor, i01_low, i01_high, cudaStream_main));
- } else {
- CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddq_i, src0, i03, i02/i02_divisor, i01_low, i01_high, cudaStream_main));
- }
+ if (convert_src1_to_q8_1 && src1->backend == GGML_BACKEND_CPU) {
+ quantize_row_q8_1_cuda(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream);
+ CUDA_CHECK(cudaGetLastError());
}
- // convert src0 to f32 if it is necessary for the ggml_cuda_op
- if (src0_needs_f32 && !src0_is_f32) {
- to_fp32_cuda(src0_ddq_i, src0_ddf_i, i01_diff*ne00, cudaStream_main);
- CUDA_CHECK(cudaGetLastError());
+ if (src1_col_0 == 0 && (!src0_on_device || !src0_is_contiguous) && i02 % i02_divisor == 0) {
+ CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, row_low[id], row_high[id], stream));
}
// do the computation
- op(src0, src1, dst, src0_ddq_i, src0_ddf_i, src1_ddf_i, dst_ddf_i, i02, i01_low, i01_high, i11, cudaStream_main);
+ op(src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,
+ row_low[id], row_high[id], src1_ncols, src1_padded_col_size, stream);
CUDA_CHECK(cudaGetLastError());
// copy dst to host or other device if necessary
// The outputs of matrix matrix multiplications can therefore NOT simply be concatenated for >1 GPU.
// Instead they need to be copied to the correct slice in ne0 = dst row index.
// If dst is a vector with ne0 == 1 then you don't have to do this but it still produces correct results.
- float * dhf_dst_i = (float *) ((char *) dst_off_device + i01_low*sizeof(float) + i02*nb2 + i03*nb3);
- CUDA_CHECK(cudaMemcpy2DAsync(dhf_dst_i, ne0*sizeof(float), dst_ddf_i, i01_diff*sizeof(float),
- i01_diff*sizeof(float), ne1, kind, cudaStream_main));
+ float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
+ GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
+ dhf_dst_i += src1_col_0*ne0 + row_low[id];
+ CUDA_CHECK(cudaMemcpy2DAsync(dhf_dst_i, ne0*sizeof(float), dst_dd_i, row_diff*sizeof(float),
+ row_diff*sizeof(float), src1_ncols, kind, stream));
} else {
float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
- CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_ddf_i, dst_stride*sizeof(float), kind, cudaStream_main));
+ GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
+ dhf_dst_i += src1_col_0*ne0;
+ CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_dd_i, src1_ncols*ne0*sizeof(float), kind, stream));
}
}
- // signify to main device that other device is done
- if (split && g_device_count > 1 && id != g_main_device) {
- CUDA_CHECK(cudaEventRecord(src0_extra->events[id], cudaStream_main));
+ // add event for the main device to wait on until other device is done
+ if (split && (id != g_main_device || is != 0)) {
+ CUDA_CHECK(cudaEventRecord(src0_extra->events[id][is], stream));
}
}
}
}
- // wait until each device is finished, then free their buffers
- for (int id = 0; id < g_device_count; ++id) {
- if (src0_asq[id] == 0 && src0_asf[id] == 0 && src1_asf[id] == 0 && dst_asf[id] == 0) {
- continue;
- }
-
- CUDA_CHECK(cudaSetDevice(id));
+ for (int64_t id = 0; id < g_device_count; ++id) {
+ CUDA_CHECK(ggml_cuda_set_device(id));
- if (src0_asq[id] > 0) {
- ggml_cuda_pool_free(src0_ddq[id], src0_asq[id]);
- }
- if (src0_asf[id] > 0) {
- ggml_cuda_pool_free(src0_ddf[id], src0_asf[id]);
+ // free buffers again when done
+ if (src0_as[id] > 0) {
+ ggml_cuda_pool_free(src0_dd[id], src0_as[id]);
}
if (src1_asf[id] > 0) {
ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]);
}
- if (dst_asf[id] > 0) {
- ggml_cuda_pool_free(dst_ddf[id], dst_asf[id]);
+ if (src1_asq[id] > 0) {
+ ggml_cuda_pool_free(src1_ddq[id], src1_asq[id]);
+ }
+ if (dst_as[id] > 0) {
+ ggml_cuda_pool_free(dst_dd[id], dst_as[id]);
}
}
// main device waits for all other devices to be finished
if (split && g_device_count > 1) {
- CUDA_CHECK(cudaSetDevice(g_main_device));
- for (int id = 0; id < g_device_count; ++id) {
- if (id != g_main_device && src0_extra->events[id]) {
- CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams_main[g_main_device], src0_extra->events[id]));
+ int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE;
+ is_max = is_max <= MAX_STREAMS ? is_max : MAX_STREAMS;
+
+ CUDA_CHECK(ggml_cuda_set_device(g_main_device));
+ for (int64_t id = 0; id < g_device_count; ++id) {
+ for (int64_t is = 0; is < is_max; ++is) {
+ CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams[g_main_device][0], src0_extra->events[id][is]));
}
}
}
if (dst->backend == GGML_BACKEND_CPU) {
- CUDA_CHECK(cudaSetDevice(g_main_device));
+ CUDA_CHECK(ggml_cuda_set_device(g_main_device));
CUDA_CHECK(cudaDeviceSynchronize());
}
}
void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
- // ggml_cuda_add permits f16 dst even though this could in theory cause problems with the pointer arithmetic in ggml_cuda_op.
- // Due to flatten_rows == true this does in practice not make a difference however.
- // Better solution would be nice but right now that would require disproportionate changes.
- GGML_ASSERT(
- (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) &&
- src1->type == GGML_TYPE_F32 &&
- (dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16));
- ggml_cuda_op(src0, src1, dst, ggml_cuda_op_add, false, true);
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_add);
}
void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
- GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
- ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul, true, false); // TODO ggml_cuda_op needs modification for flatten
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_mul);
}
void ggml_cuda_gelu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
- GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
- ggml_cuda_op(src0, src1, dst, ggml_cuda_op_gelu, true, true);
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_gelu);
}
void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
- GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
- ggml_cuda_op(src0, src1, dst, ggml_cuda_op_silu, true, true);
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_silu);
}
void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
- GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
- ggml_cuda_op(src0, src1, dst, ggml_cuda_op_norm, true, true);
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_norm);
}
void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
- GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
- ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rms_norm, true, true);
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rms_norm);
}
bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
const int64_t ne12 = src1->ne[2];
- CUDA_CHECK(cudaSetDevice(g_main_device));
- cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
+ CUDA_CHECK(ggml_cuda_set_device(g_main_device));
+ cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
void * src0_ddq = src0_extra->data_device[g_main_device];
struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
- ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, cudaStream_main);
+ ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream);
}
void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
const int64_t nb01 = src0->nb[1];
const int64_t nb02 = src0->nb[2];
- CUDA_CHECK(cudaSetDevice(g_main_device));
- cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
+ CUDA_CHECK(ggml_cuda_set_device(g_main_device));
+ cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
void * src0_ddq = src0_extra->data_device[g_main_device];
struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
- const int row_stride_x = nb01 / sizeof(half);
- const int channel_stride_x = nb02 / sizeof(half);
+ const int64_t row_stride_x = nb01 / sizeof(half);
+ const int64_t channel_stride_x = nb02 / sizeof(half);
- ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, cudaStream_main);
+ ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
}
void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) &&
src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU;
+ int64_t min_compute_capability = INT_MAX;
+ for (int64_t id = 0; id < g_device_count; ++id) {
+ if (min_compute_capability > g_compute_capabilities[id]
+ && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
+ min_compute_capability = g_compute_capabilities[id];
+ }
+ }
+
if (all_on_device && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
ggml_cuda_mul_mat_vec_p021(src0, src1, dst);
} else if (all_on_device && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && src1->ne[1] == 1) {
ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
}else if (src0->type == GGML_TYPE_F32) {
- ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);
+ ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
} else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) {
- ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_vec, false, false);
- } else {
- int min_compute_capability = INT_MAX;
- for (int id = 0; id < g_device_count; ++id) {
- if (min_compute_capability > g_compute_capabilities[id]
- && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
- min_compute_capability = g_compute_capabilities[id];
- }
- }
- if (g_mul_mat_q && ggml_is_quantized(src0->type) && min_compute_capability >= MIN_CC_DP4A) {
- ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_q, false, false);
+#ifdef GGML_CUDA_FORCE_DMMV
+ const bool use_mul_mat_vec_q = false;
+#else
+ const bool use_mul_mat_vec_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type);
+#endif // GGML_CUDA_FORCE_DMMV
+
+ if (use_mul_mat_vec_q) {
+ ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true);
+ } else {
+ ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
+ }
+ } else {
+ if (src1->backend == GGML_BACKEND_GPU && g_mul_mat_q && ggml_is_quantized(src0->type) && min_compute_capability >= MIN_CC_DP4A) {
+ ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true);
} else {
- ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);
+ ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
}
}
} else {
}
void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
- GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
- ggml_cuda_op(src0, src1, dst, ggml_cuda_op_scale, true, true);
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale);
}
void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const int64_t nb11 = src1->nb[1];
const int64_t nb12 = src1->nb[2];
- CUDA_CHECK(cudaSetDevice(g_main_device));
- cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
+ CUDA_CHECK(ggml_cuda_set_device(g_main_device));
+ cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
const struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
const struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
ggml_cpy_f32_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
- ne10, ne11, nb10, nb11, nb12, cudaStream_main);
+ ne10, ne11, nb10, nb11, nb12, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
- ne10, ne11, nb10, nb11, nb12, cudaStream_main);
+ ne10, ne11, nb10, nb11, nb12, main_stream);
} else {
GGML_ASSERT(false);
}
}
void ggml_cuda_diag_mask_inf(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
- GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
- ggml_cuda_op(src0, src1, dst, ggml_cuda_op_diag_mask_inf, true, true);
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_diag_mask_inf);
}
void ggml_cuda_soft_max(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
- GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
- ggml_cuda_op(src0, src1, dst, ggml_cuda_op_soft_max, true, true);
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_soft_max);
}
void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
- GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: this restriction is temporary until non-cont support is implemented
-
- ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, true);
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rope);
}
void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
- GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
- ggml_cuda_op(src0, src1, dst, ggml_cuda_op_alibi, true, true);
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
}
void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
}
void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
- int nrows = ggml_nrows(tensor);
+ const int64_t nrows = ggml_nrows(tensor);
const int64_t ne0 = tensor->ne[0];
struct ggml_tensor_extra_gpu * extra = new struct ggml_tensor_extra_gpu;
memset(extra, 0, sizeof(*extra));
- for (int id = 0; id < g_device_count; ++id) {
+ for (int64_t id = 0; id < g_device_count; ++id) {
if (backend == GGML_BACKEND_GPU && id != g_main_device) {
continue;
}
- cudaSetDevice(id);
+ ggml_cuda_set_device(id);
- int row_low, row_high;
+ int64_t row_low, row_high;
if (backend == GGML_BACKEND_GPU) {
row_low = 0;
row_high = nrows;
extra->data_device[id] = buf;
if (backend == GGML_BACKEND_GPU_SPLIT) {
- CUDA_CHECK(cudaEventCreateWithFlags(&extra->events[id], cudaEventDisableTiming));
+ for (int64_t is = 0; is < MAX_STREAMS; ++is) {
+ CUDA_CHECK(cudaEventCreateWithFlags(&extra->events[id][is], cudaEventDisableTiming));
+ }
}
}
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
- for (int id = 0; id < g_device_count; ++id) {
+ for (int64_t id = 0; id < g_device_count; ++id) {
if (extra->data_device[id] != nullptr) {
- CUDA_CHECK(cudaSetDevice(id));
+ CUDA_CHECK(ggml_cuda_set_device(id));
CUDA_CHECK(cudaFree(extra->data_device[id]));
}
- if (extra->events[id] != nullptr) {
- CUDA_CHECK(cudaSetDevice(id));
- CUDA_CHECK(cudaEventDestroy(extra->events[id]));
+ for (int64_t is = 0; is < MAX_STREAMS; ++is) {
+ if (extra->events[id][is] != nullptr) {
+ CUDA_CHECK(ggml_cuda_set_device(id));
+ CUDA_CHECK(cudaEventDestroy(extra->events[id][is]));
+ }
}
}
force_inplace;
const size_t size = ggml_nbytes(tensor);
- CUDA_CHECK(cudaSetDevice(g_main_device));
+ CUDA_CHECK(ggml_cuda_set_device(g_main_device));
if (inplace && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) {
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra;
char * src0_ddc = (char *) src0_extra->data_device[g_main_device];