#include <climits>
#include <cstdint>
-#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
-#define MMQ_NWARPS 8
-
-typedef void (*load_tiles_mmq_t)(
- const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
- int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride);
-typedef void (*vec_dot_mmq_t)(
- const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
- const int * __restrict__ y, float * __restrict__ sum, const int & k0);
+typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int & kbx0, const int & i_max, const int & stride);
+typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0);
typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max);
struct block_q8_1_mmq {
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
}
-#define TILE_X_SIZES_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
-#define TILE_X_SIZES_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0}
-#define TILE_X_SIZES_Q5_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_0 + mmq_y/QI5_0, 0}
-#define TILE_X_SIZES_Q5_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_1 + mmq_y/QI5_1, 0}
-#define TILE_X_SIZES_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI8_0 + mmq_y/QI8_0, 0}
-#define TILE_X_SIZES_Q2_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE + mmq_y, 0}
-#define TILE_X_SIZES_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI3_K + mmq_y/QI3_K, mmq_y*WARP_SIZE/4 + mmq_y/4}
-#define TILE_X_SIZES_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
-#define TILE_X_SIZES_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
-#define TILE_X_SIZES_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
-
-#define GET_TILE_X_SIZES_BODY \
- return type == GGML_TYPE_Q4_0 ? TILE_X_SIZES_Q4_0 : \
- type == GGML_TYPE_Q4_1 ? TILE_X_SIZES_Q4_1 : \
- type == GGML_TYPE_Q5_0 ? TILE_X_SIZES_Q5_0 : \
- type == GGML_TYPE_Q5_1 ? TILE_X_SIZES_Q5_1 : \
- type == GGML_TYPE_Q8_0 ? TILE_X_SIZES_Q8_0 : \
- type == GGML_TYPE_Q2_K ? TILE_X_SIZES_Q2_K : \
- type == GGML_TYPE_Q3_K ? TILE_X_SIZES_Q3_K : \
- type == GGML_TYPE_Q4_K ? TILE_X_SIZES_Q4_K : \
- type == GGML_TYPE_Q5_K ? TILE_X_SIZES_Q5_K : \
- type == GGML_TYPE_Q6_K ? TILE_X_SIZES_Q6_K : \
- tile_x_sizes{0, 0, 0}
-
-static tile_x_sizes get_tile_x_sizes_host(const ggml_type type, const int mmq_y) {
- GET_TILE_X_SIZES_BODY;
+#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
+#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0}
+#define MMQ_DP4A_TXS_Q5_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_0 + mmq_y/QI5_0, 0}
+#define MMQ_DP4A_TXS_Q5_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_1 + mmq_y/QI5_1, 0}
+#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI8_0 + mmq_y/QI8_0, 0}
+#define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE + mmq_y, 0}
+#define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI3_K + mmq_y/QI3_K, mmq_y*WARP_SIZE/4 + mmq_y/4}
+#define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
+#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
+#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
+
+static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
+ return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
+ type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 :
+ type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q5_0 :
+ type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q5_1 :
+ type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 :
+ type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K :
+ type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K :
+ type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
+ type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
+ type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
+ tile_x_sizes{0, 0, 0};
}
-template <int mmq_y>
-static constexpr __device__ tile_x_sizes get_tile_x_sizes_device(ggml_type type) {
- GET_TILE_X_SIZES_BODY;
+#define MMQ_MMA_TILE_X_K_Q4_0 (1*WARP_SIZE + WARP_SIZE/QI4_0 + 4)
+#define MMQ_MMA_TILE_X_K_Q4_1 (1*WARP_SIZE + WARP_SIZE/QI4_1 + 4)
+#define MMQ_MMA_TILE_X_K_Q5_0 (2*WARP_SIZE + WARP_SIZE/QI5_0 + 4)
+#define MMQ_MMA_TILE_X_K_Q5_1 (2*WARP_SIZE + WARP_SIZE/QI5_1 + 4)
+#define MMQ_MMA_TILE_X_K_Q8_0 (1*WARP_SIZE + WARP_SIZE/QI8_0 + 0)
+#define MMQ_MMA_TILE_X_K_Q2_K (1*WARP_SIZE + WARP_SIZE + 4)
+#define MMQ_MMA_TILE_X_K_Q3_K (2*WARP_SIZE + WARP_SIZE/QI3_K + WARP_SIZE/4 + 2)
+#define MMQ_MMA_TILE_X_K_Q4_K (1*WARP_SIZE + WARP_SIZE/QI4_K + WARP_SIZE/8 + 7)
+#define MMQ_MMA_TILE_X_K_Q5_K (2*WARP_SIZE + WARP_SIZE/QI5_K + WARP_SIZE/8 + 7)
+#define MMQ_MMA_TILE_X_K_Q6_K (2*WARP_SIZE + WARP_SIZE/QI6_K + WARP_SIZE/8 + 7)
+
+static_assert(MMQ_MMA_TILE_X_K_Q4_0 % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q4_1 % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q5_0 % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q5_1 % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q4_K % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q5_K % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
+
+static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
+ return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q4_0 :
+ type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q4_1 :
+ type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q5_0 :
+ type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q5_1 :
+ type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
+ type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K :
+ type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K :
+ type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q4_K :
+ type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q5_K :
+ type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
+ 0;
+}
+
+#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
+#define MMQ_NWARPS 8
+
+static int mmq_get_granularity_host(const int mmq_x, const int cc) {
+ return int8_mma_available(cc) && mmq_x >= 48 ? 16 : 8;
+}
+
+#ifdef INT8_MMA_AVAILABLE
+static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
+ return mmq_x >= 48 ? 16 : 8;
}
+#else
+static constexpr __device__ int mmq_get_granularity_device(const int /* mmq_x */) {
+ return 8;
+}
+#endif // INT8_MMA_AVAILABLE
// ------------------------------------------------------------
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
- const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
- int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
- GGML_UNUSED(x_sc);
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + WARP_SIZE);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
const int kbx = threadIdx.x / QI4_0;
const int kqsx = threadIdx.x % QI4_0;
- float * x_dmf = (float *) x_dm;
-
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
int i = i0 + threadIdx.y;
const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
- x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx);
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q4_0 + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx);
+#else
+ x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx);
+#endif // INT8_MMA_AVAILABLE
}
const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
- x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = bxi->d;
+#ifdef INT8_MMA_AVAILABLE
+ x_df[i*MMQ_MMA_TILE_X_K_Q4_0 + kbxd] = bxi->d;
+#else
+ x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
+#endif // INT8_MMA_AVAILABLE
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
- const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
- const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
- GGML_UNUSED(x_sc);
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
- const float * x_df = (const float *) x_dm;
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
+ const int * x_qs = (const int *) x;
+ const float * x_df = (const float *) x_qs + txs.qs;
const int * y_qs = (const int *) y + 4;
const half2 * y_ds = (const half2 *) y;
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma(
- const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
- const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
#ifdef INT8_MMA_AVAILABLE
- GGML_UNUSED(x_sc);
typedef mma_int_A_I16K8 mma_A;
typedef mma_int_B_J8K8 mma_B;
typedef mma_int_C_I16J8 mma_C;
- const float * x_df = (const float *) x_dm;
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
+ constexpr int rows_per_warp = 2 * granularity;
+ constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+ y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+ const int * x_qs = (const int *) x;
+ const float * x_df = (const float *) x_qs + WARP_SIZE;
const int * y_qs = (const int *) y + 4;
const half2 * y_ds = (const half2 *) y;
- mma_A A;
- float dA[mma_C::ne/2];
+ mma_A A[ntx];
+ float dA[ntx][mma_C::ne/2];
- const int i0 = threadIdx.y*mma_A::I;
- static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+ const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
#pragma unroll
- for (int l = 0; l < mma_A::ne; ++l) {
- const int i = i0 + mma_A::get_i(l);
- const int k = k0 + mma_A::get_k(l) % QI4_0;
- const int shift = 4*(mma_A::get_k(l) / QI4_0);
+ for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+ for (int l = 0; l < mma_A::ne; ++l) {
+ const int i = i0 + n*mma_A::I + mma_A::get_i(l);
+ const int k = k0 + mma_A::get_k(l) % QI4_0;
+ const int shift = 4*(mma_A::get_k(l) / QI4_0);
+
+ A[n].x[l] = __vsubss4((x_qs[i*MMQ_MMA_TILE_X_K_Q4_0 + k] >> shift) & 0x0F0F0F0F, 0x08080808);
+ }
- A.x[l] = __vsubss4((x_qs[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F, 0x08080808);
- }
#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int i = i0 + mma_C::get_i(2*l);
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
- dA[l] = x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0];
+ dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q4_0 + k0/QI4_0];
+ }
}
- for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
- mma_C C;
- mma_B B;
- half2 dsB[mma_C::ne/2];
-
#pragma unroll
- for (int l = 0; l < mma_B::ne; ++l) {
- const int j = j0 + mma_B::get_j(l);
- const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+ mma_B B;
+ float dB[mma_C::ne/2];
+
+ B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K);
- B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
- }
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
- dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
+ dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
}
- C.mma_K8(A, B);
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ mma_C C;
+ C.mma_K8(A[n], B);
#pragma unroll
- for (int l = 0; l < mma_C::ne; ++l) {
- sum[(j0/B.J)*C.ne + l] += dA[l/2]*__low2float(dsB[l%2])*C.x[l];
+ for (int l = 0; l < mma_C::ne; ++l) {
+ sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2]*dB[l%2]*C.x[l];
+ }
}
}
#else
- GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+ GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
NO_DEVICE_CODE;
#endif // INT8_MMA_AVAILABLE
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
- const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
- int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
- GGML_UNUSED(x_sc);
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ half2 * x_dm = (half2 *) (x_qs + WARP_SIZE);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
+ int * x_qs = (int *) x_tile;
+ half2 * x_dm = (half2 *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
const int kbx = threadIdx.x / QI4_1;
const int kqsx = threadIdx.x % QI4_1;
const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
- x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q4_1 + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+#else
+ x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+#endif // INT8_MMA_AVAILABLE
}
const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
- x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm;
+#ifdef INT8_MMA_AVAILABLE
+ x_dm[i*MMQ_MMA_TILE_X_K_Q4_1 + kbxd] = bxi->dm;
+#else
+ x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
+#endif // INT8_MMA_AVAILABLE
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
- const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
- const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
- GGML_UNUSED(x_sc);
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
+ const int * x_qs = (const int *) x;
+ const half2 * x_dm = (const half2 *) x_qs + txs.qs;
const int * y_qs = (const int *) y + 4;
const half2 * y_ds = (const half2 *) y;
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma(
- const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
- const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
#ifdef INT8_MMA_AVAILABLE
- GGML_UNUSED(x_sc);
typedef mma_int_A_I16K8 mma_A;
+ typedef mma_int_A_I16K4 mma_A_K4;
typedef mma_int_B_J8K8 mma_B;
typedef mma_int_C_I16J8 mma_C;
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
+ constexpr int rows_per_warp = 2 * granularity;
+ constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+ y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+ const int * x_qs = (const int *) x;
+ const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE;
const int * y_qs = (const int *) y + 4;
const half2 * y_ds = (const half2 *) y;
- mma_A A;
- half2 dmA[mma_C::ne/2];
+ mma_A A[ntx];
+ half2 dmA[ntx][mma_C::ne/2];
- const int i0 = threadIdx.y*mma_A::I;
- static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+ const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
#pragma unroll
- for (int l = 0; l < mma_A::ne; ++l) {
- const int i = i0 + mma_A::get_i(l);
- const int k = k0 + mma_A::get_k(l) % QI4_0;
- const int shift = 4*(mma_A::get_k(l) / QI4_0);
+ for (int n = 0; n < ntx; ++n) {
+ ((mma_A_K4 *) &A[n])[0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_1 + k0, MMQ_MMA_TILE_X_K_Q4_1);
+ A[n].x[2] = (A[n].x[0] >> 4) & 0x0F0F0F0F;
+ A[n].x[3] = (A[n].x[1] >> 4) & 0x0F0F0F0F;
+ A[n].x[0] &= 0x0F0F0F0F;
+ A[n].x[1] &= 0x0F0F0F0F;
- A.x[l] = (x_qs[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F;
- }
#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int i = i0 + mma_C::get_i(2*l);
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
- dmA[l] = x_dm[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0];
+ dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_1 + k0/QI4_1];
+ }
}
- for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
- mma_C C;
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
mma_B B;
half2 dsB[mma_C::ne/2];
-#pragma unroll
- for (int l = 0; l < mma_B::ne; ++l) {
- const int j = j0 + mma_B::get_j(l);
- const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
+ B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K);
- B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
- }
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
}
- C.mma_K8(A, B);
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ mma_C C;
+ C.mma_K8(A[n], B);
#pragma unroll
- for (int l = 0; l < mma_C::ne; ++l) {
- const half2 dmA_dsB = dmA[l/2]*dsB[l%2];
- sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
+ for (int l = 0; l < mma_C::ne; ++l) {
+ const half2 dmA_dsB = dmA[n][l/2]*dsB[l%2];
+ sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
+ }
}
}
#else
- GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+ GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
NO_DEVICE_CODE;
#endif // INT8_MMA_AVAILABLE
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
- const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
- int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
- GGML_UNUSED(x_sc);
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
const int kbx = threadIdx.x / QI5_0;
const int kqsx = threadIdx.x % QI5_0;
qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
qs0 = __vsubss4(qs0, 0x10101010); // subtract 16
- x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0;
-
int qs1 = (ql >> 4) & 0x0F0F0F0F;
qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
- x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1;
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
+#else
+ x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0;
+ x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
+#endif // INT8_MMA_AVAILABLE
}
const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
- float * x_dmf = (float *) x_dm;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) {
const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
- x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = bxi->d;
+#ifdef INT8_MMA_AVAILABLE
+ x_df[i*MMQ_MMA_TILE_X_K_Q5_0 + kbxd] = bxi->d;
+#else
+ x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
+#endif // INT8_MMA_AVAILABLE
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a(
- const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
- const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
- GGML_UNUSED(x_sc);
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
- const float * x_dmf = (const float *) x_dm;
- const int * y_qs = (const int *) y + 4;
- const float * y_df = (const float *) y;
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
+ const int * x_qs = (const int *) x;
+ const float * x_df = (const float *) x_qs + txs.qs;
+ const int * y_qs = (const int *) y + 4;
+ const float * y_df = (const float *) y;
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
- const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
- const int index_bx = i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0;
-
- int u[2*VDR_Q5_0_Q8_1_MMQ];
-
-#pragma unroll
- for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) {
- u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE];
- u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI5_0) % WARP_SIZE];
- }
-
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, QR5_0*VDR_Q5_0_Q8_1_MMQ>
- (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dmf[index_bx], y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
+ (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], &y_qs[j*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE],
+ x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0], y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
}
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma(
- const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
- const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
#ifdef INT8_MMA_AVAILABLE
- GGML_UNUSED(x_sc);
typedef mma_int_A_I16K8 mma_A;
typedef mma_int_B_J8K8 mma_B;
typedef mma_int_C_I16J8 mma_C;
- const float * x_df = (const float *) x_dm;
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
+ constexpr int rows_per_warp = 2 * granularity;
+ constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+ y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+ const int * x_qs = (const int *) x;
+ const float * x_df = (const float *) x_qs + WARP_SIZE*2;
const int * y_qs = (const int *) y + 4;
const float * y_df = (const float *) y;
- mma_A A;
- float dA[mma_C::ne/2];
+ mma_A A[ntx];
+ float dA[ntx][mma_C::ne/2];
- const int i0 = threadIdx.y*mma_A::I;
- static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+ const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
#pragma unroll
- for (int l = 0; l < mma_A::ne; ++l) {
- const int i = i0 + mma_A::get_i(l);
- const int k = 2*(k0 + mma_A::get_k(l) % QI5_0) + mma_A::get_k(l) / QI5_0;
+ for (int n = 0; n < ntx; ++n) {
+ A[n].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_0 + QR5_1*k0, MMQ_MMA_TILE_X_K_Q5_0);
- A.x[l] = x_qs[i*(2*WARP_SIZE + 1) + k];
- }
#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int i = i0 + mma_C::get_i(2*l);
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int i = i0 + mma_C::get_i(2*l) + n*mma_C::I;
- dA[l] = x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0];
+ dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q5_0 + k0/QI5_0];
+ }
}
- for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
- mma_C C;
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
mma_B B;
float dB[mma_C::ne/2];
-#pragma unroll
- for (int l = 0; l < mma_B::ne; ++l) {
- const int j = j0 + mma_B::get_j(l);
- const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
+ B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K);
- B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
- }
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
dB[l] = y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
}
- C.mma_K8(A, B);
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ mma_C C;
+ C.mma_K8(A[n], B);
#pragma unroll
- for (int l = 0; l < mma_C::ne; ++l) {
- sum[(j0/B.J)*C.ne + l] += dA[l/2]*dB[l%2]*C.x[l];
+ for (int l = 0; l < mma_C::ne; ++l) {
+ sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2]*dB[l%2]*C.x[l];
+ }
}
}
#else
- GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+ GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
NO_DEVICE_CODE;
#endif // INT8_MMA_AVAILABLE
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
- const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
- int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
- GGML_UNUSED(x_sc);
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
+ int * x_qs = (int *) x_tile;
+ half2 * x_dm = (half2 *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
const int kbx = threadIdx.x / QI5_1;
const int kqsx = threadIdx.x % QI5_1;
qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
- x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0;
-
int qs1 = (ql >> 4) & 0x0F0F0F0F;
qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
- x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1;
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q5_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q5_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
+#else
+ x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0;
+ x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
+#endif // INT8_MMA_AVAILABLE
}
const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
- x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm;
+#ifdef INT8_MMA_AVAILABLE
+ x_dm[i*MMQ_MMA_TILE_X_K_Q5_1 + kbxd] = bxi->dm;
+#else
+ x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
+#endif // INT8_MMA_AVAILABLE
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a(
- const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
- const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
- GGML_UNUSED(x_sc);
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
- const int * y_qs = (const int *) y + 4;
- const half2 * y_ds = (const half2 *) y;
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
+ const int * x_qs = (const int *) x;
+ const half2 * x_dm = (const half2 *) x_qs + txs.qs;
+ const int * y_qs = (const int *) y + 4;
+ const half2 * y_ds = (const half2 *) y;
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
- const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
- const int index_bx = i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI5_1;
-
- int u[2*VDR_Q5_1_Q8_1_MMQ];
-
-#pragma unroll
- for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) {
- u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE];
- u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI5_1) % WARP_SIZE];
- }
-
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
- (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dm[index_bx], y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
+ (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], &y_qs[j*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE],
+ x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI5_1], y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
}
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma(
- const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
- const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
#ifdef INT8_MMA_AVAILABLE
- GGML_UNUSED(x_sc);
typedef mma_int_A_I16K8 mma_A;
typedef mma_int_B_J8K8 mma_B;
typedef mma_int_C_I16J8 mma_C;
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
+ constexpr int rows_per_warp = 2 * granularity;
+ constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+ y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+ const int * x_qs = (const int *) x;
+ const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE;
const int * y_qs = (const int *) y + 4;
const half2 * y_ds = (const half2 *) y;
- mma_A A;
- half2 dmA[mma_C::ne/2];
+ mma_A A[ntx];
+ half2 dmA[ntx][mma_C::ne/2];
- const int i0 = threadIdx.y*mma_A::I;
- static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+ const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
#pragma unroll
- for (int l = 0; l < mma_A::ne; ++l) {
- const int i = i0 + mma_A::get_i(l);
- const int k = 2*(k0 + mma_A::get_k(l) % QI5_1) + mma_A::get_k(l) / QI5_1;
+ for (int n = 0; n < ntx; ++n) {
+ A[n].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_1 + QR5_1*k0, MMQ_MMA_TILE_X_K_Q5_1);
- A.x[l] = x_qs[i*(2*WARP_SIZE + 1) + k];
- }
#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int i = i0 + mma_C::get_i(2*l);
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int i = i0 + mma_C::get_i(2*l) + n*mma_C::I;
- dmA[l] = x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI5_1];
+ dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q5_1 + k0/QI5_1];
+ }
}
- for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
- mma_C C;
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
mma_B B;
half2 dsB[mma_C::ne/2];
-#pragma unroll
- for (int l = 0; l < mma_B::ne; ++l) {
- const int j = j0 + mma_B::get_j(l);
- const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
+ B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K);
- B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
- }
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
}
- C.mma_K8(A, B);
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ mma_C C;
+ C.mma_K8(A[n], B);
#pragma unroll
- for (int l = 0; l < mma_C::ne; ++l) {
- const half2 dmA_dsB = dmA[l/2]*dsB[l%2];
- sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
+ for (int l = 0; l < mma_C::ne; ++l) {
+ const half2 dmA_dsB = dmA[n][l/2]*dsB[l%2];
+ sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
+ }
}
}
#else
- GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+ GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
NO_DEVICE_CODE;
#endif // INT8_MMA_AVAILABLE
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
- const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
- int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
- GGML_UNUSED(x_sc);
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_tile + WARP_SIZE);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
const int kbx = threadIdx.x / QI8_0;
const int kqsx = threadIdx.x % QI8_0;
- float * x_dmf = (float *) x_dm;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
- x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx);
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx);
+#else
+ x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx);
+#endif // INT8_MMA_AVAILABLE
}
const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
- x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d;
+#ifdef INT8_MMA_AVAILABLE
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
+#else
+ x_df[i*(WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d;
+#endif // INT8_MMA_AVAILABLE
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
- const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
- const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
- GGML_UNUSED(x_sc);
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
- const float * x_dmf = (const float *) x_dm;
- const int * y_qs = (const int *) y + 4;
- const float * y_df = (const float *) y;
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
+ const int * x_qs = (const int *) x;
+ const float * x_df = (const float *) x_qs + txs.qs;
+ const int * y_qs = (const int *) y + 4;
+ const float * y_df = (const float *) y;
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
const int i = i0 + threadIdx.x;
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
- (&x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0], x_dmf[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0],
+ (&x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0], x_df[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0],
y_df[j*MMQ_TILE_Y_K + k0/QI8_1]);
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
- const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
- const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
#ifdef INT8_MMA_AVAILABLE
- GGML_UNUSED(x_sc);
typedef mma_int_A_I16K8 mma_A;
typedef mma_int_B_J8K8 mma_B;
typedef mma_int_C_I16J8 mma_C;
- const float * x_df = (const float *) x_dm;
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
+ constexpr int rows_per_warp = 2 * granularity;
+ constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+ y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+ const int * x_qs = (const int *) x;
+ const float * x_df = (const float *) x_qs + WARP_SIZE;
const int * y_qs = (const int *) y + 4;
const float * y_df = (const float *) y;
- mma_A A;
- float dA[mma_C::ne/2];
+ mma_A A[ntx];
+ float dA[ntx][mma_C::ne/2];
- const int i0 = threadIdx.y*mma_A::I;
- static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+ const int i0 = (threadIdx.y/ntx)*rows_per_warp;
#pragma unroll
- for (int l = 0; l < mma_A::ne; ++l) {
- const int i = i0 + mma_A::get_i(l);
- const int k = k0 + mma_A::get_k(l);
+ for (int n = 0; n < ntx; ++n) {
+ A[n].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
- A.x[l] = x_qs[i*(WARP_SIZE + 1) + k];
- }
#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int i = i0 + mma_C::get_i(2*l);
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
- dA[l] = x_df[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0];
+ dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
+ }
}
- for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
- mma_C C;
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
mma_B B;
float dB[mma_C::ne/2];
-#pragma unroll
- for (int l = 0; l < mma_B::ne; ++l) {
- const int j = j0 + mma_B::get_j(l);
- const int k = k0 + mma_B::get_k(l);
+ B.load(y_qs + j0*MMQ_TILE_Y_K + k0, MMQ_TILE_Y_K);
- B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
- }
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
dB[l] = y_df[j*MMQ_TILE_Y_K + k0/QI8_1];
}
- C.mma_K8(A, B);
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ mma_C C;
+ C.mma_K8(A[n], B);
#pragma unroll
- for (int l = 0; l < mma_C::ne; ++l) {
- sum[(j0/B.J)*C.ne + l] += C.x[l]*dA[l/2]*dB[l%2];
+ for (int l = 0; l < mma_C::ne; ++l) {
+ sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2]*dB[l%2];
+ }
}
}
#else
- GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+ GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
NO_DEVICE_CODE;
#endif // INT8_MMA_AVAILABLE
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
- const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
- int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ half2 * x_dm = (half2 *) (x_qs + WARP_SIZE);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
+ int * x_qs = (int *) x_tile;
+ half2 * x_dm = (half2 *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
const int kbx = threadIdx.x / QI2_K;
const int kqsx = threadIdx.x % QI2_K;
continue;
}
- x_qs[i*(WARP_SIZE + 1) + k] = x_qs_k;
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
+#else
+ x_qs[i*(WARP_SIZE + 1) + k] = x_qs_k;
+#endif // INT8_MMA_AVAILABLE
}
const int sc_m = bxi->scales[kqsx];
const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
#endif // FAST_FP16_AVAILABLE
- x_dm[i*(WARP_SIZE + 1) + threadIdx.x] = x_dm_ik;
+#ifdef INT8_MMA_AVAILABLE
+ x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + threadIdx.x] = x_dm_ik;
+#else
+ x_dm[i*(WARP_SIZE + 1) + threadIdx.x] = x_dm_ik;
+#endif // INT8_MMA_AVAILABLE
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
- const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
- const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
+ const int * x_qs = (const int *) x;
+ const half2 * x_dm = (const half2 *) x_qs + txs.qs;
const int * y_qs = (const int *) y + 4;
const float * y_df = (const float *) y;
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
- const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
- const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
#ifdef INT8_MMA_AVAILABLE
typedef mma_int_A_I16K4 mma_A;
typedef mma_int_B_J8K4 mma_B;
typedef mma_int_C_I16J8 mma_C;
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
+ constexpr int rows_per_warp = 2 * granularity;
+ constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+ y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+ const int * x_qs = (const int *) x;
+ const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE;
const int * y_qs = (const int *) y + 4;
const float * y_df = (const float *) y;
- const int i0 = threadIdx.y*mma_A::I;
- static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+ const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
- mma_A A[2];
- float dA[mma_C::ne/2][2];
- float mA[mma_C::ne/2][2];
+ mma_A A[ntx][2];
+ float dA[ntx][mma_C::ne/2][2];
+ float mA[ntx][mma_C::ne/2][2];
#pragma unroll
- for (int l = 0; l < mma_A::ne; ++l) {
- const int i = i0 + mma_A::get_i(l);
- const int shift = 2*mma_A::get_k(l);
+ for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+ for (int l = 0; l < mma_A::ne; ++l) {
+ const int i = i0 + n*mma_A::I + mma_A::get_i(l);
+ const int shift = 2*mma_A::get_k(l);
- A[0].x[l] = (x_qs[i*(WARP_SIZE + 1) + k0 + 0] >> shift) & 0x03030303;
- A[1].x[l] = (x_qs[i*(WARP_SIZE + 1) + k0 + 1] >> shift) & 0x03030303;
- }
+ A[n][0].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k0 + 0] >> shift) & 0x03030303;
+ A[n][1].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k0 + 1] >> shift) & 0x03030303;
+ }
#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int i = i0 + mma_C::get_i(2*l);
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
#pragma unroll
- for (int kk = 0; kk < 2; ++kk) {
- const float2 dm = __half22float2(x_dm[i*(WARP_SIZE + 1) + k0 + kk]);
+ for (int kdm = 0; kdm < 2; ++kdm) {
+ const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0 + kdm]);
- dA[l][kk] = dm.x;
- mA[l][kk] = dm.y;
+ dA[n][l][kdm] = dm.x;
+ mA[n][l][kdm] = dm.y;
+ }
}
}
#pragma unroll
- for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
- mma_C Cd[2];
- mma_C Cm[2];
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
mma_B B[2];
float dB[mma_C::ne/2];
-#pragma unroll
- for (int l = 0; l < mma_B::ne; ++l) {
- const int j = j0 + mma_B::get_j(l);
- const int k = (4*k0 + mma_B::get_k(l)) % WARP_SIZE;
+ B[0].load(y_qs + j0*MMQ_TILE_Y_K + (QR2_K*k0 + 0) % WARP_SIZE, MMQ_TILE_Y_K);
+ B[1].load(y_qs + j0*MMQ_TILE_Y_K + (QR2_K*k0 + mma_B::K) % WARP_SIZE, MMQ_TILE_Y_K);
- B[0].x[l] = y_qs[j*MMQ_TILE_Y_K + k + 0];
- B[1].x[l] = y_qs[j*MMQ_TILE_Y_K + k + mma_B::K];
- }
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
dB[l] = y_df[j*MMQ_TILE_Y_K + ((4*k0)/QI8_1) % (WARP_SIZE/QI8_1)];
}
- Cd[0].mma_K4(A[0], B[0]);
- Cd[1].mma_K4(A[1], B[1]);
-
+ mma_C Cm[2];
mma_A A1;
A1.x[0] = 0x01010101;
A1.x[1] = 0x01010101;
Cm[1].mma_K4(A1, B[1]);
#pragma unroll
- for (int l = 0; l < mma_C::ne; ++l) {
- sum[(j0/mma_B::J)*mma_C::ne + l] += (Cd[0].x[l]*dA[l/2][0] + Cd[1].x[l]*dA[l/2][1] - Cm[0].x[l]*mA[l/2][0] - Cm[1].x[l]*mA[l/2][1])*dB[l%2];
+ for (int n = 0; n < ntx; ++n) {
+ mma_C Cd[2];
+
+ Cd[0].mma_K4(A[n][0], B[0]);
+ Cd[1].mma_K4(A[n][1], B[1]);
+
+#pragma unroll
+ for (int l = 0; l < mma_C::ne; ++l) {
+ sum[(j0/mma_C::J + n)*mma_C::ne + l] += (
+ Cd[0].x[l]*dA[n][l/2][0] + Cd[1].x[l]*dA[n][l/2][1] - Cm[0].x[l]*mA[n][l/2][0] - Cm[1].x[l]*mA[n][l/2][1])*dB[l%2];
+ }
}
}
#else
- GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+ GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
NO_DEVICE_CODE;
#endif // INT8_MMA_AVAILABLE
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
- const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
- int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + WARP_SIZE*2);
+ int * x_sc = (int *) (x_df + WARP_SIZE/QI3_K);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+ int * x_sc = (int *) (x_df + txs.dm);
+#endif // INT8_MMA_AVAILABLE
const int kbx = threadIdx.x / QI3_K;
const int kqsx = threadIdx.x % QI3_K;
continue;
}
- x_qs[i*(2*WARP_SIZE + 1) + k/2] = x_qs_k;
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k/2] = x_qs_k;
+#else
+ x_qs[i*(2*WARP_SIZE + 1) + k/2] = x_qs_k;
+#endif // INT8_MMA_AVAILABLE
}
}
const int blocks_per_tile_x_row = WARP_SIZE / QI3_K;
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
- float * x_dmf = (float *) x_dm;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) {
const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbxd;
- x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d;
+#ifdef INT8_MMA_AVAILABLE
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K + kbxd] = bxi->d;
+#else
+ x_df[i*(WARP_SIZE/QI3_K) + i/QI3_K + kbxd] = bxi->d;
+#endif // INT8_MMA_AVAILABLE
}
#pragma unroll
const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
- x_sc[i * (WARP_SIZE/4) + i / 4 + threadIdx.x % (WARP_SIZE/4)] = sc;
+#ifdef INT8_MMA_AVAILABLE
+ x_sc[i*MMQ_MMA_TILE_X_K_Q3_K + threadIdx.x % (WARP_SIZE/4)] = sc;
+#else
+ x_sc[i*(WARP_SIZE/4) + i/4 + threadIdx.x % (WARP_SIZE/4)] = sc;
+#endif // INT8_MMA_AVAILABLE
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
- const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
- const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
- const float * x_df = (const float *) x_dm;
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
+ const int * x_qs = (const int *) x;
+ const float * x_df = (const float *) x_qs + txs.qs;
+ const int * x_sc = (const int *) x_df + txs.dm;
const int * y_qs = (const int *) y + 4;
const float * y_df = (const float *) y;
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mma(
- const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
- const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
#ifdef INT8_MMA_AVAILABLE
typedef mma_int_A_I16K4 mma_A;
typedef mma_int_B_J8K4 mma_B;
typedef mma_int_C_I16J8 mma_C;
- const float * x_df = (const float *) x_dm;
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
+ constexpr int rows_per_warp = 2 * granularity;
+ constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+ y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+ const int * x_qs = (const int *) x;
+ const float * x_df = (const float *) x_qs + WARP_SIZE*2;
+ const int * x_sc = (const int *) x_df + WARP_SIZE/QI3_K;
const int * y_qs = (const int *) y + 4;
const float * y_df = (const float *) y;
- const int i0 = threadIdx.y*mma_A::I;
- static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+ const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
- mma_A A[2];
- int scA[mma_C::ne/2][2];
- float dA[mma_C::ne/2];
+ mma_A A[ntx][2];
+ int scA[ntx][mma_C::ne/2][2];
+ float dA[ntx][mma_C::ne/2];
#pragma unroll
- for (int l = 0; l < mma_A::ne; ++l) {
- const int i = i0 + mma_A::get_i(l);
- const int k = QR3_K*k0 + mma_A::get_k(l);
+ for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+ for (int l = 0; l < mma_A::ne; ++l) {
+ const int i = i0 + n*mma_A::I + mma_A::get_i(l);
+ const int k = QR3_K*k0 + mma_A::get_k(l);
- A[0].x[l] = (x_qs[i*(2*WARP_SIZE + 1) + k/2 + 0] >> (4*(k%2))) & 0x0F0F0F0F;
- A[1].x[l] = (x_qs[i*(2*WARP_SIZE + 1) + k/2 + mma_A::K/2] >> (4*(k%2))) & 0x0F0F0F0F;
- A[0].x[l] = __vsubss4(A[0].x[l], 0x04040404);
- A[1].x[l] = __vsubss4(A[1].x[l], 0x04040404);
- }
+ A[n][0].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k/2 + 0] >> (4*(k%2))) & 0x0F0F0F0F;
+ A[n][1].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k/2 + mma_A::K/2] >> (4*(k%2))) & 0x0F0F0F0F;
+ A[n][0].x[l] = __vsubss4(A[n][0].x[l], 0x04040404);
+ A[n][1].x[l] = __vsubss4(A[n][1].x[l], 0x04040404);
+ }
#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int i = i0 + mma_C::get_i(2*l);
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
- const int kbx = k0 / QI3_K;
- const int ky = (k0 % QI3_K) * QR3_K;
- const int8_t * sc = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
+ const int kbx = k0 / QI3_K;
+ const int ky = (k0 % QI3_K) * QR3_K;
+ const int8_t * sc = ((const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q3_K + kbx*4)) + ky/4;
- scA[l][0] = sc[0];
- scA[l][1] = sc[1];
- }
+ scA[n][l][0] = sc[0];
+ scA[n][l][1] = sc[1];
+ }
#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int i = i0 + mma_C::get_i(2*l);
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
- dA[l] = x_df[i*(WARP_SIZE/QI3_K) + i/QI3_K + k0/QI3_K];
+ dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/QI3_K];
+ }
}
#pragma unroll
- for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
- mma_C C[2];
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
mma_B B[2];
float dB[mma_C::ne/2];
-#pragma unroll
- for (int l = 0; l < mma_B::ne; ++l) {
- const int j = j0 + mma_B::get_j(l);
- const int k = (4*k0 + mma_B::get_k(l)) % WARP_SIZE;
+ B[0].load(y_qs + j0*MMQ_TILE_Y_K + (QR3_K*k0 + 0) % WARP_SIZE, MMQ_TILE_Y_K);
+ B[1].load(y_qs + j0*MMQ_TILE_Y_K + (QR3_K*k0 + mma_B::K) % WARP_SIZE, MMQ_TILE_Y_K);
- B[0].x[l] = y_qs[j*MMQ_TILE_Y_K + k + 0];
- B[1].x[l] = y_qs[j*MMQ_TILE_Y_K + k + mma_B::K];
- }
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
dB[l] = y_df[j*MMQ_TILE_Y_K + ((4*k0)/QI8_1) % (WARP_SIZE/QI8_1)];
}
- C[0].mma_K4(A[0], B[0]);
- C[1].mma_K4(A[1], B[1]);
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ mma_C C[2];
+ C[0].mma_K4(A[n][0], B[0]);
+ C[1].mma_K4(A[n][1], B[1]);
#pragma unroll
- for (int l = 0; l < mma_C::ne; ++l) {
- sum[(j0/mma_B::J)*mma_C::ne + l] += (C[0].x[l]*scA[l/2][0] + C[1].x[l]*scA[l/2][1])*dA[l/2]*dB[l%2];
+ for (int l = 0; l < mma_C::ne; ++l) {
+ sum[(j0/mma_C::J + n)*mma_C::ne + l] += (C[0].x[l]*scA[n][l/2][0] + C[1].x[l]*scA[n][l/2][1])*dA[n][l/2]*dB[l%2];
+ }
}
}
#else
- GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+ GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
NO_DEVICE_CODE;
#endif // INT8_MMA_AVAILABLE
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
- const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
- int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ half2 * x_dm = (half2 *) (x_qs + WARP_SIZE);
+ int * x_sc = (int *) (x_dm + WARP_SIZE/QI4_K);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
+ int * x_qs = (int *) x_tile;
+ half2 * x_dm = (half2 *) (x_qs + txs.qs);
+ int * x_sc = (int *) (x_dm + txs.dm);
+#endif // INT8_MMA_AVAILABLE
const int kbx = 0; // threadIdx.x / QI4_K
const int kqsx = threadIdx.x; // threadIdx.x % QI4_K
const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbx;
- x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q4_K + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+#else
+ x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+#endif // INT8_MMA_AVAILABLE
}
const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256
const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbxd;
- x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm;
+#ifdef INT8_MMA_AVAILABLE
+ x_dm[i*MMQ_MMA_TILE_X_K_Q4_K + kbxd] = bxi->dm;
+#else
+ x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K + kbxd] = bxi->dm;
+#endif // INT8_MMA_AVAILABLE
}
#pragma unroll
int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits
- x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
+#ifdef INT8_MMA_AVAILABLE
+ x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + ksc] = scales8;
+#else
+ x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
+#endif // INT8_MMA_AVAILABLE
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
- const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
- const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
+ const int * x_qs = (const int *) x;
+ const half2 * x_dm = (const half2 *) x_qs + txs.qs;
+ const int * x_sc = (const int *) x_dm + txs.dm;
const int * y_qs = (const int *) y + 4;
const half2 * y_ds = (const half2 *) y;
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
- const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
- const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
#ifdef INT8_MMA_AVAILABLE
typedef mma_int_A_I16K8 mma_A;
typedef mma_int_B_J8K8 mma_B;
typedef mma_int_C_I16J8 mma_C;
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
+ constexpr int rows_per_warp = 2 * granularity;
+ constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+ y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+ const int * x_qs = (const int *) x;
+ const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE;
+ const int * x_sc = (const int *) x_dm + WARP_SIZE/QI4_K;
const int * y_qs = (const int *) y + 4;
const half2 * y_ds = (const half2 *) y;
- const int i0 = threadIdx.y*mma_A::I;
- static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+ const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
+
+ mma_A A[ntx][2];
+ int scA[ntx][mma_C::ne/2][2];
+ int mA[ntx][mma_C::ne/2][2];
+ half2 dmA[ntx][mma_C::ne/2];
- mma_A A[2];
- int scA[mma_C::ne/2][2];
- int mA[mma_C::ne/2][2];
- half2 dmA[mma_C::ne/2];
#pragma unroll
- for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 4) {
+ for (int n = 0; n < ntx; ++n) {
#pragma unroll
- for (int l = 0; l < mma_A::ne; ++l) {
- const int i = i0 + mma_A::get_i(l);
- const int k = k0 + mma_A::get_k(l);
+ for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 8) {
+ A[n][kvdr/4 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_K + k0, MMQ_MMA_TILE_X_K_Q4_K);
- A[kvdr/4].x[l] = (x_qs[i*(WARP_SIZE + 1) + k] >> kvdr) & 0x0F0F0F0F;
+#pragma unroll
+ for (int l = 0; l < mma_A::ne; ++l) {
+ A[n][kvdr/4 + 1].x[l] = (A[n][kvdr/4 + 0].x[l] >> 4) & 0x0F0F0F0F;
+ A[n][kvdr/4 + 0].x[l] &= 0x0F0F0F0F;
+ }
}
#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int i = i0 + mma_C::get_i(2*l);
+ for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 4) {
+#pragma unroll
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
- const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
- const uint8_t * m = sc + 8;
+ const uint8_t * sc = ((const uint8_t *) &x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + k0/16]) + 2 * ((k0 % 16) / 8);
+ const uint8_t * m = sc + 8;
- scA[l][kvdr/4] = sc[kvdr/4];
- mA[l][kvdr/4] = m[kvdr/4];
+ scA[n][l][kvdr/4] = sc[kvdr/4];
+ mA[n][l][kvdr/4] = m[kvdr/4];
+ }
}
- }
#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int i = i0 + mma_C::get_i(2*l);
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
- dmA[l] = x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K + k0/QI5_K];
+ dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_K + k0/QI4_K];
+ }
}
#pragma unroll
- for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
- float tmpd[mma_C::ne] = {0.0f};
- float tmpm[mma_C::ne] = {0.0f};
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+ float tmpd[ntx][mma_C::ne] = {{0.0f}};
+ float tmpm[ntx][mma_C::ne] = {{0.0f}};
#pragma unroll
- for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
- mma_C C;
+ for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 4) {
mma_B B;
half2 dsB[mma_C::ne/2];
-#pragma unroll
- for (int l = 0; l < mma_B::ne; ++l) {
- const int j = j0 + mma_B::get_j(l);
- const int k = (2*k0 + 2*kvdr + mma_B::get_k(l)) % WARP_SIZE;
+ B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0 + 2*kvdr) % WARP_SIZE, MMQ_TILE_Y_K);
- B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
- }
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
dsB[l] = y_ds[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
}
- C.mma_K8(A[kvdr/4], B);
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ mma_C C;
+ C.mma_K8(A[n][kvdr/4], B);
#pragma unroll
- for (int l = 0; l < mma_C::ne; ++l) {
- tmpd[l] += (C.x[l]*scA[l/2][kvdr/4]) * __low2float(dsB[l%2]);
- tmpm[l] += mA[l/2][kvdr/4] * __high2float(dsB[l%2]);
+ for (int l = 0; l < mma_C::ne; ++l) {
+ tmpd[n][l] += (C.x[l]*scA[n][l/2][kvdr/4]) * __low2float(dsB[l%2]);
+ tmpm[n][l] += mA[n][l/2][kvdr/4] * __high2float(dsB[l%2]);
+ }
}
}
#pragma unroll
- for (int l = 0; l < mma_C::ne; ++l) {
- sum[(j0/mma_B::J)*mma_C::ne + l] += __low2float(dmA[l/2])*tmpd[l] - __high2float(dmA[l/2])*tmpm[l];
+ for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+ for (int l = 0; l < mma_C::ne; ++l) {
+ sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA[n][l/2])*tmpd[n][l] - __high2float(dmA[n][l/2])*tmpm[n][l];
+ }
}
}
#else
- GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+ GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
NO_DEVICE_CODE;
#endif // INT8_MMA_AVAILABLE
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
- const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
- int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ half2 * x_dm = (half2 *) (x_qs + WARP_SIZE*2);
+ int * x_sc = (int *) (x_dm + WARP_SIZE/QI5_K);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
+ int * x_qs = (int *) x_tile;
+ half2 * x_dm = (half2 *) (x_qs + txs.qs);
+ int * x_sc = (int *) (x_dm + txs.dm);
+#endif // INT8_MMA_AVAILABLE
const int kbx = 0; // threadIdx.x / QI5_K
const int kqsx = threadIdx.x; // threadIdx.x % QI5_K
const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0;
const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + (QI5_K/4);
- x_qs[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
- x_qs[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q5_K + kq0] = ql0 | qh0;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q5_K + kq1] = ql1 | qh1;
+#else
+ x_qs[i*(2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
+ x_qs[i*(2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
+#endif // INT8_MMA_AVAILABLE
}
const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256
const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbxd;
- x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm;
+#ifdef INT8_MMA_AVAILABLE
+ x_dm[i*MMQ_MMA_TILE_X_K_Q5_K + kbxd] = bxi->dm;
+#else
+ x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K + kbxd] = bxi->dm;
+#endif // INT8_MMA_AVAILABLE
}
#pragma unroll
int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits
- x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
+#ifdef INT8_MMA_AVAILABLE
+ x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + ksc] = scales8;
+#else
+ x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
+#endif // INT8_MMA_AVAILABLE
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
- const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
- const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
- const int * y_qs = (const int *) y + 4;
- const half2 * y_ds = (const half2 *) y;
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
+ const int * x_qs = (const int *) x;
+ const half2 * x_dm = (const half2 *) x_qs + txs.qs;
+ const int * x_sc = (const int *) x_dm + txs.dm;
+ const int * y_qs = (const int *) y + 4;
+ const half2 * y_ds = (const half2 *) y;
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
- const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
- const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
#ifdef INT8_MMA_AVAILABLE
typedef mma_int_A_I16K8 mma_A;
typedef mma_int_B_J8K8 mma_B;
typedef mma_int_C_I16J8 mma_C;
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
+ constexpr int rows_per_warp = 2 * granularity;
+ constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+ y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+ const int * x_qs = (const int *) x;
+ const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2;
+ const int * x_sc = (const int *) x_dm + WARP_SIZE/QI5_K;
const int * y_qs = (const int *) y + 4;
const half2 * y_ds = (const half2 *) y;
- const int i0 = threadIdx.y*mma_A::I;
- static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+ const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
+
+ mma_A A[ntx][2];
+ int scA[ntx][mma_C::ne/2][2];
+ int mA[ntx][mma_C::ne/2][2];
+ half2 dmA[ntx][mma_C::ne/2];
- mma_A A[2];
- int scA[mma_C::ne/2][2];
- int mA[mma_C::ne/2][2];
- half2 dmA[mma_C::ne/2];
#pragma unroll
- for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
+ for (int n = 0; n < ntx; ++n) {
#pragma unroll
- for (int l = 0; l < mma_A::ne; ++l) {
- const int i = i0 + mma_A::get_i(l);
- const int k = QR5_K*k0 + QR5_K*kvdr + mma_A::get_k(l);
-
- A[kvdr/4].x[l] = x_qs[i*(QR5_K*WARP_SIZE + 1) + k];
- }
+ for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
+ A[n][kvdr/4].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_K + (QR5_K*k0 + QR5_K*kvdr), MMQ_MMA_TILE_X_K_Q5_K);
#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int i = i0 + mma_C::get_i(2*l);
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
- const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
- const uint8_t * m = sc + 8;
+ const uint8_t * sc = ((const uint8_t *) &x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + k0/16]) + 2 * ((k0 % 16) / 8);
+ const uint8_t * m = sc + 8;
- scA[l][kvdr/4] = sc[kvdr/4];
- mA[l][kvdr/4] = m[kvdr/4];
+ scA[n][l][kvdr/4] = sc[kvdr/4];
+ mA[n][l][kvdr/4] = m[kvdr/4];
+ }
}
- }
-#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int i = i0 + mma_C::get_i(2*l);
+ #pragma unroll
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
- dmA[l] = x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K + k0/QI5_K];
+ dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q5_K + k0/QI5_K];
+ }
}
#pragma unroll
- for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
- float tmpd[mma_C::ne] = {0.0f};
- float tmpm[mma_C::ne] = {0.0f};
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+ float tmpd[ntx][mma_C::ne] = {{0.0f}};
+ float tmpm[ntx][mma_C::ne] = {{0.0f}};
#pragma unroll
for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
- mma_C C;
mma_B B;
half2 dsB[mma_C::ne/2];
-#pragma unroll
- for (int l = 0; l < mma_B::ne; ++l) {
- const int j = j0 + mma_B::get_j(l);
- const int k = (2*k0 + 2*kvdr + mma_B::get_k(l)) % WARP_SIZE;
+ B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0 + 2*kvdr) % WARP_SIZE, MMQ_TILE_Y_K);
- B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
- }
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
dsB[l] = y_ds[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
}
- C.mma_K8(A[kvdr/4], B);
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ mma_C C;
+ C.mma_K8(A[n][kvdr/4], B);
#pragma unroll
- for (int l = 0; l < mma_C::ne; ++l) {
- tmpd[l] += (C.x[l]*scA[l/2][kvdr/4]) * __low2float(dsB[l%2]);
- tmpm[l] += mA[l/2][kvdr/4] * __high2float(dsB[l%2]);
+ for (int l = 0; l < mma_C::ne; ++l) {
+ tmpd[n][l] += (C.x[l]*scA[n][l/2][kvdr/4]) * __low2float(dsB[l%2]);
+ tmpm[n][l] += mA[n][l/2][kvdr/4] * __high2float(dsB[l%2]);
+ }
}
}
#pragma unroll
- for (int l = 0; l < mma_C::ne; ++l) {
- sum[(j0/mma_B::J)*mma_C::ne + l] += __low2float(dmA[l/2])*tmpd[l] - __high2float(dmA[l/2])*tmpm[l];
+ for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+ for (int l = 0; l < mma_C::ne; ++l) {
+ sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA[n][l/2])*tmpd[n][l] - __high2float(dmA[n][l/2])*tmpm[n][l];
+ }
}
}
#else
- GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+ GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
NO_DEVICE_CODE;
#endif // INT8_MMA_AVAILABLE
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
- const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
- int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + WARP_SIZE*2);
+ int * x_sc = (int *) (x_df + WARP_SIZE/QI6_K);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+ int * x_sc = (int *) (x_df + txs.dm);
+#endif // INT8_MMA_AVAILABLE
const int kbx = 0; // threadIdx.x / QI6_K
const int kqsx = threadIdx.x; // threadIdx.x % QI6_K
const int kq0 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + 0;
const int kq1 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + (QI6_K/2);
- x_qs[i * (2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
- x_qs[i * (2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
+ x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
+#else
+ x_qs[i*(2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
+ x_qs[i*(2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
+#endif // INT8_MMA_AVAILABLE
}
const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256
const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
- float * x_dmf = (float *) x_dm;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) {
const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbxd;
- x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = bxi->d;
+#ifdef INT8_MMA_AVAILABLE
+ x_df[i*MMQ_MMA_TILE_X_K_Q6_K + kbxd] = bxi->d;
+#else
+ x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K + kbxd] = bxi->d;
+#endif // INT8_MMA_AVAILABLE
}
#pragma unroll
const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4;
- x_sc[i * (WARP_SIZE/8) + i / 8 + threadIdx.x % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, threadIdx.x % (QI6_K/8));
+#ifdef INT8_MMA_AVAILABLE
+ x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, threadIdx.x % (QI6_K/8));
+#else
+ x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, threadIdx.x % (QI6_K/8));
+#endif // INT8_MMA_AVAILABLE
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
- const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
- const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
- const float * x_dmf = (const float *) x_dm;
- const int * y_qs = (const int *) y + 4;
- const float * y_df = (const float *) y;
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
+ const int * x_qs = (const int *) x;
+ const float * x_df = (const float *) x_qs + txs.qs;
+ const int * x_sc = (const int *) x_df + txs.dm;
+ const int * y_qs = (const int *) y + 4;
+ const float * y_df = (const float *) y;
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq(
&x_qs[i*(QR6_K*WARP_SIZE + 1) + QR6_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR6_K*k0) % WARP_SIZE], sc,
- x_dmf[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + ((QR6_K*k0) % WARP_SIZE)/QI8_1]);
+ x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + ((QR6_K*k0) % WARP_SIZE)/QI8_1]);
}
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
- const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
- const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
#ifdef INT8_MMA_AVAILABLE
typedef mma_int_A_I16K4 mma_A;
typedef mma_int_B_J8K4 mma_B;
typedef mma_int_C_I16J8 mma_C;
- const float * x_df = (const float *) x_dm;
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
+ constexpr int rows_per_warp = 2 * granularity;
+ constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+ y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+ const int * x_qs = (const int *) x;
+ const float * x_df = (const float *) x_qs + WARP_SIZE*2;
+ const int * x_sc = (const int *) x_df + WARP_SIZE/QI6_K;
const int * y_qs = (const int *) y + 4;
const float * y_df = (const float *) y;
- const int i0 = threadIdx.y*mma_A::I;
-#ifdef INT8_MMA_AVAILABLE
- static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
-#endif // INT8_MMA_AVAILABLE
+ const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
+
+ mma_A A[ntx][4];
+ int scA[ntx][mma_C::ne/2][4];
+ float dA[ntx][mma_C::ne/2];
- mma_A A[4];
- int scA[mma_C::ne/2][4];
- float dA[mma_C::ne/2];
#pragma unroll
- for (int kvdr = 0; kvdr < VDR_Q6_K_Q8_1_MMQ; kvdr += 4) {
+ for (int n = 0; n < ntx; ++n) {
#pragma unroll
- for (int l = 0; l < mma_A::ne; ++l) {
- const int i = i0 + mma_A::get_i(l);
- const int k = QR6_K*k0 + QR6_K*kvdr + mma_A::get_k(l);
-
- A[kvdr/2 + 0].x[l] = x_qs[i*(QR6_K*WARP_SIZE + 1) + k + 0];
- A[kvdr/2 + 1].x[l] = x_qs[i*(QR6_K*WARP_SIZE + 1) + k + mma_A::K];
- }
+ for (int kvdr = 0; kvdr < VDR_Q6_K_Q8_1_MMQ; kvdr += 4) {
+ A[n][kvdr/2 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (QR6_K*k0 + QR6_K*kvdr + 0), MMQ_MMA_TILE_X_K_Q6_K);
+ A[n][kvdr/2 + 1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (QR6_K*k0 + QR6_K*kvdr + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K);
#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int i = i0 + mma_C::get_i(2*l);
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
- const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/8]);
+ const int8_t * sc = ((const int8_t *) &x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/8]);
- scA[l][kvdr/2 + 0] = sc[kvdr/2 + 0];
- scA[l][kvdr/2 + 1] = sc[kvdr/2 + 1];
+ scA[n][l][kvdr/2 + 0] = sc[kvdr/2 + 0];
+ scA[n][l][kvdr/2 + 1] = sc[kvdr/2 + 1];
+ }
}
- }
#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int i = i0 + mma_C::get_i(2*l);
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
- dA[l] = x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K + k0/QI6_K];
+ dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K + k0/QI6_K];
+ }
}
#pragma unroll
- for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
- float tmp[mma_C::ne] = {0.0f};
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+ float tmp[ntx][mma_C::ne] = {{0.0f}};
#pragma unroll
for (int kvdr = 0; kvdr < VDR_Q6_K_Q8_1_MMQ; kvdr += 4) {
- mma_C C[2];
mma_B B[2];
float dB[mma_C::ne/2];
-#pragma unroll
- for (int l = 0; l < mma_B::ne; ++l) {
- const int j = j0 + mma_B::get_j(l);
- const int k = (2*k0 + 2*kvdr + mma_B::get_k(l)) % WARP_SIZE;
+ const int k0B = (2*k0 + 2*kvdr) % WARP_SIZE;
+ B[0].load(y_qs + j0*MMQ_TILE_Y_K + 0 + k0B, MMQ_TILE_Y_K);
+ B[1].load(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k0B, MMQ_TILE_Y_K);
- B[0].x[l] = y_qs[j*MMQ_TILE_Y_K + k + 0];
- B[1].x[l] = y_qs[j*MMQ_TILE_Y_K + k + mma_B::K];
- }
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
dB[l] = y_df[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
}
- C[0].mma_K4(A[kvdr/2 + 0], B[0]);
- C[1].mma_K4(A[kvdr/2 + 1], B[1]);
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ mma_C C[2];
+ C[0].mma_K4(A[n][kvdr/2 + 0], B[0]);
+ C[1].mma_K4(A[n][kvdr/2 + 1], B[1]);
#pragma unroll
- for (int l = 0; l < mma_C::ne; ++l) {
- tmp[l] += (C[0].x[l]*scA[l/2][kvdr/2 + 0] + C[1].x[l]*scA[l/2][kvdr/2 + 1])*dB[l%2];
+ for (int l = 0; l < mma_C::ne; ++l) {
+ tmp[n][l] += (C[0].x[l]*scA[n][l/2][kvdr/2 + 0] + C[1].x[l]*scA[n][l/2][kvdr/2 + 1])*dB[l%2];
+ }
}
}
#pragma unroll
- for (int l = 0; l < mma_C::ne; ++l) {
- sum[(j0/mma_B::J)*mma_C::ne + l] += tmp[l]*dA[l/2];
+ for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+ for (int l = 0; l < mma_C::ne; ++l) {
+ sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp[n][l]*dA[n][l/2];
+ }
}
}
#else
- GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+ GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
NO_DEVICE_CODE;
#endif // INT8_MMA_AVAILABLE
}
typedef mma_int_C_I16J8 mma_C;
- const int i0 = threadIdx.y*mma_C::I;
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
+ constexpr int rows_per_warp = 2 * granularity;
+ constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+ const int i0 = (threadIdx.y / ntx) * (ntx*mma_C::I);
#ifdef INT8_MMA_AVAILABLE
static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
#endif // INT8_MMA_AVAILABLE
+ dst += (threadIdx.y % ntx) * mma_C::J*stride;
+
#pragma unroll
- for (int j0 = 0; j0 < mmq_x; j0 += mma_C::J) {
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
#pragma unroll
- for (int l = 0; l < mma_C::ne; ++l) {
- const int j = j0 + mma_C::get_j(l);
+ for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+ for (int l = 0; l < mma_C::ne; ++l) {
+ const int j = j0 + mma_C::get_j(l);
- if (j > j_max) {
- continue;
- }
+ if (j > j_max) {
+ continue;
+ }
- const int i = i0 + mma_C::get_i(l);
+ const int i = i0 + n*mma_C::I + mma_C::get_i(l);
- if (need_check && i > i_max) {
- continue;
- }
+ if (need_check && i > i_max) {
+ continue;
+ }
- dst[j*stride + i] = sum[(j0/mma_C::J)*mma_C::ne + l];
+ dst[j*stride + i] = sum[(j0/mma_C::J + n)*mma_C::ne + l];
+ }
}
}
}
constexpr int vdr = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
+ extern __shared__ char data_mul_mat_q[];
+ int * tile_y = (int *) data_mul_mat_q;
+ int * tile_x = tile_y + GGML_PAD(mmq_x*(WARP_SIZE + WARP_SIZE/QI8_1), nwarps*WARP_SIZE);
+
#ifdef INT8_MMA_AVAILABLE
constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_mma;
constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
#endif // INT8_MMA_AVAILABLE
- constexpr tile_x_sizes txs = get_tile_x_sizes_device<mmq_y>(type);
-
- extern __shared__ char data_mul_mat_q[];
- int * tile_x_qs = (int *) data_mul_mat_q;
- half2 * tile_x_dm = (half2 *) (tile_x_qs + txs.qs);
- int * tile_x_sc = (int *) (tile_x_dm + txs.dm);
- int * tile_y = (int *) (tile_x_sc + txs.sc); // [mmq_x * (WARP_SIZE + WARP_SIZE/QI8_1)]
-
constexpr int blocks_per_warp = WARP_SIZE / qi;
float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_warp) {
- load_tiles(x, tile_x_qs, tile_x_dm, tile_x_sc, stride01*it*mmq_y + kb0, tile_x_max_i, stride01);
+ load_tiles(x, tile_x, stride01*it*mmq_y + kb0, tile_x_max_i, stride01);
#pragma unroll
for (int kr = 0; kr < qr; ++kr) {
// #pragma unroll // unrolling this loop causes too much register pressure
for (int k0 = kr*WARP_SIZE/qr; k0 < (kr+1)*WARP_SIZE/qr; k0 += vdr) {
- vec_dot(tile_x_qs, tile_x_dm, tile_x_sc, tile_y, sum, k0);
+ vec_dot(tile_x, tile_y, sum, k0);
}
__syncthreads();
const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) {
// Skip unused template specializations for faster compilation:
- if (mmq_x > get_mmq_x_max_device()) {
+ if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) {
NO_DEVICE_CODE;
return;
}
int64_t ne0;
};
-static int mmq_get_shmem(const ggml_type type, const int mmq_x, const int mmq_y) {
- const tile_x_sizes txs = get_tile_x_sizes_host(type, mmq_y);
-
- const int shmem_x = txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
- const int shmem_y = mmq_x*WARP_SIZE*sizeof(int) + mmq_x*(WARP_SIZE/QI8_1)*sizeof(half2);
+template<ggml_type type>
+static int mmq_get_shmem(const int mmq_x, const int mmq_y, const int cc) {
+ const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
+ const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
+ const int shmem_x = int8_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
+ const int shmem_y = mmq_x*sizeof(block_q8_1_mmq);
return shmem_x + GGML_PAD(shmem_y, MMQ_NWARPS*WARP_SIZE*sizeof(int));
}
const dim3 block_dims(WARP_SIZE, MMQ_NWARPS, 1);
- const int shmem = mmq_get_shmem(type, mmq_x, mmq_y);
+ const int shmem = mmq_get_shmem<type>(mmq_x, mmq_y, cc);
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
int nparts_best = INT_MAX;
for (int mmq_x = 8; mmq_x <= mmq_x_max && nparts_best > 1; mmq_x += 8) {
+ const int granularity = mmq_get_granularity_host(mmq_x, cc);
+
+ if (mmq_x % granularity != 0 || mmq_get_shmem<type>(mmq_x, mmq_y, cc) > smpbo) {
+ continue;
+ }
+
const int ntiles_x = (args.ne11 + mmq_x - 1) / mmq_x;
const int nwaves_xy_tiling = ntiles_x*block_num_y;
-
const int nparts = use_stream_k ? ntiles_x : nwaves_xy_tiling;
- if (nparts < nparts_best && mmq_get_shmem(type, mmq_x, mmq_y) <= smpbo) {
+ if (nparts < nparts_best) {
mmq_x_best = mmq_x;
nparts_best = nparts;
}