const int iqs4 = k_KQ % QI4_0;
const int shift = k_KQ & (QI8_1/2);
- const int v = (get_int_from_uint8(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
+ const int v = (get_int_b2(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
const int u = Q_q8[k_KQ_0/WARP_SIZE];
const int sumi = ggml_cuda_dp4a(v, u, 0);
const int iqs4 = k_KQ % QI4_1;
const int shift = k_KQ & (QI8_1/2);
- const int v = (get_int_from_uint8_aligned(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
+ const int v = (get_int_b4(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
const int u = Q_q8[k_KQ_0/WARP_SIZE];
const int sumi = ggml_cuda_dp4a(v, u, 0);
const int iqs8 = k_KQ % QI8_1;
const int shift = k_KQ & (QI8_1/2);
- int v = (get_int_from_uint8(K_q5_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
- const int vh = get_int_from_uint8(K_q5_0[ib].qh, 0) >> (iqs8 * QI5_0);
+ int v = (get_int_b2(K_q5_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
+ const int vh = get_int_b2(K_q5_0[ib].qh, 0) >> (iqs8 * QI5_0);
v |= (vh << 4) & 0x00000010; // 0 -> 4
v |= (vh << 11) & 0x00001000; // 1 -> 12
v |= (vh << 18) & 0x00100000; // 2 -> 20
const int iqs8 = k_KQ % QI8_1;
const int shift = k_KQ & (QI8_1/2);
- int v = (get_int_from_uint8(K_q5_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
- const int vh = get_int_from_uint8(K_q5_1[ib].qh, 0) >> (iqs8 * QI5_1);
+ int v = (get_int_b2(K_q5_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
+ const int vh = get_int_b2(K_q5_1[ib].qh, 0) >> (iqs8 * QI5_1);
v |= (vh << 4) & 0x00000010; // 0 -> 4
v |= (vh << 11) & 0x00001000; // 1 -> 12
v |= (vh << 18) & 0x00100000; // 2 -> 20
const int ib = k_KQ / QI8_0;
const int iqs = k_KQ % QI8_0;
- const int v = get_int_from_int8(K_q8_0[ib].qs, iqs);
+ const int v = get_int_b2(K_q8_0[ib].qs, iqs);
T Q_d;
if (std::is_same<T, half>::value) {
const T d = x[ib].d;
const int ql0 = x[ib].qs[iqs];
- const int qh0 = get_int_from_uint8(x[ib].qh, 0);
+ const int qh0 = get_int_b2(x[ib].qh, 0);
const int ql = ((ql0 >> (4*shift)) & 0x0F);
const int qh = ((qh0 >> idq) << 4) & 0x10;
const int q = (ql | qh) - 16;
const half2 dm = x[ib].dm;
const int ql0 = x[ib].qs[iqs];
- const int qh0 = get_int_from_uint8_aligned(x[ib].qh, 0);
+ const int qh0 = get_int_b4(x[ib].qh, 0);
const int ql = ((ql0 >> (4*shift)) & 0x0F);
const int qh = ((qh0 >> idq) << 4) & 0x10;
const int q = (ql | qh);
static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
- type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 :
- type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q5_0 :
- type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q5_1 :
- type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 :
- type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K :
- type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K :
- type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
- type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
- type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
+ type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 :
+ type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q5_0 :
+ type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q5_1 :
+ type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 :
+ type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K :
+ type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K :
+ type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
+ type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
+ type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
+ type == GGML_TYPE_IQ4_XS ? MMQ_DP4A_TXS_Q5_0 :
+ type == GGML_TYPE_IQ4_NL ? MMQ_DP4A_TXS_Q5_0 :
tile_x_sizes{0, 0, 0};
}
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q4_0 :
- type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q4_1 :
- type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q5_0 :
- type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q5_1 :
- type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
- type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K :
- type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K :
- type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q4_K :
- type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q5_K :
- type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
+ type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q4_1 :
+ type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q5_0 :
+ type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q5_1 :
+ type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
+ type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K :
+ type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K :
+ type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q4_K :
+ type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q5_K :
+ type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
+ type == GGML_TYPE_IQ4_XS ? MMQ_MMA_TILE_X_K_Q5_0 :
+ type == GGML_TYPE_IQ4_NL ? MMQ_MMA_TILE_X_K_Q5_0 :
0;
}
const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
#ifdef INT8_MMA_AVAILABLE
- x_qs[i*MMQ_MMA_TILE_X_K_Q4_0 + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx);
+ x_qs[i*MMQ_MMA_TILE_X_K_Q4_0 + threadIdx.x] = get_int_b2(bxi->qs, kqsx);
#else
- x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx);
+ x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_b2(bxi->qs, kqsx);
#endif // INT8_MMA_AVAILABLE
}
const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
#ifdef INT8_MMA_AVAILABLE
- x_qs[i*MMQ_MMA_TILE_X_K_Q4_1 + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+ x_qs[i*MMQ_MMA_TILE_X_K_Q4_1 + threadIdx.x] = get_int_b4(bxi->qs, kqsx);
#else
- x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+ x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_b4(bxi->qs, kqsx);
#endif // INT8_MMA_AVAILABLE
}
const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx;
- const int ql = get_int_from_uint8(bxi->qs, kqsx);
- const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_0));
+ const int ql = get_int_b2(bxi->qs, kqsx);
+ const int qh = get_int_b2(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_0));
int qs0 = (ql >> 0) & 0x0F0F0F0F;
qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx;
- const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
- const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_1));
+ const int ql = get_int_b4(bxi->qs, kqsx);
+ const int qh = get_int_b4(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_1));
int qs0 = (ql >> 0) & 0x0F0F0F0F;
qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
#ifdef INT8_MMA_AVAILABLE
- x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx);
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x] = get_int_b2(bxi->qs, kqsx);
#else
- x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx);
+ x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_b2(bxi->qs, kqsx);
#endif // INT8_MMA_AVAILABLE
}
const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + kbx;
- const int x_ql_0 = get_int_from_uint8(bxi->qs, kqsx);
+ const int x_ql_0 = get_int_b2(bxi->qs, kqsx);
#pragma unroll
for (int l = 0; l < QR2_K; ++l) {
const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbx;
- const int x_ql_0 = get_int_from_uint8(bxi->qs, kqsx);
- const int x_qh_0 = get_int_from_uint8(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2)));
+ const int x_ql_0 = get_int_b2(bxi->qs, kqsx);
+ const int x_qh_0 = get_int_b2(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2)));
#pragma unroll
for (int l = 0; l < QR3_K; ++l) {
const int ksc_low = ksc % (QI3_K/8);
const int shift_low = 4 * (ksc / (QI3_K/8));
- const int sc_low = (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
+ const int sc_low = (get_int_b2(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
const int ksc_high = QI3_K/8;
const int shift_high = 2 * ksc;
- const int sc_high = ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
+ const int sc_high = ((get_int_b2(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbx;
#ifdef INT8_MMA_AVAILABLE
- x_qs[i*MMQ_MMA_TILE_X_K_Q4_K + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+ x_qs[i*MMQ_MMA_TILE_X_K_Q4_K + threadIdx.x] = get_int_b4(bxi->qs, kqsx);
#else
- x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+ x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_b4(bxi->qs, kqsx);
#endif // INT8_MMA_AVAILABLE
}
const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbx;
const int ky = QR5_K*kqsx;
- const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
+ const int ql = get_int_b4(bxi->qs, kqsx);
const int ql0 = (ql >> 0) & 0x0F0F0F0F;
const int ql1 = (ql >> 4) & 0x0F0F0F0F;
- const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K/4));
+ const int qh = get_int_b4(bxi->qh, kqsx % (QI5_K/4));
const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010;
const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010;
const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbx;
const int ky = QR6_K*kqsx;
- const int ql = get_int_from_uint8(bxi->ql, kqsx);
+ const int ql = get_int_b2(bxi->ql, kqsx);
const int ql0 = (ql >> 0) & 0x0F0F0F0F;
const int ql1 = (ql >> 4) & 0x0F0F0F0F;
- const int qh = get_int_from_uint8(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4));
+ const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4));
const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030;
const int qh1 = (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) & 0x30303030;
const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4;
#ifdef INT8_MMA_AVAILABLE
- x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, threadIdx.x % (QI6_K/8));
+ x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8));
#else
- x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, threadIdx.x % (QI6_K/8));
+ x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8));
#endif // INT8_MMA_AVAILABLE
}
}
#endif // INT8_MMA_AVAILABLE
}
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+ const int kbx = threadIdx.x / QI4_NL;
+ const int kqsx = threadIdx.x % QI4_NL;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+ int i = i0 + threadIdx.y;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbx;
+
+ const int aux_q4 = get_int_b2(bxi->qs, kqsx);
+ const int2 v = get_int_from_table_16(aux_q4);
+ const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + k0 + 0] = v.x;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + k0 + 4] = v.y;
+#else
+ x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
+ x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
+#endif // INT8_MMA_AVAILABLE
+ }
+
+ const int blocks_per_tile_x_row = WARP_SIZE / QI4_NL;
+ const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_NL) {
+ int i = i0 + threadIdx.y * QI4_NL + threadIdx.x / blocks_per_tile_x_row;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
+
+#ifdef INT8_MMA_AVAILABLE
+ x_df[i*MMQ_MMA_TILE_X_K_Q5_0 + kbxd] = __half2float(bxi->d);
+#else
+ x_df[i*(WARP_SIZE/4) + i/4 + kbxd] = __half2float(bxi->d);
+#endif // INT8_MMA_AVAILABLE
+ }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_xs(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+ const int kbx = 0; // threadIdx.x / QI4_XS
+ const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+ int i = i0 + threadIdx.y;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride + kbx;
+
+ const int aux_q4 = get_int_b4(bxi->qs, kqsx);
+ const int2 v = get_int_from_table_16(aux_q4);
+ const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + k0 + 0] = v.x;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + k0 + 4] = v.y;
+#else
+ x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
+ x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
+#endif // INT8_MMA_AVAILABLE
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
+ int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride;
+
+ const float d = __half2float(bxi->d);
+
+ const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F)
+ | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
+
+#ifdef INT8_MMA_AVAILABLE
+ x_df[i*MMQ_MMA_TILE_X_K_Q5_0 + threadIdx.x % 8] = d * (ls - 32);
+#else
+ x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * (ls - 32);
+#endif // INT8_MMA_AVAILABLE
+ }
+}
+
template<int mmq_x, int mmq_y, int nwarps, bool need_check>
static __device__ __forceinline__ void mmq_write_back_dp4a(
const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_NL> {
+ static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_XS> {
+ static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
static bool mmq_need_sum(const ggml_type type_x) {
switch (type_x) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q5_K:
return true;
case GGML_TYPE_Q6_K:
+ case GGML_TYPE_IQ4_XS:
+ case GGML_TYPE_IQ4_NL:
return false;
default:
GGML_ASSERT(false);
extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
extern DECL_MMQ_CASE(GGML_TYPE_Q5_K);
extern DECL_MMQ_CASE(GGML_TYPE_Q6_K);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
// -------------------------------------------------------------------------------------------------------------------------
#include "common.cuh"
#include <cstdint>
-static __device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const int & i32) {
- const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
-
- int x32 = 0;
- x32 |= x16[0] << 0;
- x32 |= x16[1] << 16;
-
- return x32;
-}
-
-static __device__ __forceinline__ int get_int_from_uint8(const uint8_t * x8, const int & i32) {
- const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
-
- int x32 = 0;
- x32 |= x16[0] << 0;
- x32 |= x16[1] << 16;
-
- return x32;
-}
-
-static __device__ __forceinline__ int get_int_from_int8_aligned(const int8_t * x8, const int & i32) {
- return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
-}
-
-static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t * x8, const int & i32) {
- return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
-}
-
static __device__ __forceinline__ int get_int_b2(const void * x, const int & i32) {
- const uint16_t * x16 = (const uint16_t *) x;
+ const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment
int x32 = x16[2*i32 + 0] << 0;
x32 |= x16[2*i32 + 1] << 16;
}
#define VDR_IQ2_XXS_Q8_1_MMVQ 2
+#define VDR_IQ2_XXS_Q8_1_MMQ 2
static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
}
#define VDR_IQ2_XS_Q8_1_MMVQ 2
+#define VDR_IQ2_XS_Q8_1_MMQ 2
static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
}
#define VDR_IQ2_S_Q8_1_MMVQ 2
+#define VDR_IQ2_S_Q8_1_MMQ 2
static __device__ __forceinline__ float vec_dot_iq2_s_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
}
#define VDR_IQ3_XXS_Q8_1_MMVQ 2
+#define VDR_IQ3_XXS_Q8_1_MMQ 2
static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
}
#define VDR_IQ3_S_Q8_1_MMVQ 2
+#define VDR_IQ3_S_Q8_1_MMQ 2
// TODO: don't use lookup table for signs
static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
return d * sumi;
}
+#define VDR_IQ1_S_Q8_1_MMVQ 1
+#define VDR_IQ1_S_Q8_1_MMQ 1
+
static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
const block_iq1_s * bq1 = (const block_iq1_s *) vbq + kbx;
return d1q * (ds.x*sumi + ds.y*delta);
}
+#define VDR_IQ1_M_Q8_1_MMVQ 1
+#define VDR_IQ1_M_Q8_1_MMQ 1
+
static __device__ __forceinline__ float vec_dot_iq1_m_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
}
#define VDR_IQ4_NL_Q8_1_MMVQ 2
+#define VDR_IQ4_NL_Q8_1_MMQ 4
static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
}
#define VDR_IQ4_XS_Q8_1_MMVQ 4
+#define VDR_IQ4_XS_Q8_1_MMQ 4
static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {