#include <vector>
+// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each.
+struct mmq_ids_helper_store {
+ uint32_t data;
+
+ __device__ mmq_ids_helper_store(const uint32_t it, const uint32_t iex_used) {
+ data = (it & 0x003FFFFF) | (iex_used << 22);
+ }
+
+ __device__ uint32_t it() const {
+ return data & 0x003FFFFF;
+ }
+
+ __device__ uint32_t iex_used() const {
+ return data >> 22;
+ }
+};
+static_assert(sizeof(mmq_ids_helper_store) == 4, "unexpected size for mmq_ids_helper_store");
+
+// Helper function for mul_mat_id, converts ids to a more convenient format.
+// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert.
+// ids_dst describes the same mapping but for the dst tensor.
+// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1].
+template <int n_expert_used_template>
+__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1)
+static __global__ void mmq_ids_helper(
+ const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
+ const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) {
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+ const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template;
+ const int expert = blockIdx.x;
+
+ extern __shared__ char data_mmq_ids_helper[];
+ mmq_ids_helper_store * store = (mmq_ids_helper_store *) data_mmq_ids_helper;
+
+ int nex_prev = 0; // Number of columns for experts with a lower index.
+ int it_compact = 0; // Running index for the compact slice of this expert.
+
+ if constexpr (n_expert_used_template == 0) {
+ // Generic implementation:
+ for (int it = 0; it < n_tokens; ++it) {
+ int iex_used = -1; // The index at which the expert is used, if any.
+ for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) {
+ const int expert_used = ids[it*si1 + iex];
+ nex_prev += expert_used < expert;
+ if (expert_used == expert) {
+ iex_used = iex;
+ }
+ }
+
+ if (iex_used != -1) {
+ store[it_compact] = mmq_ids_helper_store(it, iex_used);
+ }
+
+ if (warp_reduce_any<warp_size>(iex_used != -1)) {
+ it_compact++;
+ }
+ }
+ } else {
+ // Implementation optimized for specific numbers of experts used:
+ static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used");
+ const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2.
+ for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) {
+ const int it = it0 + threadIdx.x / neu_padded;
+
+ const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any.
+ const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ?
+ ids[it*si1 + iex] : INT_MAX;
+ const int iex_used = expert_used == expert ? iex : -1;
+ nex_prev += expert_used < expert;
+
+ // Whether the threads at this token position have used the expert:
+ const int it_compact_add_self = warp_reduce_any<neu_padded>(iex_used != -1);
+
+ // Do a scan over threads at lower token positions in warp to get the correct index for writing data:
+ int it_compact_add_lower = 0;
+#pragma unroll
+ for (int offset = neu_padded; offset < warp_size; offset += neu_padded) {
+ const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size);
+ if (threadIdx.x >= offset) {
+ it_compact_add_lower += tmp;
+ }
+ }
+
+ if (iex_used != -1) {
+ store[it_compact + it_compact_add_lower] = mmq_ids_helper_store(it, iex_used);
+ }
+
+ // The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads:
+ it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size);
+ }
+ }
+ nex_prev = warp_reduce_sum<warp_size>(nex_prev);
+
+ for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) {
+ const mmq_ids_helper_store store_it = store[itc];
+ const int it = store_it.it();
+ const int iex_used = store_it.iex_used();
+ ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y;
+ ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;
+ }
+
+ if (threadIdx.x != 0) {
+ return;
+ }
+
+ expert_bounds[expert] = nex_prev;
+
+ if (expert < gridDim.x - 1) {
+ return;
+ }
+
+ expert_bounds[gridDim.x] = nex_prev + it_compact;
+}
+
+template <int n_expert_used_template>
+static void launch_mmq_ids_helper(
+ const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
+ const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
+ GGML_ASSERT(n_tokens < (1 << 22) && "too few bits in mmq_ids_helper_store");
+ GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mmq_ids_helper_store");
+
+ const int id = ggml_cuda_get_device();
+ const int warp_size = ggml_cuda_info().devices[id].warp_size;
+ const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
+ CUDA_SET_SHARED_MEMORY_LIMIT(mmq_ids_helper<n_expert_used_template>, smpbo);
+
+ const dim3 num_blocks(n_experts, 1, 1);
+ const dim3 block_size(warp_size, 1, 1);
+ const size_t nbytes_shared = n_tokens*sizeof(mmq_ids_helper_store);
+ GGML_ASSERT(nbytes_shared <= smpbo);
+ mmq_ids_helper<n_expert_used_template><<<num_blocks, block_size, nbytes_shared, stream>>>
+ (ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
+}
+
static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
switch (args.type_x) {
case GGML_TYPE_Q4_0:
ne00, ne01, ne1, s01, ne11, s1,
ne02, ne12, s02, s12, s2,
ne03, ne13, s03, s13, s3,
- use_stream_k};
+ use_stream_k, ne1};
ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
return;
}
const int64_t n_expert_used = ids->ne[0];
const int64_t ne_get_rows = ne12 * n_expert_used;
+ GGML_ASSERT(ne1 == n_expert_used);
- std::vector<char> ids_host(ggml_nbytes(ids));
- std::vector<int32_t> ids_src1_host;
- ids_src1_host.reserve(ne_get_rows);
- std::vector<int32_t> ids_dst_host;
- ids_dst_host.reserve(ne_get_rows);
- std::vector<int32_t> tokens_per_expert_host(ne02);
- std::vector<int32_t> expert_bounds_host(ne02 + 1);
- ggml_cuda_pool_alloc<int32_t> ids_buf_dev(ctx.pool());
-
- CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids->data, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
- CUDA_CHECK(cudaStreamSynchronize(stream));
-
- for (int64_t i02 = 0; i02 < ne02; ++i02) { // expert matrices
- for (int64_t i12 = 0; i12 < ne12; ++i12) { // tokens
- for (int64_t iex = 0; iex < n_expert_used; ++iex) {
- const int32_t expert_to_use = *(const int32_t *)(ids_host.data() + i12*ids->nb[1] + iex*ids->nb[0]);
- assert(expert_to_use >= 0 && expert_to_use < ne02);
- if (expert_to_use == i02) {
- ids_src1_host.push_back(i12*(nb12/nb11) + iex % ne11);
- ids_dst_host.push_back(i12*ne1 + iex);
- tokens_per_expert_host[i02]++;
- break;
- }
- }
- }
- }
+ ggml_cuda_pool_alloc<int32_t> ids_src1(ctx.pool(), ne_get_rows);
+ ggml_cuda_pool_alloc<int32_t> ids_dst(ctx.pool(), ne_get_rows);
+ ggml_cuda_pool_alloc<int32_t> expert_bounds(ctx.pool(), ne02 + 1);
- int32_t cumsum = 0;
- for (int64_t i = 0; i < ne02; ++i) {
- expert_bounds_host[i] = cumsum;
- cumsum += tokens_per_expert_host[i];
+ {
+ GGML_ASSERT(ids->nb[0] == ggml_element_size(ids));
+ const int si1 = ids->nb[1] / ggml_element_size(ids);
+ const int sis1 = nb12 / nb11;
+
+ switch (n_expert_used) {
+ case 2:
+ launch_mmq_ids_helper< 2> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
+ break;
+ case 4:
+ launch_mmq_ids_helper< 4> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
+ break;
+ case 6:
+ launch_mmq_ids_helper< 6> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
+ break;
+ case 8:
+ launch_mmq_ids_helper< 8> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
+ break;
+ case 16:
+ launch_mmq_ids_helper<16> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
+ break;
+ case 32:
+ launch_mmq_ids_helper<32> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
+ break;
+ default:
+ launch_mmq_ids_helper< 0> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
+ break;
+ }
+ CUDA_CHECK(cudaGetLastError());
}
- expert_bounds_host[ne02] = cumsum;
-
- std::vector<int32_t> ids_buf_host;
- ids_buf_host.reserve(ids_src1_host.size() + ids_dst_host.size() + expert_bounds_host.size());
- ids_buf_host.insert(ids_buf_host.end(), ids_src1_host.begin(), ids_src1_host.end());
- ids_buf_host.insert(ids_buf_host.end(), ids_dst_host.begin(), ids_dst_host.end());
- ids_buf_host.insert(ids_buf_host.end(), expert_bounds_host.begin(), expert_bounds_host.end());
- ids_buf_dev.alloc(ids_buf_host.size() + get_mmq_x_max_host(cc)); // Expert bounds are padded on device.
- CUDA_CHECK(cudaMemcpyAsync(ids_buf_dev.ptr, ids_buf_host.data(), ids_buf_host.size()*sizeof(int32_t), cudaMemcpyHostToDevice, stream));
- CUDA_CHECK(cudaStreamSynchronize(stream));
-
- const int32_t * ids_src1_dev = ids_buf_dev.ptr;
- const int32_t * ids_dst_dev = ids_src1_dev + ids_src1_host.size();
- const int32_t * expert_bounds_dev = ids_dst_dev + ids_dst_host.size();
const size_t nbytes_src1_q8_1 = ne12*n_expert_used*ne10_padded * sizeof(block_q8_1)/QK8_1 +
get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
const int64_t s11 = src1->nb[1] / ts_src1;
const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s13 = src1->nb[2] / ts_src1;
- quantize_mmq_q8_1_cuda(src1_d, ids_src1_dev, src1_q8_1.get(), src0->type,
+ quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type,
ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
CUDA_CHECK(cudaGetLastError());
}
// Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.
const mmq_args args = {
- src0_d, src0->type, (const int *) src1_q8_1.ptr, ids_dst_dev, expert_bounds_dev, dst_d,
+ src0_d, src0->type, (const int *) src1_q8_1.get(), ids_dst.get(), expert_bounds.get(), dst_d,
ne00, ne01, ne_get_rows, s01, ne_get_rows, s1,
ne02, ne02, s02, s12, s2,
ne03, ne13, s03, s13, s3,
- use_stream_k};
+ use_stream_k, ne12};
ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
}
ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst,
1, 1, 0, 0, 0,
1, 1, 0, 0, 0,
- use_stream_k};
+ use_stream_k, src1_ncols};
ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup,
const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst,
const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
- const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
+ const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
+ const int ncols_max) {
// Skip unused template specializations for faster compilation:
if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) {
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int mmq_y = get_mmq_y_device();
- const int ntx = (ncols_dst + mmq_x - 1) / mmq_x; // Number of tiles x
+ const int ntx = (ncols_max + mmq_x - 1) / mmq_x; // Number of tiles x
const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y
// Initialize the ids for writing back data with just the index.
static __global__ void mul_mat_q_stream_k_fixup(
const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile,
const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst,
- const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst) {
+ const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst,
+ const int ncols_max) {
constexpr int mmq_y = get_mmq_y_device();
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
- const int ntx = (ncols_dst + mmq_x - 1) / mmq_x;
+ const int ntx = (ncols_max + mmq_x - 1) / mmq_x;
const int nty = (nrows_x + mmq_y - 1) / mmq_y;
const int bidx0 = blockIdx.x;
int64_t ncols_x; int64_t nrows_x; int64_t ncols_dst; int64_t stride_row_x; int64_t ncols_y; int64_t nrows_dst;
int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst;
int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst;
- bool use_stream_k;
+ bool use_stream_k; int64_t ncols_max;
};
template<ggml_type type>
CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, true>), nbytes_shared);
const int nty = (args.nrows_x + mmq_y - 1) / mmq_y;
- const int ntx = (args.ncols_dst + mmq_x - 1) / mmq_x;
+ const int ntx = (args.ncols_max + mmq_x - 1) / mmq_x;
const int ntzw = args.nchannels_y * args.nsamples_y;
const dim3 block_nums_xy_tiling(nty, ntx, ntzw);
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
- sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
+ sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
+ args.ncols_max);
} else {
constexpr bool need_check = true;
mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
- sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
+ sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
+ args.ncols_max);
}
return;
}
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
- sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
+ sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
+ args.ncols_max);
if (!fixup_needed) {
return;
mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
- args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
+ args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
+ args.ncols_max);
} else {
constexpr bool need_check = true;
mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
- sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
+ sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
+ args.ncols_max);
if (!fixup_needed) {
return;
mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
- args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
+ args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
+ args.ncols_max);
}
}
continue;
}
- const int ntiles_x = (args.ncols_y + mmq_x - 1) / mmq_x;
+ const int ntiles_x = (args.ncols_max + mmq_x - 1) / mmq_x;
if (ntiles_x < ntiles_x_best) {
mmq_x_best = mmq_x;