]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: fix --split-mode row for MMQ (#13323)
authorJohannes Gäßler <redacted>
Tue, 6 May 2025 06:36:46 +0000 (08:36 +0200)
committerGitHub <redacted>
Tue, 6 May 2025 06:36:46 +0000 (08:36 +0200)
ggml/src/ggml-cuda/mmq.cu
ggml/src/ggml-cuda/mmq.cuh

index 4ccda88630a85881eb6054ee933d88deb42fd545..347d27d552573ddb32e3fcb3d2744ded8476bda0 100644 (file)
@@ -128,7 +128,7 @@ void ggml_cuda_mul_mat_q(
 
         const mmq_args args = {
             src0_d, src0->type, (const int *) src1_q8_1.ptr, nullptr, nullptr, dst_d,
-            ne00, ne01, ne1, s01, s1,
+            ne00, ne01, ne1, s01, ne11, s1,
             ne02, ne12, s02, s12, s2,
             ne03, ne13, s03, s13, s3,
             use_stream_k};
@@ -212,7 +212,7 @@ void ggml_cuda_mul_mat_q(
     // Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.
     const mmq_args args = {
         src0_d, src0->type, (const int *) src1_q8_1.ptr, ids_dst_dev, expert_bounds_dev, dst_d,
-        ne00, ne01, ne_get_rows, s01, s1,
+        ne00, ne01, ne_get_rows, s01, ne_get_rows, s1,
         ne02, ne02, s02, s12, s2,
         ne03, ne13, s03, s13, s3,
         use_stream_k};
@@ -251,7 +251,7 @@ void ggml_cuda_op_mul_mat_q(
         ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && src1_ncols == ne11;
     const mmq_args args = {
         src0_dd_i, src0->type, (const int *) src1_ddq_i, nullptr, nullptr, dst_dd_i,
-        ne00, row_diff, src1_ncols, stride01, nrows_dst,
+        ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst,
         1, 1, 0, 0, 0,
         1, 1, 0, 0, 0,
         use_stream_k};
index e1096dce6d90eb44261034190d1d09a2c840b9ac..b8143a7b23b39a445bda378dec47d46e636c5dc8 100644 (file)
@@ -2522,7 +2522,7 @@ template <ggml_type type, int mmq_x, int nwarps, bool need_check, bool fixup>
 static __device__ __forceinline__ void mul_mat_q_process_tile(
         const char * __restrict__ x, const int offset_x, const int * __restrict__ y,
         const int * __restrict__ ids_dst, float * __restrict__ dst, float * __restrict__ tmp_fixup,
-        const int nrows_x, const int ncols_y, const int stride_row_x, const int stride_col_dst,
+        const int nrows_x, const int stride_row_x, const int ncols_y, const int stride_col_dst,
         const int tile_x_max_i, const int tile_y_max_j, const int kb0_start, const int kb0_stop) {
 
     constexpr int              qk         = ggml_cuda_type_traits<type>::qk;
@@ -2606,7 +2606,7 @@ template <ggml_type type, int mmq_x, int nwarps, bool need_check>
 static __global__ void mul_mat_q(
         const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst,
         const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup,
-        const int ncols_x, const int nrows_x, const int ncols_y, const int stride_row_x, const int stride_col_dst,
+        const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst,
         const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
         const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
 
@@ -2619,8 +2619,8 @@ static __global__ void mul_mat_q(
     constexpr int qk    = ggml_cuda_type_traits<type>::qk;
     constexpr int mmq_y = get_mmq_y_device();
 
-    const int ntx = (ncols_y + mmq_x - 1) / mmq_x; // Number of tiles x
-    const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y
+    const int ntx = (ncols_dst + mmq_x - 1) / mmq_x; // Number of tiles x
+    const int nty = (nrows_x   + mmq_y - 1) / mmq_y; // Number of tiles y
 
     // Initialize the ids for writing back data with just the index.
     // For regular matrix multiplications this is never changed.
@@ -2648,8 +2648,8 @@ static __global__ void mul_mat_q(
 
         // Defaults for regular matrix multiplication:
         int col_low    = 0;
-        int col_high   = ncols_y;
-        int col_diff   = ncols_y;
+        int col_high   = ncols_dst;
+        int col_diff   = ncols_dst;
         int offset_y   = wt*stride_sample_y   + zt*stride_channel_y;
         int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
 
@@ -2689,7 +2689,7 @@ static __global__ void mul_mat_q(
 
         constexpr bool fixup = false;
         mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
-            (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, ncols_y, stride_row_x, stride_col_dst,
+            (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, stride_row_x, ncols_y, stride_col_dst,
              tile_x_max_i, tile_y_max_j, 0, ncols_x/qk);
         return;
     }
@@ -2720,8 +2720,8 @@ static __global__ void mul_mat_q(
 
         // Defaults for regular matrix multiplication:
         int col_low    = 0;
-        int col_high   = ncols_y;
-        int col_diff   = ncols_y;
+        int col_high   = ncols_dst;
+        int col_diff   = ncols_dst;
         int offset_y   = wt*stride_sample_y   + zt*stride_channel_y;
         int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
 
@@ -2767,7 +2767,7 @@ static __global__ void mul_mat_q(
 
         constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
         mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
-            (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, ncols_y, stride_row_x, stride_col_dst,
+            (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, stride_row_x, ncols_y, stride_col_dst,
              tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
 
         kbc += blocks_per_ne00;
@@ -2792,8 +2792,8 @@ static __global__ void mul_mat_q(
 
     // Defaults for regular matrix multiplication:
     int col_low    = 0;
-    int col_high   = ncols_y;
-    int col_diff   = ncols_y;
+    int col_high   = ncols_dst;
+    int col_diff   = ncols_dst;
     int offset_y   = wt*stride_sample_y   + zt*stride_channel_y;
     int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
 
@@ -2834,7 +2834,7 @@ static __global__ void mul_mat_q(
 
     constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
     mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
-        (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, ncols_y, stride_row_x, stride_col_dst,
+        (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, stride_row_x, ncols_y, stride_col_dst,
          tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
 }
 
@@ -2842,7 +2842,7 @@ static __global__ void mul_mat_q(
 template <ggml_type type, int mmq_x, int nwarps, bool need_check>
 static __global__ void mul_mat_q_stream_k_fixup(
         const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile,
-        const int ncols_x, const int nrows_x, const int ncols_y, const int stride_col_dst,
+        const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst,
         const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst) {
     constexpr int     mmq_y           = get_mmq_y_device();
     constexpr int     qk              = ggml_cuda_type_traits<type>::qk;
@@ -2851,8 +2851,8 @@ static __global__ void mul_mat_q_stream_k_fixup(
 
     float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
 
-    const int ntx  = (ncols_y + mmq_x - 1) / mmq_x;
-    const int nty  = (nrows_x + mmq_y - 1) / mmq_y;
+    const int ntx  = (ncols_dst + mmq_x - 1) / mmq_x;
+    const int nty  = (nrows_x   + mmq_y - 1) / mmq_y;
 
     const int bidx0 = blockIdx.x;
 
@@ -2925,8 +2925,8 @@ static __global__ void mul_mat_q_stream_k_fixup(
         const int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst + it*mmq_y;
         dst += offset_dst;
 
-        const int i_max = nrows_x - it*mmq_y - 1;
-        const int j_max = ncols_y - jt*mmq_x - 1;
+        const int i_max = nrows_x   - it*mmq_y - 1;
+        const int j_max = ncols_dst - jt*mmq_x - 1;
 
 #pragma unroll
         for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
@@ -2989,7 +2989,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
 
 struct mmq_args {
     const char * x; ggml_type type_x; const int * y; const int32_t * ids_dst; const int32_t * expert_bounds; float * dst;
-    int64_t ncols_x; int64_t nrows_x; int64_t ncols_y; int64_t stride_row_x; int64_t nrows_dst;
+    int64_t ncols_x; int64_t nrows_x; int64_t ncols_dst; int64_t stride_row_x; int64_t ncols_y; int64_t nrows_dst;
     int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst;
     int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst;
     bool use_stream_k;
@@ -3025,8 +3025,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
     }
 #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
 
-    const int nty  = (args.nrows_x + mmq_y - 1) / mmq_y;
-    const int ntx  = (args.ncols_y + mmq_x - 1) / mmq_x;
+    const int nty  = (args.nrows_x   + mmq_y - 1) / mmq_y;
+    const int ntx  = (args.ncols_dst + mmq_x - 1) / mmq_x;
     const int ntzw = args.nchannels_y * args.nsamples_y;
     const dim3 block_nums_xy_tiling(nty, ntx, ntzw);
 
@@ -3040,14 +3040,14 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
             constexpr bool need_check = false;
             mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
                 (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
-                 args.ncols_x, args.nrows_x, args.ncols_y, args.stride_row_x, args.nrows_dst,
+                 args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
                  channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
                  sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
         } else {
             constexpr bool need_check = true;
             mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
                 (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
-                 args.ncols_x, args.nrows_x, args.ncols_y, args.stride_row_x, args.nrows_dst,
+                 args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
                  channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
                  sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
         }
@@ -3068,7 +3068,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
 
         mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
             (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
-             args.ncols_x, args.nrows_x, args.ncols_y, args.stride_row_x, args.nrows_dst,
+             args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
              channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
              sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
 
@@ -3077,14 +3077,14 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
         }
 
         mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
-            (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_y,
+            (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
              args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
     } else {
         constexpr bool need_check = true;
 
         mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
             (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
-             args.ncols_x, args.nrows_x, args.ncols_y, args.stride_row_x, args.nrows_dst,
+             args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
              channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
              sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
 
@@ -3093,7 +3093,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
         }
 
         mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
-            (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_y,
+            (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
              args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
     }
 }