#include <cstdint>
#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
+#define MMQ_NWARPS 8
typedef void (*load_tiles_mmq_t)(
const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
typedef void (*vec_dot_mmq_t)(
const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0);
-typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1);
+typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max);
struct block_q8_1_mmq {
half2 ds[4];
// get_mmq_y_host is in common.cuh so that it can be used to determine the correct way to round for --split-mode row
+static constexpr __device__ int get_mmq_y_device() {
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-static constexpr __device__ int get_mmq_y_device(int mmq_x) {
- return mmq_x >= 32 ? 128 : 64;
-}
+ return 128;
#else
#if __CUDA_ARCH__ >= CC_VOLTA
-static constexpr __device__ int get_mmq_y_device(int mmq_x) {
- return mmq_x >= 32 ? 128 : 64;
-}
+ return 128;
#else
-static constexpr __device__ int get_mmq_y_device(int /*mmq_x*/) {
return 64;
-}
#endif // __CUDA_ARCH__ >= CC_VOLTA
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+}
#define TILE_X_SIZES_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
#define TILE_X_SIZES_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0}
}
template<int mmq_x, int mmq_y, int nwarps, bool need_check>
-static __device__ __forceinline__ void mmq_write_back_dp4a(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) {
+static __device__ __forceinline__ void mmq_write_back_dp4a(
+ const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
+
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
- const int j = blockIdx.y*mmq_x + j0 + threadIdx.y;
+ const int j = j0 + threadIdx.y;
- if (j >= ne1) {
+ if (j > j_max) {
return;
}
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
- const int i = blockIdx.x*mmq_y + i0 + threadIdx.x;
+ const int i = i0 + threadIdx.x;
- if (need_check && i >= ne0) {
+ if (need_check && i > i_max) {
continue;
}
- dst[j*ne0 + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
+ dst[j*stride + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
}
}
}
template<int mmq_x, int mmq_y, int nwarps, bool need_check>
-static __device__ __forceinline__ void mmq_write_back_mma(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) {
+static __device__ __forceinline__ void mmq_write_back_mma(
+ const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
+
typedef mma_int_C_I16J8 mma_C;
const int i0 = threadIdx.y*mma_C::I;
for (int j0 = 0; j0 < mmq_x; j0 += mma_C::J) {
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
- const int j = blockIdx.y*mmq_x + j0 + mma_C::get_j(l);
+ const int j = j0 + mma_C::get_j(l);
- if (j >= ne1) {
+ if (j > j_max) {
continue;
}
- const int i = blockIdx.x*mmq_y + i0 + mma_C::get_i(l);
+ const int i = i0 + mma_C::get_i(l);
- if (need_check && i >= ne0) {
+ if (need_check && i > i_max) {
continue;
}
- dst[j*ne0 + i] = sum[(j0/mma_C::J)*mma_C::ne + l];
+ dst[j*stride + i] = sum[(j0/mma_C::J)*mma_C::ne + l];
}
}
}
return false;
}
-template <ggml_type type, int mmq_x, int nwarps, bool need_check>
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-#if defined(RDNA3) || defined(RDNA2)
- __launch_bounds__(WARP_SIZE*nwarps, 2)
-#endif // defined(RDNA3) || defined(RDNA2)
-#else
-#if __CUDA_ARCH__ >= CC_VOLTA
- __launch_bounds__(WARP_SIZE*nwarps, 1)
-#else
- __launch_bounds__(WARP_SIZE*nwarps, 2)
-#endif // __CUDA_ARCH__ >= CC_VOLTA
-#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-static __global__ void mul_mat_q(
- const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst,
- const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) {
-
- // Skip unused template specializations for faster compilation:
- if (mmq_x > get_mmq_x_max_device()) {
- NO_DEVICE_CODE;
- return;
- }
+template <ggml_type type, int mmq_x, int nwarps, bool need_check, bool fixup>
+static __device__ void mul_mat_q_process_tile(
+ const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
+ const int & ne00, const int & ne01, const int & stride01, const int & ne10, const int & ne11, const int & stride11, const int & ne0,
+ const int & it, const int & jt, const int & kb0_start, const int & kb0_stop) {
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qr = ggml_cuda_type_traits<type>::qr;
constexpr int qi = ggml_cuda_type_traits<type>::qi;
- constexpr int mmq_y = get_mmq_y_device(mmq_x);
+ constexpr int mmq_y = get_mmq_y_device();
constexpr int vdr = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
int * tile_x_sc = (int *) (tile_x_dm + txs.dm);
int * tile_y = (int *) (tile_x_sc + txs.sc); // [mmq_x * (WARP_SIZE + WARP_SIZE/QI8_1)]
- const int blocks_per_row_x = ne00 / qk;
- const int blocks_per_warp = WARP_SIZE / qi;
+ constexpr int blocks_per_warp = WARP_SIZE / qi;
- const int & ne1 = ne11;
-
- const int tile_x_max_i = ne01 - blockIdx.x*mmq_y - 1;
+ float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
- const int * y = (const int *) yc + blockIdx.y*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
+ const int tile_x_max_i = ne01 - it*mmq_y - 1;
+ const int tile_y_max_j = ne11 - jt*mmq_x - 1;
- float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
+ const int * y = (const int *) yc + jt*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
- for (int kb0 = 0; kb0 < blocks_per_row_x; kb0 += blocks_per_warp) {
+ for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_warp) {
- load_tiles(x, tile_x_qs, tile_x_dm, tile_x_sc, stride01*blockIdx.x*mmq_y + kb0, tile_x_max_i, stride01);
+ load_tiles(x, tile_x_qs, tile_x_dm, tile_x_sc, stride01*it*mmq_y + kb0, tile_x_max_i, stride01);
#pragma unroll
for (int kr = 0; kr < qr; ++kr) {
}
}
- write_back(sum, dst, ne0, ne1);
+ if (fixup) {
+ write_back(sum, tmp_fixup + blockIdx.x*(mmq_x*mmq_y), mmq_y, mmq_y, mmq_x);
+ } else {
+ write_back(sum, dst + jt*mmq_x*ne0 + it*mmq_y, ne0, tile_x_max_i, tile_y_max_j);
+ }
+}
+
+
+// The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598
+
+template <ggml_type type, int mmq_x, int nwarps, bool need_check>
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+ __launch_bounds__(WARP_SIZE*nwarps, 2)
+#endif // defined(RDNA3) || defined(RDNA2)
+#else
+#if __CUDA_ARCH__ >= CC_VOLTA
+ __launch_bounds__(WARP_SIZE*nwarps, 1)
+#else
+ __launch_bounds__(WARP_SIZE*nwarps, 2)
+#endif // __CUDA_ARCH__ >= CC_VOLTA
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+static __global__ void mul_mat_q(
+ const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
+ const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) {
+
+ // Skip unused template specializations for faster compilation:
+ if (mmq_x > get_mmq_x_max_device()) {
+ NO_DEVICE_CODE;
+ return;
+ }
+
+ constexpr int qk = ggml_cuda_type_traits<type>::qk;
+ constexpr int qi = ggml_cuda_type_traits<type>::qi;
+ constexpr int mmq_y = get_mmq_y_device();
+
+ // On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
+#if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA
+ {
+ constexpr bool fixup = false;
+ mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
+ (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
+ blockIdx.x, blockIdx.y, 0, ne00/qk);
+ return;
+ }
+#endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA
+
+ const int64_t blocks_per_ne00 = ne00 / qk;
+ constexpr int blocks_per_warp = WARP_SIZE / qi;
+
+ const int ntx = (ne11 + mmq_x - 1) / mmq_x; // Number of tiles x
+ const int nty = (ne01 + mmq_y - 1) / mmq_y; // Number of tiles y
+
+ // kbc == k block continuous, current index in continuous ijk space.
+ int64_t kbc = GGML_PAD((int64_t) blockIdx.x *blocks_per_ne00*ntx*nty / gridDim.x, blocks_per_warp);
+ const int64_t kbc_stop = GGML_PAD((int64_t)(blockIdx.x + 1)*blocks_per_ne00*ntx*nty / gridDim.x, blocks_per_warp);
+
+ // kb0 == k index when doing the matrix multiplication for an output tile.
+ int kb0_start = kbc % blocks_per_ne00;
+ int kb0_stop = min(blocks_per_ne00, kb0_start + kbc_stop - kbc);
+ while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) {
+ const int jt = kbc / (blocks_per_ne00*nty); // j index of current tile.
+ const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00; // i index of current tile.
+
+ constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
+ mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
+ (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
+ it, jt, kb0_start, kb0_stop);
+
+ kbc += blocks_per_ne00;
+ kbc -= kbc % blocks_per_ne00;
+
+ kb0_start = 0;
+ kb0_stop = min(blocks_per_ne00, kbc_stop - kbc);
+ }
+
+ if (kbc >= kbc_stop) {
+ return;
+ }
+
+ const int jt = kbc / (blocks_per_ne00*nty);
+ const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
+
+ constexpr bool fixup = true; // Last index writes it data to fixup buffer to avoid data races with other blocks.
+ mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
+ (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
+ it, jt, kb0_start, kb0_stop);
+}
+
+
+template <ggml_type type, int mmq_x, int nwarps, bool need_check>
+static __global__ void mul_mat_q_stream_k_fixup(
+ float * __restrict__ dst, const float * __restrict__ tmp_last_tile, const int ne00, const int ne01, const int ne11, const int ne0, const int block_num_mmq) {
+
+ constexpr int mmq_y = get_mmq_y_device();
+ constexpr int qk = ggml_cuda_type_traits<type>::qk;
+ constexpr int qi = ggml_cuda_type_traits<type>::qi;
+ constexpr int blocks_per_warp = WARP_SIZE / qi;
+ const int64_t blocks_per_ne00 = ne00 / qk;
+
+ float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
+
+ const int ntx = (ne11 + mmq_x - 1) / mmq_x;
+ const int nty = (ne01 + mmq_y - 1) / mmq_y;
+
+ bool any_fixup = false;
+
+ const int bidx_start = (blockIdx.y*nty + blockIdx.x) * block_num_mmq / (gridDim.y*gridDim.x);
+ const int bidx_stop = (blockIdx.y*nty + blockIdx.x + 1) * block_num_mmq / (gridDim.y*gridDim.x) + 1;
+
+ for (int bidx = bidx_start; bidx < bidx_stop; ++bidx) {
+ const int64_t kbc = GGML_PAD((int64_t) bidx *blocks_per_ne00*ntx*nty / block_num_mmq, blocks_per_warp);
+ const int64_t kbc_stop = GGML_PAD((int64_t)(bidx + 1)*blocks_per_ne00*ntx*nty / block_num_mmq, blocks_per_warp);
+
+ // Skip fixup tile if the MMQ CUDA block never wrote anything to it:
+ if (kbc == kbc_stop || kbc_stop % blocks_per_ne00 == 0) {
+ continue;
+ }
+
+ const int jt = kbc_stop / (blocks_per_ne00*nty);
+ const int it = (kbc_stop - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
+
+ // Skip fixup tile if it's unrelated to the output tile assigned to this CUDA block:
+ if (it != blockIdx.x || jt != blockIdx.y) {
+ continue;
+ }
+
+ any_fixup = true;
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
+
+ sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
+ }
+ }
+ }
+
+ if (!any_fixup) {
+ return;
+ }
+
+ dst += blockIdx.y*mmq_x*ne0 + blockIdx.x*mmq_y;
+
+ const int i_max = ne01 - blockIdx.x*mmq_y - 1;
+ const int j_max = ne11 - blockIdx.y*mmq_x - 1;
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+ if (j > j_max) {
+ return;
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
+
+ if (need_check && i > i_max) {
+ continue;
+ }
+
+ dst[j*ne0 + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
+ }
+ }
}
struct mmq_args {
int64_t ne0;
};
-constexpr int mmq_get_nwarps(int mmq_x) {
- return mmq_x >= 32 ? 8 : 4;
-}
-
static int mmq_get_shmem(const ggml_type type, const int mmq_x, const int mmq_y) {
const tile_x_sizes txs = get_tile_x_sizes_host(type, mmq_y);
- const int nwarps = mmq_get_nwarps(mmq_x);
const int shmem_x = txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
const int shmem_y = mmq_x*WARP_SIZE*sizeof(int) + mmq_x*(WARP_SIZE/QI8_1)*sizeof(half2);
- return shmem_x + GGML_PAD(shmem_y, nwarps*WARP_SIZE*sizeof(int));
+ return shmem_x + GGML_PAD(shmem_y, MMQ_NWARPS*WARP_SIZE*sizeof(int));
}
-template <ggml_type type, int mmq_x, int nwarps>
-static void launch_mul_mat_q(const mmq_args & args, cudaStream_t stream) {
+template <ggml_type type, int mmq_x>
+static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
const int id = ggml_cuda_get_device();
const int cc = ggml_cuda_info().devices[id].cc;
- const int mmq_y = get_mmq_y_host(cc, mmq_x);
+ const int nsm = ggml_cuda_info().devices[id].nsm;
+ const int mmq_y = get_mmq_y_host(cc);
- const int block_num_x = (args.ne01 + mmq_y - 1) / mmq_y;
- const int block_num_y = (args.ne11 + mmq_x - 1) / mmq_x;
- const dim3 block_nums(block_num_x, block_num_y, 1);
- const dim3 block_dims(WARP_SIZE, nwarps, 1);
+ const dim3 block_dims(WARP_SIZE, MMQ_NWARPS, 1);
const int shmem = mmq_get_shmem(type, mmq_x, mmq_y);
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
if (!shmem_limit_raised[id]) {
- CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, nwarps, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
- CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, nwarps, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
+ CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
+ CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
shmem_limit_raised[id] = true;
}
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+ const int nty = (args.ne01 + mmq_y - 1) / mmq_y;
+ const int ntx = (args.ne11 + mmq_x - 1) / mmq_x;
+ const dim3 block_nums_xy_tiling(nty, ntx, 1);
+
+ const bool use_stream_k = cc >= CC_VOLTA && cc < CC_OFFSET_AMD;
+ if (!use_stream_k) {
+ if (args.ne01 % mmq_y == 0) {
+ constexpr bool need_check = false;
+ mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, shmem, stream>>>
+ (args.x, args.y, args.dst, nullptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
+ } else {
+ constexpr bool need_check = true;
+ mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, shmem, stream>>>
+ (args.x, args.y, args.dst, nullptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
+ }
+ return;
+ }
+
+ const dim3 block_nums_mmq(nsm, 1, 1);
+
+ ggml_cuda_pool & pool = ctx.pool();
+ ggml_cuda_pool_alloc<float> tmp_fixup(pool, block_nums_mmq.x * mmq_x*mmq_y);
+
if (args.ne01 % mmq_y == 0) {
- const bool need_check = false;
- mul_mat_q<type, mmq_x, nwarps, need_check><<<block_nums, block_dims, shmem, stream>>>
- (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
+ constexpr bool need_check = false;
+
+ mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_mmq, block_dims, shmem, stream>>>
+ (args.x, args.y, args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
+
+ mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, 0, stream>>>
+ (args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x);
} else {
- const bool need_check = true;
- mul_mat_q<type, mmq_x, nwarps, need_check><<<block_nums, block_dims, shmem, stream>>>
- (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
+ constexpr bool need_check = true;
+
+ mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_mmq, block_dims, shmem, stream>>>
+ (args.x, args.y, args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
+
+ mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, 0, stream>>>
+ (args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x);
}
}
template <ggml_type type>
-void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) {
+void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
const int id = ggml_cuda_get_device();
const int nsm = ggml_cuda_info().devices[id].nsm;
const int cc = ggml_cuda_info().devices[id].cc;
const int smpbo = ggml_cuda_info().devices[id].smpbo;
const int mmq_x_max = get_mmq_x_max_host(cc);
- const int mmq_y = get_mmq_y_host(cc, mmq_x_max);
+ const int mmq_y = get_mmq_y_host(cc);
const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
+ const bool use_stream_k = cc >= CC_VOLTA && cc < CC_OFFSET_AMD;
int mmq_x_best = 0;
- int nwaves_best = INT_MAX;
+ int nparts_best = INT_MAX;
+
+ for (int mmq_x = 8; mmq_x <= mmq_x_max && nparts_best > 1; mmq_x += 8) {
+ const int ntiles_x = (args.ne11 + mmq_x - 1) / mmq_x;
+ const int nwaves_xy_tiling = ntiles_x*block_num_y;
- for (int mmq_x = 8; mmq_x <= mmq_x_max && nwaves_best > 1; mmq_x += 8) {
- const int block_num_x = (args.ne11 + mmq_x - 1) / mmq_x;
- const int nwaves = (block_num_x*block_num_y + nsm - 1) / nsm;
+ const int nparts = use_stream_k ? ntiles_x : nwaves_xy_tiling;
- if (nwaves < nwaves_best && mmq_get_shmem(type, mmq_x, mmq_y) <= smpbo) {
+ if (nparts < nparts_best && mmq_get_shmem(type, mmq_x, mmq_y) <= smpbo) {
mmq_x_best = mmq_x;
- nwaves_best = nwaves;
+ nparts_best = nparts;
}
}
switch (mmq_x_best) {
case 8:
- launch_mul_mat_q<type, 8, mmq_get_nwarps( 8)>(args, stream);
+ launch_mul_mat_q<type, 8>(ctx, args, stream);
break;
case 16:
- launch_mul_mat_q<type, 16, mmq_get_nwarps( 16)>(args, stream);
+ launch_mul_mat_q<type, 16>(ctx, args, stream);
break;
case 24:
- launch_mul_mat_q<type, 24, mmq_get_nwarps( 24)>(args, stream);
+ launch_mul_mat_q<type, 24>(ctx, args, stream);
break;
case 32:
- launch_mul_mat_q<type, 32, mmq_get_nwarps( 32)>(args, stream);
+ launch_mul_mat_q<type, 32>(ctx, args, stream);
break;
case 40:
- launch_mul_mat_q<type, 40, mmq_get_nwarps( 40)>(args, stream);
+ launch_mul_mat_q<type, 40>(ctx, args, stream);
break;
case 48:
- launch_mul_mat_q<type, 48, mmq_get_nwarps( 48)>(args, stream);
+ launch_mul_mat_q<type, 48>(ctx, args, stream);
break;
case 56:
- launch_mul_mat_q<type, 56, mmq_get_nwarps( 56)>(args, stream);
+ launch_mul_mat_q<type, 56>(ctx, args, stream);
break;
case 64:
- launch_mul_mat_q<type, 64, mmq_get_nwarps( 64)>(args, stream);
+ launch_mul_mat_q<type, 64>(ctx, args, stream);
break;
case 72:
- launch_mul_mat_q<type, 72, mmq_get_nwarps( 72)>(args, stream);
+ launch_mul_mat_q<type, 72>(ctx, args, stream);
break;
case 80:
- launch_mul_mat_q<type, 80, mmq_get_nwarps( 80)>(args, stream);
+ launch_mul_mat_q<type, 80>(ctx, args, stream);
break;
case 88:
- launch_mul_mat_q<type, 88, mmq_get_nwarps( 88)>(args, stream);
+ launch_mul_mat_q<type, 88>(ctx, args, stream);
break;
case 96:
- launch_mul_mat_q<type, 96, mmq_get_nwarps( 96)>(args, stream);
+ launch_mul_mat_q<type, 96>(ctx, args, stream);
break;
case 104:
- launch_mul_mat_q<type, 104, mmq_get_nwarps(104)>(args, stream);
+ launch_mul_mat_q<type, 104>(ctx, args, stream);
break;
case 112:
- launch_mul_mat_q<type, 112, mmq_get_nwarps(112)>(args, stream);
+ launch_mul_mat_q<type, 112>(ctx, args, stream);
break;
case 120:
- launch_mul_mat_q<type, 120, mmq_get_nwarps(120)>(args, stream);
+ launch_mul_mat_q<type, 120>(ctx, args, stream);
break;
case 128:
- launch_mul_mat_q<type, 128, mmq_get_nwarps(128)>(args, stream);
+ launch_mul_mat_q<type, 128>(ctx, args, stream);
break;
default:
fprintf(stderr, "mmq_x_best=%d\n", mmq_x_best);
}
#define DECL_MMQ_CASE(type) \
- template void mul_mat_q_case<type>(const mmq_args & args, cudaStream_t stream) \
+ template void mul_mat_q_case<type>(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) \
extern DECL_MMQ_CASE(GGML_TYPE_Q4_0);
extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);