tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
}
-
template <ggml_type type, int mmq_x, bool need_check>
-static __global__ void mul_mat_q_stream_k_fixup(
- const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile,
- const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst,
- const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst,
- const int ncols_max) {
+static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst,
+ const int32_t * expert_bounds,
+ float * __restrict__ dst,
+ const float * __restrict__ tmp_last_tile,
+ const int ncols_x,
+ const int nrows_x,
+ const int ncols_dst,
+ const size_t stride_col_dst,
+ const int nchannels_y,
+ const size_t stride_channel_dst,
+ const int nsamples_y,
+ const size_t stride_sample_dst,
+ const int ncols_max) {
constexpr int mmq_y = get_mmq_y_device();
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int ITER_K = get_iter_k(type);