]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: use fastdiv + ggml_cuda_mad for mmvf (#16557)
authorAman Gupta <redacted>
Tue, 14 Oct 2025 11:16:21 +0000 (19:16 +0800)
committerGitHub <redacted>
Tue, 14 Oct 2025 11:16:21 +0000 (13:16 +0200)
* CUDA: use fastdiv + ggml_cuda_mad for mmvf

* use bf16 directly + fix formatting

* Add exception for HIP code

ggml/src/ggml-cuda/mmvf.cu

index 5b21ef05b3c3595a867c0b2b47c06f3004157231..57ab839393aa00746779cc311d719ff812e2c5ac 100644 (file)
@@ -7,14 +7,14 @@ template <typename T, typename type_acc, int ncols_dst, int block_size>
 static __global__ void mul_mat_vec_f(
         const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
         const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
-        const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
-        const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
+        const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
+        const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
     const int row         = blockIdx.x;
     const int channel_dst = blockIdx.y;
-    const int channel_x   = ids ? ids[channel_dst]          : channel_dst / channel_ratio;
+    const int channel_x   = ids ? ids[channel_dst]          : fastdiv((uint32_t) channel_dst, channel_ratio);
     const int channel_y   = ids ? channel_dst % nchannels_y : channel_dst;
     const int sample_dst  = blockIdx.z;
-    const int sample_x    = sample_dst / sample_ratio;
+    const int sample_x    = fastdiv((uint32_t) sample_dst, sample_ratio);
     const int sample_y    = sample_dst;
     const int tid         = threadIdx.x;
 
@@ -47,8 +47,8 @@ static __global__ void mul_mat_vec_f(
 #pragma unroll
             for (int j = 0; j < ncols_dst; ++j) {
                 const float2 tmpy = y2[j*stride_col_y2 + col2];
-                sumf[j] += tmpx.x*tmpy.x;
-                sumf[j] += tmpx.y*tmpy.y;
+                ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
+                ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
             }
         }
     } else if constexpr (std::is_same_v<T, half>) {
@@ -61,8 +61,8 @@ static __global__ void mul_mat_vec_f(
 #pragma unroll
                 for (int j = 0; j < ncols_dst; ++j) {
                     const float2 tmpy = y2[j*stride_col_y2 + col2];
-                    sumf[j] += tmpx.x * tmpy.x;
-                    sumf[j] += tmpx.y * tmpy.y;
+                    ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
+                    ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
                 }
             }
         } else {
@@ -88,16 +88,32 @@ static __global__ void mul_mat_vec_f(
 #endif // FP16_AVAILABLE
         }
     } else if constexpr (std::is_same_v<T, nv_bfloat16>) {
+//TODO: add support for ggml_cuda_mad for hip_bfloat162
+#if defined(GGML_USE_HIP)
         const int * x2 = (const int *) x;
         for (int col2 = tid; col2 < ncols2; col2 += block_size) {
             const int tmpx = x2[col2];
 #pragma unroll
             for (int j = 0; j < ncols_dst; ++j) {
                 const float2 tmpy = y2[j*stride_col_y2 + col2];
-                sumf[j] += ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
-                sumf[j] += ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
+                const float tmpx0 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]);
+                const float tmpx1 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]);
+                ggml_cuda_mad(sumf[j], tmpx0, tmpy.x);
+                ggml_cuda_mad(sumf[j], tmpx1, tmpy.y);
             }
         }
+#else
+        const nv_bfloat162 * x2 = (const nv_bfloat162 *) x;
+        for (int col2 = tid; col2 < ncols2; col2 += block_size) {
+            const nv_bfloat162 tmpx = x2[col2];
+#pragma unroll
+            for (int j = 0; j < ncols_dst; ++j) {
+                const float2 tmpy = y2[j*stride_col_y2 + col2];
+                ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
+                ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
+            }
+        }
+#endif
     } else {
         static_assert(std::is_same_v<T, void>, "unsupported type");
     }
@@ -140,8 +156,8 @@ static void launch_mul_mat_vec_f_cuda(
     GGML_ASSERT(stride_col_y % 2 == 0);
     GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
     GGML_ASSERT(       nsamples_dst  % nsamples_x  == 0);
-    const int64_t channel_ratio = nchannels_dst / nchannels_x;
-    const int64_t sample_ratio  = nsamples_dst  / nsamples_x;
+    const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x);
+    const uint3 sample_ratio_fd  = init_fastdiv_values(nsamples_dst  / nsamples_x);
 
     const int device = ggml_cuda_get_device();
     const int warp_size = ggml_cuda_info().devices[device].warp_size;
@@ -167,50 +183,50 @@ static void launch_mul_mat_vec_f_cuda(
         case   32: {
             mul_mat_vec_f<T, type_acc, ncols_dst,  32><<<block_nums, block_dims, nbytes_shared, stream>>>
                 (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
-                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case   64: {
             mul_mat_vec_f<T, type_acc, ncols_dst,  64><<<block_nums, block_dims, nbytes_shared, stream>>>
                 (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
-                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case   96: {
             mul_mat_vec_f<T, type_acc, ncols_dst,  96><<<block_nums, block_dims, nbytes_shared, stream>>>
                 (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
-                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case  128: {
             mul_mat_vec_f<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
                 (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
-                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case  160: {
             mul_mat_vec_f<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, nbytes_shared, stream>>>
                 (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
-                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case  192: {
             mul_mat_vec_f<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, nbytes_shared, stream>>>
                 (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
-                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case  224: {
             mul_mat_vec_f<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, nbytes_shared, stream>>>
                 (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
-                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case  256: {
             mul_mat_vec_f<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
                 (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
-                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         default: {
             GGML_ABORT("fatal error");