};
static int get_mmq_x_max_host(const int cc) {
- return new_mma_available(cc) ? 128 :
+ return (amd_mfma_available(cc) || new_mma_available(cc)) ? 128 :
GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ?
#ifdef GGML_CUDA_FORCE_MMQ
128 : 64;
}
static constexpr __device__ int get_mmq_x_max_device() {
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
return 128;
-#else // NEW_MMA_AVAILABLE
+#else // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
- return 128;
+ return 64;
#else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
return MMQ_DP4A_MAX_BATCH_SIZE;
#endif // GGML_CUDA_FORCE_MMQ
#else // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
-
return 64;
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
}
static int get_mmq_y_host(const int cc) {
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
}
-#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
-#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0}
-#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_0 + mmq_y/(QI8_0/2), 0}
-#define MMQ_DP4A_TXS_Q8_0_16 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*4/QI8_0 + mmq_y/(QI8_0/4), 0}
-#define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_1 + mmq_y/(QI8_1/2), 0}
-#define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE + mmq_y, 0}
-#define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y, mmq_y*WARP_SIZE/8 + mmq_y/8}
-#define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
-#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
-#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
+// Decouple shared memory tile sizes from WARP_SIZE to allow for different warp sizes.
+// The K dimension of the tiles has either,
+// 1*MMQ_TILE_NE_K==32 (always for TILE_Y_K) or 2*MMQ_TILE_NE_K==64 (typically for TILE_X_K),
+// 32 bit elements for the quantized data (does not include scales).
+// In other words, the size of the quantized data in the K dimension is a multiple of MMQ_TILE_NE_K.
+// The final tile size in K direction is padded to avoid shared memory bank conflicts,
+// in terms of 32 bit elements that means K % 2 == 1 for dp4a or K % 8 == 4 for mma.
+#define MMQ_TILE_NE_K 32
+
+#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_0 + mmq_y/QI4_0, 0}
+#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_1 + mmq_y/QI4_1, 0}
+#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_0 + mmq_y/(QI8_0/2), 0}
+#define MMQ_DP4A_TXS_Q8_0_16 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*4/QI8_0 + mmq_y/(QI8_0/4), 0}
+#define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_1 + mmq_y/(QI8_1/2), 0}
+#define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K + mmq_y, 0}
+#define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
+#define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
+#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI5_K + mmq_y/QI5_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
+#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI6_K + mmq_y/QI6_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
switch (type) {
}
}
-#define MMQ_MMA_TILE_X_K_Q8_0 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4)
-#define MMQ_MMA_TILE_X_K_Q8_1 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4)
-#define MMQ_MMA_TILE_X_K_Q2_K (2*WARP_SIZE + WARP_SIZE + 4)
-#define MMQ_MMA_TILE_X_K_Q3_K (2*WARP_SIZE + WARP_SIZE/2 + 4)
-#define MMQ_MMA_TILE_X_K_Q6_K (2*WARP_SIZE + WARP_SIZE/QI6_K + WARP_SIZE/8 + 7)
+#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
+#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
+#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
+#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
+#define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7)
static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
}
}
-#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
+// block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit scales)
+#define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI8_1)
static int mmq_get_granularity_host(const int mmq_x, const int cc) {
- return new_mma_available(cc) && mmq_x >= 48 ? 16 : 8;
+ if (amd_mfma_available(cc)) {
+ return mmq_x >= 128 ? 32 : 16;
+ } else if (new_mma_available(cc) && mmq_x >= 48) {
+ return 16;
+ } else {
+ return 8;
+ }
}
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE)
+static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
+ return mmq_x >= 128 ? 32 : 16;
+}
+#elif defined(NEW_MMA_AVAILABLE)
static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
return mmq_x >= 48 ? 16 : 8;
}
#else
-static constexpr __device__ int mmq_get_granularity_device(const int /* mmq_x */) {
+static constexpr __device__ int mmq_get_granularity_device(const int /*mmq_x*/) {
return 8;
}
-#endif // NEW_MMA_AVAILABLE
+#endif // AMD_MFMA_AVAILABLE
+
+#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+static int mmq_get_nwarps_host(const int cc) {
+ return amd_mfma_available(cc) ? 8 : 4;
+}
+#else
+static int mmq_get_nwarps_host(const int /*cc*/) {
+ return 8;
+}
+#endif // (GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+
+static constexpr __device__ int mmq_get_nwarps_device() {
+#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+#if defined(AMD_MFMA_AVAILABLE)
+ return 8;
+#else
+ return 4;
+#endif // AMD_MFMA_AVAILABLE
+#else
+ return 8;
+#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+}
// ------------------------------------------------------------
-template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
+template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
int * x_qs = (int *) x_tile;
- float * x_df = (float *) (x_qs + 2*WARP_SIZE);
+ float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
- const int kbx = threadIdx.x / QI4_0;
- const int kqsx = threadIdx.x % QI4_0;
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_0);
+ constexpr int nrows = warp_size / threads_per_row;
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
+ const int kbx = txi / QI4_0;
+ const int kqsx = txi % QI4_0;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
- int i = i0 + threadIdx.y;
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
if (need_check) {
i = min(i, i_max);
const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
const int qs0 = get_int_b2(bxi->qs, kqsx);
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808);
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808);
#else
- x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
-#endif // NEW_MMA_AVAILABLE
+ x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
}
- const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
+ constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_0;
+ constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) {
- int i = i0 + threadIdx.y * QI4_0 + threadIdx.x / blocks_per_tile_x_row;
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
+ int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
if (need_check) {
i = min(i, i_max);
const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
-#ifdef NEW_MMA_AVAILABLE
- x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
#else
- x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
-#endif // NEW_MMA_AVAILABLE
+ x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
}
}
-template <int mmq_x, int mmq_y, int nwarps>
+template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
const int * x_qs = (const int *) x;
const half2 * y_ds = (const half2 *) y;
// #pragma unroll
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_0*VDR_Q4_0_Q8_1_MMQ) {
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_0*VDR_Q4_0_Q8_1_MMQ) {
const int k0 = k00 + k01;
#pragma unroll
const int j = j0 + threadIdx.y;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
const int i = i0 + threadIdx.x;
const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)];
}
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
- (&x_qs[i*(WARP_SIZE + 1) + k0/QR4_0], u,
- x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
+ (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_0], u,
+ x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
}
}
}
}
-template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
+template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
int * x_qs = (int *) x_tile;
- half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
+ half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
int * x_qs = (int *) x_tile;
half2 * x_dm = (half2 *) (x_qs + txs.qs);
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
- const int kbx = threadIdx.x / QI4_1;
- const int kqsx = threadIdx.x % QI4_1;
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_1);
+ constexpr int nrows = warp_size / threads_per_row;
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
+ const int kbx = txi / QI4_1;
+ const int kqsx = txi % QI4_1;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
- int i = i0 + threadIdx.y;
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
if (need_check) {
i = min(i, i_max);
const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
const int qs0 = get_int_b4(bxi->qs, kqsx);
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F;
#else
- x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
-#endif // NEW_MMA_AVAILABLE
+ x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
}
- const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
+ constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_1;
+ constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) {
- int i = i0 + threadIdx.y * QI4_1 + threadIdx.x / blocks_per_tile_x_row;
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
+ int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
if (need_check) {
i = min(i, i_max);
const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
-#ifdef NEW_MMA_AVAILABLE
- x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+ x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
#else
- x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
-#endif // NEW_MMA_AVAILABLE
+ x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
}
}
-template <int mmq_x, int mmq_y, int nwarps>
+template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
const int * x_qs = (const int *) x;
const half2 * y_ds = (const half2 *) y;
// #pragma unroll
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_1*VDR_Q4_1_Q8_1_MMQ) {
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_1*VDR_Q4_1_Q8_1_MMQ) {
const int k0 = k00 + k01;
#pragma unroll
const int j = j0 + threadIdx.y;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
const int i = i0 + threadIdx.x;
const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)];
}
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
- (&x_qs[i*(WARP_SIZE + 1) + k0/QR4_1], u,
- x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
+ (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_1], u,
+ x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
}
}
}
}
-template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
+template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
int * x_qs = (int *) x_tile;
- float * x_df = (float *) (x_qs + WARP_SIZE*2);
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
- const int kbx = threadIdx.x / QI5_0;
- const int kqsx = threadIdx.x % QI5_0;
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_0);
+ constexpr int nrows = warp_size / threads_per_row;
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
+ const int kbx = txi / QI5_0;
+ const int kqsx = txi % QI5_0;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
- int i = i0 + threadIdx.y;
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
if (need_check) {
i = min(i, i_max);
const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx;
const int ql = get_int_b2(bxi->qs, kqsx);
- const int qh = get_int_b2(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_0));
+ const int qh = get_int_b2(bxi->qh, 0) >> (4 * kqsx);
int qs0 = (ql >> 0) & 0x0F0F0F0F;
qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
#else
- x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0;
- x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
-#endif // NEW_MMA_AVAILABLE
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0;
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
}
- const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
+ constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_0;
+ constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) {
- int i = i0 + threadIdx.y * QI5_0 + threadIdx.x / blocks_per_tile_x_row;
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
+ int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
if (need_check) {
i = min(i, i_max);
const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
-#ifdef NEW_MMA_AVAILABLE
- x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
#else
- x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
-#endif // NEW_MMA_AVAILABLE
+ x_df[i*(MMQ_TILE_NE_K/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
}
}
-template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
+template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
int * x_qs = (int *) x_tile;
- half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
+ half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
int * x_qs = (int *) x_tile;
half2 * x_dm = (half2 *) (x_qs + txs.qs);
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
- const int kbx = threadIdx.x / QI5_1;
- const int kqsx = threadIdx.x % QI5_1;
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_1);
+ constexpr int nrows = warp_size / threads_per_row;
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
+ const int kbx = txi / QI5_1;
+ const int kqsx = txi % QI5_1;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
- int i = i0 + threadIdx.y;
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
if (need_check) {
i = min(i, i_max);
const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx;
const int ql = get_int_b4(bxi->qs, kqsx);
- const int qh = get_int_b4(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_1));
+ const int qh = get_int_b4(bxi->qh, 0) >> (4 * kqsx);
int qs0 = (ql >> 0) & 0x0F0F0F0F;
qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
#else
- x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0;
- x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
-#endif // NEW_MMA_AVAILABLE
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0;
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
}
- const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
+ constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_1;
+ constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) {
- int i = i0 + threadIdx.y * QI5_1 + threadIdx.x / blocks_per_tile_x_row;
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
+ int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
if (need_check) {
i = min(i, i_max);
const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
-#ifdef NEW_MMA_AVAILABLE
- x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+ x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
#else
- x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
-#endif // NEW_MMA_AVAILABLE
+ x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
}
}
-template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
+template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
int * x_qs = (int *) x_tile;
- float * x_df = (float *) (x_tile + 2*WARP_SIZE);
+ float * x_df = (float *) (x_tile + 2*MMQ_TILE_NE_K);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
- const int kbx = threadIdx.x / QI8_0;
- const int kqsx = threadIdx.x % QI8_0;
+ // MMQ_ITER_K / (4 * QR8_0) == 64 required. but NV has only 32 threads per warp
+ constexpr int threads_per_row = 32;
+ constexpr int nrows = warp_size / threads_per_row;
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
+ const int kbx = txi / QI8_0;
+ const int kqsx = txi % QI8_0;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
- int i = i0 + threadIdx.y;
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
if (need_check) {
i = min(i, i_max);
const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
-#ifdef NEW_MMA_AVAILABLE
- x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx);
- x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx);
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + txi] = get_int_b2(bxi[0].qs, kqsx);
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
#else
- x_qs[i*(2*WARP_SIZE + 1) + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx);
- x_qs[i*(2*WARP_SIZE + 1) + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx);
-#endif // NEW_MMA_AVAILABLE
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 0 + txi] = get_int_b2(bxi[0].qs, kqsx);
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
}
- const int blocks_per_tile_x_row = 2*WARP_SIZE / QI8_0;
+ constexpr int blocks_per_tile_x_row = 2*MMQ_TILE_NE_K / QI8_0;
+ constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0/2) {
- int i = i0 + threadIdx.y * (QI8_0/2) + threadIdx.x / blocks_per_tile_x_row;
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
+ int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
if (need_check) {
i = min(i, i_max);
const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
-#ifdef NEW_MMA_AVAILABLE
- x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
#else
- x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;
-#endif // NEW_MMA_AVAILABLE
+ x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
}
}
-template <int mmq_x, int mmq_y, int nwarps>
+template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
const int * x_qs = (const int *) x;
const float * y_df = (const float *) y;
// #pragma unroll
- for (int k01 = 0; k01 < WARP_SIZE; k01 += VDR_Q8_0_Q8_1_MMQ) {
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) {
const int k0 = k00 + k01;
#pragma unroll
const int j = j0 + threadIdx.y;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
const int i = i0 + threadIdx.x;
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
- (&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0 % WARP_SIZE],
- x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)]);
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
+ (&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0 % MMQ_TILE_NE_K],
+ x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (MMQ_TILE_NE_K/QI8_1)]);
}
}
}
}
-template <int mmq_x, int mmq_y, int nwarps, mmq_q8_1_ds_layout ds_layout>
+template <int mmq_x, int mmq_y, mmq_q8_1_ds_layout ds_layout>
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+#if defined(AMD_MFMA_AVAILABLE)
+ typedef tile<16, 8, int> tile_A;
+ typedef tile<16, 8, int> tile_B;
+ typedef tile<16, 16, int> tile_C;
+
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
+ constexpr int rows_per_warp = granularity;
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
+
+ const int * x_qs = (const int *) x;
+ const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K;
+ const int * y_qs = (const int *) y + 4;
+ const float * y_df = (const float *) y;
+ const half2 * y_ds = (const half2 *) y;
+
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
+
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
+ const int k0 = k00 + k01;
+
+ tile_A A[ntx];
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
+ }
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
+ tile_B B;
+ load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
+
+ float dB;
+ const int j = j0 + tile_C::get_j(0);
+ if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
+ dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
+ } else {
+ dB = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+ }
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ tile_C C;
+ mma(C, A[n], B);
+
+#pragma unroll
+ for (int l = 0; l < tile_C::ne; ++l) {
+ const int i = i0 + n*tile_A::I + tile_C::get_i(l);
+ const float dA = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA*dB;
+ }
+ }
+ }
+ }
+#else
typedef tile<16, 8, int> tile_A;
typedef tile< 8, 8, int> tile_B;
typedef tile<16, 8, int> tile_C;
constexpr int rows_per_warp = 2 * granularity;
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
- y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
const int * x_qs = (const int *) x;
- const float * x_df = (const float *) x_qs + 2*WARP_SIZE;
+ const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K;
const int * y_qs = (const int *) y + 4;
const float * y_df = (const float *) y;
const half2 * y_ds = (const half2 *) y;
- tile_A A[ntx][WARP_SIZE/QI8_0];
- float dA[ntx][tile_C::ne/2][WARP_SIZE/QI8_0];
+ tile_A A[ntx][MMQ_TILE_NE_K/QI8_0];
+ float dA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_0];
const int i0 = (threadIdx.y/ntx)*rows_per_warp;
#pragma unroll
for (int n = 0; n < ntx; ++n) {
#pragma unroll
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
const int k0 = k00 + k01;
load_ldmatrix(A[n][k01/QI8_0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
#pragma unroll
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
const int k0 = k00 + k01;
dA[n][l][k01/QI8_0] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
#pragma unroll
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
tile_B B;
float dB[tile_C::ne/2];
}
}
}
+#endif // defined(AMD_MFMA_AVAILABLE)
}
-template <int mmq_x, int mmq_y, int nwarps>
+template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
const int * x_qs = (const int *) x;
const half2 * y_ds = (const half2 *) y;
// #pragma unroll
- for (int k01 = 0; k01 < WARP_SIZE; k01 += VDR_Q8_0_Q8_1_MMQ) {
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) {
const int k0 = k00 + k01;
#pragma unroll
const int j = j0 + threadIdx.y;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
const int i = i0 + threadIdx.x;
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
- (&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
- x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI8_1], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
+ (&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
+ x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + k0/QI8_1], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
}
}
}
}
-template <int mmq_x, int mmq_y, int nwarps>
+template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+#if defined(AMD_MFMA_AVAILABLE)
+ typedef tile<16, 8, int> tile_A;
+ typedef tile<16, 8, int> tile_B;
+ typedef tile<16, 16, int> tile_C;
- typedef tile<16, 8, int> tile_A;
- typedef tile< 8, 8, int> tile_B;
- typedef tile<16, 8, int> tile_C;
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
+ constexpr int rows_per_warp = granularity;
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
+
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
+
+ const int * x_qs = (const int *) x;
+ const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K;
+ const int * y_qs = (const int *) y + 4;
+ const half2 * y_dm = (const half2 *) y;
+
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
+
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
+ const int k0 = k00 + k01;
+
+ tile_A A[ntx];
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
+ }
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
+ tile_B B;
+ load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
+
+ const int j = j0 + tile_C::get_j(0);
+ const float2 dsB = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ tile_C C;
+ mma(C, A[n], B);
+
+#pragma unroll
+ for (int l = 0; l < tile_C::ne; ++l) {
+ const int i = i0 + n*tile_A::I + tile_C::get_i(l);
+ float2 dmA = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]);
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.x*dsB.x*C.x[l];
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.y*dsB.y;
+ }
+ }
+ }
+ }
+#else
+ typedef tile<16, 8, int> tile_A;
+ typedef tile< 8, 8, int> tile_B;
+ typedef tile<16, 8, int> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = 2 * granularity;
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
- y += (threadIdx.y % ntx) * (tile_B::J*MMQ_TILE_Y_K);
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
const int * x_qs = (const int *) x;
- const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE;
+ const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K;
const int * y_qs = (const int *) y + 4;
const half2 * y_dm = (const half2 *) y;
- tile_A A[ntx][WARP_SIZE/QI8_1];
- float2 dmA[ntx][tile_C::ne/2][WARP_SIZE/QI8_1];
+ tile_A A[ntx][MMQ_TILE_NE_K/QI8_1];
+ float2 dmA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_1];
const int i0 = (threadIdx.y/ntx)*rows_per_warp;
#pragma unroll
for (int n = 0; n < ntx; ++n) {
#pragma unroll
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
const int k0 = k00 + k01;
load_ldmatrix(A[n][k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
#pragma unroll
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
const int k0 = k00 + k01;
dmA[n][l][k01/QI8_1] = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]);
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
#pragma unroll
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
tile_B B;
float2 dsB[tile_C::ne/2];
}
}
}
+#endif // defined(AMD_MFMA_AVAILABLE)
}
-template <int mmq_x, int mmq_y, int nwarps>
+// Used for Q3_K, IQ2_S, and IQ2_XS
+template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
const int * x_qs = (const int *) x;
const float * y_df = (const float *) y;
// #pragma unroll
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
const int k0 = k00 + k01;
#pragma unroll
const int j = j0 + threadIdx.y;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
const int i = i0 + threadIdx.x;
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_16_q8_1_impl<QI8_0>(
- &x_qs[i*(2*WARP_SIZE + 1) + k0],
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_16_q8_1_impl<QI8_0>(
+ &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0],
&y_qs[j*MMQ_TILE_Y_K + k01],
- &x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + k0/(QI8_0/2)],
+ &x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + k0/(QI8_0/2)],
y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
}
}
}
}
-template <int mmq_x, int mmq_y, int nwarps>
+// Used for Q3_K, IQ2_S, and IQ2_XS:
+template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE)
+ typedef tile<16, 8, int> tile_A;
+ typedef tile<16, 8, int> tile_B;
+ typedef tile<16, 16, int> tile_C;
+ typedef tile<64, 2, int> tile_load;
+
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
+ constexpr int rows_per_warp = granularity;
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
+
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
+
+ const int * x_qs = (const int *) x;
+ const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
+ const int * y_qs = (const int *) y + 4;
+ const float * y_df = (const float *) y;
+
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
+
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
+ const int k0 = k00 + k01;
+
+ tile_A A[ntx];
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
+ }
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
+ tile_B B[1];
+ load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
+
+ const int j = j0 + tile_C::get_j(0);
+ const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2;
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ tile_C C;
+ mma(C, A[n], B[0]);
+
+#pragma unroll
+ for (int l = 0; l < tile_C::ne; ++l) {
+ const int i = i0 + n*tile_C::I + tile_C::get_i(l);
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB;
+ }
+ }
+ }
+ }
+#elif defined(NEW_MMA_AVAILABLE)
typedef tile<16, 4, int> tile_A;
typedef tile<16, 8, int> tile_A_8;
constexpr int rows_per_warp = 2 * granularity;
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
- y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
const int * x_qs = (const int *) x;
- const float * x_df = (const float *) x_qs + WARP_SIZE*2;
+ const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
const int * y_qs = (const int *) y + 4;
const float * y_df = (const float *) y;
#pragma unroll
for (int n = 0; n < ntx; ++n) {
#pragma unroll
- for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) {
const int k0 = k00 + k01;
load_ldmatrix(((tile_A_8 *) A[n])[k01/8], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
#pragma unroll
- for (int k01 = 0; k01 < WARP_SIZE; k01 += 4) {
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
const int k0 = k00 + k01;
dA[n][l][k01/4] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4];
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
#pragma unroll
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
tile_B B[2];
float dB[tile_C::ne/2];
#else
GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00);
NO_DEVICE_CODE;
-#endif // NEW_MMA_AVAILABLE
+#endif // AMD_MFMA_AVAILABLE
}
-template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
+template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+ constexpr int nwarps = mmq_get_nwarps_device();
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
int * x_qs = (int *) x_tile;
- half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
+ half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
int * x_qs = (int *) x_tile;
half2 * x_dm = (half2 *) (x_qs + txs.qs);
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
- const int kqsx = threadIdx.x % QI2_K;
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR2_K);
+ constexpr int nrows = ggml_cuda_get_physical_warp_size() / threads_per_row;
+ const int kqsx = threadIdx.x % threads_per_row;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI2_K) {
- int i = i0 + threadIdx.y*(WARP_SIZE/QI2_K) + threadIdx.x/QI2_K;
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
+ int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
if (need_check) {
i = min(i, i_max);
const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303;
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
#else
- x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k;
-#endif // NEW_MMA_AVAILABLE
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
}
const int sc_m = bxi->scales[kqsx];
const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
#endif // FAST_FP16_AVAILABLE
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik;
#else
- x_dm[i*(WARP_SIZE + 1) + kqsx] = x_dm_ik;
-#endif // NEW_MMA_AVAILABLE
+ x_dm[i*(MMQ_TILE_NE_K + 1) + kqsx] = x_dm_ik;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
}
}
-template <int mmq_x, int mmq_y, int nwarps>
+template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
const int * x_qs = (const int *) x;
}
#pragma unroll
- for (int k01 = 0; k01 < WARP_SIZE/2; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K/2; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
const int k0 = k00 + k01;
#pragma unroll
const int j = j0 + threadIdx.y;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
const int i = i0 + threadIdx.x;
constexpr int ns = 2;
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
- &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
- &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
+ &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
+ &x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
&y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
}
}
// Some compilers fail to unroll the loop over k01 if there is a conditional statement for ns in the inner loop.
// As a workaround 2 separate loops are used instead.
#pragma unroll
- for (int k01 = WARP_SIZE/2; k01 < WARP_SIZE; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
+ for (int k01 = MMQ_TILE_NE_K/2; k01 < MMQ_TILE_NE_K; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
const int k0 = k00 + k01;
#pragma unroll
const int j = j0 + threadIdx.y;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
const int i = i0 + threadIdx.x;
constexpr int ns = 1;
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
- &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
- &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
+ &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
+ &x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
&y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
}
}
}
}
-template <int mmq_x, int mmq_y, int nwarps>
+template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE)
+ typedef tile<16, 8, int> tile_A;
+ typedef tile<16, 8, int> tile_B;
+ typedef tile<16, 16, int> tile_C;
+ typedef tile<64, 2, int> tile_load;
+
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
+ constexpr int rows_per_warp = granularity;
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
+
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
+
+ const int * x_qs = (const int *) x;
+ const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;
+ const int * y_qs = (const int *) y + 4;
+ const half2 * y_ds = (const half2 *) y;
+
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
+
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
+ const int k0 = k00 + k01;
+
+ tile_A A[ntx];
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
+ }
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
+ tile_B B[1];
+ load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
+
+ const int j = j0 + tile_C::get_j(0);
+ const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x/2 : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y/2;
+ const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0
+ : (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y
+ : __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x);
+
+ tile_C Cm;
+ if (k01 >= MMQ_TILE_NE_K * 3/4) {
+ tile_A A1;
+ A1.x[0] = 0x01010101;
+ A1.x[1] = 0x01010101;
+ mma(Cm, A1, B[0]);
+ }
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ tile_C Cd;
+ mma(Cd, A[n], B[0]);
+
+#pragma unroll
+ for (int l = 0; l < tile_C::ne; ++l) {
+ const int i = i0 + n*tile_C::I + tile_C::get_i(l);
+ const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]);
+ float tmp = Cd.x[l]*dm.x;
+ if (k01 >= MMQ_TILE_NE_K * 3/4) {
+ tmp -= Cm.x[l]*dm.y;
+ }
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB;
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB;
+ }
+ }
+ }
+ }
+#elif defined(NEW_MMA_AVAILABLE)
typedef tile<16, 4, int> tile_A;
typedef tile<16, 8, int> tile_A_8;
constexpr int rows_per_warp = 2 * granularity;
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
- y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
const int * x_qs = (const int *) x;
- const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2;
+ const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;
const int * y_qs = (const int *) y + 4;
const half2 * y_ds = (const half2 *) y;
#pragma unroll
for (int n = 0; n < ntx; ++n) {
#pragma unroll
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
const int k0 = k00 + k01;
load_ldmatrix(((tile_A_8 *) A[n])[k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
#pragma unroll
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1/2) {
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1/2) {
const int k0 = k00 + k01;
const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/(QI8_1/2)]);
}
#pragma unroll
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
tile_B B[2];
// Here load_generic is faster than load_ldmatrix.
load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K);
tile_C Cm[2];
- if (k01 >= WARP_SIZE * 3/4) {
+ if (k01 >= MMQ_TILE_NE_K * 3/4) {
tile_A A1;
A1.x[0] = 0x01010101;
A1.x[1] = 0x01010101;
#pragma unroll
for (int l = 0; l < tile_C::ne; ++l) {
float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1];
- if (k01 >= WARP_SIZE * 3/4) {
+ if (k01 >= MMQ_TILE_NE_K * 3/4) {
tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1];
}
- sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y);
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < MMQ_TILE_NE_K/2 ? dB[l%2].x : dB[l%2].y);
}
}
}
#pragma unroll
- for (int k01 = 0; k01 < WARP_SIZE * 3/4; k01 += QI8_1) {
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K * 3/4; k01 += QI8_1) {
float2 sB[tile_C::ne/2];
#pragma unroll
#else
GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00);
NO_DEVICE_CODE;
-#endif // NEW_MMA_AVAILABLE
+#endif // AMD_MFMA_AVAILABLE
}
-template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
+template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
int * x_qs = (int *) x_tile;
- float * x_df = (float *) (x_qs + WARP_SIZE*2);
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
int * x_sc = (int *) (x_df + txs.dm);
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
- const int kqsx = threadIdx.x % QI3_K;
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR3_K);
+ constexpr int nrows = warp_size / threads_per_row;
+ const int kqsx = threadIdx.x % threads_per_row;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI3_K) {
- int i = i0 + threadIdx.y * (WARP_SIZE/QI3_K) + threadIdx.x / QI3_K;
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
+ int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
if (need_check) {
i = min(i, i_max);
const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404);
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k;
#else
- x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k;
-#endif // NEW_MMA_AVAILABLE
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
}
}
+ constexpr int rows_per_warp = warp_size / 4;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps*8) {
- int i = i0 + threadIdx.y*8 + threadIdx.x/(WARP_SIZE/8);
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
+ int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/4;
if (need_check) {
i = min(i, i_max);
const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
- const int ksc = threadIdx.x % (WARP_SIZE/8);
+ const int ksc = threadIdx.x % 4;
const int ksc_low = ksc % (QI3_K/8);
const int shift_low = 4 * (ksc / (QI3_K/8));
const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
const int8_t * sc8 = (const int8_t *) ≻
const float d = bxi->d;
#pragma unroll
for (int l = 0; l < int(sizeof(int)); ++l) {
- x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*(threadIdx.x % (WARP_SIZE/8)) + l] = d*sc8[l];
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*ksc + l] = d*sc8[l];
}
#else
- x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = sc;
-#endif // NEW_MMA_AVAILABLE
+ x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = sc;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
}
-#ifndef NEW_MMA_AVAILABLE
+#if !(defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE))
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps*WARP_SIZE) {
- int i = (i0 + threadIdx.y*WARP_SIZE + threadIdx.x) % mmq_y;
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
+ int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
if (need_check) {
i = min(i, i_max);
x_df[i] = bxi->d;
}
-#endif // NEW_MMA_AVAILABLE
+#endif // !(defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE))
}
-template <int mmq_x, int mmq_y, int nwarps>
+template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
const int * x_qs = (const int *) x;
const float * y_df = (const float *) y;
// #pragma unroll
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
const int k0 = k00 + k01;
#pragma unroll
const int j = j0 + threadIdx.y;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
const int i = i0 + threadIdx.x;
- const int8_t * scales = ((const int8_t *) (x_sc + i*(WARP_SIZE/8) + i/8)) + k0/4;
+ const int8_t * scales = ((const int8_t *) (x_sc + i*(MMQ_TILE_NE_K/8) + i/8)) + k0/4;
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq(
- &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], scales,
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q3_K_q8_1_impl_mmq(
+ &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], scales,
x_df[i], y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
}
}
((scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030); // upper 2 bits
}
-template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
+template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
int * x_qs = (int *) x_tile;
- half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
+ half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
int * x_qs = (int *) x_tile;
half2 * x_dm = (half2 *) (x_qs + txs.qs);
int * x_sc = (int *) (x_dm + txs.dm);
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_K);
+ constexpr int nrows = warp_size / threads_per_row;
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
- int i = i0 + threadIdx.y;
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
if (need_check) {
i = min(i, i_max);
}
const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
- const int qs0 = get_int_b4(bxi->qs, threadIdx.x);
+ const int qs0 = get_int_b4(bxi->qs, txi);
-#ifdef NEW_MMA_AVAILABLE
- x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F;
- x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F;
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F;
#else
- x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
-#endif // NEW_MMA_AVAILABLE
+ x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
}
-#ifdef NEW_MMA_AVAILABLE
-
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+ constexpr int rows_per_warp = warp_size / 2;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) {
- int i = (i0 + threadIdx.y*16 + threadIdx.x/(WARP_SIZE/16)) % mmq_y;
-
- if (need_check) {
- i = min(i, i_max);
- }
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
+#if defined(AMD_MFMA_AVAILABLE)
+ // Need if on AMD instead of % because warp_size == 64
+ // This causes double work and throughput loss (MI300X)
+ // H100 loses about 100 t/s with 'if' condition over '%'
+ int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2;
+ if (i < mmq_y) {
+#else
+ int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y;
+ {
+#endif // defined(AMD_MFMA_AVAILABLE)
+ if (need_check) {
+ i = min(i, i_max);
+ }
- const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
+ const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
- const int * scales = (const int *) bxi->scales;
- const int ksc = threadIdx.x % (WARP_SIZE/16);
+ const int * scales = (const int *) bxi->scales;
+ const int ksc = threadIdx.x % 2;
- const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
- const int m32 = unpack_scales_q45_K(scales, ksc + 2);
+ const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
+ const int m32 = unpack_scales_q45_K(scales, ksc + 2);
- const uint8_t * sc8 = (const uint8_t *) &sc32;
- const uint8_t * m8 = (const uint8_t *) &m32;
+ const uint8_t * sc8 = (const uint8_t *) &sc32;
+ const uint8_t * m8 = (const uint8_t *) &m32;
- const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
+ const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
-#pragma unroll
- for (int l = 0; l < int(sizeof(int)); ++l) {
- x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
+ #pragma unroll
+ for (int l = 0; l < sizeof(int); ++l) {
+ x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
+ }
}
}
-
#else
-
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps*QI4_K) {
- int i = (i0 + threadIdx.y*QI4_K + threadIdx.x) % mmq_y;
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
+ int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
if (need_check) {
i = min(i, i_max);
x_dm[i] = bxi->dm;
}
-
+ constexpr int rows_per_warp = warp_size / 4;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
- int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
+ int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y;
if (need_check) {
i = min(i, i_max);
}
- const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / (QI4_K/8);
+ const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / (QI4_K/8);
const int * scales = (const int *) bxi->scales;
- const int ksc = threadIdx.x % (WARP_SIZE/8);
+ const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8);
const int scales8 = unpack_scales_q45_K(scales, ksc);
- x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
+ x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
}
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
}
-template <int mmq_x, int mmq_y, int nwarps>
+template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
const int * x_qs = (const int *) x;
const half2 * y_ds = (const half2 *) y;
// #pragma unroll
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_K*VDR_Q4_K_Q8_1_MMQ) {
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_K*VDR_Q4_K_Q8_1_MMQ) {
const int k0 = k00 + k01;
#pragma unroll
const int j = j0 + threadIdx.y;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
const int i = i0 + threadIdx.x;
- const uint8_t * sc = (const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/32] + 2*(k01/16);
+ const uint8_t * sc = (const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/32] + 2*(k01/16);
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq(
- &x_qs[i*(WARP_SIZE + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_K_q8_1_impl_mmq(
+ &x_qs[i*(MMQ_TILE_NE_K + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
}
}
}
}
-template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
+template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
int * x_qs = (int *) x_tile;
- half2 * x_dm = (half2 *) (x_qs + WARP_SIZE*2);
+ half2 * x_dm = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
int * x_qs = (int *) x_tile;
half2 * x_dm = (half2 *) (x_qs + txs.qs);
int * x_sc = (int *) (x_dm + txs.dm);
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_K);
+ constexpr int nrows = warp_size / threads_per_row;
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
- int i = i0 + threadIdx.y;
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
if (need_check) {
i = min(i, i_max);
}
const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
- const int ky = QR5_K*threadIdx.x;
+ const int ky = QR5_K*txi;
- const int ql = get_int_b4(bxi->qs, threadIdx.x);
+ const int ql = get_int_b4(bxi->qs, txi);
const int ql0 = (ql >> 0) & 0x0F0F0F0F;
const int ql1 = (ql >> 4) & 0x0F0F0F0F;
- const int qh = get_int_b4(bxi->qh, threadIdx.x % (QI5_K/4));
- const int qh0 = ((qh >> (2 * (threadIdx.x / (QI5_K/4)) + 0)) << 4) & 0x10101010;
- const int qh1 = ((qh >> (2 * (threadIdx.x / (QI5_K/4)) + 1)) << 4) & 0x10101010;
+ const int qh = get_int_b4(bxi->qh, txi % (QI5_K/4));
+ const int qh0 = ((qh >> (2 * (txi / (QI5_K/4)) + 0)) << 4) & 0x10101010;
+ const int qh1 = ((qh >> (2 * (txi / (QI5_K/4)) + 1)) << 4) & 0x10101010;
- const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0;
- const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + QI5_K/4;
+ const int kq0 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + 0;
+ const int kq1 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + QI5_K/4;
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1;
#else
- x_qs[i*(2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
- x_qs[i*(2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
-#endif // NEW_MMA_AVAILABLE
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = ql0 | qh0;
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = ql1 | qh1;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
}
-#ifdef NEW_MMA_AVAILABLE
-
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+ constexpr int rows_per_warp = warp_size / 2;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) {
- int i = (i0 + threadIdx.y*16 + threadIdx.x/(WARP_SIZE/16)) % mmq_y;
-
- if (need_check) {
- i = min(i, i_max);
- }
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
+#if defined(AMD_MFMA_AVAILABLE)
+ // Need if on AMD instead of % because warp_size == 64
+ // This causes double work and throughput loss (MI300X)
+ // H100 loses about 100 t/s with 'if' condition over '%'
+ int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2;
+ if (i < mmq_y) {
+#else
+ int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y;
+ {
+#endif // defined(AMD_MFMA_AVAILABLE)
+ if (need_check) {
+ i = min(i, i_max);
+ }
- const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
+ const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
- const int * scales = (const int *) bxi->scales;
- const int ksc = threadIdx.x % (WARP_SIZE/16);
+ const int * scales = (const int *) bxi->scales;
+ const int ksc = threadIdx.x % 2;
- const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
- const int m32 = unpack_scales_q45_K(scales, ksc + 2);
+ const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
+ const int m32 = unpack_scales_q45_K(scales, ksc + 2);
- const uint8_t * sc8 = (const uint8_t *) &sc32;
- const uint8_t * m8 = (const uint8_t *) &m32;
+ const uint8_t * sc8 = (const uint8_t *) &sc32;
+ const uint8_t * m8 = (const uint8_t *) &m32;
- const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
+ const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
#pragma unroll
- for (int l = 0; l < int(sizeof(int)); ++l) {
- x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
+ for (int l = 0; l < int(sizeof(int)); ++l) {
+ x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
+ }
}
}
-
#else
-
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps*QI5_K) {
- int i = (i0 + threadIdx.y*QI5_K + threadIdx.x) % mmq_y;
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
+ int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
if (need_check) {
i = min(i, i_max);
x_dm[i] = bxi->dm;
}
+ constexpr int rows_per_warp = warp_size / 4;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps*8) {
- int i = (i0 + threadIdx.y*8 + threadIdx.x/(WARP_SIZE/8)) % mmq_y;
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
+ int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y;
if (need_check) {
i = min(i, i_max);
const int * scales = (const int *) bxi->scales;
- const int ksc = threadIdx.x % (WARP_SIZE/8);
+ const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8);
const int scales8 = unpack_scales_q45_K(scales, ksc);
- x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
+ x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
}
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
}
-template <int mmq_x, int mmq_y, int nwarps>
+template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
const int * x_qs = (const int *) x;
const half2 * y_ds = (const half2 *) y;
// #pragma unroll
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) {
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) {
const int k0 = k00 + k01;
#pragma unroll
const int j = j0 + threadIdx.y;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
const int i = i0 + threadIdx.x;
- const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k00/32]) + 2*(k01/16);
+ const uint8_t * sc = ((const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k00/32]) + 2*(k01/16);
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq(
- &x_qs[i*(QR5_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q5_K_q8_1_impl_mmq(
+ &x_qs[i*(QR5_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
}
}
}
}
-template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
+template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
int * x_qs = (int *) x_tile;
- float * x_df = (float *) (x_qs + WARP_SIZE*2);
- int * x_sc = (int *) (x_df + WARP_SIZE/QI6_K);
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
+ int * x_sc = (int *) (x_df + MMQ_TILE_NE_K/QI6_K);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
int * x_sc = (int *) (x_df + txs.dm);
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR6_K);
+ constexpr int nrows = warp_size / threads_per_row;
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
- int i = i0 + threadIdx.y;
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
if (need_check) {
i = min(i, i_max);
const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;
- const int ql = get_int_b2(bxi->ql, threadIdx.x);
+ const int ql = get_int_b2(bxi->ql, txi);
const int ql0 = (ql >> 0) & 0x0F0F0F0F;
const int ql1 = (ql >> 4) & 0x0F0F0F0F;
- const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (threadIdx.x / (QI6_K/2)) + threadIdx.x % (QI6_K/4));
- const int qh0 = ((qh >> ((threadIdx.x & 0x08) >> 2)) << 4) & 0x30303030;
- const int qh1 = (qh >> ((threadIdx.x & 0x08) >> 2)) & 0x30303030;
+ const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (txi / (QI6_K/2)) + txi % (QI6_K/4));
+ const int qh0 = ((qh >> ((txi & 0x08) >> 2)) << 4) & 0x30303030;
+ const int qh1 = (qh >> ((txi & 0x08) >> 2)) & 0x30303030;
- const int kq0 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + 0;
- const int kq1 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + QI6_K/2;
+ const int kq0 = 2*txi - txi % (QI6_K/2) + 0;
+ const int kq1 = 2*txi - txi % (QI6_K/2) + QI6_K/2;
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
#else
- x_qs[i*(2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
- x_qs[i*(2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
-#endif // NEW_MMA_AVAILABLE
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
}
- const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256
- const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
-
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) {
- int i = (i0 + threadIdx.y * QI6_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
+ int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
if (need_check) {
i = min(i, i_max);
}
- const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbxd;
+ const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;
-#ifdef NEW_MMA_AVAILABLE
- x_df[i*MMQ_MMA_TILE_X_K_Q6_K + kbxd] = bxi->d;
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+ x_df[i*MMQ_MMA_TILE_X_K_Q6_K] = bxi->d;
#else
- x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K + kbxd] = bxi->d;
-#endif // NEW_MMA_AVAILABLE
+ x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K] = bxi->d;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
}
+ constexpr int rows_per_warp = warp_size / 4;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
- int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
+ int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y;
if (need_check) {
i = min(i, i_max);
}
- const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4;
+ const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / 4;
-#ifdef NEW_MMA_AVAILABLE
- x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8));
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+ x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x % (MMQ_TILE_NE_K/8));
#else
- x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8));
-#endif // NEW_MMA_AVAILABLE
+ x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + threadIdx.x%(MMQ_TILE_NE_K/8)] = get_int_b2(bxi->scales, threadIdx.x%(QI6_K/8));
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
}
}
-template <int mmq_x, int mmq_y, int nwarps>
+template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
const int * x_qs = (const int *) x;
const float * y_df = (const float *) y;
// #pragma unroll
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QR6_K*VDR_Q6_K_Q8_1_MMQ) {
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR6_K*VDR_Q6_K_Q8_1_MMQ) {
const int k0 = k00 + k01;
#pragma unroll
const int j = j0 + threadIdx.y;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
const int i = i0 + threadIdx.x;
- const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]);
+ const int8_t * sc = ((const int8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/16]);
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq(
- &x_qs[i*(QR6_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc,
- x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q6_K_q8_1_impl_mmq(
+ &x_qs[i*(QR6_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc,
+ x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
}
}
}
}
-template <int mmq_x, int mmq_y, int nwarps>
+template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE)
+ typedef tile<16, 8, int> tile_A;
+ typedef tile<16, 8, int> tile_B;
+ typedef tile<16, 16, int> tile_C;
+ typedef tile<64, 2, int> tile_load;
+
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
+ constexpr int rows_per_warp = granularity;
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
+
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
+
+ const int * x_qs = (const int *) x;
+ const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
+ const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K;
+ const int * y_qs = (const int *) y + 4;
+ const float * y_df = (const float *) y;
+
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
+
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
+ const int k0 = k00 + k01;
+
+ tile_A A[ntx];
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
+ }
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
+ tile_B B[1];
+ load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
+
+ const int j = j0 + tile_C::get_j(0);
+ const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2;
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ tile_C C;
+ mma(C, A[n], B[0]);
+
+#pragma unroll
+ for (int l = 0; l < tile_C::ne; ++l) {
+ const int i = i0 + n*tile_C::I + tile_C::get_i(l);
+ const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16);
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB;
+ }
+ }
+ }
+ }
+#elif defined(NEW_MMA_AVAILABLE)
typedef tile<16, 4, int> tile_A;
typedef tile< 8, 4, int> tile_B;
constexpr int rows_per_warp = 2 * granularity;
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
- y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
const int * x_qs = (const int *) x;
- const float * x_df = (const float *) x_qs + WARP_SIZE*2;
- const int * x_sc = (const int *) x_df + WARP_SIZE/QI6_K;
+ const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
+ const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K;
const int * y_qs = (const int *) y + 4;
const float * y_df = (const float *) y;
#pragma unroll
for (int n = 0; n < ntx; ++n) {
#pragma unroll
- for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) {
const int k0 = k00 + k01;
load_ldmatrix(A[n][k01/4 + 0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K);
}
#pragma unroll
- for (int k01 = 0; k01 < WARP_SIZE; k01 += 16) {
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 16) {
const int k0 = k00 + k01;
#pragma unroll
float tmp[ntx][tile_C::ne] = {{0.0f}};
#pragma unroll
- for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) {
tile_B B[2];
float dB[tile_C::ne/2];
#else
GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00);
NO_DEVICE_CODE;
-#endif // NEW_MMA_AVAILABLE
+#endif // AMD_MFMA_AVAILABLE
}
-template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
+template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
int * x_qs = (int *) x_tile;
- float * x_df = (float *) (x_qs + WARP_SIZE*2);
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
- const int kbx = threadIdx.x / QI4_NL;
- const int kqsx = threadIdx.x % QI4_NL;
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_NL);
+ constexpr int nrows = warp_size / threads_per_row;
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
+ const int kbx = txi / QI4_NL;
+ const int kqsx = txi % QI4_NL;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
- int i = i0 + threadIdx.y;
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
if (need_check) {
i = min(i, i_max);
const int aux_q4 = get_int_b2(bxi->qs, kqsx);
const int2 v = get_int_from_table_16(aux_q4);
- const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
-#ifdef NEW_MMA_AVAILABLE
- x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
- x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
+ const int k0 = kbx * (2 * QI4_NL) + kqsx;
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + QI4_NL] = v.y;
#else
- x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
- x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
-#endif // NEW_MMA_AVAILABLE
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI4_NL] = v.y;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
}
- const int blocks_per_tile_x_row = WARP_SIZE / QI4_NL;
+ constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_NL;
+ constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_NL) {
- int i = i0 + threadIdx.y * QI4_NL + threadIdx.x / blocks_per_tile_x_row;
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
+ int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
if (need_check) {
i = min(i, i_max);
const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
-#ifdef NEW_MMA_AVAILABLE
- x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d);
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d);
#else
- x_df[i*(WARP_SIZE/4) + i/4 + kbxd] = __half2float(bxi->d);
-#endif // NEW_MMA_AVAILABLE
+ x_df[i*(MMQ_TILE_NE_K/QI4_NL) + i/QI4_NL + kbxd] = __half2float(bxi->d);
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
}
}
-template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xxs(
+template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xxs(
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
int * x_qs = (int *) x_tile;
- float * x_df = (float *) (x_qs + WARP_SIZE*2);
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
- const int kqsx = threadIdx.x % (QI2_XXS/2);
+ constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XXS)) / 2;
+ constexpr int nrows = warp_size / threads_per_row;
+ const int kqsx = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_XXS/2)) {
- int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_XXS) + threadIdx.x/(QI2_XXS/2);
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
+ int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
if (need_check) {
i = min(i, i_max);
const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1;
#else
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid0;
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid1;
-#endif // NEW_MMA_AVAILABLE
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid0;
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid1;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
}
const int ls = aux32 >> 28;
const float d = bxi->d;
-#ifdef NEW_MMA_AVAILABLE
- x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
#else
- x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = (ls*d + d/2)/4;
-#endif // NEW_MMA_AVAILABLE
+ x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
}
}
-template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xs(
+template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xs(
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
int * x_qs = (int *) x_tile;
- float * x_df = (float *) (x_qs + WARP_SIZE*2);
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
#else
constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
- const int kqsx = threadIdx.x % (QI2_XS/2);
+ constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XS)) / 2;
+ constexpr int nrows = warp_size / threads_per_row;
+ const int kqsx = threadIdx.x % threads_per_row;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_XS/2)) {
- int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_XS) + threadIdx.x/(QI2_XS/2);
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
+ int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
if (need_check) {
i = min(i, i_max);
const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
#else
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l;
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h;
-#endif // NEW_MMA_AVAILABLE
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
}
const int ls = bxi->scales[kqsx];
const float d = bxi->d;
-#ifdef NEW_MMA_AVAILABLE
- x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
- x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
#else
- x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
- x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
-#endif // NEW_MMA_AVAILABLE
+ x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
+ x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
}
}
-template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_s(
+template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_s(
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
int * x_qs = (int *) x_tile;
- float * x_df = (float *) (x_qs + WARP_SIZE*2);
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
- const int kqsx = threadIdx.x % (QI2_S/2);
+ constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_S)) / 2;
+ constexpr int nrows = warp_size / threads_per_row;
+ const int kqsx = threadIdx.x % threads_per_row;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_S/2)) {
- int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_S) + threadIdx.x/(QI2_S/2);
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
+ int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
if (need_check) {
i = min(i, i_max);
const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0);
const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1);
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
#else
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l;
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h;
-#endif // NEW_MMA_AVAILABLE
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
}
const int ls = bxi->scales[kqsx];
const float d = bxi->d;
-#ifdef NEW_MMA_AVAILABLE
- x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
- x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
#else
- x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
- x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
-#endif // NEW_MMA_AVAILABLE
+ x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
+ x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
}
}
-template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_xxs(
+template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_xxs(
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
int * x_qs = (int *) x_tile;
- float * x_df = (float *) (x_qs + WARP_SIZE*2);
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
- const int kqsx = threadIdx.x % (QI3_XXS/2);
+ constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_XXS)) / 2;
+ constexpr int nrows = warp_size / threads_per_row;
+ const int kqsx = threadIdx.x % threads_per_row;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI3_XXS/2)) {
- int i = i0 + threadIdx.y*(2*WARP_SIZE/QI3_XXS) + threadIdx.x/(QI3_XXS/2);
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
+ int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
if (need_check) {
i = min(i, i_max);
const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h;
#else
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l;
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h;
-#endif // NEW_MMA_AVAILABLE
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
}
const int ls = aux32 >> 28;
const float d = bxi->d;
-#ifdef NEW_MMA_AVAILABLE
- x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2;
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2;
#else
- x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = (ls*d + d/2)/2;
-#endif // NEW_MMA_AVAILABLE
+ x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/2;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
}
}
-template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_s(
+template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_s(
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
int * x_qs = (int *) x_tile;
- float * x_df = (float *) (x_qs + WARP_SIZE*2);
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
- const int kqsx = threadIdx.x % (QI3_S/2);
+ constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_S)) / 2;
+ constexpr int nrows = warp_size / threads_per_row;
+ const int kqsx = threadIdx.x % threads_per_row;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI3_S/2)) {
- int i = i0 + threadIdx.y*(2*WARP_SIZE/QI3_S) + threadIdx.x/(QI3_S/2);
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
+ int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
if (need_check) {
i = min(i, i_max);
const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h;
#else
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+0)] = grid_l;
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+1)] = grid_h;
-#endif // NEW_MMA_AVAILABLE
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid_l;
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid_h;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
}
const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F);
const float d = bxi->d;
-#ifdef NEW_MMA_AVAILABLE
- x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d;
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d;
#else
- x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = ls*d;
-#endif // NEW_MMA_AVAILABLE
+ x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = ls*d;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
}
}
-template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq1_s(
+template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq1_s(
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
int * x_qs = (int *) x_tile;
- half2 * x_ds = (half2 *) (x_qs + WARP_SIZE*2);
+ half2 * x_ds = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
int * x_qs = (int *) x_tile;
half2 * x_ds = (half2 *) (x_qs + txs.qs);
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
- const int kqsx = threadIdx.x % QI1_S;
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR1_S);
+ constexpr int nrows = warp_size / threads_per_row;
+ const int kqsx = threadIdx.x % threads_per_row;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI1_S) {
- int i = i0 + threadIdx.y*(WARP_SIZE/QI1_S) + threadIdx.x/QI1_S;
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
+ int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
if (need_check) {
i = min(i, i_max);
const int grid0 = (grid >> 0) & 0x0F0F0F0F;
const int grid1 = (grid >> 4) & 0x0F0F0F0F;
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1;
#else
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+0)] = grid0;
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+1)] = grid1;
-#endif // NEW_MMA_AVAILABLE
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid0;
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid1;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
}
const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1);
const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
-#ifdef NEW_MMA_AVAILABLE
- x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta);
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+ x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta);
#else
- x_ds[i*(WARP_SIZE/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta);
-#endif // NEW_MMA_AVAILABLE
+ x_ds[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta);
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
}
}
-template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_xs(
+template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_xs(
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
int * x_qs = (int *) x_tile;
- float * x_df = (float *) (x_qs + WARP_SIZE*2);
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
- const int kbx = 0; // threadIdx.x / QI4_XS
- const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_XS);
+ constexpr int nrows = warp_size / threads_per_row;
+ const int kqsx = threadIdx.x % threads_per_row;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
- int i = i0 + threadIdx.y;
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
if (need_check) {
i = min(i, i_max);
}
- const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride + kbx;
+ const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride;
const int aux_q4 = get_int_b4(bxi->qs, kqsx);
const int2 v = get_int_from_table_16(aux_q4);
- const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
-#ifdef NEW_MMA_AVAILABLE
+ const int k0 = 8 * (kqsx / 4) + kqsx % 4;
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
#else
- x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
- x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
-#endif // NEW_MMA_AVAILABLE
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 4] = v.y;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
}
+ constexpr int rows_per_warp = warp_size / 8;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
- int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
+ int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / (MMQ_TILE_NE_K/4);
if (need_check) {
i = min(i, i_max);
const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F)
| (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
-#ifdef NEW_MMA_AVAILABLE
- x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32);
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32);
#else
- x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * (ls - 32);
-#endif // NEW_MMA_AVAILABLE
+ x_df[i*(MMQ_TILE_NE_K/4) + i/4 + threadIdx.x % 8] = d * (ls - 32);
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
}
}
-template<int mmq_x, int mmq_y, int nwarps, bool need_check>
+template<int mmq_x, int mmq_y, bool need_check>
static __device__ __forceinline__ void mmq_write_back_dp4a(
const float * __restrict__ sum, const int32_t * __restrict__ ids_dst, float * __restrict__ dst,
const int stride, const int i_max, const int j_max) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
const int j = j0 + threadIdx.y;
}
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
const int i = i0 + threadIdx.x;
if (need_check && i > i_max) {
continue;
}
- dst[ids_dst[j]*stride + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
+ dst[ids_dst[j]*stride + i] = sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
}
}
}
-template<int mmq_x, int mmq_y, int nwarps, bool need_check>
+template<ggml_type type, int mmq_x, int mmq_y, bool need_check>
static __device__ __forceinline__ void mmq_write_back_mma(
const float * __restrict__ sum, const int * __restrict__ ids_dst, float * __restrict__ dst,
const int stride, const int i_max, const int j_max) {
- typedef tile<16, 8, int> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
+ constexpr int nwarps = mmq_get_nwarps_device();
+
+#if defined(AMD_MFMA_AVAILABLE)
+ constexpr int tileC_IJ = mmq_get_granularity_device(0);
+ typedef tile<tileC_IJ, tileC_IJ, int> tile_C;
+ constexpr int rows_per_warp = granularity;
+#else
+ typedef tile<16, 8, int> tile_C;
constexpr int rows_per_warp = 2 * granularity;
+#endif
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I);
-#ifdef NEW_MMA_AVAILABLE
+#if defined(NEW_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y");
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
// -------------------------------------------------------------------------------------------------------------------------------------
-template <int mmq_x, int mmq_y, int nwarps, bool need_check, ggml_type type>
+template <int mmq_x, int mmq_y, bool need_check, ggml_type type>
struct mmq_type_traits;
-template <int mmq_x, int mmq_y, int nwarps, bool need_check>
-struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
+template <int mmq_x, int mmq_y, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_0> {
static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, nwarps, need_check>;
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_DS4>;
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_DS4>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y>;
};
-template <int mmq_x, int mmq_y, int nwarps, bool need_check>
-struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
+template <int mmq_x, int mmq_y, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_1> {
static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ;
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, nwarps, need_check>;
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y>;
};
-template <int mmq_x, int mmq_y, int nwarps, bool need_check>
-struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
+template <int mmq_x, int mmq_y, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_0> {
static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
};
-template <int mmq_x, int mmq_y, int nwarps, bool need_check>
-struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
+template <int mmq_x, int mmq_y, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_1> {
static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>;
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y>;
};
-template <int mmq_x, int mmq_y, int nwarps, bool need_check>
-struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
+template <int mmq_x, int mmq_y, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q8_0> {
static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>;
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
};
-template <int mmq_x, int mmq_y, int nwarps, bool need_check>
-struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q2_K> {
+template <int mmq_x, int mmq_y, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q2_K> {
static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, nwarps, need_check>;
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma<mmq_x, mmq_y>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a<mmq_x, mmq_y>;
};
-template <int mmq_x, int mmq_y, int nwarps, bool need_check>
-struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> {
+template <int mmq_x, int mmq_y, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q3_K> {
static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ;
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, nwarps, need_check>;
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a<mmq_x, mmq_y>;
};
-template <int mmq_x, int mmq_y, int nwarps, bool need_check>
-struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
+template <int mmq_x, int mmq_y, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_K> {
static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y>;
};
-template <int mmq_x, int mmq_y, int nwarps, bool need_check>
-struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
+template <int mmq_x, int mmq_y, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_K> {
static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y>;
};
-template <int mmq_x, int mmq_y, int nwarps, bool need_check>
-struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
+template <int mmq_x, int mmq_y, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q6_K> {
static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>;
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma<mmq_x, mmq_y>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y>;
};
-template <int mmq_x, int mmq_y, int nwarps, bool need_check>
-struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_XXS> {
+template <int mmq_x, int mmq_y, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_XXS> {
static constexpr int vdr = VDR_IQ2_XXS_Q8_1_MMQ;
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xxs<mmq_y, nwarps, need_check>;
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xxs<mmq_y, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
};
-template <int mmq_x, int mmq_y, int nwarps, bool need_check>
-struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_XS> {
+template <int mmq_x, int mmq_y, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_XS> {
static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ;
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xs<mmq_y, nwarps, need_check>;
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xs<mmq_y, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
};
-template <int mmq_x, int mmq_y, int nwarps, bool need_check>
-struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_S> {
+template <int mmq_x, int mmq_y, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_S> {
static constexpr int vdr = VDR_IQ2_S_Q8_1_MMQ;
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_s<mmq_y, nwarps, need_check>;
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_s<mmq_y, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
};
-template <int mmq_x, int mmq_y, int nwarps, bool need_check>
-struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ3_XXS> {
+template <int mmq_x, int mmq_y, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ3_XXS> {
static constexpr int vdr = VDR_IQ3_XXS_Q8_1_MMQ;
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_xxs<mmq_y, nwarps, need_check>;
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_xxs<mmq_y, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
};
-template <int mmq_x, int mmq_y, int nwarps, bool need_check>
-struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ3_S> {
+template <int mmq_x, int mmq_y, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ3_S> {
static constexpr int vdr = VDR_IQ3_S_Q8_1_MMQ;
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_s<mmq_y, nwarps, need_check>;
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_s<mmq_y, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
};
-template <int mmq_x, int mmq_y, int nwarps, bool need_check>
-struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ1_S> {
+template <int mmq_x, int mmq_y, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ1_S> {
static constexpr int vdr = VDR_IQ1_S_Q8_1_MMQ;
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s<mmq_y, nwarps, need_check>;
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s<mmq_y, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y>;
};
-template <int mmq_x, int mmq_y, int nwarps, bool need_check>
-struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_NL> {
+template <int mmq_x, int mmq_y, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ4_NL> {
static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ;
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl<mmq_y, nwarps, need_check>;
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl<mmq_y, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
};
-template <int mmq_x, int mmq_y, int nwarps, bool need_check>
-struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_XS> {
+template <int mmq_x, int mmq_y, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ4_XS> {
static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ;
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs<mmq_y, nwarps, need_check>;
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs<mmq_y, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
};
-template <ggml_type type, int mmq_x, int nwarps, bool need_check, bool fixup>
+template <ggml_type type, int mmq_x, bool need_check, bool fixup>
static __device__ __forceinline__ void mul_mat_q_process_tile(
const char * __restrict__ x, const int offset_x, const int * __restrict__ y,
const int * __restrict__ ids_dst, float * __restrict__ dst, float * __restrict__ tmp_fixup,
const int stride_row_x, const int ncols_y, const int stride_col_dst,
const int tile_x_max_i, const int tile_y_max_j, const int kb0_start, const int kb0_stop) {
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+ constexpr int nwarps = mmq_get_nwarps_device();
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int mmq_y = get_mmq_y_device();
- constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
+ constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, need_check, type>::load_tiles;
extern __shared__ int data_mul_mat_q[];
int * tile_y = data_mul_mat_q + mmq_x;
- int * tile_x = tile_y + GGML_PAD(mmq_x*(WARP_SIZE + WARP_SIZE/QI8_1), nwarps*WARP_SIZE);
+ int * tile_x = tile_y + GGML_PAD(mmq_x*MMQ_TILE_Y_K, nwarps*warp_size);
-#ifdef NEW_MMA_AVAILABLE
- constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_mma;
- constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+ constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_mma;
+ constexpr mmq_write_back_t write_back = mmq_write_back_mma<type, mmq_x, mmq_y, need_check>;
#else
- constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_dp4a;
- constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
-#endif // NEW_MMA_AVAILABLE
+ constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_dp4a;
+ constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, need_check>;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
- float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
+ float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) {
load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x);
{
const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int));
#pragma unroll
- for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
- int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
+ for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) {
+ int l = l0 + threadIdx.y*warp_size + threadIdx.x;
tile_y[l] = by0[l];
}
{
const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int));
#pragma unroll
- for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
- int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
+ for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) {
+ int l = l0 + threadIdx.y*warp_size + threadIdx.x;
tile_y[l] = by0[l];
}
__syncthreads();
- vec_dot(tile_x, tile_y, sum, WARP_SIZE);
+ vec_dot(tile_x, tile_y, sum, MMQ_TILE_NE_K);
__syncthreads();
}
// The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598
-template <ggml_type type, int mmq_x, int nwarps, bool need_check>
+template <ggml_type type, int mmq_x, bool need_check>
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
- __launch_bounds__(WARP_SIZE*nwarps, 2)
+ __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2)
#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
#else
#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
- __launch_bounds__(WARP_SIZE*nwarps, 1)
+ __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 1)
#else
- __launch_bounds__(WARP_SIZE*nwarps, 2)
+ __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2)
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
static __global__ void mul_mat_q(
return;
}
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int mmq_y = get_mmq_y_device();
// For MoE the correct indices are loaded from ids_dst.
extern __shared__ int ids_dst_shared[]; // Stored at beginning of shared memory.
#pragma unroll
- for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
- const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
+ const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
- if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) {
+ if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
break;
}
__syncthreads();
// On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
-#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
+#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
{
const int wt = blockIdx.z / nchannels_y;
const int zt = blockIdx.z - wt*nchannels_y;
// __syncthreads(); // There is no previous tile that could cause a race condition.
#pragma unroll
- for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
- const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
+ const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
- if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) {
+ if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
break;
}
const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
constexpr bool fixup = false;
- mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
+ mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
tile_x_max_i, tile_y_max_j, 0, ncols_x/qk);
return;
}
-#endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
+#endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
const int64_t blocks_per_ne00 = ncols_x / qk;
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
__syncthreads();
#pragma unroll
- for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
- const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
+ const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
- if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) {
+ if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
break;
}
const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
- mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
+ mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
// The memory layout for the fixup buffer is always contiguous, therefore reset ids:
__syncthreads();
#pragma unroll
- for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
- const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
+ const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
- if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) {
+ if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
break;
}
const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
- mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
+ mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
}
-template <ggml_type type, int mmq_x, int nwarps, bool need_check>
+template <ggml_type type, int mmq_x, bool need_check>
static __global__ void mul_mat_q_stream_k_fixup(
const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile,
const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst,
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
const int64_t blocks_per_ne00 = ncols_x / qk;
- float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+ float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
const int ntx = (ncols_dst + mmq_x - 1) / mmq_x;
const int nty = (nrows_x + mmq_y - 1) / mmq_y;
const int j = j0 + threadIdx.y;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
const int i = i0 + threadIdx.x;
- sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
+ sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
}
}
}
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
const int i = i0 + threadIdx.x;
if (need_check && i > i_max) {
continue;
}
- dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
+ dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
}
}
return;
const int col_high = expert_bounds[zt + 1];
const int col_diff = col_high - col_low;
- for (int j = threadIdx.y*WARP_SIZE + threadIdx.x; j < mmq_x; j += nwarps*WARP_SIZE) {
+ for (int j = threadIdx.y*warp_size + threadIdx.x; j < mmq_x; j += nwarps*warp_size) {
ids_dst_shared[j] = ids_dst[col_low + j];
}
__syncthreads();
}
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
const int i = i0 + threadIdx.x;
if (need_check && i > i_max) {
continue;
}
- dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
+ dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
}
}
}
};
template<ggml_type type>
-static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int cc) {
+static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int cc, const int warp_size, const int nwarps) {
const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
const size_t nbs_ids = mmq_x*sizeof(int);
- const size_t nbs_x = new_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
+ const size_t nbs_x = (new_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq);
- return nbs_ids + nbs_x + GGML_PAD(nbs_y, MMQ_NWARPS*WARP_SIZE*sizeof(int));
+ return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int));
}
template <ggml_type type, int mmq_x>
const int id = ggml_cuda_get_device();
const int cc = ggml_cuda_info().devices[id].cc;
const int nsm = ggml_cuda_info().devices[id].nsm;
+ const int warp_size = ggml_cuda_info().devices[id].warp_size;
+ const int nwarps = mmq_get_nwarps_host(cc);
const int mmq_y = get_mmq_y_host(cc);
- const dim3 block_dims(WARP_SIZE, MMQ_NWARPS, 1);
+ const dim3 block_dims(warp_size, nwarps, 1);
- const int nbytes_shared = mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc);
+ const int nbytes_shared = mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc, warp_size, nwarps);
- CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, MMQ_NWARPS, false>), nbytes_shared);
- CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, MMQ_NWARPS, true>), nbytes_shared);
+ CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, false>), nbytes_shared);
+ CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, true>), nbytes_shared);
const int nty = (args.nrows_x + mmq_y - 1) / mmq_y;
const int ntx = (args.ncols_dst + mmq_x - 1) / mmq_x;
if (!args.use_stream_k) {
if (args.nrows_x % mmq_y == 0) {
constexpr bool need_check = false;
- mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
+ mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
} else {
constexpr bool need_check = true;
- mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
+ mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
if (args.nrows_x % mmq_y == 0) {
constexpr bool need_check = false;
-
- mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
+ mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
return;
}
- mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
+ mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
} else {
constexpr bool need_check = true;
-
- mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
+ mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
return;
}
- mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
+ mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
}
template <ggml_type type>
void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
- const int id = ggml_cuda_get_device();
- const int cc = ggml_cuda_info().devices[id].cc;
- const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
+ const int id = ggml_cuda_get_device();
+ const int cc = ggml_cuda_info().devices[id].cc;
+ const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
+ const int warp_size = ggml_cuda_info().devices[id].warp_size;
+ const int nwarps = mmq_get_nwarps_host(cc);
const int mmq_x_max = get_mmq_x_max_host(cc);
const int mmq_y = get_mmq_y_host(cc);
for (int mmq_x = 8; mmq_x <= mmq_x_max && ntiles_x_best > 1; mmq_x += 8) {
const int granularity = mmq_get_granularity_host(mmq_x, cc);
- if (mmq_x % granularity != 0 || mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc) > smpbo) {
+ if (mmq_x % granularity != 0 || mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc, warp_size, nwarps) > smpbo) {
continue;
}