]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: fastdiv, launch bounds for mmvq + q8_1 quant (llama/15802)
authorJohannes Gäßler <redacted>
Fri, 5 Sep 2025 14:07:02 +0000 (16:07 +0200)
committerGeorgi Gerganov <redacted>
Sat, 20 Sep 2025 10:33:50 +0000 (13:33 +0300)
* CUDA: fastdiv, launch bounds for mmvq + q8_1 quant

src/ggml-cuda/common.cuh
src/ggml-cuda/mmvq.cu
src/ggml-cuda/quantize.cu

index a2dc26eab7e4c4e59f26a05336fa74fb3f1c8f27..931524a2008974889e47868913ff93dbc004dfaa 100644 (file)
@@ -570,6 +570,8 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
 //
 // n/d = (mulhi(n, mp) + n) >> L;
 static const uint3 init_fastdiv_values(uint32_t d) {
+    GGML_ASSERT(d != 0);
+
     // compute L = ceil(log2(d));
     uint32_t L = 0;
     while (L < 32 && (uint32_t{ 1 } << L) < d) {
index b7c3079308e3f473e3e3fb3964f8157971ee1362..52de4e78d1321737dc58811afcdf279a123f40fa 100644 (file)
@@ -141,9 +141,10 @@ template <ggml_type type, int ncols_dst>
 __launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
 static __global__ void mul_mat_vec_q(
         const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, float * __restrict__ dst,
-        const int ncols_x, const int nchannels_y, const int stride_row_x, const int stride_col_y, const int stride_col_dst,
-        const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
-        const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
+        const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
+        const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
+        const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
+        const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst) {
 
     constexpr int qk  = ggml_cuda_type_traits<type>::qk;
     constexpr int qi  = ggml_cuda_type_traits<type>::qi;
@@ -161,12 +162,12 @@ static __global__ void mul_mat_vec_q(
     constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
 
     // The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1.
-    const int channel_dst = blockIdx.y;
-    const int channel_x   = ncols_dst == 1 && ids ? ids[channel_dst]          : channel_dst / channel_ratio;
-    const int channel_y   = ncols_dst == 1 && ids ? channel_dst % nchannels_y : channel_dst;
-    const int sample_dst  = blockIdx.z;
-    const int sample_x    = sample_dst / sample_ratio;
-    const int sample_y    = sample_dst;
+    const uint32_t channel_dst = blockIdx.y;
+    const uint32_t channel_x   = ncols_dst == 1 && ids ? ids[channel_dst]                     : fastdiv(channel_dst, channel_ratio);
+    const uint32_t channel_y   = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
+    const uint32_t sample_dst  = blockIdx.z;
+    const uint32_t sample_x    = fastdiv(sample_dst, sample_ratio);
+    const uint32_t sample_y    = sample_dst;
 
     // partial sum for each thread
     float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}};
@@ -247,8 +248,9 @@ static void mul_mat_vec_q_switch_ncols_dst(
     GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
     GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE);
 
-    const int channel_ratio = nchannels_dst / nchannels_x;
-    const int sample_ratio  = nsamples_dst  / nsamples_x;
+    const uint3 nchannels_y_fd   = ids ? init_fastdiv_values(nchannels_y) : make_uint3(0, 0, 0);
+    const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0)              : init_fastdiv_values(nchannels_dst / nchannels_x);
+    const uint3 sample_ratio_fd  = init_fastdiv_values(nsamples_dst  / nsamples_x);
 
     const int device = ggml_cuda_get_device();
     const int warp_size = ggml_cuda_info().devices[device].warp_size;
@@ -256,86 +258,70 @@ static void mul_mat_vec_q_switch_ncols_dst(
 
     GGML_ASSERT(!ids || ncols_dst == 1);
     switch (ncols_dst) {
-        case 1:
-        {
+        case 1: {
             constexpr int c_ncols_dst = 1;
             std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
             mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
-                (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
-                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
-            break;
-        }
-        case 2:
-        {
+                (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
+                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        case 2: {
             constexpr int c_ncols_dst = 2;
             std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
             mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
-                (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
-                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
-            break;
-        }
-        case 3:
-        {
+                (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
+                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        case 3: {
             constexpr int c_ncols_dst = 3;
             std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
             mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
-                (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
-                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
-            break;
-        }
-        case 4:
-        {
+                (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
+                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        case 4: {
             constexpr int c_ncols_dst = 4;
             std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
             mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
-                (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
-                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
-            break;
-        }
-        case 5:
-        {
+                (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
+                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        case 5: {
             constexpr int c_ncols_dst = 5;
             std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
             mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
-                (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
-                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
-            break;
-        }
-        case 6:
-        {
+                (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
+                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        case 6: {
             constexpr int c_ncols_dst = 6;
             std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
             mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
-                (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
-                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
-            break;
-        }
-        case 7:
-        {
+                (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
+                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        case 7: {
             constexpr int c_ncols_dst = 7;
             std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
             mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
-                (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
-                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
-            break;
-        }
-        case 8:
-        {
+                (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
+                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        case 8: {
             constexpr int c_ncols_dst = 8;
             std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
             mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
-                (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
-                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
-            break;
-        }
+                (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
+                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
         default:
             GGML_ABORT("fatal error");
             break;
index a0b03a740d74c9814a61f51f17096826b963a17d..5117f9ffc0ff9cd3bdcdbde39b0ee2d8d209503d 100644 (file)
@@ -1,26 +1,27 @@
 #include "quantize.cuh"
 #include <cstdint>
 
+__launch_bounds__(CUDA_QUANTIZE_BLOCK_SIZE, 1)
 static __global__ void quantize_q8_1(
         const float * __restrict__ x, void * __restrict__ vy,
         const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
-        const int64_t ne0, const int ne1, const int ne2) {
+        const int64_t ne0, const uint32_t ne1, const uint3 ne2) {
     const int64_t i0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
 
     if (i0 >= ne0) {
         return;
     }
 
+    const int64_t i3 = fastdiv(blockIdx.z, ne2);
+    const int64_t i2 = blockIdx.z - i3*ne2.z;
     const int64_t i1 = blockIdx.y;
-    const int64_t i2 = blockIdx.z % ne2;
-    const int64_t i3 = blockIdx.z / ne2;
 
     const int64_t & i00 = i0;
     const int64_t & i01 = i1;
     const int64_t & i02 = i2;
     const int64_t & i03 = i3;
 
-    const int64_t i_cont = ((i3*ne2 + i2) * ne1 + i1) * ne0 + i0;
+    const int64_t i_cont = ((i3*ne2.z + i2) * ne1 + i1) * ne0 + i0;
 
     block_q8_1 * y = (block_q8_1 *) vy;
 
@@ -31,10 +32,10 @@ static __global__ void quantize_q8_1(
     float amax = fabsf(xi);
     float sum = xi;
 
-    amax = warp_reduce_max(amax);
-    sum  = warp_reduce_sum(sum);
+    amax = warp_reduce_max<QK8_1>(amax);
+    sum  = warp_reduce_sum<QK8_1>(sum);
 
-    const float  d = amax / 127;
+    const float  d = amax / 127.0f;
     const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
 
     y[ib].qs[iqs] = q;
@@ -43,8 +44,7 @@ static __global__ void quantize_q8_1(
         return;
     }
 
-    reinterpret_cast<half&>(y[ib].ds.x) = d;
-    reinterpret_cast<half&>(y[ib].ds.y) = sum;
+    y[ib].ds = make_half2(d, sum);
 }
 
 template <mmq_q8_1_ds_layout ds_layout>
@@ -152,10 +152,12 @@ void quantize_row_q8_1_cuda(
     GGML_ASSERT(!ids);
     GGML_ASSERT(ne0 % QK8_1 == 0);
 
+    const uint3 ne2_fastdiv = init_fastdiv_values(ne2);
+
     const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
     const dim3 num_blocks(block_num_x, ne1, ne2*ne3);
     const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
-    quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
+    quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2_fastdiv);
     GGML_UNUSED(type_src0);
 }