#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
typedef void (*load_tiles_mmq_t)(
- const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+ const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride);
typedef void (*vec_dot_mmq_t)(
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+ const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0);
typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1);
static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size");
struct tile_x_sizes {
- int ql;
+ int qs;
int dm;
- int qh;
int sc;
};
#endif // __CUDA_ARCH__ >= CC_VOLTA
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-#define TILE_X_SIZES_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0, 0}
-#define TILE_X_SIZES_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0, 0}
-#define TILE_X_SIZES_Q5_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_0 + mmq_y/QI5_0, 0, 0}
-#define TILE_X_SIZES_Q5_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_1 + mmq_y/QI5_1, 0, 0}
-#define TILE_X_SIZES_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI8_0 + mmq_y/QI8_0, 0, 0}
-#define TILE_X_SIZES_Q2_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI2_K + mmq_y/QI2_K, 0, mmq_y*WARP_SIZE/4 + mmq_y/4}
-#define TILE_X_SIZES_Q3_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI3_K + mmq_y/QI3_K, mmq_y*WARP_SIZE/2 + mmq_y/2, mmq_y*WARP_SIZE/4 + mmq_y/4}
-#define TILE_X_SIZES_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, 0, mmq_y*WARP_SIZE/8 + mmq_y/8}
-#define TILE_X_SIZES_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, 0, mmq_y*WARP_SIZE/8 + mmq_y/8}
-#define TILE_X_SIZES_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, 0, mmq_y*WARP_SIZE/8 + mmq_y/8}
+#define TILE_X_SIZES_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
+#define TILE_X_SIZES_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0}
+#define TILE_X_SIZES_Q5_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_0 + mmq_y/QI5_0, 0}
+#define TILE_X_SIZES_Q5_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_1 + mmq_y/QI5_1, 0}
+#define TILE_X_SIZES_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI8_0 + mmq_y/QI8_0, 0}
+#define TILE_X_SIZES_Q2_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE + mmq_y, 0}
+#define TILE_X_SIZES_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI3_K + mmq_y/QI3_K, mmq_y*WARP_SIZE/4 + mmq_y/4}
+#define TILE_X_SIZES_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
+#define TILE_X_SIZES_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
+#define TILE_X_SIZES_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
#define GET_TILE_X_SIZES_BODY \
return type == GGML_TYPE_Q4_0 ? TILE_X_SIZES_Q4_0 : \
type == GGML_TYPE_Q4_K ? TILE_X_SIZES_Q4_K : \
type == GGML_TYPE_Q5_K ? TILE_X_SIZES_Q5_K : \
type == GGML_TYPE_Q6_K ? TILE_X_SIZES_Q6_K : \
- tile_x_sizes{0, 0, 0, 0}
+ tile_x_sizes{0, 0, 0}
static tile_x_sizes get_tile_x_sizes_host(const ggml_type type, const int mmq_y) {
GET_TILE_X_SIZES_BODY;
// ------------------------------------------------------------
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
- const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+ const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+ GGML_UNUSED(x_sc);
const int kbx = threadIdx.x / QI4_0;
const int kqsx = threadIdx.x % QI4_0;
const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
- x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx);
+ x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx);
}
const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+ const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
-
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+ GGML_UNUSED(x_sc);
const float * x_df = (const float *) x_dm;
const int * y_qs = (const int *) y + 4;
}
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
- (&x_ql[i*(WARP_SIZE + 1) + k0], u, x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0],
+ (&x_qs[i*(WARP_SIZE + 1) + k0], u, x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0],
y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma(
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+ const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
-
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+#ifdef INT8_MMA_AVAILABLE
+ GGML_UNUSED(x_sc);
typedef mma_int_A_I16K8 mma_A;
typedef mma_int_B_J8K8 mma_B;
const int k = k0 + mma_A::get_k(l) % QI4_0;
const int shift = 4*(mma_A::get_k(l) / QI4_0);
- A.x[l] = __vsubss4((x_ql[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F, 0x08080808);
+ A.x[l] = __vsubss4((x_qs[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F, 0x08080808);
}
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
sum[(j0/B.J)*C.ne + l] += dA[l/2]*__low2float(dsB[l%2])*C.x[l];
}
}
+#else
+ GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+ NO_DEVICE_CODE;
+#endif // INT8_MMA_AVAILABLE
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
- const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+ const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+ GGML_UNUSED(x_sc);
const int kbx = threadIdx.x / QI4_1;
const int kqsx = threadIdx.x % QI4_1;
const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
- x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+ x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
}
const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+ const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
-
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+ GGML_UNUSED(x_sc);
const int * y_qs = (const int *) y + 4;
const half2 * y_ds = (const half2 *) y;
}
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
- (&x_ql[i*(WARP_SIZE + 1) + k0], u, x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + k0/QI4_1],
+ (&x_qs[i*(WARP_SIZE + 1) + k0], u, x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + k0/QI4_1],
y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma(
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+ const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
-
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+#ifdef INT8_MMA_AVAILABLE
+ GGML_UNUSED(x_sc);
typedef mma_int_A_I16K8 mma_A;
typedef mma_int_B_J8K8 mma_B;
const int k = k0 + mma_A::get_k(l) % QI4_0;
const int shift = 4*(mma_A::get_k(l) / QI4_0);
- A.x[l] = (x_ql[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F;
+ A.x[l] = (x_qs[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F;
}
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
}
}
+#else
+ GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+ NO_DEVICE_CODE;
+#endif // INT8_MMA_AVAILABLE
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
- const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+ const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+ GGML_UNUSED(x_sc);
const int kbx = threadIdx.x / QI5_0;
const int kqsx = threadIdx.x % QI5_0;
qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
qs0 = __vsubss4(qs0, 0x10101010); // subtract 16
- x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0;
+ x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0;
int qs1 = (ql >> 4) & 0x0F0F0F0F;
qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
- x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1;
+ x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1;
}
const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a(
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+ const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
-
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+ GGML_UNUSED(x_sc);
const float * x_dmf = (const float *) x_dm;
const int * y_qs = (const int *) y + 4;
}
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, QR5_0*VDR_Q5_0_Q8_1_MMQ>
- (&x_ql[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dmf[index_bx], y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
+ (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dmf[index_bx], y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
}
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma(
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+ const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
-
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+#ifdef INT8_MMA_AVAILABLE
+ GGML_UNUSED(x_sc);
typedef mma_int_A_I16K8 mma_A;
typedef mma_int_B_J8K8 mma_B;
const int i = i0 + mma_A::get_i(l);
const int k = 2*(k0 + mma_A::get_k(l) % QI5_0) + mma_A::get_k(l) / QI5_0;
- A.x[l] = x_ql[i*(2*WARP_SIZE + 1) + k];
+ A.x[l] = x_qs[i*(2*WARP_SIZE + 1) + k];
}
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
sum[(j0/B.J)*C.ne + l] += dA[l/2]*dB[l%2]*C.x[l];
}
}
+#else
+ GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+ NO_DEVICE_CODE;
+#endif // INT8_MMA_AVAILABLE
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
- const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+ const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+ GGML_UNUSED(x_sc);
const int kbx = threadIdx.x / QI5_1;
const int kqsx = threadIdx.x % QI5_1;
qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
- x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0;
+ x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0;
int qs1 = (ql >> 4) & 0x0F0F0F0F;
qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
- x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1;
+ x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1;
}
const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a(
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+ const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
-
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+ GGML_UNUSED(x_sc);
const int * y_qs = (const int *) y + 4;
const half2 * y_ds = (const half2 *) y;
}
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
- (&x_ql[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dm[index_bx], y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
+ (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dm[index_bx], y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
}
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma(
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+ const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
-
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+#ifdef INT8_MMA_AVAILABLE
+ GGML_UNUSED(x_sc);
typedef mma_int_A_I16K8 mma_A;
typedef mma_int_B_J8K8 mma_B;
const int i = i0 + mma_A::get_i(l);
const int k = 2*(k0 + mma_A::get_k(l) % QI5_1) + mma_A::get_k(l) / QI5_1;
- A.x[l] = x_ql[i*(2*WARP_SIZE + 1) + k];
+ A.x[l] = x_qs[i*(2*WARP_SIZE + 1) + k];
}
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
}
}
+#else
+ GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+ NO_DEVICE_CODE;
+#endif // INT8_MMA_AVAILABLE
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
- const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+ const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
-
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+ GGML_UNUSED(x_sc);
const int kbx = threadIdx.x / QI8_0;
const int kqsx = threadIdx.x % QI8_0;
const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
- x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx);
+ x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx);
}
const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+ const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
-
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+ GGML_UNUSED(x_sc);
const float * x_dmf = (const float *) x_dm;
const int * y_qs = (const int *) y + 4;
const int i = i0 + threadIdx.x;
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
- (&x_ql[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0], x_dmf[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0],
+ (&x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0], x_dmf[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0],
y_df[j*MMQ_TILE_Y_K + k0/QI8_1]);
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+ const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
-
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+#ifdef INT8_MMA_AVAILABLE
+ GGML_UNUSED(x_sc);
typedef mma_int_A_I16K8 mma_A;
typedef mma_int_B_J8K8 mma_B;
const int i = i0 + mma_A::get_i(l);
const int k = k0 + mma_A::get_k(l);
- A.x[l] = x_ql[i*(WARP_SIZE + 1) + k];
+ A.x[l] = x_qs[i*(WARP_SIZE + 1) + k];
}
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
sum[(j0/B.J)*C.ne + l] += C.x[l]*dA[l/2]*dB[l%2];
}
}
+#else
+ GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+ NO_DEVICE_CODE;
+#endif // INT8_MMA_AVAILABLE
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
- const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+ const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
- GGML_UNUSED(x_qh);
const int kbx = threadIdx.x / QI2_K;
const int kqsx = threadIdx.x % QI2_K;
const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + kbx;
- x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
- }
-
- const int blocks_per_tile_x_row = WARP_SIZE / QI2_K;
- const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+ const int x_ql_0 = get_int_from_uint8(bxi->qs, kqsx);
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI2_K) {
- int i = (i0 + threadIdx.y * QI2_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
+ for (int l = 0; l < QR2_K; ++l) {
+ const int k = kbx*QI2_K + (kqsx/8)*8 + l*2 + (kqsx % 8)/4;
- if (need_check) {
- i = min(i, i_max);
- }
-
- const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + kbxd;
-
- x_dm[i * (WARP_SIZE/QI2_K) + i / QI2_K + kbxd] = bxi->dm;
- }
+ int x_qs_k = ((x_ql_0 >> (2*l)) & 0x03030303) << (2*(kqsx % 4));
+ x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 1, WARP_SIZE);
+ x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 2, WARP_SIZE);
-#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
- int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
+ if (kqsx % QR2_K != 0) {
+ continue;
+ }
- if (need_check) {
- i = min(i, i_max);
+ x_qs[i*(WARP_SIZE + 1) + k] = x_qs_k;
}
- const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/4)) / (QI2_K/4);
+ const int sc_m = bxi->scales[kqsx];
+#ifdef FAST_FP16_AVAILABLE
+ const half2 x_dm_ik = __hmul2(bxi->dm, make_half2(sc_m & 0x0F, sc_m >> 4));
+#else
+ const float2 bxi_dmf = __half22float2(bxi->dm);
+ const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
+#endif // FAST_FP16_AVAILABLE
- x_sc[i * (WARP_SIZE/4) + i / 4 + threadIdx.x % (WARP_SIZE/4)] = get_int_from_uint8_aligned(bxi->scales, threadIdx.x % (QI2_K/4));
+ x_dm[i*(WARP_SIZE + 1) + threadIdx.x] = x_dm_ik;
}
}
template <int mmq_x, int mmq_y, int nwarps>
-static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mul_mat(
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
+ const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
- GGML_UNUSED(x_qh);
-
- const int * y_qs = (const int *) y + 4;
- const float * y_df = (const float *) y;
+ const int * y_qs = (const int *) y + 4;
+ const float * y_df = (const float *) y;
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
- const int kbx = k0 / QI2_K;
- const int ky = (k0 % QI2_K) * QR2_K;
+ sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq(
+ &x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + (QR2_K*k0) % WARP_SIZE],
+ &x_dm[i*(WARP_SIZE + 1) + k0], y_df[j*MMQ_TILE_Y_K + ((QR2_K*k0) % WARP_SIZE)/QI8_1]);
+ }
+ }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
+ const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
+ const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+#ifdef INT8_MMA_AVAILABLE
+
+ typedef mma_int_A_I16K4 mma_A;
+ typedef mma_int_B_J8K4 mma_B;
+ typedef mma_int_C_I16J8 mma_C;
+
+ const int * y_qs = (const int *) y + 4;
+ const float * y_df = (const float *) y;
- int v[QR2_K*VDR_Q2_K_Q8_1_MMQ];
+ const int i0 = threadIdx.y*mma_A::I;
+ static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
- const int kqsx = i*(WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2);
- const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2));
+ mma_A A[2];
+ float dA[mma_C::ne/2][2];
+ float mA[mma_C::ne/2][2];
#pragma unroll
- for (int l = 0; l < QR2_K*VDR_Q2_K_Q8_1_MMQ; ++l) {
- v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303;
- }
+ for (int l = 0; l < mma_A::ne; ++l) {
+ const int i = i0 + mma_A::get_i(l);
+ const int shift = 2*mma_A::get_k(l);
- const uint8_t * scales = ((const uint8_t *) &x_sc[i*(WARP_SIZE/4) + i/4 + kbx*4]) + ky/4;
+ A[0].x[l] = (x_qs[i*(WARP_SIZE + 1) + k0 + 0] >> shift) & 0x03030303;
+ A[1].x[l] = (x_qs[i*(WARP_SIZE + 1) + k0 + 1] >> shift) & 0x03030303;
+ }
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq(
- v, &y_qs[j*MMQ_TILE_Y_K + (QR2_K*k0) % WARP_SIZE], scales,
- x_dm[i*(WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[j*MMQ_TILE_Y_K + ((QR2_K*k0) % WARP_SIZE)/QI8_1]);
+#pragma unroll
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int i = i0 + mma_C::get_i(2*l);
+
+#pragma unroll
+ for (int kk = 0; kk < 2; ++kk) {
+ const float2 dm = __half22float2(x_dm[i*(WARP_SIZE + 1) + k0 + kk]);
+
+ dA[l][kk] = dm.x;
+ mA[l][kk] = dm.y;
}
}
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
+ mma_C Cd[2];
+ mma_C Cm[2];
+ mma_B B[2];
+ float dB[mma_C::ne/2];
+
+#pragma unroll
+ for (int l = 0; l < mma_B::ne; ++l) {
+ const int j = j0 + mma_B::get_j(l);
+ const int k = (4*k0 + mma_B::get_k(l)) % WARP_SIZE;
+
+ B[0].x[l] = y_qs[j*MMQ_TILE_Y_K + k + 0];
+ B[1].x[l] = y_qs[j*MMQ_TILE_Y_K + k + mma_B::K];
+ }
+#pragma unroll
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int j = j0 + mma_C::get_j(l);
+
+ dB[l] = y_df[j*MMQ_TILE_Y_K + ((4*k0)/QI8_1) % (WARP_SIZE/QI8_1)];
+ }
+
+ Cd[0].mma_K4(A[0], B[0]);
+ Cd[1].mma_K4(A[1], B[1]);
+
+ mma_A A1;
+ A1.x[0] = 0x01010101;
+ A1.x[1] = 0x01010101;
+ Cm[0].mma_K4(A1, B[0]);
+ Cm[1].mma_K4(A1, B[1]);
+
+#pragma unroll
+ for (int l = 0; l < mma_C::ne; ++l) {
+ sum[(j0/mma_B::J)*mma_C::ne + l] += (Cd[0].x[l]*dA[l/2][0] + Cd[1].x[l]*dA[l/2][1] - Cm[0].x[l]*mA[l/2][0] - Cm[1].x[l]*mA[l/2][1])*dB[l%2];
+ }
+ }
+#else
+ GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+ NO_DEVICE_CODE;
+#endif // INT8_MMA_AVAILABLE
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
- const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+ const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
const int kbx = threadIdx.x / QI3_K;
const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbx;
- x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx);
+ const int x_ql_0 = get_int_from_uint8(bxi->qs, kqsx);
+ const int x_qh_0 = get_int_from_uint8(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2)));
+
+#pragma unroll
+ for (int l = 0; l < QR3_K; ++l) {
+ const int k = kbx*(QR3_K*QI3_K) + (kqsx/8)*32 + l*8 + kqsx % 8;
+
+ const int x_ql_k = (x_ql_0 >> (2*l)) & 0x03030303;
+ const int x_qh_k = ((x_qh_0 >> l) << 2) & 0x04040404;
+
+ int x_qs_k = (x_ql_k | x_qh_k) << (4*(k%2));
+ x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 1, WARP_SIZE);
+
+ if (kqsx % 2 != 0) {
+ continue;
+ }
+
+ x_qs[i*(2*WARP_SIZE + 1) + k/2] = x_qs_k;
+ }
}
const int blocks_per_tile_x_row = WARP_SIZE / QI3_K;
x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d;
}
-#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) {
- int i = i0 + threadIdx.y * 2 + threadIdx.x / (WARP_SIZE/2);
-
- if (need_check) {
- i = min(i, i_max);
- }
-
- const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/2)) / (QI3_K/2);
-
- // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
- x_qh[i * (WARP_SIZE/2) + i / 2 + threadIdx.x % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, threadIdx.x % (QI3_K/2));
- }
-
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
}
template <int mmq_x, int mmq_y, int nwarps>
-static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mul_mat(
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
+ const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
- const float * x_dmf = (const float *) x_dm;
- const int * y_qs = (const int *) y + 4;
- const float * y_df = (const float *) y;
+ const float * x_df = (const float *) x_dm;
+ const int * y_qs = (const int *) y + 4;
+ const float * y_df = (const float *) y;
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
- int v[QR3_K*VDR_Q3_K_Q8_1_MMQ];
+ sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq(
+ &x_qs[i*(2*WARP_SIZE + 1) + 2*k0], &y_qs[j*MMQ_TILE_Y_K + (k0*QR3_K) % WARP_SIZE], scales,
+ x_df[i*(WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[j*MMQ_TILE_Y_K + ((k0*QR3_K) % WARP_SIZE)/QI8_1]);
+ }
+ }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mma(
+ const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
+ const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+#ifdef INT8_MMA_AVAILABLE
+
+ typedef mma_int_A_I16K4 mma_A;
+ typedef mma_int_B_J8K4 mma_B;
+ typedef mma_int_C_I16J8 mma_C;
+
+ const float * x_df = (const float *) x_dm;
+ const int * y_qs = (const int *) y + 4;
+ const float * y_df = (const float *) y;
+
+ const int i0 = threadIdx.y*mma_A::I;
+ static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+
+ mma_A A[2];
+ int scA[mma_C::ne/2][2];
+ float dA[mma_C::ne/2];
#pragma unroll
- for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) {
- const int kqsx = i*(WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2);
- const int shift = 2 * ((ky % 32) / 8);
- const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303;
+ for (int l = 0; l < mma_A::ne; ++l) {
+ const int i = i0 + mma_A::get_i(l);
+ const int k = QR3_K*k0 + mma_A::get_k(l);
- const int vh = x_qh[i*(WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8);
- const int vlh = (vh << 2) & 0x04040404;
+ A[0].x[l] = (x_qs[i*(2*WARP_SIZE + 1) + k/2 + 0] >> (4*(k%2))) & 0x0F0F0F0F;
+ A[1].x[l] = (x_qs[i*(2*WARP_SIZE + 1) + k/2 + mma_A::K/2] >> (4*(k%2))) & 0x0F0F0F0F;
+ A[0].x[l] = __vsubss4(A[0].x[l], 0x04040404);
+ A[1].x[l] = __vsubss4(A[1].x[l], 0x04040404);
+ }
- v[l] = __vsubss4(vll, vlh);
- }
+#pragma unroll
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int i = i0 + mma_C::get_i(2*l);
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq(
- v, &y_qs[j*MMQ_TILE_Y_K + (k0*QR3_K) % WARP_SIZE], scales,
- x_dmf[i*(WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[j*MMQ_TILE_Y_K + ((k0*QR3_K) % WARP_SIZE)/QI8_1]);
+ const int kbx = k0 / QI3_K;
+ const int ky = (k0 % QI3_K) * QR3_K;
+ const int8_t * sc = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
+
+ scA[l][0] = sc[0];
+ scA[l][1] = sc[1];
+ }
+
+#pragma unroll
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int i = i0 + mma_C::get_i(2*l);
+
+ dA[l] = x_df[i*(WARP_SIZE/QI3_K) + i/QI3_K + k0/QI3_K];
+ }
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
+ mma_C C[2];
+ mma_B B[2];
+ float dB[mma_C::ne/2];
+
+#pragma unroll
+ for (int l = 0; l < mma_B::ne; ++l) {
+ const int j = j0 + mma_B::get_j(l);
+ const int k = (4*k0 + mma_B::get_k(l)) % WARP_SIZE;
+
+ B[0].x[l] = y_qs[j*MMQ_TILE_Y_K + k + 0];
+ B[1].x[l] = y_qs[j*MMQ_TILE_Y_K + k + mma_B::K];
+ }
+#pragma unroll
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int j = j0 + mma_C::get_j(l);
+
+ dB[l] = y_df[j*MMQ_TILE_Y_K + ((4*k0)/QI8_1) % (WARP_SIZE/QI8_1)];
+ }
+
+ C[0].mma_K4(A[0], B[0]);
+ C[1].mma_K4(A[1], B[1]);
+
+#pragma unroll
+ for (int l = 0; l < mma_C::ne; ++l) {
+ sum[(j0/mma_B::J)*mma_C::ne + l] += (C[0].x[l]*scA[l/2][0] + C[1].x[l]*scA[l/2][1])*dA[l/2]*dB[l%2];
}
}
+#else
+ GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+ NO_DEVICE_CODE;
+#endif // INT8_MMA_AVAILABLE
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
- const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+ const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
- GGML_UNUSED(x_qh);
const int kbx = 0; // threadIdx.x / QI4_K
const int kqsx = threadIdx.x; // threadIdx.x % QI4_K
const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbx;
- x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+ x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
}
const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+ const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
- GGML_UNUSED(x_qh);
-
const int * y_qs = (const int *) y + 4;
const half2 * y_ds = (const half2 *) y;
const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2*((k0 % 16) / 8);
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq(
- &x_ql[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + (QR4_K*k0) % WARP_SIZE], sc, sc+8,
+ &x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + (QR4_K*k0) % WARP_SIZE], sc, sc+8,
x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[j*MMQ_TILE_Y_K + ((QR4_K*k0) % WARP_SIZE)/QI8_1]);
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+ const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
-
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+#ifdef INT8_MMA_AVAILABLE
typedef mma_int_A_I16K8 mma_A;
typedef mma_int_B_J8K8 mma_B;
const int i = i0 + mma_A::get_i(l);
const int k = k0 + mma_A::get_k(l);
- A[kvdr/4].x[l] = (x_ql[i*(WARP_SIZE + 1) + k] >> kvdr) & 0x0F0F0F0F;
+ A[kvdr/4].x[l] = (x_qs[i*(WARP_SIZE + 1) + k] >> kvdr) & 0x0F0F0F0F;
}
#pragma unroll
sum[(j0/mma_B::J)*mma_C::ne + l] += __low2float(dmA[l/2])*tmpd[l] - __high2float(dmA[l/2])*tmpm[l];
}
}
+#else
+ GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+ NO_DEVICE_CODE;
+#endif // INT8_MMA_AVAILABLE
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
- const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+ const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
- GGML_UNUSED(x_qh);
const int kbx = 0; // threadIdx.x / QI5_K
const int kqsx = threadIdx.x; // threadIdx.x % QI5_K
const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0;
const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + (QI5_K/4);
- x_ql[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
- x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
+ x_qs[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
+ x_qs[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
}
const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+ const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
- GGML_UNUSED(x_qh);
-
const int * y_qs = (const int *) y + 4;
const half2 * y_ds = (const half2 *) y;
const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq(
- &x_ql[i*(QR5_K*WARP_SIZE + 1) + QR5_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR5_K*k0) % WARP_SIZE], sc, sc+8,
+ &x_qs[i*(QR5_K*WARP_SIZE + 1) + QR5_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR5_K*k0) % WARP_SIZE], sc, sc+8,
x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[j*MMQ_TILE_Y_K + ((QR5_K*k0) % WARP_SIZE)/QI8_1]);
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+ const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
-
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+#ifdef INT8_MMA_AVAILABLE
typedef mma_int_A_I16K8 mma_A;
typedef mma_int_B_J8K8 mma_B;
const int i = i0 + mma_A::get_i(l);
const int k = QR5_K*k0 + QR5_K*kvdr + mma_A::get_k(l);
- A[kvdr/4].x[l] = x_ql[i*(QR5_K*WARP_SIZE + 1) + k];
+ A[kvdr/4].x[l] = x_qs[i*(QR5_K*WARP_SIZE + 1) + k];
}
#pragma unroll
sum[(j0/mma_B::J)*mma_C::ne + l] += __low2float(dmA[l/2])*tmpd[l] - __high2float(dmA[l/2])*tmpm[l];
}
}
+#else
+ GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+ NO_DEVICE_CODE;
+#endif // INT8_MMA_AVAILABLE
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
- const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+ const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
- GGML_UNUSED(x_qh);
const int kbx = 0; // threadIdx.x / QI6_K
const int kqsx = threadIdx.x; // threadIdx.x % QI6_K
const int kq0 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + 0;
const int kq1 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + (QI6_K/2);
- x_ql[i * (2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
- x_ql[i * (2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
+ x_qs[i * (2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
+ x_qs[i * (2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
}
const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+ const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
- GGML_UNUSED(x_qh);
-
const float * x_dmf = (const float *) x_dm;
const int * y_qs = (const int *) y + 4;
const float * y_df = (const float *) y;
const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/8]);
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq(
- &x_ql[i*(QR6_K*WARP_SIZE + 1) + QR6_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR6_K*k0) % WARP_SIZE], sc,
+ &x_qs[i*(QR6_K*WARP_SIZE + 1) + QR6_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR6_K*k0) % WARP_SIZE], sc,
x_dmf[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + ((QR6_K*k0) % WARP_SIZE)/QI8_1]);
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+ const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
-
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+#ifdef INT8_MMA_AVAILABLE
typedef mma_int_A_I16K4 mma_A;
typedef mma_int_B_J8K4 mma_B;
const float * y_df = (const float *) y;
const int i0 = threadIdx.y*mma_A::I;
+#ifdef INT8_MMA_AVAILABLE
static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+#endif // INT8_MMA_AVAILABLE
mma_A A[4];
int scA[mma_C::ne/2][4];
const int i = i0 + mma_A::get_i(l);
const int k = QR6_K*k0 + QR6_K*kvdr + mma_A::get_k(l);
- A[kvdr/2 + 0].x[l] = x_ql[i*(QR6_K*WARP_SIZE + 1) + k + 0];
- A[kvdr/2 + 1].x[l] = x_ql[i*(QR6_K*WARP_SIZE + 1) + k + mma_A::K];
+ A[kvdr/2 + 0].x[l] = x_qs[i*(QR6_K*WARP_SIZE + 1) + k + 0];
+ A[kvdr/2 + 1].x[l] = x_qs[i*(QR6_K*WARP_SIZE + 1) + k + mma_A::K];
}
#pragma unroll
sum[(j0/mma_B::J)*mma_C::ne + l] += tmp[l]*dA[l/2];
}
}
+#else
+ GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+ NO_DEVICE_CODE;
+#endif // INT8_MMA_AVAILABLE
}
template<int mmq_x, int mmq_y, int nwarps, bool need_check>
typedef mma_int_C_I16J8 mma_C;
const int i0 = threadIdx.y*mma_C::I;
+#ifdef INT8_MMA_AVAILABLE
static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
+#endif // INT8_MMA_AVAILABLE
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += mma_C::J) {
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
- static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, nwarps, need_check>;
-#ifdef INT8_MMA_AVAILABLE
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
- static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
-#else
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
- static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
-#endif // INT8_MMA_AVAILABLE
+ static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q4_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
- static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ;
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, nwarps, need_check>;
-#ifdef INT8_MMA_AVAILABLE
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
- static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
-#else
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
- static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
-#endif // INT8_MMA_AVAILABLE
+ static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q4_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
- static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
-#ifdef INT8_MMA_AVAILABLE
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
- static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
-#else
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
- static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
-#endif // INT8_MMA_AVAILABLE
+ static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
- static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>;
-#ifdef INT8_MMA_AVAILABLE
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
- static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
-#else
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
- static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
-#endif // INT8_MMA_AVAILABLE
+ static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
- static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>;
-#ifdef INT8_MMA_AVAILABLE
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
- static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
-#else
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
- static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
-#endif // INT8_MMA_AVAILABLE
+ static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q2_K> {
- static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, nwarps, need_check>;
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q2_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
- static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
+ static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> {
- static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ;
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, nwarps, need_check>;
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q3_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
- static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
+ static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q3_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
- static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
-#ifdef INT8_MMA_AVAILABLE
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
- static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
-#else
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
- static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
-#endif // INT8_MMA_AVAILABLE
+ static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q4_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
- static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
-#ifdef INT8_MMA_AVAILABLE
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
- static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
-#else
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
- static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
-#endif // INT8_MMA_AVAILABLE
+ static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
- static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>;
-#ifdef INT8_MMA_AVAILABLE
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q6_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
- static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
-#else
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
- static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
-#endif // INT8_MMA_AVAILABLE
+ static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};
-static int mmq_need_sum(const ggml_type type_x) {
+static bool mmq_need_sum(const ggml_type type_x) {
switch (type_x) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
#if __CUDA_ARCH__ >= CC_VOLTA
__launch_bounds__(WARP_SIZE*nwarps, 1)
#else
- __launch_bounds__(WARP_SIZE*nwarps, type == GGML_TYPE_Q2_K ? 1 : 2)
+ __launch_bounds__(WARP_SIZE*nwarps, 2)
#endif // __CUDA_ARCH__ >= CC_VOLTA
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
static __global__ void mul_mat_q(
constexpr int mmq_y = get_mmq_y_device(mmq_x);
constexpr int vdr = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
- constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot;
- constexpr mmq_write_back_t write_back = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::write_back;
+
+#ifdef INT8_MMA_AVAILABLE
+ constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_mma;
+ constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
+#else
+ constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_dp4a;
+ constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
+#endif // INT8_MMA_AVAILABLE
constexpr tile_x_sizes txs = get_tile_x_sizes_device<mmq_y>(type);
extern __shared__ char data_mul_mat_q[];
- int * tile_x_ql = (int *) data_mul_mat_q;
- half2 * tile_x_dm = (half2 *) (tile_x_ql + txs.ql);
- int * tile_x_qh = (int *) (tile_x_dm + txs.dm);
- int * tile_x_sc = (int *) (tile_x_qh + txs.qh);
+ int * tile_x_qs = (int *) data_mul_mat_q;
+ half2 * tile_x_dm = (half2 *) (tile_x_qs + txs.qs);
+ int * tile_x_sc = (int *) (tile_x_dm + txs.dm);
int * tile_y = (int *) (tile_x_sc + txs.sc); // [mmq_x * (WARP_SIZE + WARP_SIZE/QI8_1)]
const int blocks_per_row_x = ne00 / qk;
for (int kb0 = 0; kb0 < blocks_per_row_x; kb0 += blocks_per_warp) {
- load_tiles(x, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, stride01*blockIdx.x*mmq_y + kb0, tile_x_max_i, stride01);
+ load_tiles(x, tile_x_qs, tile_x_dm, tile_x_sc, stride01*blockIdx.x*mmq_y + kb0, tile_x_max_i, stride01);
#pragma unroll
for (int kr = 0; kr < qr; ++kr) {
// #pragma unroll // unrolling this loop causes too much register pressure
for (int k0 = kr*WARP_SIZE/qr; k0 < (kr+1)*WARP_SIZE/qr; k0 += vdr) {
- vec_dot(tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y, sum, k0);
+ vec_dot(tile_x_qs, tile_x_dm, tile_x_sc, tile_y, sum, k0);
}
__syncthreads();
int64_t ne0;
};
+constexpr int mmq_get_nwarps(int mmq_x) {
+ return mmq_x >= 32 ? 8 : 4;
+}
+
+static int mmq_get_shmem(const ggml_type type, const int mmq_x, const int mmq_y) {
+ const tile_x_sizes txs = get_tile_x_sizes_host(type, mmq_y);
+ const int nwarps = mmq_get_nwarps(mmq_x);
+
+ const int shmem_x = txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
+ const int shmem_y = mmq_x*WARP_SIZE*sizeof(int) + mmq_x*(WARP_SIZE/QI8_1)*sizeof(half2);
+ return shmem_x + GGML_PAD(shmem_y, nwarps*WARP_SIZE*sizeof(int));
+}
+
template <ggml_type type, int mmq_x, int nwarps>
static void launch_mul_mat_q(const mmq_args & args, cudaStream_t stream) {
const int id = ggml_cuda_get_device();
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);
- const tile_x_sizes txs = get_tile_x_sizes_host(type, mmq_y);
- const int shmem_x = txs.ql*sizeof(int) + txs.dm*sizeof(half2) + txs.qh*sizeof(int) + txs.sc*sizeof(int);
- const int shmem_y = mmq_x*WARP_SIZE*sizeof(int) + mmq_x*(WARP_SIZE/QI8_1)*sizeof(half2);
- const int shmem = shmem_x + GGML_PAD(shmem_y, nwarps*WARP_SIZE*sizeof(int));
+ const int shmem = mmq_get_shmem(type, mmq_x, mmq_y);
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
template <ggml_type type>
void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) {
- const int id = ggml_cuda_get_device();
- const int nsm = ggml_cuda_info().devices[id].nsm;
- const int cc = ggml_cuda_info().devices[id].cc;
+ const int id = ggml_cuda_get_device();
+ const int nsm = ggml_cuda_info().devices[id].nsm;
+ const int cc = ggml_cuda_info().devices[id].cc;
+ const int smpbo = ggml_cuda_info().devices[id].smpbo;
const int mmq_x_max = get_mmq_x_max_host(cc);
const int mmq_y = get_mmq_y_host(cc, mmq_x_max);
const int block_num_x = (args.ne11 + mmq_x - 1) / mmq_x;
const int nwaves = (block_num_x*block_num_y + nsm - 1) / nsm;
- if (nwaves < nwaves_best) {
+ if (nwaves < nwaves_best && mmq_get_shmem(type, mmq_x, mmq_y) <= smpbo) {
mmq_x_best = mmq_x;
nwaves_best = nwaves;
}
switch (mmq_x_best) {
case 8:
- launch_mul_mat_q<type, 8, 4>(args, stream);
+ launch_mul_mat_q<type, 8, mmq_get_nwarps( 8)>(args, stream);
break;
case 16:
- launch_mul_mat_q<type, 16, 4>(args, stream);
+ launch_mul_mat_q<type, 16, mmq_get_nwarps( 16)>(args, stream);
break;
case 24:
- launch_mul_mat_q<type, 24, 4>(args, stream);
+ launch_mul_mat_q<type, 24, mmq_get_nwarps( 24)>(args, stream);
break;
case 32:
- launch_mul_mat_q<type, 32, 8>(args, stream);
+ launch_mul_mat_q<type, 32, mmq_get_nwarps( 32)>(args, stream);
break;
case 40:
- launch_mul_mat_q<type, 40, 8>(args, stream);
+ launch_mul_mat_q<type, 40, mmq_get_nwarps( 40)>(args, stream);
break;
case 48:
- launch_mul_mat_q<type, 48, 8>(args, stream);
+ launch_mul_mat_q<type, 48, mmq_get_nwarps( 48)>(args, stream);
break;
case 56:
- launch_mul_mat_q<type, 56, 8>(args, stream);
+ launch_mul_mat_q<type, 56, mmq_get_nwarps( 56)>(args, stream);
break;
case 64:
- launch_mul_mat_q<type, 64, 8>(args, stream);
+ launch_mul_mat_q<type, 64, mmq_get_nwarps( 64)>(args, stream);
break;
case 72:
- launch_mul_mat_q<type, 72, 8>(args, stream);
+ launch_mul_mat_q<type, 72, mmq_get_nwarps( 72)>(args, stream);
break;
case 80:
- launch_mul_mat_q<type, 80, 8>(args, stream);
+ launch_mul_mat_q<type, 80, mmq_get_nwarps( 80)>(args, stream);
break;
case 88:
- launch_mul_mat_q<type, 88, 8>(args, stream);
+ launch_mul_mat_q<type, 88, mmq_get_nwarps( 88)>(args, stream);
break;
case 96:
- launch_mul_mat_q<type, 96, 8>(args, stream);
+ launch_mul_mat_q<type, 96, mmq_get_nwarps( 96)>(args, stream);
break;
case 104:
- launch_mul_mat_q<type, 104, 8>(args, stream);
+ launch_mul_mat_q<type, 104, mmq_get_nwarps(104)>(args, stream);
break;
case 112:
- launch_mul_mat_q<type, 112, 8>(args, stream);
+ launch_mul_mat_q<type, 112, mmq_get_nwarps(112)>(args, stream);
break;
case 120:
- launch_mul_mat_q<type, 120, 8>(args, stream);
+ launch_mul_mat_q<type, 120, mmq_get_nwarps(120)>(args, stream);
break;
case 128:
- launch_mul_mat_q<type, 128, 8>(args, stream);
+ launch_mul_mat_q<type, 128, mmq_get_nwarps(128)>(args, stream);
break;
default:
+ fprintf(stderr, "mmq_x_best=%d\n", mmq_x_best);
GGML_ASSERT(false);
break;
}