#define MMF_ROWS_PER_BLOCK 32
+struct mmf_ids_data {
+ const int32_t * ids_src_compact = nullptr;
+ const int32_t * ids_dst_compact = nullptr;
+ const int32_t * expert_bounds_dev = nullptr;
+ int n_experts = 0;
+ int sis1 = 0;
+};
+
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols, bool mul_mat_id);
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
}
+
+//This kernel is for larger batch sizes of mul_mat_id
+template <typename T, int rows_per_block, int cols_per_block, int nwarps>
+__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
+static __global__ void mul_mat_f_ids(
+ const T * __restrict__ x, const float * __restrict__ y,
+ const int32_t * __restrict__ ids_src_compact, const int32_t * __restrict__ ids_dst_compact,
+ const int32_t * __restrict__ expert_bounds, float * __restrict__ dst,
+ const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,
+ const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
+ const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
+ const uint3 sis1_fd, const uint3 nch_fd) {
+#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+ typedef tile<16, 8, T> tile_A;
+ typedef tile< 8, 8, T> tile_B;
+ typedef tile<16, 8, float> tile_C;
+
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+ constexpr int tile_k_padded = warp_size + 4;
+ constexpr int ntA = rows_per_block / tile_A::I;
+ constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
+
+ const int row0 = blockIdx.x * rows_per_block;
+
+ const int expert_idx = blockIdx.y;
+ const int expert_start = expert_bounds[expert_idx];
+ const int expert_end = expert_bounds[expert_idx + 1];
+ const int ncols_expert = expert_end - expert_start;
+
+ const int tiles_for_expert = (ncols_expert + cols_per_block - 1) / cols_per_block;
+ const int tile_idx = blockIdx.z;
+ if (tile_idx >= tiles_for_expert) {
+ return;
+ }
+
+ const int col_base = tile_idx * cols_per_block;
+
+ GGML_UNUSED(channel_ratio);
+
+ const int channel_x = expert_idx;
+ const int sample_dst = 0;
+ const int sample_x = sample_dst / sample_ratio;
+ const int sample_y = sample_dst;
+
+ x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row;
+ y += int64_t(sample_y) *stride_sample_y;
+ dst += int64_t(sample_dst)*stride_sample_dst;
+
+ const int32_t * ids_src_expert = ids_src_compact + expert_start;
+ const int32_t * ids_dst_expert = ids_dst_compact + expert_start;
+
+ extern __shared__ char data_mmv[];
+ char * compute_base = data_mmv;
+
+ //const float2 * y2 = (const float2 *) y;
+
+ tile_C C[ntA][ntB];
+
+ T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded);
+
+ for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
+ tile_A A[ntA][warp_size / tile_A::J];
+#pragma unroll
+ for (int itA = 0; itA < ntA; ++itA) {
+#pragma unroll
+ for (int i = 0; i < tile_A::I; ++i) {
+ tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col];
+ }
+#pragma unroll
+ for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
+ load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
+ }
+ }
+
+ if constexpr (std::is_same_v<T, float>) {
+ float vals_buf[2][tile_B::I];
+ auto gather_tile = [&](int tile_idx_local, float *vals) {
+#pragma unroll
+ for (int j0 = 0; j0 < tile_B::I; ++j0) {
+ const int j = j0 + tile_idx_local*tile_B::I;
+ const int global_j = col_base + j;
+ float val = 0.0f;
+ if (j < cols_per_block && global_j < ncols_expert) {
+ const int src_entry = ids_src_expert[global_j];
+ const uint2 qrm = fast_div_modulo((uint32_t) src_entry, sis1_fd);
+ const int token = (int) qrm.x;
+ const int channel = (int) qrm.y;
+ if (token < ncols_dst_total) {
+ val = y[channel*stride_channel_y + token*stride_col_y + col];
+ }
+ }
+ vals[j0] = val;
+ }
+ };
+
+ gather_tile(0, vals_buf[0]);
+
+ int curr_buf = 0;
+ int next_buf = 1;
+#pragma unroll
+ for (int itB = 0; itB < ntB; ++itB) {
+#pragma unroll
+ for (int j0 = 0; j0 < tile_B::I; ++j0) {
+ tile_xy[j0*tile_k_padded + threadIdx.x] = vals_buf[curr_buf][j0];
+ }
+
+ if (itB + 1 < ntB) {
+ gather_tile(itB + 1, vals_buf[next_buf]);
+ }
+
+#pragma unroll
+ for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
+ tile_B B;
+ load_ldmatrix(B, tile_xy + k0, tile_k_padded);
+#pragma unroll
+ for (int itA = 0; itA < ntA; ++itA) {
+ mma(C[itA][itB], A[itA][k0/tile_B::J], B);
+ }
+ }
+
+ if (itB + 1 < ntB) {
+ curr_buf ^= 1;
+ next_buf ^= 1;
+ }
+ }
+ } else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
+ float2 vals_buf[2][tile_B::I];
+ auto gather_tile = [&](int tile_idx_local, float2 *vals) {
+#pragma unroll
+ for (int j0 = 0; j0 < tile_B::I; ++j0) {
+ const int j = j0 + tile_idx_local*tile_B::I;
+ const int global_j = col_base + j;
+ float2 tmp = make_float2(0.0f, 0.0f);
+ if (j < cols_per_block && global_j < ncols_expert) {
+ const int src_entry = ids_src_expert[global_j];
+ const uint2 qrm = fast_div_modulo((uint32_t) src_entry, sis1_fd);
+ const int token = (int) qrm.x;
+ const int channel = (int) qrm.y;
+ if (token < ncols_dst_total) {
+ tmp = *(const float2*) &y[channel*stride_channel_y + 2*(token*stride_col_y + col)];
+ }
+ }
+ vals[j0] = tmp;
+ }
+ };
+
+ if (ntB > 0) {
+ gather_tile(0, vals_buf[0]);
+ }
+
+ int curr_buf = 0;
+ int next_buf = 1;
+#pragma unroll
+ for (int itB = 0; itB < ntB; ++itB) {
+#pragma unroll
+ for (int j0 = 0; j0 < tile_B::I; ++j0) {
+ const float2 tmp = vals_buf[curr_buf][j0];
+ tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
+ }
+
+ if (itB + 1 < ntB) {
+ gather_tile(itB + 1, vals_buf[next_buf]);
+ }
+
+#pragma unroll
+ for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
+ tile_B B;
+ load_ldmatrix(B, tile_xy + k0, tile_k_padded);
+#pragma unroll
+ for (int itA = 0; itA < ntA; ++itA) {
+ mma(C[itA][itB], A[itA][k0/tile_B::J], B);
+ }
+ }
+
+ if (itB + 1 < ntB) {
+ curr_buf ^= 1;
+ next_buf ^= 1;
+ }
+ }
+ } else {
+ static_assert(std::is_same_v<T, void>, "unsupported type");
+ }
+ }
+
+ float * buf_iw = (float *) compute_base;
+ constexpr int kiw = nwarps*rows_per_block + 4;
+
+ if (nwarps > 1) {
+ __syncthreads();
+ }
+#pragma unroll
+ for (int itB = 0; itB < ntB; ++itB) {
+#pragma unroll
+ for (int itA = 0; itA < ntA; ++itA) {
+#pragma unroll
+ for (int l = 0; l < tile_C::ne; ++l) {
+ const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);
+ const int j = itB*tile_C::J + tile_C::get_j(l);
+ buf_iw[j*kiw + i] = C[itA][itB].x[l];
+ }
+ }
+ }
+
+ if (nwarps > 1) {
+ __syncthreads();
+ }
+
+#pragma unroll
+ for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+ if (j0 + nwarps > cols_per_block && j >= cols_per_block) {
+ return;
+ }
+
+ float sum = 0.0f;
+ static_assert(rows_per_block == warp_size, "need loop/check");
+#pragma unroll
+ for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
+ const int i = i0 + threadIdx.x;
+
+ sum += buf_iw[j*kiw + i];
+ }
+
+ const int global_j = col_base + j;
+ if (j < cols_per_block && global_j < ncols_expert && nchannels_dst > 0) {
+ const int dst_entry = ids_dst_expert[global_j];
+ const uint2 qrm = fast_div_modulo((uint32_t) dst_entry, nch_fd);
+ const int token = (int) qrm.x;
+ if (token < ncols_dst_total) {
+ const int slot = (int) qrm.y;
+ dst[slot*stride_channel_dst + token*stride_col_dst + row0 + threadIdx.x] = sum;
+ }
+ }
+ }
+#else
+ GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst,
+ ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
+ NO_DEVICE_CODE;
+#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+}
+
template<typename T, int cols_per_block, int nwarps>
static inline void mul_mat_f_switch_ids(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t stride_col_id, const int64_t stride_row_id,
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
- const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) {
- if (ids) {
+ const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream,
+ const mmf_ids_data * ids_data) {
+ const bool has_ids_data = ids_data && ids_data->ids_src_compact;
+
+ // Use the compact-ids kernel only for larger tiles; for small ncols_dst (< 16)
+ // we prefer the normal mul_mat_f path with has_ids=true.
+ if (has_ids_data && ncols_dst > 16) {
+ const int max_tiles = (int) ((ncols_dst + cols_per_block - 1) / cols_per_block);
+ if (max_tiles == 0) {
+ return;
+ }
+ dim3 block_nums_ids(block_nums.x, ids_data->n_experts, max_tiles);
+
+ const uint3 sis1_fd = ids_data->sis1 > 0 ? init_fastdiv_values((uint32_t) ids_data->sis1) : make_uint3(0, 0, 1);
+ const uint3 nch_fd = init_fastdiv_values((uint32_t) nchannels_dst);
+
+ mul_mat_f_ids<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
+ (x, y, ids_data->ids_src_compact, ids_data->ids_dst_compact, ids_data->expert_bounds_dev, dst,
+ ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst,
+ sis1_fd, nch_fd);
+ } else if (ids) {
const int64_t col_tiles = (ncols_dst + cols_per_block - 1) / cols_per_block;
dim3 block_nums_ids = block_nums;
block_nums_ids.y *= col_tiles;
+
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
- (x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+ (x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} else {
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
- cudaStream_t stream) {
+ cudaStream_t stream, const mmf_ids_data * ids_data) {
typedef tile<16, 8, T> tile_A;
typedef tile< 8, 8, T> tile_B;
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
- const int64_t grid_y = ids ? nchannels_x : nchannels_dst; // per expert when ids present
+ const int64_t grid_y = ids ? nchannels_x : nchannels_dst;
const dim3 block_nums(nrows_x/rows_per_block, grid_y, nsamples_dst);
const dim3 block_dims(warp_size, nwarps_best, 1);
mul_mat_f_switch_ids<T, cols_per_block, 1>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
+ ids_data);
} break;
case 2: {
mul_mat_f_switch_ids<T, cols_per_block, 2>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
+ ids_data);
} break;
case 3: {
mul_mat_f_switch_ids<T, cols_per_block, 3>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
+ ids_data);
} break;
case 4: {
mul_mat_f_switch_ids<T, cols_per_block, 4>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
+ ids_data);
} break;
case 5: {
mul_mat_f_switch_ids<T, cols_per_block, 5>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
+ ids_data);
} break;
case 6: {
mul_mat_f_switch_ids<T, cols_per_block, 6>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
+ ids_data);
} break;
case 7: {
mul_mat_f_switch_ids<T, cols_per_block, 7>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
+ ids_data);
} break;
case 8: {
mul_mat_f_switch_ids<T, cols_per_block, 8>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
+ ids_data);
} break;
default: {
GGML_ABORT("fatal error");
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
- cudaStream_t stream) {
+ cudaStream_t stream, const mmf_ids_data * ids_data) {
const int ncols_case = (ids && ncols_dst > 16) ? 16 : ncols_dst;
case 1: {
mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 2: {
mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 3: {
mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 4: {
mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 5: {
mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 6: {
mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 7: {
mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 8: {
mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 9: {
mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 10: {
mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 11: {
mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 12: {
mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 13: {
mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 14: {
mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 15: {
mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 16: {
mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
default: {
GGML_ABORT("fatal error");
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \
- cudaStream_t stream);
+ cudaStream_t stream, const mmf_ids_data * ids_data);
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
#define DECL_MMF_CASE_EXTERN(ncols_dst) \
--- /dev/null
+#include "common.cuh"
+#include "mmid.cuh"
+
+// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each.
+struct mm_ids_helper_store {
+ uint32_t data;
+
+ __device__ mm_ids_helper_store(const uint32_t it, const uint32_t iex_used) {
+ data = (it & 0x003FFFFF) | (iex_used << 22);
+ }
+
+ __device__ uint32_t it() const {
+ return data & 0x003FFFFF;
+ }
+
+ __device__ uint32_t iex_used() const {
+ return data >> 22;
+ }
+};
+static_assert(sizeof(mm_ids_helper_store) == 4, "unexpected size for mm_ids_helper_store");
+
+// Helper function for mul_mat_id, converts ids to a more convenient format.
+// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert.
+// ids_dst describes the same mapping but for the dst tensor.
+// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1].
+template <int n_expert_used_template>
+__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1)
+static __global__ void mm_ids_helper(
+ const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
+ const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) {
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+ const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template;
+ const int expert = blockIdx.x;
+
+ extern __shared__ char data_mm_ids_helper[];
+ mm_ids_helper_store * store = (mm_ids_helper_store *) data_mm_ids_helper;
+
+ int nex_prev = 0; // Number of columns for experts with a lower index.
+ int it_compact = 0; // Running index for the compact slice of this expert.
+
+ if constexpr (n_expert_used_template == 0) {
+ // Generic implementation:
+ for (int it = 0; it < n_tokens; ++it) {
+ int iex_used = -1; // The index at which the expert is used, if any.
+ for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) {
+ const int expert_used = ids[it*si1 + iex];
+ nex_prev += expert_used < expert;
+ if (expert_used == expert) {
+ iex_used = iex;
+ }
+ }
+
+ if (iex_used != -1) {
+ store[it_compact] = mm_ids_helper_store(it, iex_used);
+ }
+
+ if (warp_reduce_any<warp_size>(iex_used != -1)) {
+ it_compact++;
+ }
+ }
+ } else {
+ // Implementation optimized for specific numbers of experts used:
+ static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used");
+ const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2.
+ for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) {
+ const int it = it0 + threadIdx.x / neu_padded;
+
+ const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any.
+ const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ?
+ ids[it*si1 + iex] : INT_MAX;
+ const int iex_used = expert_used == expert ? iex : -1;
+ nex_prev += expert_used < expert;
+
+ // Whether the threads at this token position have used the expert:
+ const int it_compact_add_self = warp_reduce_any<neu_padded>(iex_used != -1);
+
+ // Do a scan over threads at lower token positions in warp to get the correct index for writing data:
+ int it_compact_add_lower = 0;
+#pragma unroll
+ for (int offset = neu_padded; offset < warp_size; offset += neu_padded) {
+ const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size);
+ if (threadIdx.x >= static_cast<unsigned int>(offset)) {
+ it_compact_add_lower += tmp;
+ }
+ }
+
+ if (iex_used != -1) {
+ store[it_compact + it_compact_add_lower] = mm_ids_helper_store(it, iex_used);
+ }
+
+ // The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads:
+ it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size);
+ }
+ }
+ nex_prev = warp_reduce_sum<warp_size>(nex_prev);
+
+ for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) {
+ const mm_ids_helper_store store_it = store[itc];
+ const int it = store_it.it();
+ const int iex_used = store_it.iex_used();
+ ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y;
+ ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;
+ }
+
+ if (threadIdx.x != 0) {
+ return;
+ }
+
+ expert_bounds[expert] = nex_prev;
+
+ if (expert < static_cast<int>(gridDim.x) - 1) {
+ return;
+ }
+
+ expert_bounds[gridDim.x] = nex_prev + it_compact;
+}
+
+template <int n_expert_used_template>
+static void launch_mm_ids_helper(
+ const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
+ const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
+ GGML_ASSERT(n_tokens < (1 << 22) && "too few bits in mm_ids_helper_store");
+ GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mm_ids_helper_store");
+
+ const int id = ggml_cuda_get_device();
+ const int warp_size = ggml_cuda_info().devices[id].warp_size;
+ const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
+ CUDA_SET_SHARED_MEMORY_LIMIT(mm_ids_helper<n_expert_used_template>, smpbo);
+
+ const dim3 num_blocks(n_experts, 1, 1);
+ const dim3 block_size(warp_size, 1, 1);
+ const size_t nbytes_shared = n_tokens*sizeof(mm_ids_helper_store);
+ GGML_ASSERT(nbytes_shared <= smpbo);
+ mm_ids_helper<n_expert_used_template><<<num_blocks, block_size, nbytes_shared, stream>>>
+ (ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
+}
+
+void ggml_cuda_launch_mm_ids_helper(
+ const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
+ const int n_experts, const int n_tokens, const int n_expert_used, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
+ switch (n_expert_used) {
+ case 2:
+ launch_mm_ids_helper< 2>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
+ break;
+ case 4:
+ launch_mm_ids_helper< 4>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
+ break;
+ case 6:
+ launch_mm_ids_helper< 6>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
+ break;
+ case 8:
+ launch_mm_ids_helper< 8>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
+ break;
+ case 16:
+ launch_mm_ids_helper<16>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
+ break;
+ case 32:
+ launch_mm_ids_helper<32>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
+ break;
+ default:
+ launch_mm_ids_helper< 0>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
+ break;
+ }
+}
#include "mmq.cuh"
#include "quantize.cuh"
-
-#include <vector>
-
-// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each.
-struct mmq_ids_helper_store {
- uint32_t data;
-
- __device__ mmq_ids_helper_store(const uint32_t it, const uint32_t iex_used) {
- data = (it & 0x003FFFFF) | (iex_used << 22);
- }
-
- __device__ uint32_t it() const {
- return data & 0x003FFFFF;
- }
-
- __device__ uint32_t iex_used() const {
- return data >> 22;
- }
-};
-static_assert(sizeof(mmq_ids_helper_store) == 4, "unexpected size for mmq_ids_helper_store");
-
-// Helper function for mul_mat_id, converts ids to a more convenient format.
-// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert.
-// ids_dst describes the same mapping but for the dst tensor.
-// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1].
-template <int n_expert_used_template>
-__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1)
-static __global__ void mmq_ids_helper(
- const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
- const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) {
- constexpr int warp_size = ggml_cuda_get_physical_warp_size();
- const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template;
- const int expert = blockIdx.x;
-
- extern __shared__ char data_mmq_ids_helper[];
- mmq_ids_helper_store * store = (mmq_ids_helper_store *) data_mmq_ids_helper;
-
- int nex_prev = 0; // Number of columns for experts with a lower index.
- int it_compact = 0; // Running index for the compact slice of this expert.
-
- if constexpr (n_expert_used_template == 0) {
- // Generic implementation:
- for (int it = 0; it < n_tokens; ++it) {
- int iex_used = -1; // The index at which the expert is used, if any.
- for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) {
- const int expert_used = ids[it*si1 + iex];
- nex_prev += expert_used < expert;
- if (expert_used == expert) {
- iex_used = iex;
- }
- }
-
- if (iex_used != -1) {
- store[it_compact] = mmq_ids_helper_store(it, iex_used);
- }
-
- if (warp_reduce_any<warp_size>(iex_used != -1)) {
- it_compact++;
- }
- }
- } else {
- // Implementation optimized for specific numbers of experts used:
- static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used");
- const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2.
- for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) {
- const int it = it0 + threadIdx.x / neu_padded;
-
- const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any.
- const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ?
- ids[it*si1 + iex] : INT_MAX;
- const int iex_used = expert_used == expert ? iex : -1;
- nex_prev += expert_used < expert;
-
- // Whether the threads at this token position have used the expert:
- const int it_compact_add_self = warp_reduce_any<neu_padded>(iex_used != -1);
-
- // Do a scan over threads at lower token positions in warp to get the correct index for writing data:
- int it_compact_add_lower = 0;
-#pragma unroll
- for (int offset = neu_padded; offset < warp_size; offset += neu_padded) {
- const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size);
- if (threadIdx.x >= static_cast<unsigned int>(offset)) {
- it_compact_add_lower += tmp;
- }
- }
-
- if (iex_used != -1) {
- store[it_compact + it_compact_add_lower] = mmq_ids_helper_store(it, iex_used);
- }
-
- // The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads:
- it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size);
- }
- }
- nex_prev = warp_reduce_sum<warp_size>(nex_prev);
-
- for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) {
- const mmq_ids_helper_store store_it = store[itc];
- const int it = store_it.it();
- const int iex_used = store_it.iex_used();
- ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y;
- ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;
- }
-
- if (threadIdx.x != 0) {
- return;
- }
-
- expert_bounds[expert] = nex_prev;
-
- if (expert < static_cast<int>(gridDim.x) - 1) {
- return;
- }
-
- expert_bounds[gridDim.x] = nex_prev + it_compact;
-}
-
-template <int n_expert_used_template>
-static void launch_mmq_ids_helper(
- const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
- const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
- GGML_ASSERT(n_tokens < (1 << 22) && "too few bits in mmq_ids_helper_store");
- GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mmq_ids_helper_store");
-
- const int id = ggml_cuda_get_device();
- const int warp_size = ggml_cuda_info().devices[id].warp_size;
- const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
- CUDA_SET_SHARED_MEMORY_LIMIT(mmq_ids_helper<n_expert_used_template>, smpbo);
-
- const dim3 num_blocks(n_experts, 1, 1);
- const dim3 block_size(warp_size, 1, 1);
- const size_t nbytes_shared = n_tokens*sizeof(mmq_ids_helper_store);
- GGML_ASSERT(nbytes_shared <= smpbo);
- mmq_ids_helper<n_expert_used_template><<<num_blocks, block_size, nbytes_shared, stream>>>
- (ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
-}
+#include "mmid.cuh"
static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
switch (args.type_x) {
const int si1 = ids->nb[1] / ggml_element_size(ids);
const int sis1 = nb12 / nb11;
- switch (n_expert_used) {
- case 2:
- launch_mmq_ids_helper< 2> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
- ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
- break;
- case 4:
- launch_mmq_ids_helper< 4> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
- ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
- break;
- case 6:
- launch_mmq_ids_helper< 6> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
- ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
- break;
- case 8:
- launch_mmq_ids_helper< 8> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
- ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
- break;
- case 16:
- launch_mmq_ids_helper<16> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
- ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
- break;
- case 32:
- launch_mmq_ids_helper<32> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
- ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
- break;
- default:
- launch_mmq_ids_helper< 0> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
- ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
- break;
- }
+ ggml_cuda_launch_mm_ids_helper((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
CUDA_CHECK(cudaGetLastError());
}