]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CUDA: mul_mat_id for mmf for bs <= 64 for f16 and bs <= 32 for f32 (llama/16277)
authorAman Gupta <redacted>
Sat, 27 Sep 2025 16:49:32 +0000 (00:49 +0800)
committerGeorgi Gerganov <redacted>
Mon, 29 Sep 2025 12:18:11 +0000 (15:18 +0300)
* CUDA: mul_mat_id for mmf for bs <= 64 for f16 and bs <= 32 for f32

This commit adds mul_mat_id support for ncols_dst >= 16. It does this by
packing ncols_dst tiles into the blockDim.y.

My tests on a RTX 3090 show that this is faster than the cuBLAS fallback
for f16 till bs=64, and for f32 till bs=32

* Review: refactor if statement

ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/mmf.cu
ggml/src/ggml-cuda/mmf.cuh

index 8c8647b147369d27f84154e2d505c320cbfa39f5..5cd1e0d862db4fe96eeb4228555d6ffe9451edad 100644 (file)
@@ -2031,7 +2031,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
             const int cc            = ggml_cuda_info().devices[id].cc;
             const int warp_size     = ggml_cuda_info().devices[id].warp_size;
             use_mul_mat_q           = use_mul_mat_q             && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
-            use_mul_mat_f           = use_mul_mat_f             && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1]);
+            use_mul_mat_f           = use_mul_mat_f             && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1], /*mul_mat_id=*/false);
             use_mul_mat_vec_f       = use_mul_mat_vec_f         && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]);
             any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16   || !fast_fp16_hardware_available(cc);
         }
@@ -2039,7 +2039,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
         const int cc            = ggml_cuda_info().devices[ctx.device].cc;
         const int warp_size     = ggml_cuda_info().devices[ctx.device].warp_size;
         use_mul_mat_q           = use_mul_mat_q             && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
-        use_mul_mat_f           = use_mul_mat_f             && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1]);
+        use_mul_mat_f           = use_mul_mat_f             && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1], /*mul_mat_id=*/false);
         use_mul_mat_vec_f       = use_mul_mat_vec_f         && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]);
         any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16   || !fast_fp16_hardware_available(cc);
     }
@@ -2111,7 +2111,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
             return;
         }
 
-        if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src1->ne[2])) {
+        if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src1->ne[2], /*mul_mat_id=*/true)) {
             ggml_cuda_mul_mat_f(ctx, src0, src1, ids, dst);
             return;
         }
index 16331e9ecfad0d9d89fe8ce0eb1057cdf5d4fdd3..599e085ee91b7e058ea7184bab002064f38b5a4b 100644 (file)
@@ -84,7 +84,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
     }
 }
 
-bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, const int src1_ncols) {
+bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, const int src1_ncols, bool mul_mat_id) {
 
     if (ggml_is_quantized(type)) {
         return false;
@@ -96,8 +96,18 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
     if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) {
         return false;
     }
-    if (src1_ncols > 16) {
-        return false;
+
+    if (mul_mat_id) {
+        if (type == GGML_TYPE_F32 && src1_ncols > 32) {
+            return false;
+        }
+        if ((type == GGML_TYPE_F16 || type == GGML_TYPE_BF16) && src1_ncols > 64) {
+            return false;
+        }
+    } else {
+        if (src1_ncols > 16) {
+            return false;
+        }
     }
 
     switch (type) {
index 61e3bf30152c75a8618c9da318df04cbf42a3350..a6c3adfcf1704a99538805898f148dff3d11eb27 100644 (file)
@@ -9,13 +9,13 @@ using namespace ggml_cuda_mma;
 
 void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
 
-bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols);
+bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols, bool mul_mat_id);
 
 template <typename T, int rows_per_block, int cols_per_block, int nwarps, bool has_ids>
 __launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
 static __global__ void mul_mat_f(
         const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
-        const int ncols, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,
+        const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,
         const int stride_col_id, const int stride_row_id,
         const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
         const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
@@ -31,9 +31,20 @@ static __global__ void mul_mat_f(
 
     const int row0        = blockIdx.x * rows_per_block;
 
-    const int expert_idx  = has_ids ? blockIdx.y : 0;
+    int expert_idx = 0;
+    int col_base = 0;
+
     const int channel_dst = has_ids ? 0 : blockIdx.y;
 
+    if constexpr (has_ids) {
+        // experts + tiles of ncols_dst are packed in the y dimension
+        int col_tiles = (ncols_dst_total + cols_per_block - 1) / cols_per_block;
+        const int nchannels_x = gridDim.y / col_tiles;
+        const int tile_idx = blockIdx.y / nchannels_x;
+        expert_idx = blockIdx.y - tile_idx * nchannels_x;
+        col_base = tile_idx * cols_per_block;
+    }
+
     const int channel_x   = has_ids ? expert_idx : (channel_dst / channel_ratio);
     const int channel_y   = channel_dst;
     const int sample_dst  = blockIdx.z;
@@ -44,6 +55,14 @@ static __global__ void mul_mat_f(
     y   += int64_t(sample_y)  *stride_sample_y   + (has_ids ? 0 : channel_y  *stride_channel_y);
     dst += int64_t(sample_dst)*stride_sample_dst + (has_ids ? 0 : channel_dst*stride_channel_dst);
 
+    if constexpr (has_ids) {
+        constexpr int y_stride_scale = std::is_same_v<T, float> ? 1 : 2;
+        const int64_t col_offset = col_base;
+        y   += col_offset * stride_col_y * y_stride_scale;
+        dst += col_offset * stride_col_dst;
+        ids += col_offset * stride_row_id;
+    }
+
     const float2 * y2 = (const float2 *) y;
 
     extern __shared__ char data_mmv[];
@@ -61,12 +80,17 @@ static __global__ void mul_mat_f(
 
         for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
             const int j = j0 + threadIdx.y;
-            const int32_t * __restrict__ id_row = ids + j*stride_row_id;
 
             if (threadIdx.x == 0) {
                 slot_map[j] = -1;
             }
 
+            if (col_base + j >= ncols_dst_total) {
+                continue;
+            }
+
+            const int32_t * __restrict__ id_row = ids + j*stride_row_id;
+
             for (int k = threadIdx.x; k < nchannels_dst; k += warp_size) {
                 int match = id_row[k*stride_col_id] == expert_idx;
 
@@ -108,7 +132,8 @@ static __global__ void mul_mat_f(
                     if constexpr (!has_ids) {
                         tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f;
                     } else {
-                        tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[slot_map[j]*stride_channel_y + j*stride_col_y + col] : 0.0f;
+                        const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0;
+                        tile_xy[j0*tile_k_padded + threadIdx.x] = valid ? y[slot_map[j]*stride_channel_y + j*stride_col_y + col] : 0.0f;
                     }
                 }
             } else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
@@ -120,7 +145,8 @@ static __global__ void mul_mat_f(
                         const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
                         tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
                     } else {
-                        float2 tmp = j < cols_per_block && slot_map[j] >= 0 ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f);
+                        const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0;
+                        float2 tmp = valid ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f);
                         tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
                     }
                 }
@@ -183,14 +209,14 @@ static __global__ void mul_mat_f(
             dst[j*stride_col_dst + row0 + threadIdx.x] = sum;
         } else {
             const int slot = (j < cols_per_block) ? slot_map[j] : -1;
-            if (slot >= 0) {
+            if (slot >= 0 && (col_base + j) < ncols_dst_total) {
                 dst[slot*stride_channel_dst + j*stride_col_dst + row0 + threadIdx.x] = sum;
             }
         }
     }
 #else
     GGML_UNUSED_VARS(x, y, ids, dst,
-        ncols, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+        ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
         stride_col_id, stride_row_id,
         channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
         sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
@@ -201,20 +227,23 @@ static __global__ void mul_mat_f(
 template<typename T, int cols_per_block, int nwarps>
 static inline void mul_mat_f_switch_ids(
         const T * x, const float * y, const int32_t * ids, float * dst,
-        const int64_t ncols_x, const int64_t nchannels_dst,
+        const int64_t ncols_x, const int64_t ncols_dst, const int64_t nchannels_dst,
         const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
         const int64_t stride_col_id, const int64_t stride_row_id,
         const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
         const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
         const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) {
     if (ids) {
-        mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums, block_dims, nbytes_shared_total, stream>>>
-             (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+        const int64_t col_tiles = (ncols_dst + cols_per_block - 1) / cols_per_block;
+        dim3 block_nums_ids = block_nums;
+        block_nums_ids.y *= col_tiles;
+        mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
+             (x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
              stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
              sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
     } else {
         mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
-            (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+            (x, y, ids, dst, ncols_x, cols_per_block, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
              stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
              sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
     }
@@ -223,7 +252,8 @@ static inline void mul_mat_f_switch_ids(
 template <typename T, int cols_per_block>
 void mul_mat_f_cuda(
         const T * x, const float * y, const int32_t * ids, float * dst,
-        const int64_t ncols_x, const int64_t nrows_x, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
+        const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
+        const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
         const int64_t stride_col_id, const int64_t stride_row_id,
         const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
         const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
@@ -268,49 +298,49 @@ void mul_mat_f_cuda(
     switch (nwarps_best) {
         case 1: {
             mul_mat_f_switch_ids<T, cols_per_block, 1>(
-                x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+                x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
         } break;
         case 2: {
             mul_mat_f_switch_ids<T, cols_per_block, 2>(
-                x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+                x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
         } break;
         case 3: {
             mul_mat_f_switch_ids<T, cols_per_block, 3>(
-                x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+                x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
         } break;
         case 4: {
             mul_mat_f_switch_ids<T, cols_per_block, 4>(
-                x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+                x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
         } break;
         case 5: {
             mul_mat_f_switch_ids<T, cols_per_block, 5>(
-                x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+                x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
         } break;
         case 6: {
             mul_mat_f_switch_ids<T, cols_per_block, 6>(
-                x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+                x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
         } break;
         case 7: {
             mul_mat_f_switch_ids<T, cols_per_block, 7>(
-                x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+                x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
         } break;
         case 8: {
             mul_mat_f_switch_ids<T, cols_per_block, 8>(
-                x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+                x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
         } break;
@@ -332,84 +362,89 @@ static void mul_mat_f_switch_cols_per_block(
         const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
         const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
         cudaStream_t stream) {
-    switch (ncols_dst) {
+
+    const int ncols_case = (ids && ncols_dst > 16) ? 16 : ncols_dst;
+
+    GGML_ASSERT(ids || ncols_dst <= 16);
+
+    switch (ncols_case) {
         case  1: {
-            mul_mat_f_cuda<T,  1>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda<T,  1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
         } break;
         case  2: {
-            mul_mat_f_cuda<T,  2>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda<T,  2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
         } break;
         case  3: {
-            mul_mat_f_cuda<T,  3>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda<T,  3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
         } break;
         case  4: {
-            mul_mat_f_cuda<T,  4>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda<T,  4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
         } break;
         case  5: {
-            mul_mat_f_cuda<T,  5>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda<T,  5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y,  stride_sample_dst, stream);
         } break;
         case  6: {
-            mul_mat_f_cuda<T,  6>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda<T,  6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
         } break;
         case  7: {
-            mul_mat_f_cuda<T,  7>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda<T,  7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
         } break;
         case  8: {
-            mul_mat_f_cuda<T,  8>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda<T,  8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
         } break;
         case  9: {
-            mul_mat_f_cuda<T,  9>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda<T,  9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
         } break;
         case 10: {
-            mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
         } break;
         case 11: {
-            mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
         } break;
         case 12: {
-            mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
         } break;
         case 13: {
-            mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
         } break;
         case 14: {
-            mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
         } break;
         case 15: {
-            mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
         } break;
         case 16: {
-            mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
         } break;
@@ -422,7 +457,7 @@ static void mul_mat_f_switch_cols_per_block(
 #define DECL_MMF_CASE_HELPER(T, ncols_dst) \
     template void mul_mat_f_cuda<T, ncols_dst>( \
         const T * x, const float * y, const int32_t * ids, float * dst, \
-        const int64_t ncols_x, const int64_t nrows_x, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \
+        const int64_t ncols_x, const int64_t nrows_x, int64_t ncols_dst_total, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \
         const int64_t stride_col_id, const int64_t stride_row_id, \
         const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \
         const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\