From: Johannes Gäßler Date: Fri, 5 Sep 2025 14:07:02 +0000 (+0200) Subject: CUDA: fastdiv, launch bounds for mmvq + q8_1 quant (llama/15802) X-Git-Tag: v0.9.1~79 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=72febac2c8059eb5f90b58c811c7b38d35b2c153;p=pkg%2Fggml%2Fsources%2Fggml CUDA: fastdiv, launch bounds for mmvq + q8_1 quant (llama/15802) * CUDA: fastdiv, launch bounds for mmvq + q8_1 quant --- diff --git a/src/ggml-cuda/common.cuh b/src/ggml-cuda/common.cuh index a2dc26ea..931524a2 100644 --- a/src/ggml-cuda/common.cuh +++ b/src/ggml-cuda/common.cuh @@ -570,6 +570,8 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) { // // n/d = (mulhi(n, mp) + n) >> L; static const uint3 init_fastdiv_values(uint32_t d) { + GGML_ASSERT(d != 0); + // compute L = ceil(log2(d)); uint32_t L = 0; while (L < 32 && (uint32_t{ 1 } << L) < d) { diff --git a/src/ggml-cuda/mmvq.cu b/src/ggml-cuda/mmvq.cu index b7c30793..52de4e78 100644 --- a/src/ggml-cuda/mmvq.cu +++ b/src/ggml-cuda/mmvq.cu @@ -141,9 +141,10 @@ template __launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) static __global__ void mul_mat_vec_q( const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, float * __restrict__ dst, - const int ncols_x, const int nchannels_y, const int stride_row_x, const int stride_col_y, const int stride_col_dst, - const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, - const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { + const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, + const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x, + const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio, + const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst) { constexpr int qk = ggml_cuda_type_traits::qk; constexpr int qi = ggml_cuda_type_traits::qi; @@ -161,12 +162,12 @@ static __global__ void mul_mat_vec_q( constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi; // The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1. - const int channel_dst = blockIdx.y; - const int channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : channel_dst / channel_ratio; - const int channel_y = ncols_dst == 1 && ids ? channel_dst % nchannels_y : channel_dst; - const int sample_dst = blockIdx.z; - const int sample_x = sample_dst / sample_ratio; - const int sample_y = sample_dst; + const uint32_t channel_dst = blockIdx.y; + const uint32_t channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio); + const uint32_t channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst; + const uint32_t sample_dst = blockIdx.z; + const uint32_t sample_x = fastdiv(sample_dst, sample_ratio); + const uint32_t sample_y = sample_dst; // partial sum for each thread float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}}; @@ -247,8 +248,9 @@ static void mul_mat_vec_q_switch_ncols_dst( GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0); GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE); - const int channel_ratio = nchannels_dst / nchannels_x; - const int sample_ratio = nsamples_dst / nsamples_x; + const uint3 nchannels_y_fd = ids ? init_fastdiv_values(nchannels_y) : make_uint3(0, 0, 0); + const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x); + const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x); const int device = ggml_cuda_get_device(); const int warp_size = ggml_cuda_info().devices[device].warp_size; @@ -256,86 +258,70 @@ static void mul_mat_vec_q_switch_ncols_dst( GGML_ASSERT(!ids || ncols_dst == 1); switch (ncols_dst) { - case 1: - { + case 1: { constexpr int c_ncols_dst = 1; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - break; - } - case 2: - { + (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 2: { constexpr int c_ncols_dst = 2; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - break; - } - case 3: - { + (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 3: { constexpr int c_ncols_dst = 3; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - break; - } - case 4: - { + (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 4: { constexpr int c_ncols_dst = 4; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - break; - } - case 5: - { + (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 5: { constexpr int c_ncols_dst = 5; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - break; - } - case 6: - { + (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 6: { constexpr int c_ncols_dst = 6; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - break; - } - case 7: - { + (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 7: { constexpr int c_ncols_dst = 7; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - break; - } - case 8: - { + (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 8: { constexpr int c_ncols_dst = 8; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - break; - } + (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; default: GGML_ABORT("fatal error"); break; diff --git a/src/ggml-cuda/quantize.cu b/src/ggml-cuda/quantize.cu index a0b03a74..5117f9ff 100644 --- a/src/ggml-cuda/quantize.cu +++ b/src/ggml-cuda/quantize.cu @@ -1,26 +1,27 @@ #include "quantize.cuh" #include +__launch_bounds__(CUDA_QUANTIZE_BLOCK_SIZE, 1) static __global__ void quantize_q8_1( const float * __restrict__ x, void * __restrict__ vy, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, - const int64_t ne0, const int ne1, const int ne2) { + const int64_t ne0, const uint32_t ne1, const uint3 ne2) { const int64_t i0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; if (i0 >= ne0) { return; } + const int64_t i3 = fastdiv(blockIdx.z, ne2); + const int64_t i2 = blockIdx.z - i3*ne2.z; const int64_t i1 = blockIdx.y; - const int64_t i2 = blockIdx.z % ne2; - const int64_t i3 = blockIdx.z / ne2; const int64_t & i00 = i0; const int64_t & i01 = i1; const int64_t & i02 = i2; const int64_t & i03 = i3; - const int64_t i_cont = ((i3*ne2 + i2) * ne1 + i1) * ne0 + i0; + const int64_t i_cont = ((i3*ne2.z + i2) * ne1 + i1) * ne0 + i0; block_q8_1 * y = (block_q8_1 *) vy; @@ -31,10 +32,10 @@ static __global__ void quantize_q8_1( float amax = fabsf(xi); float sum = xi; - amax = warp_reduce_max(amax); - sum = warp_reduce_sum(sum); + amax = warp_reduce_max(amax); + sum = warp_reduce_sum(sum); - const float d = amax / 127; + const float d = amax / 127.0f; const int8_t q = amax == 0.0f ? 0 : roundf(xi / d); y[ib].qs[iqs] = q; @@ -43,8 +44,7 @@ static __global__ void quantize_q8_1( return; } - reinterpret_cast(y[ib].ds.x) = d; - reinterpret_cast(y[ib].ds.y) = sum; + y[ib].ds = make_half2(d, sum); } template @@ -152,10 +152,12 @@ void quantize_row_q8_1_cuda( GGML_ASSERT(!ids); GGML_ASSERT(ne0 % QK8_1 == 0); + const uint3 ne2_fastdiv = init_fastdiv_values(ne2); + const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; const dim3 num_blocks(block_num_x, ne1, ne2*ne3); const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1); - quantize_q8_1<<>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2); + quantize_q8_1<<>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2_fastdiv); GGML_UNUSED(type_src0); }