case GGML_TYPE_Q5_K:
return MMQ_Q8_1_DS_LAYOUT_DS4;
case GGML_TYPE_Q6_K:
+ case GGML_TYPE_IQ2_XXS:
+ case GGML_TYPE_IQ2_XS:
+ case GGML_TYPE_IQ2_S:
+ case GGML_TYPE_IQ3_XXS:
+ case GGML_TYPE_IQ3_S:
+ return MMQ_Q8_1_DS_LAYOUT_D4;
+ case GGML_TYPE_IQ1_S:
+ return MMQ_Q8_1_DS_LAYOUT_DS4;
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
return MMQ_Q8_1_DS_LAYOUT_D4;
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
}
-#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
-#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0}
-#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_0 + mmq_y/(QI8_0/2), 0}
-#define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_1 + mmq_y/(QI8_1/2), 0}
-#define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE + mmq_y, 0}
-#define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y, mmq_y*WARP_SIZE/8 + mmq_y/8}
-#define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
-#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
-#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
+#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
+#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0}
+#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_0 + mmq_y/(QI8_0/2), 0}
+#define MMQ_DP4A_TXS_Q8_0_16 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*4/QI8_0 + mmq_y/(QI8_0/4), 0}
+#define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_1 + mmq_y/(QI8_1/2), 0}
+#define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE + mmq_y, 0}
+#define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y, mmq_y*WARP_SIZE/8 + mmq_y/8}
+#define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
+#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
+#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
+ type == GGML_TYPE_IQ2_XXS ? MMQ_DP4A_TXS_Q8_0 :
+ type == GGML_TYPE_IQ2_XS ? MMQ_DP4A_TXS_Q8_0_16 :
+ type == GGML_TYPE_IQ2_S ? MMQ_DP4A_TXS_Q8_0_16 :
+ type == GGML_TYPE_IQ3_XXS ? MMQ_DP4A_TXS_Q8_0 :
+ type == GGML_TYPE_IQ3_S ? MMQ_DP4A_TXS_Q8_0 :
+ type == GGML_TYPE_IQ1_S ? MMQ_DP4A_TXS_Q8_0 :
type == GGML_TYPE_IQ4_XS ? MMQ_DP4A_TXS_Q8_0 :
type == GGML_TYPE_IQ4_NL ? MMQ_DP4A_TXS_Q8_0 :
tile_x_sizes{0, 0, 0};
}
-#define MMQ_MMA_TILE_X_K_Q4_0 (1*WARP_SIZE + WARP_SIZE/QI4_0 + 4)
-#define MMQ_MMA_TILE_X_K_Q4_1 (1*WARP_SIZE + WARP_SIZE/QI4_1 + 4)
#define MMQ_MMA_TILE_X_K_Q8_0 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4)
#define MMQ_MMA_TILE_X_K_Q8_1 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4)
#define MMQ_MMA_TILE_X_K_Q2_K (2*WARP_SIZE + WARP_SIZE + 4)
-#define MMQ_MMA_TILE_X_K_Q3_K (2*WARP_SIZE + WARP_SIZE/(2*QI3_K) + WARP_SIZE/8 + 7)
-#define MMQ_MMA_TILE_X_K_Q4_K (1*WARP_SIZE + WARP_SIZE/QI4_K + WARP_SIZE/8 + 7)
-#define MMQ_MMA_TILE_X_K_Q5_K (2*WARP_SIZE + WARP_SIZE/QI5_K + WARP_SIZE/8 + 7)
+#define MMQ_MMA_TILE_X_K_Q3_K (2*WARP_SIZE + WARP_SIZE/2 + 4)
#define MMQ_MMA_TILE_X_K_Q6_K (2*WARP_SIZE + WARP_SIZE/QI6_K + WARP_SIZE/8 + 7)
-static_assert(MMQ_MMA_TILE_X_K_Q4_0 % 8 == 4, "Wrong padding.");
-static_assert(MMQ_MMA_TILE_X_K_Q4_1 % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
-static_assert(MMQ_MMA_TILE_X_K_Q4_K % 8 == 4, "Wrong padding.");
-static_assert(MMQ_MMA_TILE_X_K_Q5_K % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
- return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q4_0 :
- type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q4_1 :
+ return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
+ type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K :
type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K :
- type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q4_K :
- type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q5_K :
+ type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q8_1 :
+ type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q8_1 :
type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
+ type == GGML_TYPE_IQ2_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
+ type == GGML_TYPE_IQ2_XS ? MMQ_MMA_TILE_X_K_Q3_K :
+ type == GGML_TYPE_IQ2_S ? MMQ_MMA_TILE_X_K_Q3_K :
+ type == GGML_TYPE_IQ3_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
+ type == GGML_TYPE_IQ3_S ? MMQ_MMA_TILE_X_K_Q8_0 :
+ type == GGML_TYPE_IQ1_S ? MMQ_MMA_TILE_X_K_Q8_0 :
type == GGML_TYPE_IQ4_XS ? MMQ_MMA_TILE_X_K_Q8_0 :
type == GGML_TYPE_IQ4_NL ? MMQ_MMA_TILE_X_K_Q8_0 :
0;
#ifdef INT8_MMA_AVAILABLE
int * x_qs = (int *) x_tile;
- float * x_df = (float *) (x_qs + WARP_SIZE);
+ float * x_df = (float *) (x_qs + 2*WARP_SIZE);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
int * x_qs = (int *) x_tile;
}
const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
+ const int qs0 = get_int_b2(bxi->qs, kqsx);
#ifdef INT8_MMA_AVAILABLE
- x_qs[i*MMQ_MMA_TILE_X_K_Q4_0 + threadIdx.x] = get_int_b2(bxi->qs, kqsx);
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808);
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808);
#else
- x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_b2(bxi->qs, kqsx);
+ x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
#endif // INT8_MMA_AVAILABLE
}
const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
#ifdef INT8_MMA_AVAILABLE
- x_df[i*MMQ_MMA_TILE_X_K_Q4_0 + kbxd] = bxi->d;
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
#else
x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
#endif // INT8_MMA_AVAILABLE
}
}
-template <int mmq_x, int mmq_y, int nwarps>
-static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma(
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
-#ifdef INT8_MMA_AVAILABLE
-
- typedef mma_int_A_I16K8 mma_A;
- typedef mma_int_B_J8K8 mma_B;
- typedef mma_int_C_I16J8 mma_C;
-
- constexpr int granularity = mmq_get_granularity_device(mmq_x);
- constexpr int rows_per_warp = 2 * granularity;
- constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
-
- y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
-
- const int * x_qs = (const int *) x;
- const float * x_df = (const float *) x_qs + WARP_SIZE;
- const int * y_qs = (const int *) y + 4;
- const half2 * y_ds = (const half2 *) y;
-
- mma_A A[ntx][4];
- float dA[ntx][mma_C::ne/2][4];
-
- const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
-
-#pragma unroll
- for (int n = 0; n < ntx; ++n) {
-#pragma unroll
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_0*QI4_0) {
- const int k0 = k00 + k01;
-
-#pragma unroll
- for (int l = 0; l < mma_A::ne; ++l) {
- const int i = i0 + n*mma_A::I + mma_A::get_i(l);
- const int k = k0/QR4_0 + mma_A::get_k(l) % QI4_0;
- const int shift = 4*(mma_A::get_k(l) / QI4_0);
-
- A[n][k01/(QR4_0*QI4_0)].x[l] = __vsubss4((x_qs[i*MMQ_MMA_TILE_X_K_Q4_0 + k] >> shift) & 0x0F0F0F0F, 0x08080808);
- }
-
-#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
-
- dA[n][l][k01/(QR4_0*QI4_0)] = x_df[i*MMQ_MMA_TILE_X_K_Q4_0 + k0/(QR4_0*QI4_0)];
- }
- }
- }
-
-#pragma unroll
- for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
-#pragma unroll
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_0*QI4_0) {
- mma_B B;
- float dB[mma_C::ne/2];
-
- B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
-
-#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int j = j0 + mma_C::get_j(l);
-
- dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
- }
-
-#pragma unroll
- for (int n = 0; n < ntx; ++n) {
- mma_C C;
- C.mma_K8(A[n][k01/(QR4_0*QI4_0)], B);
-
-#pragma unroll
- for (int l = 0; l < mma_C::ne; ++l) {
- sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2][k01/(QR4_0*QI4_0)]*dB[l%2]*C.x[l];
- }
- }
- }
- }
-#else
- GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
- NO_DEVICE_CODE;
-#endif // INT8_MMA_AVAILABLE
-}
-
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
#ifdef INT8_MMA_AVAILABLE
int * x_qs = (int *) x_tile;
- half2 * x_dm = (half2 *) (x_qs + WARP_SIZE);
+ half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
int * x_qs = (int *) x_tile;
}
const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
+ const int qs0 = get_int_b4(bxi->qs, kqsx);
#ifdef INT8_MMA_AVAILABLE
- x_qs[i*MMQ_MMA_TILE_X_K_Q4_1 + threadIdx.x] = get_int_b4(bxi->qs, kqsx);
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F;
#else
- x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_b4(bxi->qs, kqsx);
+ x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
#endif // INT8_MMA_AVAILABLE
}
const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
#ifdef INT8_MMA_AVAILABLE
- x_dm[i*MMQ_MMA_TILE_X_K_Q4_1 + kbxd] = bxi->dm;
+ x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
#else
x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
#endif // INT8_MMA_AVAILABLE
}
}
-template <int mmq_x, int mmq_y, int nwarps>
-static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma(
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
-#ifdef INT8_MMA_AVAILABLE
-
- typedef mma_int_A_I16K8 mma_A;
- typedef mma_int_A_I16K4 mma_A_K4;
- typedef mma_int_B_J8K8 mma_B;
- typedef mma_int_C_I16J8 mma_C;
-
- constexpr int granularity = mmq_get_granularity_device(mmq_x);
- constexpr int rows_per_warp = 2 * granularity;
- constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
-
- y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
-
- const int * x_qs = (const int *) x;
- const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE;
- const int * y_qs = (const int *) y + 4;
- const half2 * y_ds = (const half2 *) y;
-
- mma_A A[ntx][4];
- half2 dmA[ntx][mma_C::ne/2][4];
-
- const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
-
-#pragma unroll
- for (int n = 0; n < ntx; ++n) {
-#pragma unroll
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_1*QI4_1) {
- const int k0 = k00 + k01;
-
- A[n][k01/(QR4_1*QI4_1)].load_low(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_1 + k0/QR4_1, MMQ_MMA_TILE_X_K_Q4_1);
- A[n][k01/(QR4_1*QI4_1)].x[2] = (A[n][k01/(QR4_1*QI4_1)].x[0] >> 4) & 0x0F0F0F0F;
- A[n][k01/(QR4_1*QI4_1)].x[3] = (A[n][k01/(QR4_1*QI4_1)].x[1] >> 4) & 0x0F0F0F0F;
- A[n][k01/(QR4_1*QI4_1)].x[0] &= 0x0F0F0F0F;
- A[n][k01/(QR4_1*QI4_1)].x[1] &= 0x0F0F0F0F;
-
-#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
-
- dmA[n][l][k01/(QR4_1*QI4_1)] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_1 + k0/(QR4_1*QI4_1)];
- }
- }
- }
-
-#pragma unroll
- for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
-#pragma unroll
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_1*QI4_1) {
- mma_B B;
- half2 dsB[mma_C::ne/2];
-
- B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
-
-#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int j = j0 + mma_C::get_j(l);
-
- dsB[l] = y_ds[j*MMQ_TILE_Y_K + k01/QI8_1];
- }
-
-#pragma unroll
- for (int n = 0; n < ntx; ++n) {
- mma_C C;
- C.mma_K8(A[n][k01/(QR4_1*QI4_1)], B);
-
-#pragma unroll
- for (int l = 0; l < mma_C::ne; ++l) {
- const half2 dmA_dsB = dmA[n][l/2][k01/(QR4_1*QI4_1)]*dsB[l%2];
- sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
- }
- }
- }
- }
-#else
- GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
- NO_DEVICE_CODE;
-#endif // INT8_MMA_AVAILABLE
-}
-
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
}
}
-template <int mmq_x, int mmq_y, int nwarps>
+template <int mmq_x, int mmq_y, int nwarps, mmq_q8_1_ds_layout ds_layout>
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
-#ifdef INT8_MMA_AVAILABLE
typedef mma_int_A_I16K8 mma_A;
typedef mma_int_B_J8K8 mma_B;
const float * x_df = (const float *) x_qs + 2*WARP_SIZE;
const int * y_qs = (const int *) y + 4;
const float * y_df = (const float *) y;
+ const half2 * y_ds = (const half2 *) y;
mma_A A[ntx][WARP_SIZE/QI8_0];
float dA[ntx][mma_C::ne/2][WARP_SIZE/QI8_0];
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
- const int k0 = k00 + k01;
-
- mma_B B;
+ mma_B B;
float dB[mma_C::ne/2];
- B.load(y_qs + j0*MMQ_TILE_Y_K + k0 % WARP_SIZE, MMQ_TILE_Y_K);
+ B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
- dB[l] = y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)];
+ if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
+ dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
+ } else {
+ dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+ }
}
#pragma unroll
}
}
}
-#else
- GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
- NO_DEVICE_CODE;
-#endif // INT8_MMA_AVAILABLE
}
template <int mmq_x, int mmq_y, int nwarps>
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
-#ifdef INT8_MMA_AVAILABLE
typedef mma_int_A_I16K8 mma_A;
typedef mma_int_B_J8K8 mma_B;
const int * y_qs = (const int *) y + 4;
const half2 * y_dm = (const half2 *) y;
- mma_A A[ntx][WARP_SIZE/QI8_1];
- half2 dmA[ntx][mma_C::ne/2][WARP_SIZE/QI8_1];
+ mma_A A[ntx][WARP_SIZE/QI8_1];
+ float2 dmA[ntx][mma_C::ne/2][WARP_SIZE/QI8_1];
const int i0 = (threadIdx.y/ntx)*rows_per_warp;
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
const int k0 = k00 + k01;
- dmA[n][l][k01/QI8_1] = x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1];
+ dmA[n][l][k01/QI8_1] = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]);
}
}
}
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
- const int k0 = k00 + k01;
-
- mma_B B;
- half2 dsB[mma_C::ne/2];
+ mma_B B;
+ float2 dsB[mma_C::ne/2];
- B.load(y_qs + j0*MMQ_TILE_Y_K + k0 % WARP_SIZE, MMQ_TILE_Y_K);
+ B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
- dsB[l] = y_dm[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)];
+ dsB[l] = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
}
#pragma unroll
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
- const half2 dmA_dsB = dmA[n][l/2][k01/QI8_1]*dsB[l%2];
- sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
+ sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l];
+ sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y;
+ }
+ }
+ }
+ }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+
+ constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
+ const int * x_qs = (const int *) x;
+ const float * x_df = (const float *) x_qs + txs.qs;
+ const int * y_qs = (const int *) y + 4;
+ const float * y_df = (const float *) y;
+
+// #pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
+ const int k0 = k00 + k01;
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
+
+ sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_16_q8_1_impl<QI8_0>(
+ &x_qs[i*(2*WARP_SIZE + 1) + k0],
+ &y_qs[j*MMQ_TILE_Y_K + k01],
+ &x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + k0/(QI8_0/2)],
+ y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
+ }
+ }
+ }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+#ifdef INT8_MMA_AVAILABLE
+
+ typedef mma_int_A_I16K4 mma_A;
+ typedef mma_int_A_I16K8 mma_A_K8;
+ typedef mma_int_B_J8K4 mma_B;
+ typedef mma_int_C_I16J8 mma_C;
+
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
+ constexpr int rows_per_warp = 2 * granularity;
+ constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+ y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+ const int * x_qs = (const int *) x;
+ const float * x_df = (const float *) x_qs + WARP_SIZE*2;
+ const int * y_qs = (const int *) y + 4;
+ const float * y_df = (const float *) y;
+
+ const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
+
+ mma_A A[ntx][8];
+ float dA[ntx][mma_C::ne/2][8];
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
+ const int k0 = k00 + k01;
+
+ ((mma_A_K8 *) A[n])[k01/8].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
+ }
+
+#pragma unroll
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
+
+#pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += 4) {
+ const int k0 = k00 + k01;
+
+ dA[n][l][k01/4] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4];
+ }
+ }
+ }
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+#pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
+ mma_B B[2];
+ float dB[mma_C::ne/2];
+
+ B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
+ B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
+
+#pragma unroll
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int j = j0 + mma_C::get_j(l);
+
+ dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
+ }
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ mma_C C[2];
+ C[0].mma_K4(A[n][k01/4 + 0], B[0]);
+ C[1].mma_K4(A[n][k01/4 + 1], B[1]);
+
+#pragma unroll
+ for (int l = 0; l < mma_C::ne; ++l) {
+ sum[(j0/mma_C::J + n)*mma_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]);
}
}
}
#ifdef INT8_MMA_AVAILABLE
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + WARP_SIZE*2);
- int * x_sc = (int *) (x_df + 1);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
int * x_qs = (int *) x_tile;
}
}
-#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps*WARP_SIZE) {
- int i = (i0 + threadIdx.y*WARP_SIZE + threadIdx.x) % mmq_y;
-
- if (need_check) {
- i = min(i, i_max);
- }
-
- const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
-
-#ifdef INT8_MMA_AVAILABLE
- x_df[i*MMQ_MMA_TILE_X_K_Q3_K] = bxi->d;
-#else
- x_df[i] = bxi->d;
-#endif // INT8_MMA_AVAILABLE
- }
-
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*8) {
int i = i0 + threadIdx.y*8 + threadIdx.x/(WARP_SIZE/8);
const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
#ifdef INT8_MMA_AVAILABLE
- x_sc[i*MMQ_MMA_TILE_X_K_Q3_K + threadIdx.x % (WARP_SIZE/8)] = sc;
+ const int8_t * sc8 = (const int8_t *) ≻
+ const float d = bxi->d;
+
+#pragma unroll
+ for (int l = 0; l < sizeof(int); ++l) {
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*(threadIdx.x % (WARP_SIZE/8)) + l] = d*sc8[l];
+ }
#else
- x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = sc;
+ x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = sc;
#endif // INT8_MMA_AVAILABLE
}
+
+#ifndef INT8_MMA_AVAILABLE
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*WARP_SIZE) {
+ int i = (i0 + threadIdx.y*WARP_SIZE + threadIdx.x) % mmq_y;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
+
+ x_df[i] = bxi->d;
+ }
+#endif // INT8_MMA_AVAILABLE
}
template <int mmq_x, int mmq_y, int nwarps>
}
}
-template <int mmq_x, int mmq_y, int nwarps>
-static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mma(
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
-#ifdef INT8_MMA_AVAILABLE
-
- typedef mma_int_A_I16K4 mma_A;
- typedef mma_int_A_I16K8 mma_A_K8;
- typedef mma_int_B_J8K4 mma_B;
- typedef mma_int_C_I16J8 mma_C;
-
- constexpr int granularity = mmq_get_granularity_device(mmq_x);
- constexpr int rows_per_warp = 2 * granularity;
- constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
-
- y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
-
- const int * x_qs = (const int *) x;
- const float * x_df = (const float *) x_qs + WARP_SIZE*2;
- const int * x_sc = (const int *) x_df + 1;
- const int * y_qs = (const int *) y + 4;
- const float * y_df = (const float *) y;
-
- const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
-
- mma_A A[ntx][8];
- int scA[ntx][mma_C::ne/2][8];
- float dA[ntx][mma_C::ne/2];
-
-#pragma unroll
- for (int n = 0; n < ntx; ++n) {
-#pragma unroll
- for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
- const int k0 = k00 + k01;
-
- ((mma_A_K8 *) A[n])[k01/8].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
- }
-
-#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
-
-#pragma unroll
- for (int k01 = 0; k01 < WARP_SIZE; k01 += 16) {
- const int k0 = k00 + k01;
-
- const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q3_K + k0/16];
- const int8_t * sc = (const int8_t *) &sc_packed;
-
-#pragma unroll
- for (int ksc = 0; ksc < sizeof(int); ++ksc) {
- scA[n][l][k01/4 + ksc] = sc[ksc];
- }
- }
-
- dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K];
- }
- }
-
-#pragma unroll
- for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
-#pragma unroll
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
- mma_B B[2];
- float dB[mma_C::ne/2];
-
- B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
- B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
-
-#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int j = j0 + mma_C::get_j(l);
-
- dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
- }
-
-#pragma unroll
- for (int n = 0; n < ntx; ++n) {
- mma_C C[2];
- C[0].mma_K4(A[n][k01/4 + 0], B[0]);
- C[1].mma_K4(A[n][k01/4 + 1], B[1]);
-
-#pragma unroll
- for (int l = 0; l < mma_C::ne; ++l) {
- sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2]*dB[l%2]*
- (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1]);
- }
- }
- }
- }
-#else
- GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
- NO_DEVICE_CODE;
-#endif // INT8_MMA_AVAILABLE
+static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, const int ksc) {
+ // scale arrangement after the following two lines:
+ // - ksc == 0: sc0, sc1, sc2, sc3
+ // - ksc == 1: sc4, sc5, sc6, sc7
+ // - ksc == 2: m0, m1, m2, m3
+ // - ksc == 3: m4, m5, m6, m7
+ return ((scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F) | // lower 4 bits
+ ((scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030); // upper 2 bits
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
#ifdef INT8_MMA_AVAILABLE
int * x_qs = (int *) x_tile;
- half2 * x_dm = (half2 *) (x_qs + WARP_SIZE);
- int * x_sc = (int *) (x_dm + WARP_SIZE/QI4_K);
+ half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
int * x_qs = (int *) x_tile;
int * x_sc = (int *) (x_dm + txs.dm);
#endif // INT8_MMA_AVAILABLE
- const int kbx = 0; // threadIdx.x / QI4_K
- const int kqsx = threadIdx.x; // threadIdx.x % QI4_K
-
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
int i = i0 + threadIdx.y;
i = min(i, i_max);
}
- const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbx;
+ const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
+ const int qs0 = get_int_b4(bxi->qs, threadIdx.x);
#ifdef INT8_MMA_AVAILABLE
- x_qs[i*MMQ_MMA_TILE_X_K_Q4_K + threadIdx.x] = get_int_b4(bxi->qs, kqsx);
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F;
#else
- x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_b4(bxi->qs, kqsx);
+ x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
#endif // INT8_MMA_AVAILABLE
}
- const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256
- const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
+#ifdef INT8_MMA_AVAILABLE
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) {
- int i = (i0 + threadIdx.y * QI4_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) {
+ int i = (i0 + threadIdx.y*16 + threadIdx.x/(WARP_SIZE/16)) % mmq_y;
if (need_check) {
i = min(i, i_max);
}
- const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbxd;
+ const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
+
+ const int * scales = (const int *) bxi->scales;
+ const int ksc = threadIdx.x % (WARP_SIZE/16);
+
+ const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
+ const int m32 = unpack_scales_q45_K(scales, ksc + 2);
+
+ const uint8_t * sc8 = (const uint8_t *) &sc32;
+ const uint8_t * m8 = (const uint8_t *) &m32;
+
+ const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
+
+#pragma unroll
+ for (int l = 0; l < sizeof(int); ++l) {
+ x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
+ }
+ }
-#ifdef INT8_MMA_AVAILABLE
- x_dm[i*MMQ_MMA_TILE_X_K_Q4_K + kbxd] = bxi->dm;
#else
- x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K + kbxd] = bxi->dm;
-#endif // INT8_MMA_AVAILABLE
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*QI4_K) {
+ int i = (i0 + threadIdx.y*QI4_K + threadIdx.x) % mmq_y;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
+
+ x_dm[i] = bxi->dm;
}
#pragma unroll
const int * scales = (const int *) bxi->scales;
const int ksc = threadIdx.x % (WARP_SIZE/8);
+ const int scales8 = unpack_scales_q45_K(scales, ksc);
- // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
- int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
- scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits
-
-#ifdef INT8_MMA_AVAILABLE
- x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + ksc] = scales8;
-#else
- x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
-#endif // INT8_MMA_AVAILABLE
+ x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
}
+#endif // INT8_MMA_AVAILABLE
}
template <int mmq_x, int mmq_y, int nwarps>
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq(
&x_qs[i*(WARP_SIZE + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
- x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
- }
- }
- }
-}
-
-template <int mmq_x, int mmq_y, int nwarps>
-static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
-#ifdef INT8_MMA_AVAILABLE
-
- typedef mma_int_A_I16K8 mma_A;
- typedef mma_int_B_J8K8 mma_B;
- typedef mma_int_C_I16J8 mma_C;
-
- constexpr int granularity = mmq_get_granularity_device(mmq_x);
- constexpr int rows_per_warp = 2 * granularity;
- constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
-
- y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
-
- const int * x_qs = (const int *) x;
- const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE;
- const int * x_sc = (const int *) x_dm + WARP_SIZE/QI4_K;
- const int * y_qs = (const int *) y + 4;
- const half2 * y_ds = (const half2 *) y;
-
- const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
-
- mma_A A[ntx][4];
- int scA[ntx][mma_C::ne/2][4];
- int mA[ntx][mma_C::ne/2][4];
- half2 dmA[ntx][mma_C::ne/2];
-
-#pragma unroll
- for (int n = 0; n < ntx; ++n) {
-#pragma unroll
- for (int k01 = 0; k01 < WARP_SIZE; k01 += 16) {
- const int k0 = k00 + k01;
-
- A[n][k01/8 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_K + k0/QR4_K, MMQ_MMA_TILE_X_K_Q4_K);
-
-#pragma unroll
- for (int l = 0; l < mma_A::ne; ++l) {
- A[n][k01/8 + 1].x[l] = (A[n][k01/8 + 0].x[l] >> 4) & 0x0F0F0F0F;
- A[n][k01/8 + 0].x[l] &= 0x0F0F0F0F;
- }
- }
-
-#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
-
- const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + (k00/32 + 0)];
- const int m_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + (k00/32 + 2)];
-
- const uint8_t * sc = (const uint8_t *) &sc_packed;
- const uint8_t * m = (const uint8_t *) &m_packed;
-
-#pragma unroll
- for (int ksc = 0; ksc < sizeof(int); ++ksc) {
- scA[n][l][ksc] = sc[ksc];
- mA[n][l][ksc] = m[ksc];
+ x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
}
}
-
-#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
-
- dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_K];
- }
}
-
-#pragma unroll
- for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
- float tmpd[ntx][mma_C::ne] = {{0.0f}};
- float tmpm[ntx][mma_C::ne] = {{0.0f}};
-
-#pragma unroll
- for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
- mma_B B;
- half2 dsB[mma_C::ne/2];
-
- B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
-
-#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int j = j0 + mma_C::get_j(l);
-
- dsB[l] = y_ds[j*MMQ_TILE_Y_K + k01/QI8_1];
- }
-
-#pragma unroll
- for (int n = 0; n < ntx; ++n) {
- mma_C C;
- C.mma_K8(A[n][k01/8], B);
-
-#pragma unroll
- for (int l = 0; l < mma_C::ne; ++l) {
- tmpd[n][l] += (C.x[l]*scA[n][l/2][k01/8]) * __low2float(dsB[l%2]);
- tmpm[n][l] += mA[n][l/2][k01/8] * __high2float(dsB[l%2]);
- }
- }
- }
-
-#pragma unroll
- for (int n = 0; n < ntx; ++n) {
-#pragma unroll
- for (int l = 0; l < mma_C::ne; ++l) {
- sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA[n][l/2])*tmpd[n][l] - __high2float(dmA[n][l/2])*tmpm[n][l];
- }
- }
- }
-#else
- GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
- NO_DEVICE_CODE;
-#endif // INT8_MMA_AVAILABLE
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
#ifdef INT8_MMA_AVAILABLE
int * x_qs = (int *) x_tile;
half2 * x_dm = (half2 *) (x_qs + WARP_SIZE*2);
- int * x_sc = (int *) (x_dm + WARP_SIZE/QI5_K);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
int * x_qs = (int *) x_tile;
int * x_sc = (int *) (x_dm + txs.dm);
#endif // INT8_MMA_AVAILABLE
- const int kbx = 0; // threadIdx.x / QI5_K
- const int kqsx = threadIdx.x; // threadIdx.x % QI5_K
-
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
int i = i0 + threadIdx.y;
i = min(i, i_max);
}
- const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbx;
- const int ky = QR5_K*kqsx;
+ const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
+ const int ky = QR5_K*threadIdx.x;
- const int ql = get_int_b4(bxi->qs, kqsx);
+ const int ql = get_int_b4(bxi->qs, threadIdx.x);
const int ql0 = (ql >> 0) & 0x0F0F0F0F;
const int ql1 = (ql >> 4) & 0x0F0F0F0F;
- const int qh = get_int_b4(bxi->qh, kqsx % (QI5_K/4));
- const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010;
- const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010;
+ const int qh = get_int_b4(bxi->qh, threadIdx.x % (QI5_K/4));
+ const int qh0 = ((qh >> (2 * (threadIdx.x / (QI5_K/4)) + 0)) << 4) & 0x10101010;
+ const int qh1 = ((qh >> (2 * (threadIdx.x / (QI5_K/4)) + 1)) << 4) & 0x10101010;
const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0;
- const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + (QI5_K/4);
+ const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + QI5_K/4;
#ifdef INT8_MMA_AVAILABLE
- x_qs[i*MMQ_MMA_TILE_X_K_Q5_K + kq0] = ql0 | qh0;
- x_qs[i*MMQ_MMA_TILE_X_K_Q5_K + kq1] = ql1 | qh1;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1;
#else
x_qs[i*(2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
x_qs[i*(2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
#endif // INT8_MMA_AVAILABLE
}
- const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256
- const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
+#ifdef INT8_MMA_AVAILABLE
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) {
- int i = (i0 + threadIdx.y * QI5_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) {
+ int i = (i0 + threadIdx.y*16 + threadIdx.x/(WARP_SIZE/16)) % mmq_y;
if (need_check) {
i = min(i, i_max);
}
- const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbxd;
+ const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
+
+ const int * scales = (const int *) bxi->scales;
+ const int ksc = threadIdx.x % (WARP_SIZE/16);
+
+ const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
+ const int m32 = unpack_scales_q45_K(scales, ksc + 2);
+
+ const uint8_t * sc8 = (const uint8_t *) &sc32;
+ const uint8_t * m8 = (const uint8_t *) &m32;
+
+ const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
+
+#pragma unroll
+ for (int l = 0; l < sizeof(int); ++l) {
+ x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
+ }
+ }
-#ifdef INT8_MMA_AVAILABLE
- x_dm[i*MMQ_MMA_TILE_X_K_Q5_K + kbxd] = bxi->dm;
#else
- x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K + kbxd] = bxi->dm;
-#endif // INT8_MMA_AVAILABLE
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*QI5_K) {
+ int i = (i0 + threadIdx.y*QI5_K + threadIdx.x) % mmq_y;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
+
+ x_dm[i] = bxi->dm;
}
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
- int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*8) {
+ int i = (i0 + threadIdx.y*8 + threadIdx.x/(WARP_SIZE/8)) % mmq_y;
if (need_check) {
i = min(i, i_max);
}
- const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / (QI5_K/8);
+ const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
const int * scales = (const int *) bxi->scales;
const int ksc = threadIdx.x % (WARP_SIZE/8);
+ const int scales8 = unpack_scales_q45_K(scales, ksc);
- // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
- int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
- scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits
-
-#ifdef INT8_MMA_AVAILABLE
- x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + ksc] = scales8;
-#else
- x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
-#endif // INT8_MMA_AVAILABLE
+ x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
}
+#endif // INT8_MMA_AVAILABLE
}
template <int mmq_x, int mmq_y, int nwarps>
// #pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) {
- const int k0 = k00 + k01;
-
-#pragma unroll
- for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
- const int j = j0 + threadIdx.y;
-
-#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
- const int i = i0 + threadIdx.x;
-
- const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k00/32]) + 2*(k01/16);
-
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq(
- &x_qs[i*(QR5_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
- x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
- }
- }
- }
-}
-
-template <int mmq_x, int mmq_y, int nwarps>
-static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
-#ifdef INT8_MMA_AVAILABLE
-
- typedef mma_int_A_I16K8 mma_A;
- typedef mma_int_B_J8K8 mma_B;
- typedef mma_int_C_I16J8 mma_C;
-
- constexpr int granularity = mmq_get_granularity_device(mmq_x);
- constexpr int rows_per_warp = 2 * granularity;
- constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
-
- y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
-
- const int * x_qs = (const int *) x;
- const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2;
- const int * x_sc = (const int *) x_dm + WARP_SIZE/QI5_K;
- const int * y_qs = (const int *) y + 4;
- const half2 * y_ds = (const half2 *) y;
-
- const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
-
- mma_A A[ntx][4];
- int scA[ntx][mma_C::ne/2][4];
- int mA[ntx][mma_C::ne/2][4];
- half2 dmA[ntx][mma_C::ne/2];
-
-#pragma unroll
- for (int n = 0; n < ntx; ++n) {
-#pragma unroll
- for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
- const int k0 = k00 + k01;
-
- A[n][k01/8].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_K + k0, MMQ_MMA_TILE_X_K_Q5_K);
- }
-
-#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
-
- const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + (k00/32 + 0)];
- const int m_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + (k00/32 + 2)];
-
- const uint8_t * sc = (const uint8_t *) &sc_packed;
- const uint8_t * m = (const uint8_t *) &m_packed;
-
-#pragma unroll
- for (int ksc = 0; ksc < sizeof(int); ++ksc) {
- scA[n][l][ksc] = sc[ksc];
- mA[n][l][ksc] = m[ksc];
- }
- }
-
- #pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
-
- dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q5_K];
- }
- }
-
-#pragma unroll
- for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
- float tmpd[ntx][mma_C::ne] = {{0.0f}};
- float tmpm[ntx][mma_C::ne] = {{0.0f}};
-
-#pragma unroll
- for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
- const int k0 = k00 + k01;
-
- mma_B B;
- half2 dsB[mma_C::ne/2];
-
- B.load(y_qs + j0*MMQ_TILE_Y_K + k0 % WARP_SIZE, MMQ_TILE_Y_K);
-
-#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int j = j0 + mma_C::get_j(l);
-
- dsB[l] = y_ds[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)];
- }
+ const int k0 = k00 + k01;
#pragma unroll
- for (int n = 0; n < ntx; ++n) {
- mma_C C;
- C.mma_K8(A[n][k01/8], B);
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
#pragma unroll
- for (int l = 0; l < mma_C::ne; ++l) {
- tmpd[n][l] += (C.x[l]*scA[n][l/2][k01/8]) * __low2float(dsB[l%2]);
- tmpm[n][l] += mA[n][l/2][k01/8] * __high2float(dsB[l%2]);
- }
- }
- }
+ for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
-#pragma unroll
- for (int n = 0; n < ntx; ++n) {
-#pragma unroll
- for (int l = 0; l < mma_C::ne; ++l) {
- sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA[n][l/2])*tmpd[n][l] - __high2float(dmA[n][l/2])*tmpm[n][l];
+ const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k00/32]) + 2*(k01/16);
+
+ sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq(
+ &x_qs[i*(QR5_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
+ x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
}
}
}
-#else
- GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
- NO_DEVICE_CODE;
-#endif // INT8_MMA_AVAILABLE
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
int * x_sc = (int *) (x_df + txs.dm);
#endif // INT8_MMA_AVAILABLE
- const int kbx = 0; // threadIdx.x / QI6_K
- const int kqsx = threadIdx.x; // threadIdx.x % QI6_K
-
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
int i = i0 + threadIdx.y;
i = min(i, i_max);
}
- const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbx;
- const int ky = QR6_K*kqsx;
+ const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;
- const int ql = get_int_b2(bxi->ql, kqsx);
+ const int ql = get_int_b2(bxi->ql, threadIdx.x);
const int ql0 = (ql >> 0) & 0x0F0F0F0F;
const int ql1 = (ql >> 4) & 0x0F0F0F0F;
- const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4));
- const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030;
- const int qh1 = (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) & 0x30303030;
+ const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (threadIdx.x / (QI6_K/2)) + threadIdx.x % (QI6_K/4));
+ const int qh0 = ((qh >> ((threadIdx.x & 0x08) >> 2)) << 4) & 0x30303030;
+ const int qh1 = (qh >> ((threadIdx.x & 0x08) >> 2)) & 0x30303030;
- const int kq0 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + 0;
- const int kq1 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + (QI6_K/2);
+ const int kq0 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + 0;
+ const int kq1 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + QI6_K/2;
#ifdef INT8_MMA_AVAILABLE
x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
}
}
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xxs(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+ const int kqsx = threadIdx.x % (QI2_XXS/2);
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_XXS/2)) {
+ int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_XXS) + threadIdx.x/(QI2_XXS/2);
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_iq2_xxs * bxi = (const block_iq2_xxs *) x + kbx0 + i*stride;
+
+ const int q2 = get_int_b2(bxi->qs, 2*kqsx+0);
+ const uint8_t * aux8 = (const uint8_t *) &q2;
+ const uint32_t aux32 = get_int_b2(bxi->qs, 2*kqsx+1);
+
+#pragma unroll
+ for (int l = 0; l < QR2_XXS; ++l) {
+ const int * grid_pos = (const int *) (iq2xxs_grid + aux8[l]);
+ const int signs_packed = ksigns_iq2xs[(aux32 >> (7*l)) & 0x7F];
+
+ const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000);
+ const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0);
+
+ const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
+ const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
+
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1;
+#else
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid0;
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid1;
+#endif // INT8_MMA_AVAILABLE
+ }
+
+ const int ls = aux32 >> 28;
+ const float d = bxi->d;
+#ifdef INT8_MMA_AVAILABLE
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
+#else
+ x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = (ls*d + d/2)/4;
+#endif // INT8_MMA_AVAILABLE
+ }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xs(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+ constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+ const int kqsx = threadIdx.x % (QI2_XS/2);
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_XS/2)) {
+ int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_XS) + threadIdx.x/(QI2_XS/2);
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_iq2_xs * bxi = (const block_iq2_xs *) x + kbx0 + i*stride;
+
+ const int2 q2_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1));
+ const uint16_t * q2 = (const uint16_t *) &q2_packed;
+
+ #pragma unroll
+ for (int l = 0; l < QR2_XS; ++l) {
+ const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l] & 0x000001FF));
+ const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9));
+
+ const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
+ const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
+
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
+#else
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l;
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h;
+#endif // INT8_MMA_AVAILABLE
+ }
+
+ const int ls = bxi->scales[kqsx];
+ const float d = bxi->d;
+#ifdef INT8_MMA_AVAILABLE
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
+#else
+ x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
+ x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
+#endif // INT8_MMA_AVAILABLE
+ }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_s(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+ const int kqsx = threadIdx.x % (QI2_S/2);
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_S/2)) {
+ int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_S) + threadIdx.x/(QI2_S/2);
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_iq2_s * bxi = (const block_iq2_s *) x + kbx0 + i*stride;
+
+ const int qs_packed = get_int_b2(bxi->qs, kqsx);
+ const uint8_t * qs = (const uint8_t *) &qs_packed;
+
+ const int qh = bxi->qh[kqsx];
+
+ const int signs_packed_32 = get_int_b2(bxi->qs, QK_K/32 + kqsx);
+ const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32;
+
+#pragma unroll
+ for (int l = 0; l < QR2_S; ++l) {
+ const int * grid_pos = (const int *)(iq2s_grid + (qs[l] | ((qh << (8-2*l)) & 0x300)));
+
+ const int signs0 = __vcmpne4(((signs_packed_8[l] & 0x03) << 7) | ((signs_packed_8[l] & 0x0C) << 21), 0x00000000);
+ const int signs1 = __vcmpne4(((signs_packed_8[l] & 0x30) << 3) | ((signs_packed_8[l] & 0xC0) << 17), 0x00000000);
+
+ const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0);
+ const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1);
+
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
+#else
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l;
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h;
+#endif // INT8_MMA_AVAILABLE
+ }
+
+ const int ls = bxi->scales[kqsx];
+ const float d = bxi->d;
+#ifdef INT8_MMA_AVAILABLE
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
+#else
+ x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
+ x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
+#endif // INT8_MMA_AVAILABLE
+ }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_xxs(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+ const int kqsx = threadIdx.x % (QI3_XXS/2);
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI3_XXS/2)) {
+ int i = i0 + threadIdx.y*(2*WARP_SIZE/QI3_XXS) + threadIdx.x/(QI3_XXS/2);
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_iq3_xxs * bxi = (const block_iq3_xxs *) x + kbx0 + i*stride;
+
+ const int2 q3_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1));
+ const uint8_t * q3 = (const uint8_t *) &q3_packed;
+ const uint32_t aux32 = get_int_b2(bxi->qs, QK_K/16 + kqsx);
+
+#pragma unroll
+ for (int l = 0; l < QR3_XXS; ++l) {
+ const int2 grid_pos = make_int2(iq3xxs_grid[q3[2*l+0]], iq3xxs_grid[q3[2*l+1]]);
+
+ const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l)) & 0x7F));
+
+ const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
+ const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
+
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h;
+#else
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l;
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h;
+#endif // INT8_MMA_AVAILABLE
+ }
+
+ const int ls = aux32 >> 28;
+ const float d = bxi->d;
+#ifdef INT8_MMA_AVAILABLE
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2;
+#else
+ x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = (ls*d + d/2)/2;
+#endif // INT8_MMA_AVAILABLE
+ }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_s(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+ const int kqsx = threadIdx.x % (QI3_S/2);
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI3_S/2)) {
+ int i = i0 + threadIdx.y*(2*WARP_SIZE/QI3_S) + threadIdx.x/(QI3_S/2);
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_iq3_s * bxi = (const block_iq3_s *) x + kbx0 + i*stride;
+
+ const int2 qs_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1));
+ const uint8_t * qs = (const uint8_t *) &qs_packed;
+
+ const int qh = bxi->qh[kqsx];
+
+ const int signs_packed_32 = get_int_b2(bxi->signs, kqsx);
+ const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32;
+
+#pragma unroll
+ for (int l = 0; l < QR3_S; ++l) {
+ const int2 grid_pos = make_int2(
+ iq3s_grid[qs[2*l+0] | ((qh << (8 - 2*l)) & 0x100)],
+ iq3s_grid[qs[2*l+1] | ((qh << (7 - 2*l)) & 0x100)]);
+
+ const int signs0 = __vcmpne4(((signs_packed_8[l] & 0x03) << 7) | ((signs_packed_8[l] & 0x0C) << 21), 0x00000000);
+ const int signs1 = __vcmpne4(((signs_packed_8[l] & 0x30) << 3) | ((signs_packed_8[l] & 0xC0) << 17), 0x00000000);
+
+ const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
+ const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
+
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h;
+#else
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+0)] = grid_l;
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+1)] = grid_h;
+#endif // INT8_MMA_AVAILABLE
+ }
+
+ const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F);
+ const float d = bxi->d;
+#ifdef INT8_MMA_AVAILABLE
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d;
+#else
+ x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = ls*d;
+#endif // INT8_MMA_AVAILABLE
+ }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq1_s(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ half2 * x_ds = (half2 *) (x_qs + WARP_SIZE*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
+ int * x_qs = (int *) x_tile;
+ half2 * x_ds = (half2 *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+ const int kqsx = threadIdx.x % QI1_S;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI1_S) {
+ int i = i0 + threadIdx.y*(WARP_SIZE/QI1_S) + threadIdx.x/QI1_S;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_iq1_s * bxi = (const block_iq1_s *) x + kbx0 + i*stride;
+
+ const int qs_packed = get_int_b2(bxi->qs, kqsx);
+ const uint8_t * qs = (const uint8_t *) &qs_packed;
+
+ const int qh = bxi->qh[kqsx];
+
+ #pragma unroll
+ for (int l = 0; l < QR1_S/2; ++l) {
+ const int grid = iq1s_grid_gpu[qs[l] | (((qh >> (3*l)) & 0x07) << 8)];
+
+ const int grid0 = (grid >> 0) & 0x0F0F0F0F;
+ const int grid1 = (grid >> 4) & 0x0F0F0F0F;
+
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1;
+#else
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+0)] = grid0;
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+1)] = grid1;
+#endif // INT8_MMA_AVAILABLE
+ }
+
+ const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1);
+ const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
+
+#ifdef INT8_MMA_AVAILABLE
+ x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta);
+#else
+ x_ds[i*(WARP_SIZE/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta);
+#endif // INT8_MMA_AVAILABLE
+ }
+}
+
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_xs(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, nwarps, need_check>;
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q4_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_DS4>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, nwarps, need_check>;
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q4_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>;
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> {
static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, nwarps, need_check>;
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q3_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q4_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_XXS> {
+ static constexpr int vdr = VDR_IQ2_XXS_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xxs<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_XS> {
+ static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xs<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_S> {
+ static constexpr int vdr = VDR_IQ2_S_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_s<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ3_XXS> {
+ static constexpr int vdr = VDR_IQ3_XXS_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_xxs<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ3_S> {
+ static constexpr int vdr = VDR_IQ3_S_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_s<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ1_S> {
+ static constexpr int vdr = VDR_IQ1_S_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_NL> {
static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl<mmq_y, nwarps, need_check>;
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_XS> {
static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs<mmq_y, nwarps, need_check>;
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};
extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
extern DECL_MMQ_CASE(GGML_TYPE_Q5_K);
extern DECL_MMQ_CASE(GGML_TYPE_Q6_K);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XS);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ2_S);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ3_S);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S);
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);