};
static int get_mmq_x_max_host(const int cc) {
- return (amd_mfma_available(cc) || turing_mma_available(cc)) ? 128 :
+ return (amd_mfma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc)) ? 128 :
GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ?
#ifdef GGML_CUDA_FORCE_MMQ
128 : 64;
}
static constexpr __device__ int get_mmq_x_max_device() {
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
return 128;
#else // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
#endif // defined(GGML_USE_HIP)
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
static int get_mmq_y_host(const int cc) {
#define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI8_1)
static int mmq_get_granularity_host(const int mmq_x, const int cc) {
- if (amd_mfma_available(cc)) {
+ if (amd_mfma_available(cc) || amd_wmma_available(cc)) {
return mmq_x >= 128 ? 32 : 16;
} else if (turing_mma_available(cc) && mmq_x >= 48) {
return 16;
}
}
-#if defined(AMD_MFMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
return mmq_x >= 128 ? 32 : 16;
}
#endif // (GGML_USE_HIP)
static constexpr __device__ int mmq_get_nwarps_device() {
-#if defined(AMD_MFMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
return 8;
#else
return 256/ggml_cuda_get_physical_warp_size();
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_0);
constexpr int nrows = warp_size / threads_per_row;
const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
const int qs0 = get_int_b2(bxi->qs, kqsx);
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808);
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808);
#else
const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
#else
x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
}
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
int * x_qs = (int *) x_tile;
half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
int * x_qs = (int *) x_tile;
half2 * x_dm = (half2 *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_1);
constexpr int nrows = warp_size / threads_per_row;
const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
const int qs0 = get_int_b4(bxi->qs, kqsx);
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F;
#else
x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_1;
const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
#else
x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
}
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_0);
constexpr int nrows = warp_size / threads_per_row;
qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
#else
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0;
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_0;
const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
#else
x_df[i*(MMQ_TILE_NE_K/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
}
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
int * x_qs = (int *) x_tile;
half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
int * x_qs = (int *) x_tile;
half2 * x_dm = (half2 *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_1);
constexpr int nrows = warp_size / threads_per_row;
qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
#else
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0;
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_1;
const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
#else
x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
}
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_tile + 2*MMQ_TILE_NE_K);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
// MMQ_ITER_K / (4 * QR8_0) == 64 required. but NV has only 32 threads per warp
constexpr int threads_per_row = 32;
const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + txi] = get_int_b2(bxi[0].qs, kqsx);
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
#else
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 0 + txi] = get_int_b2(bxi[0].qs, kqsx);
x_qs[i*(2*MMQ_TILE_NE_K + 1) + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
constexpr int blocks_per_tile_x_row = 2*MMQ_TILE_NE_K / QI8_0;
const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
#else
x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
}
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR_MXFP4);
constexpr int nrows = warp_size / threads_per_row;
const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
const int k0 = kbx * (2 * QI_MXFP4) + kqsx;
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + 0] = v.x;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + QI_MXFP4] = v.y;
#else
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI_MXFP4] = v.y;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI_MXFP4;
const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbxd;
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_df[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
#else
x_df[i*(MMQ_TILE_NE_K/QI_MXFP4) + i/QI_MXFP4 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
}
template <int mmq_x, int mmq_y, mmq_q8_1_ds_layout ds_layout>
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
-#if defined(AMD_MFMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
typedef tile<16, 8, int> tile_A;
typedef tile<16, 8, int> tile_B;
typedef tile<16, 16, int> tile_C;
}
}
}
-#endif // defined(AMD_MFMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
template <int mmq_x, int mmq_y>
template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
-#if defined(AMD_MFMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
typedef tile<16, 8, int> tile_A;
typedef tile<16, 8, int> tile_B;
typedef tile<16, 16, int> tile_C;
}
}
}
-#endif // defined(AMD_MFMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
// Used for Q3_K, IQ2_S, and IQ2_XS
tile_C C;
mma(C, A[n], B[0]);
+#pragma unroll
+ for (int l = 0; l < tile_C::ne; ++l) {
+ const int i = i0 + n*tile_C::I + tile_C::get_i(l);
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB;
+ }
+ }
+ }
+ }
+#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
+ typedef tile<16, 4, int> tile_A;
+ typedef tile<16, 4, int> tile_B;
+ typedef tile<16, 16, int> tile_C;
+
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
+ constexpr int rows_per_warp = granularity;
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
+
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
+
+ const int * x_qs = (const int *) x;
+ const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
+ const int * y_qs = (const int *) y + 4;
+ const float * y_df = (const float *) y;
+
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
+
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
+ const int k0 = k00 + k01;
+
+ tile_A A[ntx];
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
+ }
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
+ tile_B B;
+ load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
+
+ const int j = j0 + tile_C::get_j(0);
+ const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ tile_C C;
+ mma(C, A[n], B);
+
#pragma unroll
for (int l = 0; l < tile_C::ne; ++l) {
const int i = i0 + n*tile_C::I + tile_C::get_i(l);
#else
GGML_UNUSED_VARS(x, y, sum, k00);
NO_DEVICE_CODE;
-#endif // AMD_MFMA_AVAILABLE
+#endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE
}
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
constexpr int nwarps = mmq_get_nwarps_device();
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
int * x_qs = (int *) x_tile;
half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
int * x_qs = (int *) x_tile;
half2 * x_dm = (half2 *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR2_K);
constexpr int nrows = ggml_cuda_get_physical_warp_size() / threads_per_row;
const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303;
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
#else
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
const int sc_m = bxi->scales[kqsx];
const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
#endif // FAST_FP16_AVAILABLE
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik;
#else
x_dm[i*(MMQ_TILE_NE_K + 1) + kqsx] = x_dm_ik;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
}
tile_C Cd;
mma(Cd, A[n], B[0]);
+#pragma unroll
+ for (int l = 0; l < tile_C::ne; ++l) {
+ const int i = i0 + n*tile_C::I + tile_C::get_i(l);
+ const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]);
+ float tmp = Cd.x[l]*dm.x;
+ if (k01 >= MMQ_TILE_NE_K * 3/4) {
+ tmp -= Cm.x[l]*dm.y;
+ }
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB;
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB;
+ }
+ }
+ }
+ }
+#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
+
+ typedef tile<16, 4, int> tile_A;
+ typedef tile<16, 4, int> tile_B;
+ typedef tile<16, 16, int> tile_C;
+
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
+ constexpr int rows_per_warp = granularity;
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
+
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
+
+ const int * x_qs = (const int *) x;
+ const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;
+ const int * y_qs = (const int *) y + 4;
+ const half2 * y_ds = (const half2 *) y;
+
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
+
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
+ const int k0 = k00 + k01;
+
+ tile_A A[ntx];
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
+ }
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
+ tile_B B;
+ load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
+
+ const int j = j0 + tile_C::get_j(0);
+ const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y;
+ const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0
+ : (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y
+ : __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x);
+
+ tile_C Cm;
+ if (k01 >= MMQ_TILE_NE_K * 3/4) {
+ tile_A A1;
+ A1.x[0] = 0x01010101;
+ A1.x[1] = 0x01010101;
+ mma(Cm, A1, B);
+ }
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ tile_C Cd;
+ mma(Cd, A[n], B);
+
#pragma unroll
for (int l = 0; l < tile_C::ne; ++l) {
const int i = i0 + n*tile_C::I + tile_C::get_i(l);
#else
GGML_UNUSED_VARS(x, y, sum, k00);
NO_DEVICE_CODE;
-#endif // AMD_MFMA_AVAILABLE
+#endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE
}
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
#else
const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404);
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k;
#else
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
}
const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
const int8_t * sc8 = (const int8_t *) ≻
const float d = bxi->d;
}
#else
x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = sc;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
-#if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
+#if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE))
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
x_df[i] = bxi->d;
}
-#endif // !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
+#endif // !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) || defined(AMD_WMMA_AVAILABLE)
}
template <int mmq_x, int mmq_y>
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
int * x_qs = (int *) x_tile;
half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
#else
int * x_qs = (int *) x_tile;
half2 * x_dm = (half2 *) (x_qs + txs.qs);
int * x_sc = (int *) (x_dm + txs.dm);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_K);
constexpr int nrows = warp_size / threads_per_row;
const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
const int qs0 = get_int_b4(bxi->qs, txi);
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F;
#else
x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
constexpr int rows_per_warp = warp_size / 2;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
-#if defined(AMD_MFMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
// Need if on AMD instead of % because warp_size == 64
// This causes double work and throughput loss (MI300X)
// H100 loses about 100 t/s with 'if' condition over '%'
#else
int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y;
{
-#endif // defined(AMD_MFMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
if (need_check) {
i = min(i, i_max);
}
x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
}
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
template <int mmq_x, int mmq_y>
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
int * x_qs = (int *) x_tile;
half2 * x_dm = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
#else
const int kq0 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + 0;
const int kq1 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + QI5_K/4;
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1;
#else
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = ql0 | qh0;
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = ql1 | qh1;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
constexpr int rows_per_warp = warp_size / 2;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
#else
int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y;
{
-#endif // defined(AMD_MFMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
if (need_check) {
i = min(i, i_max);
}
x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
}
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
template <int mmq_x, int mmq_y>
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
int * x_sc = (int *) (x_df + MMQ_TILE_NE_K/QI6_K);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
int * x_sc = (int *) (x_df + txs.dm);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR6_K);
constexpr int nrows = warp_size / threads_per_row;
const int kq0 = 2*txi - txi % (QI6_K/2) + 0;
const int kq1 = 2*txi - txi % (QI6_K/2) + QI6_K/2;
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
#else
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
#pragma unroll
const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_df[i*MMQ_MMA_TILE_X_K_Q6_K] = bxi->d;
#else
x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K] = bxi->d;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
constexpr int rows_per_warp = warp_size / 4;
const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / 4;
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x % (MMQ_TILE_NE_K/8));
#else
x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + threadIdx.x%(MMQ_TILE_NE_K/8)] = get_int_b2(bxi->scales, threadIdx.x%(QI6_K/8));
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
}
tile_C C;
mma(C, A[n], B[0]);
+#pragma unroll
+ for (int l = 0; l < tile_C::ne; ++l) {
+ const int i = i0 + n*tile_C::I + tile_C::get_i(l);
+ const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16);
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB;
+ }
+ }
+ }
+ }
+#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
+ typedef tile<16, 4, int> tile_A;
+ typedef tile<16, 4, int> tile_B;
+ typedef tile<16, 16, int> tile_C;
+
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
+ constexpr int rows_per_warp = granularity;
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
+
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
+
+ const int * x_qs = (const int *) x;
+ const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
+ const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K;
+ const int * y_qs = (const int *) y + 4;
+ const float * y_df = (const float *) y;
+
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
+
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
+ const int k0 = k00 + k01;
+
+ tile_A A[ntx];
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
+ }
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
+ tile_B B;
+ load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
+
+ const int j = j0 + tile_C::get_j(0);
+ const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ tile_C C;
+ mma(C, A[n], B);
+
#pragma unroll
for (int l = 0; l < tile_C::ne; ++l) {
const int i = i0 + n*tile_C::I + tile_C::get_i(l);
#else
GGML_UNUSED_VARS(x, y, sum, k00);
NO_DEVICE_CODE;
-#endif // AMD_MFMA_AVAILABLE
+#endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE
}
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_NL);
constexpr int nrows = warp_size / threads_per_row;
const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
const int k0 = kbx * (2 * QI4_NL) + kqsx;
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + QI4_NL] = v.y;
#else
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI4_NL] = v.y;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_NL;
const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d);
#else
x_df[i*(MMQ_TILE_NE_K/QI4_NL) + i/QI4_NL + kbxd] = __half2float(bxi->d);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
}
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XXS)) / 2;
constexpr int nrows = warp_size / threads_per_row;
const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1;
#else
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid0;
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid1;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
const int ls = aux32 >> 28;
const float d = bxi->d;
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
#else
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
}
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
#else
constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XS)) / 2;
constexpr int nrows = warp_size / threads_per_row;
const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
#else
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
const int ls = bxi->scales[kqsx];
const float d = bxi->d;
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
#else
x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
}
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
-
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_S)) / 2;
constexpr int nrows = warp_size / threads_per_row;
const int kqsx = threadIdx.x % threads_per_row;
const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0);
const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1);
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
#else
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
const int ls = bxi->scales[kqsx];
const float d = bxi->d;
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
#else
x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
}
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_XXS)) / 2;
constexpr int nrows = warp_size / threads_per_row;
const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h;
#else
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
const int ls = aux32 >> 28;
const float d = bxi->d;
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2;
#else
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/2;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
}
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_S)) / 2;
constexpr int nrows = warp_size / threads_per_row;
const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h;
#else
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid_l;
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid_h;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F);
const float d = bxi->d;
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d;
#else
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = ls*d;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
}
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
int * x_qs = (int *) x_tile;
half2 * x_ds = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
int * x_qs = (int *) x_tile;
half2 * x_ds = (half2 *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR1_S);
constexpr int nrows = warp_size / threads_per_row;
const int grid0 = (grid >> 0) & 0x0F0F0F0F;
const int grid1 = (grid >> 4) & 0x0F0F0F0F;
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1;
#else
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid0;
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid1;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1);
const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta);
#else
x_ds[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
}
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_XS);
constexpr int nrows = warp_size / threads_per_row;
const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
const int k0 = 8 * (kqsx / 4) + kqsx % 4;
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
#else
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 4] = v.y;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
constexpr int rows_per_warp = warp_size / 8;
const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F)
| (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32);
#else
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + threadIdx.x % 8] = d * (ls - 32);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
}
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int nwarps = mmq_get_nwarps_device();
-#if defined(AMD_MFMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
constexpr int tileC_IJ = mmq_get_granularity_device(0);
typedef tile<tileC_IJ, tileC_IJ, int> tile_C;
constexpr int rows_per_warp = granularity;
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I);
-#if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+#if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y");
#else
GGML_UNUSED(nwarps);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
int * tile_y = data_mul_mat_q + mmq_x;
int * tile_x = tile_y + GGML_PAD(mmq_x*MMQ_TILE_Y_K, nwarps*warp_size);
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_mma;
constexpr mmq_write_back_t write_back = mmq_write_back_mma<type, mmq_x, mmq_y, need_check>;
#else
constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_dp4a;
constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, need_check>;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
constexpr int blocks_per_iter = MMQ_ITER_K / qk;