1;
}
+enum mmvq_parameter_table_id {
+ MMVQ_PARAMETERS_GENERIC = 0,
+ MMVQ_PARAMETERS_GCN,
+ MMVQ_PARAMETERS_RDNA2
+};
+
+static constexpr __device__ mmvq_parameter_table_id get_device_table_id() {
+#if defined(RDNA2) || defined(RDNA3)
+ return MMVQ_PARAMETERS_RDNA2;
+#elif defined(GCN) || defined(CDNA)
+ return MMVQ_PARAMETERS_GCN;
+#else
+ return MMVQ_PARAMETERS_GENERIC;
+#endif
+}
+
+static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
+ if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) {
+ return MMVQ_PARAMETERS_RDNA2;
+ }
+ if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) {
+ return MMVQ_PARAMETERS_GCN;
+ }
+ return MMVQ_PARAMETERS_GENERIC;
+}
+
+static constexpr __host__ __device__ int calc_nwarps(int ncols_y, mmvq_parameter_table_id table_id) {
+ if (table_id == MMVQ_PARAMETERS_GENERIC) {
+ switch (ncols_y) {
+ case 1:
+ case 2:
+ case 3:
+ case 4:
+ return 4;
+ case 5:
+ case 6:
+ case 7:
+ case 8:
+ return 2;
+ default:
+ return 1;
+ }
+ } else if (table_id == MMVQ_PARAMETERS_GCN) {
+ switch (ncols_y) {
+ case 1:
+ case 2:
+ case 3:
+ case 4:
+ return 2;
+ case 5:
+ case 6:
+ case 7:
+ case 8:
+ default:
+ return 1;
+ }
+ }
+ return 1;
+}
+
+static constexpr __host__ __device__ int calc_rows_per_block(int ncols_y, int table_id) {
+ if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) {
+ switch (ncols_y) {
+ case 1:
+ return 1;
+ case 2:
+ case 3:
+ case 4:
+ case 5:
+ case 6:
+ case 7:
+ case 8:
+ return 2;
+ default:
+ return 1;
+ }
+ }
+ return 1;
+}
+
template <ggml_type type, int ncols_y>
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
// tell the compiler to use as many registers as it wants, see nwarps definition below
-__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+__launch_bounds__(calc_nwarps(ncols_y, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
static __global__ void mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qi = ggml_cuda_type_traits<type>::qi;
constexpr int vdr = get_vdr_mmvq(type);
+ constexpr mmvq_parameter_table_id table_id = get_device_table_id();
+ constexpr int nwarps = calc_nwarps(ncols_y, table_id);
+ constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_y, table_id);
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
-#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
- constexpr int nwarps = 1;
- constexpr int rows_per_cuda_block = 1;
-#else
- constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
- constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
-#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
-
- const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
+ const int tid = warp_size*threadIdx.y + threadIdx.x;
const int row0 = rows_per_cuda_block*blockIdx.x;
const int blocks_per_row_x = ncols_x / qk;
const int blocks_per_col_y = nrows_y / QK8_1;
- constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
+ constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
-// partial sum for each thread
+ // partial sum for each thread
float tmp[ncols_y][rows_per_cuda_block] = {0.0f};
const block_q8_1 * y = (const block_q8_1 *) vy;
}
}
- __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][WARP_SIZE];
+ __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][warp_size];
if (threadIdx.y > 0) {
#pragma unroll
for (int j = 0; j < ncols_y; ++j) {
for (int l = 0; l < nwarps-1; ++l) {
tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];
}
- tmp[j][i] = warp_reduce_sum(tmp[j][i]);
+ tmp[j][i] = warp_reduce_sum<warp_size>(tmp[j][i]);
}
if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
}
}
+static std::pair<dim3, dim3> calc_launch_params(const int ncols_y, const int nrows_x, const int warp_size, const mmvq_parameter_table_id table_id) {
+ const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_y, table_id) - 1) / calc_rows_per_block(ncols_y, table_id);
+ const dim3 block_nums(nblocks, 1, 1);
+ const dim3 block_dims(warp_size, calc_nwarps(ncols_y, table_id), 1);
+ return {block_nums, block_dims};
+}
+
template <ggml_type type>
static void mul_mat_vec_q_cuda(
const void * vx, const void * vy, float * dst,
GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
- int id = ggml_cuda_get_device();
-
- int64_t nwarps = 1;
- int64_t rows_per_cuda_block = 1;
-
- if (ggml_cuda_info().devices[id].cc < GGML_CUDA_CC_RDNA2) { // NVIDIA and AMD older than RDNA2
- switch(ncols_y) {
- case 1:
- nwarps = 4;
- rows_per_cuda_block = 1;
- break;
- case 2:
- case 3:
- case 4:
- nwarps = 4;
- rows_per_cuda_block = 2;
- break;
- case 5:
- case 6:
- case 7:
- case 8:
- nwarps = 2;
- rows_per_cuda_block = 2;
- break;
- default:
- GGML_ABORT("fatal error");
- break;
- }
- }
-
- const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block;
- const dim3 block_nums(nblocks, 1, 1);
- const dim3 block_dims(WARP_SIZE, nwarps, 1);
+ const int device = ggml_cuda_get_device();
+ const int warp_size = ggml_cuda_info().devices[device].warp_size;
+ const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);
switch (ncols_y) {
case 1:
- mul_mat_vec_q<type, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+ {
+ constexpr int c_ncols_y = 1;
+ std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break;
+ }
case 2:
- mul_mat_vec_q<type, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+ {
+ constexpr int c_ncols_y = 2;
+ std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break;
+ }
case 3:
- mul_mat_vec_q<type, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+ {
+ constexpr int c_ncols_y = 3;
+ std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break;
+ }
case 4:
- mul_mat_vec_q<type, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+ {
+ constexpr int c_ncols_y = 4;
+ std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break;
+ }
case 5:
- mul_mat_vec_q<type, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+ {
+ constexpr int c_ncols_y = 5;
+ std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break;
+ }
case 6:
- mul_mat_vec_q<type, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+ {
+ constexpr int c_ncols_y = 6;
+ std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break;
+ }
case 7:
- mul_mat_vec_q<type, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+ {
+ constexpr int c_ncols_y = 7;
+ std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break;
+ }
case 8:
- mul_mat_vec_q<type, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+ {
+ constexpr int c_ncols_y = 8;
+ std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break;
+ }
default:
GGML_ABORT("fatal error");
break;