#include <cstdint>
#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
+#define MMQ_ITER_K 256
+#define MMQ_NWARPS 8
typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int & kbx0, const int & i_max, const int & stride);
-typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0);
+typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00);
typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max);
+enum mmq_q8_1_ds_layout {
+ MMQ_Q8_1_DS_LAYOUT_D4,
+ MMQ_Q8_1_DS_LAYOUT_DS4,
+ MMQ_Q8_1_DS_LAYOUT_D2S6,
+};
+
struct block_q8_1_mmq {
- half2 ds[4];
- int8_t qs[4*QK8_1];
+ // The y float data is converted to a data layout that can simply be copied to shared memory as a contiguous block.
+ // The y float data is first grouped as blocks of 128 values.
+ // These blocks are then treated as individual data values and transposed.
+ //
+ // To avoid shared memory bank conflicts each block is padded with 16 bytes.
+ // This padding is also used to store block scales/partial sums.
+ // The scales multiplied with the quantized data are equal to the unquantized values.
+ // The partial sums are obtained by summing up a subgroup of the contained values (prior to quantization)
+ // and are only needed for performance reasons.
+ //
+ // The exact data stored depends on the x data type.
+ union {
+ float d4[4]; // 1 32 bit scale per 32 values, stored as d0,d1,d2,d3
+ half2 ds4[4]; // 1 16 bit scale + 1 16 bit partial sum per 32 values, stored as d0,s0,d1,s1,d2,s2,d3,s3
+ half d2s6[8]; // 1 16 bit scale per 64 values + 1 16 bit partial sum per 16 values for the first 96 values,
+ // stored as d0,d1,s1,s2,s3,s4,s5
+ };
+ int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each
};
static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size");
+static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
+ switch (type_x) {
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ return MMQ_Q8_1_DS_LAYOUT_DS4;
+ case GGML_TYPE_Q5_0:
+ return MMQ_Q8_1_DS_LAYOUT_D4;
+ case GGML_TYPE_Q5_1:
+ return MMQ_Q8_1_DS_LAYOUT_DS4;
+ case GGML_TYPE_Q8_0:
+ return MMQ_Q8_1_DS_LAYOUT_D4;
+ case GGML_TYPE_Q2_K:
+ return MMQ_Q8_1_DS_LAYOUT_D2S6;
+ case GGML_TYPE_Q3_K:
+ return MMQ_Q8_1_DS_LAYOUT_D4;
+ case GGML_TYPE_Q4_K:
+ case GGML_TYPE_Q5_K:
+ return MMQ_Q8_1_DS_LAYOUT_DS4;
+ case GGML_TYPE_Q6_K:
+ case GGML_TYPE_IQ4_XS:
+ case GGML_TYPE_IQ4_NL:
+ return MMQ_Q8_1_DS_LAYOUT_D4;
+ default:
+ GGML_ASSERT(false);
+ break;
+ }
+}
+
struct tile_x_sizes {
int qs;
int dm;
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
}
-#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
-#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0}
-#define MMQ_DP4A_TXS_Q5_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_0 + mmq_y/QI5_0, 0}
-#define MMQ_DP4A_TXS_Q5_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_1 + mmq_y/QI5_1, 0}
-#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI8_0 + mmq_y/QI8_0, 0}
-#define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE + mmq_y, 0}
-#define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI3_K + mmq_y/QI3_K, mmq_y*WARP_SIZE/4 + mmq_y/4}
-#define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
-#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
-#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
+#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
+#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0}
+#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_0 + mmq_y/(QI8_0/2), 0}
+#define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_1 + mmq_y/(QI8_1/2), 0}
+#define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE + mmq_y, 0}
+#define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y, mmq_y*WARP_SIZE/8 + mmq_y/8}
+#define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
+#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
+#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 :
- type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q5_0 :
- type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q5_1 :
+ type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q8_0 :
+ type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q8_1 :
type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 :
type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K :
type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K :
type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
- type == GGML_TYPE_IQ4_XS ? MMQ_DP4A_TXS_Q5_0 :
- type == GGML_TYPE_IQ4_NL ? MMQ_DP4A_TXS_Q5_0 :
+ type == GGML_TYPE_IQ4_XS ? MMQ_DP4A_TXS_Q8_0 :
+ type == GGML_TYPE_IQ4_NL ? MMQ_DP4A_TXS_Q8_0 :
tile_x_sizes{0, 0, 0};
}
-#define MMQ_MMA_TILE_X_K_Q4_0 (1*WARP_SIZE + WARP_SIZE/QI4_0 + 4)
-#define MMQ_MMA_TILE_X_K_Q4_1 (1*WARP_SIZE + WARP_SIZE/QI4_1 + 4)
-#define MMQ_MMA_TILE_X_K_Q5_0 (2*WARP_SIZE + WARP_SIZE/QI5_0 + 4)
-#define MMQ_MMA_TILE_X_K_Q5_1 (2*WARP_SIZE + WARP_SIZE/QI5_1 + 4)
-#define MMQ_MMA_TILE_X_K_Q8_0 (1*WARP_SIZE + WARP_SIZE/QI8_0 + 0)
-#define MMQ_MMA_TILE_X_K_Q2_K (1*WARP_SIZE + WARP_SIZE + 4)
-#define MMQ_MMA_TILE_X_K_Q3_K (2*WARP_SIZE + WARP_SIZE/QI3_K + WARP_SIZE/4 + 2)
-#define MMQ_MMA_TILE_X_K_Q4_K (1*WARP_SIZE + WARP_SIZE/QI4_K + WARP_SIZE/8 + 7)
-#define MMQ_MMA_TILE_X_K_Q5_K (2*WARP_SIZE + WARP_SIZE/QI5_K + WARP_SIZE/8 + 7)
-#define MMQ_MMA_TILE_X_K_Q6_K (2*WARP_SIZE + WARP_SIZE/QI6_K + WARP_SIZE/8 + 7)
+#define MMQ_MMA_TILE_X_K_Q4_0 (1*WARP_SIZE + WARP_SIZE/QI4_0 + 4)
+#define MMQ_MMA_TILE_X_K_Q4_1 (1*WARP_SIZE + WARP_SIZE/QI4_1 + 4)
+#define MMQ_MMA_TILE_X_K_Q8_0 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4)
+#define MMQ_MMA_TILE_X_K_Q8_1 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4)
+#define MMQ_MMA_TILE_X_K_Q2_K (2*WARP_SIZE + WARP_SIZE + 4)
+#define MMQ_MMA_TILE_X_K_Q3_K (2*WARP_SIZE + WARP_SIZE/(2*QI3_K) + WARP_SIZE/8 + 7)
+#define MMQ_MMA_TILE_X_K_Q4_K (1*WARP_SIZE + WARP_SIZE/QI4_K + WARP_SIZE/8 + 7)
+#define MMQ_MMA_TILE_X_K_Q5_K (2*WARP_SIZE + WARP_SIZE/QI5_K + WARP_SIZE/8 + 7)
+#define MMQ_MMA_TILE_X_K_Q6_K (2*WARP_SIZE + WARP_SIZE/QI6_K + WARP_SIZE/8 + 7)
static_assert(MMQ_MMA_TILE_X_K_Q4_0 % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_Q4_1 % 8 == 4, "Wrong padding.");
-static_assert(MMQ_MMA_TILE_X_K_Q5_0 % 8 == 4, "Wrong padding.");
-static_assert(MMQ_MMA_TILE_X_K_Q5_1 % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_Q4_K % 8 == 4, "Wrong padding.");
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q4_0 :
type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q4_1 :
- type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q5_0 :
- type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q5_1 :
+ type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
+ type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K :
type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K :
type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q4_K :
type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q5_K :
type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
- type == GGML_TYPE_IQ4_XS ? MMQ_MMA_TILE_X_K_Q5_0 :
- type == GGML_TYPE_IQ4_NL ? MMQ_MMA_TILE_X_K_Q5_0 :
+ type == GGML_TYPE_IQ4_XS ? MMQ_MMA_TILE_X_K_Q8_0 :
+ type == GGML_TYPE_IQ4_NL ? MMQ_MMA_TILE_X_K_Q8_0 :
0;
}
#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
-#define MMQ_NWARPS 8
static int mmq_get_granularity_host(const int mmq_x, const int cc) {
return int8_mma_available(cc) && mmq_x >= 48 ? 16 : 8;
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
const int * x_qs = (const int *) x;
const int * y_qs = (const int *) y + 4;
const half2 * y_ds = (const half2 *) y;
+// #pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_0*VDR_Q4_0_Q8_1_MMQ) {
+ const int k0 = k00 + k01;
+
#pragma unroll
- for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
- const int j = j0 + threadIdx.y;
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
- const int i = i0 + threadIdx.x;
+ for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
- const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
+ const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
- int u[2*VDR_Q4_0_Q8_1_MMQ];
+ int u[2*VDR_Q4_0_Q8_1_MMQ];
#pragma unroll
- for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
- u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE];
- u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI4_0) % WARP_SIZE];
- }
+ for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
+ u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l];
+ u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)];
+ }
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
- (&x_qs[i*(WARP_SIZE + 1) + k0], u, x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0],
- y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
+ sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
+ (&x_qs[i*(WARP_SIZE + 1) + k0/QR4_0], u,
+ x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+ }
}
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma(
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
#ifdef INT8_MMA_AVAILABLE
typedef mma_int_A_I16K8 mma_A;
const int * y_qs = (const int *) y + 4;
const half2 * y_ds = (const half2 *) y;
- mma_A A[ntx];
- float dA[ntx][mma_C::ne/2];
+ mma_A A[ntx][4];
+ float dA[ntx][mma_C::ne/2][4];
const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
#pragma unroll
for (int n = 0; n < ntx; ++n) {
#pragma unroll
- for (int l = 0; l < mma_A::ne; ++l) {
- const int i = i0 + n*mma_A::I + mma_A::get_i(l);
- const int k = k0 + mma_A::get_k(l) % QI4_0;
- const int shift = 4*(mma_A::get_k(l) / QI4_0);
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_0*QI4_0) {
+ const int k0 = k00 + k01;
- A[n].x[l] = __vsubss4((x_qs[i*MMQ_MMA_TILE_X_K_Q4_0 + k] >> shift) & 0x0F0F0F0F, 0x08080808);
- }
+#pragma unroll
+ for (int l = 0; l < mma_A::ne; ++l) {
+ const int i = i0 + n*mma_A::I + mma_A::get_i(l);
+ const int k = k0/QR4_0 + mma_A::get_k(l) % QI4_0;
+ const int shift = 4*(mma_A::get_k(l) / QI4_0);
+
+ A[n][k01/(QR4_0*QI4_0)].x[l] = __vsubss4((x_qs[i*MMQ_MMA_TILE_X_K_Q4_0 + k] >> shift) & 0x0F0F0F0F, 0x08080808);
+ }
#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
- dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q4_0 + k0/QI4_0];
+ dA[n][l][k01/(QR4_0*QI4_0)] = x_df[i*MMQ_MMA_TILE_X_K_Q4_0 + k0/(QR4_0*QI4_0)];
+ }
}
}
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
- mma_B B;
- float dB[mma_C::ne/2];
+#pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_0*QI4_0) {
+ mma_B B;
+ float dB[mma_C::ne/2];
- B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K);
+ B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int j = j0 + mma_C::get_j(l);
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int j = j0 + mma_C::get_j(l);
- dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
- }
+ dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+ }
#pragma unroll
- for (int n = 0; n < ntx; ++n) {
- mma_C C;
- C.mma_K8(A[n], B);
+ for (int n = 0; n < ntx; ++n) {
+ mma_C C;
+ C.mma_K8(A[n][k01/(QR4_0*QI4_0)], B);
#pragma unroll
- for (int l = 0; l < mma_C::ne; ++l) {
- sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2]*dB[l%2]*C.x[l];
+ for (int l = 0; l < mma_C::ne; ++l) {
+ sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2][k01/(QR4_0*QI4_0)]*dB[l%2]*C.x[l];
+ }
}
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
const int * x_qs = (const int *) x;
const int * y_qs = (const int *) y + 4;
const half2 * y_ds = (const half2 *) y;
+// #pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_1*VDR_Q4_1_Q8_1_MMQ) {
+ const int k0 = k00 + k01;
+
#pragma unroll
- for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
- const int j = j0 + threadIdx.y;
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
- const int i = i0 + threadIdx.x;
+ for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
- const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
+ const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
- int u[2*VDR_Q4_1_Q8_1_MMQ];
+ int u[2*VDR_Q4_1_Q8_1_MMQ];
#pragma unroll
- for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
- u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE];
- u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI4_1) % WARP_SIZE];
- }
+ for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
+ u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l];
+ u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)];
+ }
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
- (&x_qs[i*(WARP_SIZE + 1) + k0], u, x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + k0/QI4_1],
- y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
+ sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
+ (&x_qs[i*(WARP_SIZE + 1) + k0/QR4_1], u,
+ x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+ }
}
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma(
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
#ifdef INT8_MMA_AVAILABLE
typedef mma_int_A_I16K8 mma_A;
const int * y_qs = (const int *) y + 4;
const half2 * y_ds = (const half2 *) y;
- mma_A A[ntx];
- half2 dmA[ntx][mma_C::ne/2];
+ mma_A A[ntx][4];
+ half2 dmA[ntx][mma_C::ne/2][4];
const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
#pragma unroll
for (int n = 0; n < ntx; ++n) {
- ((mma_A_K4 *) &A[n])[0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_1 + k0, MMQ_MMA_TILE_X_K_Q4_1);
- A[n].x[2] = (A[n].x[0] >> 4) & 0x0F0F0F0F;
- A[n].x[3] = (A[n].x[1] >> 4) & 0x0F0F0F0F;
- A[n].x[0] &= 0x0F0F0F0F;
- A[n].x[1] &= 0x0F0F0F0F;
+#pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_1*QI4_1) {
+ const int k0 = k00 + k01;
+
+ A[n][k01/(QR4_1*QI4_1)].load_low(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_1 + k0/QR4_1, MMQ_MMA_TILE_X_K_Q4_1);
+ A[n][k01/(QR4_1*QI4_1)].x[2] = (A[n][k01/(QR4_1*QI4_1)].x[0] >> 4) & 0x0F0F0F0F;
+ A[n][k01/(QR4_1*QI4_1)].x[3] = (A[n][k01/(QR4_1*QI4_1)].x[1] >> 4) & 0x0F0F0F0F;
+ A[n][k01/(QR4_1*QI4_1)].x[0] &= 0x0F0F0F0F;
+ A[n][k01/(QR4_1*QI4_1)].x[1] &= 0x0F0F0F0F;
#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
- dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_1 + k0/QI4_1];
+ dmA[n][l][k01/(QR4_1*QI4_1)] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_1 + k0/(QR4_1*QI4_1)];
+ }
}
}
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
- mma_B B;
- half2 dsB[mma_C::ne/2];
+#pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_1*QI4_1) {
+ mma_B B;
+ half2 dsB[mma_C::ne/2];
- B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K);
+ B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int j = j0 + mma_C::get_j(l);
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int j = j0 + mma_C::get_j(l);
- dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
- }
+ dsB[l] = y_ds[j*MMQ_TILE_Y_K + k01/QI8_1];
+ }
#pragma unroll
- for (int n = 0; n < ntx; ++n) {
- mma_C C;
- C.mma_K8(A[n], B);
+ for (int n = 0; n < ntx; ++n) {
+ mma_C C;
+ C.mma_K8(A[n][k01/(QR4_1*QI4_1)], B);
#pragma unroll
- for (int l = 0; l < mma_C::ne; ++l) {
- const half2 dmA_dsB = dmA[n][l/2]*dsB[l%2];
- sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
+ for (int l = 0; l < mma_C::ne; ++l) {
+ const half2 dmA_dsB = dmA[n][l/2][k01/(QR4_1*QI4_1)]*dsB[l%2];
+ sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
+ }
}
}
}
qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
#ifdef INT8_MMA_AVAILABLE
- x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0;
- x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
#else
x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0;
x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
#ifdef INT8_MMA_AVAILABLE
- x_df[i*MMQ_MMA_TILE_X_K_Q5_0 + kbxd] = bxi->d;
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
#else
x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
#endif // INT8_MMA_AVAILABLE
}
}
-template <int mmq_x, int mmq_y, int nwarps>
-static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a(
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
-
- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
- const int * x_qs = (const int *) x;
- const float * x_df = (const float *) x_qs + txs.qs;
- const int * y_qs = (const int *) y + 4;
- const float * y_df = (const float *) y;
-
-#pragma unroll
- for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
- const int j = j0 + threadIdx.y;
-
-#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
- const int i = i0 + threadIdx.x;
-
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, QR5_0*VDR_Q5_0_Q8_1_MMQ>
- (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], &y_qs[j*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE],
- x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0], y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
- }
- }
-}
-
-template <int mmq_x, int mmq_y, int nwarps>
-static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma(
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
-#ifdef INT8_MMA_AVAILABLE
-
- typedef mma_int_A_I16K8 mma_A;
- typedef mma_int_B_J8K8 mma_B;
- typedef mma_int_C_I16J8 mma_C;
-
- constexpr int granularity = mmq_get_granularity_device(mmq_x);
- constexpr int rows_per_warp = 2 * granularity;
- constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
-
- y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
-
- const int * x_qs = (const int *) x;
- const float * x_df = (const float *) x_qs + WARP_SIZE*2;
- const int * y_qs = (const int *) y + 4;
- const float * y_df = (const float *) y;
-
- mma_A A[ntx];
- float dA[ntx][mma_C::ne/2];
-
- const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
-
-#pragma unroll
- for (int n = 0; n < ntx; ++n) {
- A[n].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_0 + QR5_1*k0, MMQ_MMA_TILE_X_K_Q5_0);
-
-#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int i = i0 + mma_C::get_i(2*l) + n*mma_C::I;
-
- dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q5_0 + k0/QI5_0];
- }
- }
-
-#pragma unroll
- for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
- mma_B B;
- float dB[mma_C::ne/2];
-
- B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K);
-
-#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int j = j0 + mma_C::get_j(l);
-
- dB[l] = y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
- }
-
-#pragma unroll
- for (int n = 0; n < ntx; ++n) {
- mma_C C;
- C.mma_K8(A[n], B);
-
-#pragma unroll
- for (int l = 0; l < mma_C::ne; ++l) {
- sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2]*dB[l%2]*C.x[l];
- }
- }
- }
-#else
- GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
- NO_DEVICE_CODE;
-#endif // INT8_MMA_AVAILABLE
-}
-
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
#ifdef INT8_MMA_AVAILABLE
- x_qs[i*MMQ_MMA_TILE_X_K_Q5_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0;
- x_qs[i*MMQ_MMA_TILE_X_K_Q5_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
#else
x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0;
x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
#ifdef INT8_MMA_AVAILABLE
- x_dm[i*MMQ_MMA_TILE_X_K_Q5_1 + kbxd] = bxi->dm;
+ x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
#else
x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
#endif // INT8_MMA_AVAILABLE
}
}
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_tile + 2*WARP_SIZE);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+ const int kbx = threadIdx.x / QI8_0;
+ const int kqsx = threadIdx.x % QI8_0;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+ int i = i0 + threadIdx.y;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
+
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx);
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx);
+#else
+ x_qs[i*(2*WARP_SIZE + 1) + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx);
+ x_qs[i*(2*WARP_SIZE + 1) + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx);
+#endif // INT8_MMA_AVAILABLE
+ }
+
+ const int blocks_per_tile_x_row = 2*WARP_SIZE / QI8_0;
+ const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0/2) {
+ int i = i0 + threadIdx.y * (QI8_0/2) + threadIdx.x / blocks_per_tile_x_row;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
+
+#ifdef INT8_MMA_AVAILABLE
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
+#else
+ x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;
+#endif // INT8_MMA_AVAILABLE
+ }
+}
+
template <int mmq_x, int mmq_y, int nwarps>
-static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a(
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
const int * x_qs = (const int *) x;
- const half2 * x_dm = (const half2 *) x_qs + txs.qs;
+ const float * x_df = (const float *) x_qs + txs.qs;
const int * y_qs = (const int *) y + 4;
- const half2 * y_ds = (const half2 *) y;
+ const float * y_df = (const float *) y;
+
+// #pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += VDR_Q8_0_Q8_1_MMQ) {
+ const int k0 = k00 + k01;
#pragma unroll
- for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
- const int j = j0 + threadIdx.y;
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
- const int i = i0 + threadIdx.x;
+ for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
- (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], &y_qs[j*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE],
- x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI5_1], y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
+ sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
+ (&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0 % WARP_SIZE],
+ x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)]);
+ }
}
}
}
template <int mmq_x, int mmq_y, int nwarps>
-static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma(
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
#ifdef INT8_MMA_AVAILABLE
typedef mma_int_A_I16K8 mma_A;
y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
const int * x_qs = (const int *) x;
- const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE;
+ const float * x_df = (const float *) x_qs + 2*WARP_SIZE;
const int * y_qs = (const int *) y + 4;
- const half2 * y_ds = (const half2 *) y;
+ const float * y_df = (const float *) y;
- mma_A A[ntx];
- half2 dmA[ntx][mma_C::ne/2];
+ mma_A A[ntx][WARP_SIZE/QI8_0];
+ float dA[ntx][mma_C::ne/2][WARP_SIZE/QI8_0];
- const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
+ const int i0 = (threadIdx.y/ntx)*rows_per_warp;
#pragma unroll
for (int n = 0; n < ntx; ++n) {
- A[n].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_1 + QR5_1*k0, MMQ_MMA_TILE_X_K_Q5_1);
-
#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int i = i0 + mma_C::get_i(2*l) + n*mma_C::I;
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
+ const int k0 = k00 + k01;
- dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q5_1 + k0/QI5_1];
+ A[n][k01/QI8_0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
}
- }
-
-#pragma unroll
- for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
- mma_B B;
- half2 dsB[mma_C::ne/2];
-
- B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K);
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
- const int j = j0 + mma_C::get_j(l);
-
- dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
- }
+ const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
#pragma unroll
- for (int n = 0; n < ntx; ++n) {
- mma_C C;
- C.mma_K8(A[n], B);
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
+ const int k0 = k00 + k01;
-#pragma unroll
- for (int l = 0; l < mma_C::ne; ++l) {
- const half2 dmA_dsB = dmA[n][l/2]*dsB[l%2];
- sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
+ dA[n][l][k01/QI8_0] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
}
}
}
-#else
- GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
- NO_DEVICE_CODE;
-#endif // INT8_MMA_AVAILABLE
-}
-template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
- const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+#pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
+ const int k0 = k00 + k01;
-#ifdef INT8_MMA_AVAILABLE
- int * x_qs = (int *) x_tile;
- float * x_df = (float *) (x_tile + WARP_SIZE);
-#else
- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
- int * x_qs = (int *) x_tile;
- float * x_df = (float *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+ mma_B B;
+ float dB[mma_C::ne/2];
- const int kbx = threadIdx.x / QI8_0;
- const int kqsx = threadIdx.x % QI8_0;
+ B.load(y_qs + j0*MMQ_TILE_Y_K + k0 % WARP_SIZE, MMQ_TILE_Y_K);
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
- int i = i0 + threadIdx.y;
-
- if (need_check) {
- i = min(i, i_max);
- }
-
- const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int j = j0 + mma_C::get_j(l);
-#ifdef INT8_MMA_AVAILABLE
- x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x] = get_int_b2(bxi->qs, kqsx);
-#else
- x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_b2(bxi->qs, kqsx);
-#endif // INT8_MMA_AVAILABLE
- }
+ dB[l] = y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)];
+ }
- const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
- const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ mma_C C;
+ C.mma_K8(A[n][k01/QI8_0], B);
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) {
- int i = i0 + threadIdx.y * QI8_0 + threadIdx.x / blocks_per_tile_x_row;
-
- if (need_check) {
- i = min(i, i_max);
+ for (int l = 0; l < mma_C::ne; ++l) {
+ sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2];
+ }
+ }
}
-
- const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
-
-#ifdef INT8_MMA_AVAILABLE
- x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
+ }
#else
- x_df[i*(WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d;
+ GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
+ NO_DEVICE_CODE;
#endif // INT8_MMA_AVAILABLE
- }
}
template <int mmq_x, int mmq_y, int nwarps>
-static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
const int * x_qs = (const int *) x;
- const float * x_df = (const float *) x_qs + txs.qs;
+ const half2 * x_dm = (const half2 *) x_qs + txs.qs;
const int * y_qs = (const int *) y + 4;
- const float * y_df = (const float *) y;
+ const half2 * y_ds = (const half2 *) y;
+
+// #pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += VDR_Q8_0_Q8_1_MMQ) {
+ const int k0 = k00 + k01;
#pragma unroll
- for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
- const int j = j0 + threadIdx.y;
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
- const int i = i0 + threadIdx.x;
+ for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
- (&x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0], x_df[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0],
- y_df[j*MMQ_TILE_Y_K + k0/QI8_1]);
+ sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
+ (&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
+ x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI8_1], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+ }
}
}
}
template <int mmq_x, int mmq_y, int nwarps>
-static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
#ifdef INT8_MMA_AVAILABLE
typedef mma_int_A_I16K8 mma_A;
y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
const int * x_qs = (const int *) x;
- const float * x_df = (const float *) x_qs + WARP_SIZE;
+ const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE;
const int * y_qs = (const int *) y + 4;
- const float * y_df = (const float *) y;
+ const half2 * y_dm = (const half2 *) y;
- mma_A A[ntx];
- float dA[ntx][mma_C::ne/2];
+ mma_A A[ntx][WARP_SIZE/QI8_1];
+ half2 dmA[ntx][mma_C::ne/2][WARP_SIZE/QI8_1];
const int i0 = (threadIdx.y/ntx)*rows_per_warp;
#pragma unroll
for (int n = 0; n < ntx; ++n) {
- A[n].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
+#pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
+ const int k0 = k00 + k01;
+
+ A[n][k01/QI8_1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
+ }
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
- dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
+#pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
+ const int k0 = k00 + k01;
+
+ dmA[n][l][k01/QI8_1] = x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1];
+ }
}
}
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
- mma_B B;
- float dB[mma_C::ne/2];
+#pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
+ const int k0 = k00 + k01;
+
+ mma_B B;
+ half2 dsB[mma_C::ne/2];
- B.load(y_qs + j0*MMQ_TILE_Y_K + k0, MMQ_TILE_Y_K);
+ B.load(y_qs + j0*MMQ_TILE_Y_K + k0 % WARP_SIZE, MMQ_TILE_Y_K);
#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int j = j0 + mma_C::get_j(l);
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int j = j0 + mma_C::get_j(l);
- dB[l] = y_df[j*MMQ_TILE_Y_K + k0/QI8_1];
- }
+ dsB[l] = y_dm[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)];
+ }
#pragma unroll
- for (int n = 0; n < ntx; ++n) {
- mma_C C;
- C.mma_K8(A[n], B);
+ for (int n = 0; n < ntx; ++n) {
+ mma_C C;
+ C.mma_K8(A[n][k01/QI8_1], B);
#pragma unroll
- for (int l = 0; l < mma_C::ne; ++l) {
- sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2]*dB[l%2];
+ for (int l = 0; l < mma_C::ne; ++l) {
+ const half2 dmA_dsB = dmA[n][l/2][k01/QI8_1]*dsB[l%2];
+ sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
+ }
}
}
}
#ifdef INT8_MMA_AVAILABLE
int * x_qs = (int *) x_tile;
- half2 * x_dm = (half2 *) (x_qs + WARP_SIZE);
+ half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
int * x_qs = (int *) x_tile;
half2 * x_dm = (half2 *) (x_qs + txs.qs);
#endif // INT8_MMA_AVAILABLE
- const int kbx = threadIdx.x / QI2_K;
const int kqsx = threadIdx.x % QI2_K;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
- int i = i0 + threadIdx.y;
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI2_K) {
+ int i = i0 + threadIdx.y*(WARP_SIZE/QI2_K) + threadIdx.x/QI2_K;
if (need_check) {
i = min(i, i_max);
}
- const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + kbx;
+ const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride;
const int x_ql_0 = get_int_b2(bxi->qs, kqsx);
#pragma unroll
for (int l = 0; l < QR2_K; ++l) {
- const int k = kbx*QI2_K + (kqsx/8)*8 + l*2 + (kqsx % 8)/4;
+ const int k = (kqsx/8)*32 + l*8 + kqsx % 8;
- int x_qs_k = ((x_ql_0 >> (2*l)) & 0x03030303) << (2*(kqsx % 4));
- x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 1, WARP_SIZE);
- x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 2, WARP_SIZE);
-
- if (kqsx % QR2_K != 0) {
- continue;
- }
+ const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303;
#ifdef INT8_MMA_AVAILABLE
x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
#else
- x_qs[i*(WARP_SIZE + 1) + k] = x_qs_k;
+ x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k;
#endif // INT8_MMA_AVAILABLE
}
#endif // FAST_FP16_AVAILABLE
#ifdef INT8_MMA_AVAILABLE
- x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + threadIdx.x] = x_dm_ik;
+ x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik;
#else
- x_dm[i*(WARP_SIZE + 1) + threadIdx.x] = x_dm_ik;
+ x_dm[i*(WARP_SIZE + 1) + kqsx] = x_dm_ik;
#endif // INT8_MMA_AVAILABLE
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
const int * x_qs = (const int *) x;
const half2 * x_dm = (const half2 *) x_qs + txs.qs;
const int * y_qs = (const int *) y + 4;
- const float * y_df = (const float *) y;
+ const half2 * y_ds = (const half2 *) y;
+ float2 y_df[mmq_x/nwarps];
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
const int j = j0 + threadIdx.y;
+ y_df[j0/nwarps] = __half22float2(y_ds[j*MMQ_TILE_Y_K]);
+ }
+
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
- const int i = i0 + threadIdx.x;
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
+ const int k0 = k00 + k01;
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq(
- &x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + (QR2_K*k0) % WARP_SIZE],
- &x_dm[i*(WARP_SIZE + 1) + k0], y_df[j*MMQ_TILE_Y_K + ((QR2_K*k0) % WARP_SIZE)/QI8_1]);
+ if (k01 < WARP_SIZE/2) {
+ constexpr int ns = 2;
+ sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
+ &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
+ &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
+ &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
+ } else {
+ constexpr int ns = 1;
+ sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
+ &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
+ &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
+ &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
+ }
+ }
}
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
#ifdef INT8_MMA_AVAILABLE
typedef mma_int_A_I16K4 mma_A;
+ typedef mma_int_A_I16K8 mma_A_K8;
typedef mma_int_B_J8K4 mma_B;
typedef mma_int_C_I16J8 mma_C;
y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
const int * x_qs = (const int *) x;
- const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE;
+ const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2;
const int * y_qs = (const int *) y + 4;
- const float * y_df = (const float *) y;
+ const half2 * y_ds = (const half2 *) y;
const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
- mma_A A[ntx][2];
- float dA[ntx][mma_C::ne/2][2];
- float mA[ntx][mma_C::ne/2][2];
+ mma_A A[ntx][8];
+ float dA[ntx][mma_C::ne/2][8];
+ float mA[ntx][mma_C::ne/2][8];
#pragma unroll
for (int n = 0; n < ntx; ++n) {
#pragma unroll
- for (int l = 0; l < mma_A::ne; ++l) {
- const int i = i0 + n*mma_A::I + mma_A::get_i(l);
- const int shift = 2*mma_A::get_k(l);
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
+ const int k0 = k00 + k01;
- A[n][0].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k0 + 0] >> shift) & 0x03030303;
- A[n][1].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k0 + 1] >> shift) & 0x03030303;
+ ((mma_A_K8 *) A[n])[k01/QI8_1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
}
+ }
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
#pragma unroll
- for (int kdm = 0; kdm < 2; ++kdm) {
- const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0 + kdm]);
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1/2) {
+ const int k0 = k00 + k01;
- dA[n][l][kdm] = dm.x;
- mA[n][l][kdm] = dm.y;
+ const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/(QI8_1/2)]);
+
+ dA[n][l][k01/(QI8_1/2)] = dm.x;
+ mA[n][l][k01/(QI8_1/2)] = dm.y;
}
}
}
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
- mma_B B[2];
- float dB[mma_C::ne/2];
-
- B[0].load(y_qs + j0*MMQ_TILE_Y_K + (QR2_K*k0 + 0) % WARP_SIZE, MMQ_TILE_Y_K);
- B[1].load(y_qs + j0*MMQ_TILE_Y_K + (QR2_K*k0 + mma_B::K) % WARP_SIZE, MMQ_TILE_Y_K);
+ float2 dB[mma_C::ne/2];
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
- dB[l] = y_df[j*MMQ_TILE_Y_K + ((4*k0)/QI8_1) % (WARP_SIZE/QI8_1)];
+ dB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K]);
}
- mma_C Cm[2];
- mma_A A1;
- A1.x[0] = 0x01010101;
- A1.x[1] = 0x01010101;
- Cm[0].mma_K4(A1, B[0]);
- Cm[1].mma_K4(A1, B[1]);
+#pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
+ mma_B B[2];
+
+ B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
+ B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
+
+ mma_C Cm[2];
+ if (k01 >= WARP_SIZE * 3/4) {
+ mma_A A1;
+ A1.x[0] = 0x01010101;
+ A1.x[1] = 0x01010101;
+ Cm[0].mma_K4(A1, B[0]);
+ Cm[1].mma_K4(A1, B[1]);
+ }
#pragma unroll
- for (int n = 0; n < ntx; ++n) {
- mma_C Cd[2];
+ for (int n = 0; n < ntx; ++n) {
+ mma_C Cd[2];
- Cd[0].mma_K4(A[n][0], B[0]);
- Cd[1].mma_K4(A[n][1], B[1]);
+ Cd[0].mma_K4(A[n][k01/4 + 0], B[0]);
+ Cd[1].mma_K4(A[n][k01/4 + 1], B[1]);
#pragma unroll
- for (int l = 0; l < mma_C::ne; ++l) {
- sum[(j0/mma_C::J + n)*mma_C::ne + l] += (
- Cd[0].x[l]*dA[n][l/2][0] + Cd[1].x[l]*dA[n][l/2][1] - Cm[0].x[l]*mA[n][l/2][0] - Cm[1].x[l]*mA[n][l/2][1])*dB[l%2];
+ for (int l = 0; l < mma_C::ne; ++l) {
+ float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1];
+ if (k01 >= WARP_SIZE * 3/4) {
+ tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1];
+ }
+ sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y);
+ }
+ }
+ }
+
+#pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE * 3/4; k01 += QI8_1) {
+ float2 sB[mma_C::ne/2];
+
+#pragma unroll
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int j = j0 + mma_C::get_j(l);
+
+ sB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
+ }
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+ for (int l = 0; l < mma_C::ne; ++l) {
+ sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x;
+ sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y;
+ }
}
}
}
#ifdef INT8_MMA_AVAILABLE
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + WARP_SIZE*2);
- int * x_sc = (int *) (x_df + WARP_SIZE/QI3_K);
+ int * x_sc = (int *) (x_df + 1);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
int * x_qs = (int *) x_tile;
int * x_sc = (int *) (x_df + txs.dm);
#endif // INT8_MMA_AVAILABLE
- const int kbx = threadIdx.x / QI3_K;
const int kqsx = threadIdx.x % QI3_K;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
- int i = i0 + threadIdx.y;
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI3_K) {
+ int i = i0 + threadIdx.y * (WARP_SIZE/QI3_K) + threadIdx.x / QI3_K;
if (need_check) {
i = min(i, i_max);
}
- const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbx;
+ const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
const int x_ql_0 = get_int_b2(bxi->qs, kqsx);
const int x_qh_0 = get_int_b2(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2)));
#pragma unroll
for (int l = 0; l < QR3_K; ++l) {
- const int k = kbx*(QR3_K*QI3_K) + (kqsx/8)*32 + l*8 + kqsx % 8;
+ const int k = (kqsx/8)*32 + l*8 + kqsx % 8;
const int x_ql_k = (x_ql_0 >> (2*l)) & 0x03030303;
const int x_qh_k = ((x_qh_0 >> l) << 2) & 0x04040404;
- int x_qs_k = (x_ql_k | x_qh_k) << (4*(k%2));
- x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 1, WARP_SIZE);
-
- if (kqsx % 2 != 0) {
- continue;
- }
+ const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404);
#ifdef INT8_MMA_AVAILABLE
- x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k/2] = x_qs_k;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k;
#else
- x_qs[i*(2*WARP_SIZE + 1) + k/2] = x_qs_k;
+ x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k;
#endif // INT8_MMA_AVAILABLE
}
}
- const int blocks_per_tile_x_row = WARP_SIZE / QI3_K;
- const int kbxd = threadIdx.x % blocks_per_tile_x_row;
-
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) {
- int i = (i0 + threadIdx.y * QI3_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*WARP_SIZE) {
+ int i = (i0 + threadIdx.y*WARP_SIZE + threadIdx.x) % mmq_y;
if (need_check) {
i = min(i, i_max);
}
- const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbxd;
+ const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
#ifdef INT8_MMA_AVAILABLE
- x_df[i*MMQ_MMA_TILE_X_K_Q3_K + kbxd] = bxi->d;
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K] = bxi->d;
#else
- x_df[i*(WARP_SIZE/QI3_K) + i/QI3_K + kbxd] = bxi->d;
+ x_df[i] = bxi->d;
#endif // INT8_MMA_AVAILABLE
}
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
- int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*8) {
+ int i = i0 + threadIdx.y*8 + threadIdx.x/(WARP_SIZE/8);
if (need_check) {
i = min(i, i_max);
}
- const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/4)) / (QI3_K/4);
+ const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
- const int ksc = threadIdx.x % (QI3_K/4);
+ const int ksc = threadIdx.x % (WARP_SIZE/8);
const int ksc_low = ksc % (QI3_K/8);
const int shift_low = 4 * (ksc / (QI3_K/8));
const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
#ifdef INT8_MMA_AVAILABLE
- x_sc[i*MMQ_MMA_TILE_X_K_Q3_K + threadIdx.x % (WARP_SIZE/4)] = sc;
+ x_sc[i*MMQ_MMA_TILE_X_K_Q3_K + threadIdx.x % (WARP_SIZE/8)] = sc;
#else
- x_sc[i*(WARP_SIZE/4) + i/4 + threadIdx.x % (WARP_SIZE/4)] = sc;
+ x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = sc;
#endif // INT8_MMA_AVAILABLE
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
const int * x_qs = (const int *) x;
const int * y_qs = (const int *) y + 4;
const float * y_df = (const float *) y;
-#pragma unroll
- for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
- const int j = j0 + threadIdx.y;
+// #pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
+ const int k0 = k00 + k01;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
- const int i = i0 + threadIdx.x;
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
- const int kbx = k0 / QI3_K;
- const int ky = (k0 % QI3_K) * QR3_K;
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
- const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
+ const int8_t * scales = ((const int8_t *) (x_sc + i*(WARP_SIZE/8) + i/8)) + k0/4;
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq(
- &x_qs[i*(2*WARP_SIZE + 1) + 2*k0], &y_qs[j*MMQ_TILE_Y_K + (k0*QR3_K) % WARP_SIZE], scales,
- x_df[i*(WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[j*MMQ_TILE_Y_K + ((k0*QR3_K) % WARP_SIZE)/QI8_1]);
+ sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq(
+ &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], scales,
+ x_df[i], y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
+ }
}
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mma(
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
#ifdef INT8_MMA_AVAILABLE
typedef mma_int_A_I16K4 mma_A;
+ typedef mma_int_A_I16K8 mma_A_K8;
typedef mma_int_B_J8K4 mma_B;
typedef mma_int_C_I16J8 mma_C;
const int * x_qs = (const int *) x;
const float * x_df = (const float *) x_qs + WARP_SIZE*2;
- const int * x_sc = (const int *) x_df + WARP_SIZE/QI3_K;
+ const int * x_sc = (const int *) x_df + 1;
const int * y_qs = (const int *) y + 4;
const float * y_df = (const float *) y;
const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
- mma_A A[ntx][2];
- int scA[ntx][mma_C::ne/2][2];
+ mma_A A[ntx][8];
+ int scA[ntx][mma_C::ne/2][8];
float dA[ntx][mma_C::ne/2];
#pragma unroll
for (int n = 0; n < ntx; ++n) {
#pragma unroll
- for (int l = 0; l < mma_A::ne; ++l) {
- const int i = i0 + n*mma_A::I + mma_A::get_i(l);
- const int k = QR3_K*k0 + mma_A::get_k(l);
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
+ const int k0 = k00 + k01;
- A[n][0].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k/2 + 0] >> (4*(k%2))) & 0x0F0F0F0F;
- A[n][1].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k/2 + mma_A::K/2] >> (4*(k%2))) & 0x0F0F0F0F;
- A[n][0].x[l] = __vsubss4(A[n][0].x[l], 0x04040404);
- A[n][1].x[l] = __vsubss4(A[n][1].x[l], 0x04040404);
+ ((mma_A_K8 *) A[n])[k01/8].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
}
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
- const int kbx = k0 / QI3_K;
- const int ky = (k0 % QI3_K) * QR3_K;
- const int8_t * sc = ((const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q3_K + kbx*4)) + ky/4;
+#pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += 16) {
+ const int k0 = k00 + k01;
- scA[n][l][0] = sc[0];
- scA[n][l][1] = sc[1];
- }
+ const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q3_K + k0/16];
+ const int8_t * sc = (const int8_t *) &sc_packed;
#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
+ for (int ksc = 0; ksc < sizeof(int); ++ksc) {
+ scA[n][l][k01/4 + ksc] = sc[ksc];
+ }
+ }
- dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/QI3_K];
+ dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K];
}
}
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
- mma_B B[2];
- float dB[mma_C::ne/2];
+#pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
+ mma_B B[2];
+ float dB[mma_C::ne/2];
- B[0].load(y_qs + j0*MMQ_TILE_Y_K + (QR3_K*k0 + 0) % WARP_SIZE, MMQ_TILE_Y_K);
- B[1].load(y_qs + j0*MMQ_TILE_Y_K + (QR3_K*k0 + mma_B::K) % WARP_SIZE, MMQ_TILE_Y_K);
+ B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
+ B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int j = j0 + mma_C::get_j(l);
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int j = j0 + mma_C::get_j(l);
- dB[l] = y_df[j*MMQ_TILE_Y_K + ((4*k0)/QI8_1) % (WARP_SIZE/QI8_1)];
- }
+ dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
+ }
#pragma unroll
- for (int n = 0; n < ntx; ++n) {
- mma_C C[2];
- C[0].mma_K4(A[n][0], B[0]);
- C[1].mma_K4(A[n][1], B[1]);
+ for (int n = 0; n < ntx; ++n) {
+ mma_C C[2];
+ C[0].mma_K4(A[n][k01/4 + 0], B[0]);
+ C[1].mma_K4(A[n][k01/4 + 1], B[1]);
#pragma unroll
- for (int l = 0; l < mma_C::ne; ++l) {
- sum[(j0/mma_C::J + n)*mma_C::ne + l] += (C[0].x[l]*scA[n][l/2][0] + C[1].x[l]*scA[n][l/2][1])*dA[n][l/2]*dB[l%2];
+ for (int l = 0; l < mma_C::ne; ++l) {
+ sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2]*dB[l%2]*
+ (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1]);
+ }
}
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
const int * x_qs = (const int *) x;
const int * y_qs = (const int *) y + 4;
const half2 * y_ds = (const half2 *) y;
+// #pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_K*VDR_Q4_K_Q8_1_MMQ) {
+ const int k0 = k00 + k01;
+
#pragma unroll
- for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
- const int j = j0 + threadIdx.y;
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
- const int i = i0 + threadIdx.x;
+ for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
- const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2*((k0 % 16) / 8);
+ const uint8_t * sc = (const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/32] + 2*(k01/16);
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq(
- &x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + (QR4_K*k0) % WARP_SIZE], sc, sc+8,
- x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[j*MMQ_TILE_Y_K + ((QR4_K*k0) % WARP_SIZE)/QI8_1]);
+ sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq(
+ &x_qs[i*(WARP_SIZE + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
+ x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+ }
}
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
#ifdef INT8_MMA_AVAILABLE
typedef mma_int_A_I16K8 mma_A;
const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
- mma_A A[ntx][2];
- int scA[ntx][mma_C::ne/2][2];
- int mA[ntx][mma_C::ne/2][2];
+ mma_A A[ntx][4];
+ int scA[ntx][mma_C::ne/2][4];
+ int mA[ntx][mma_C::ne/2][4];
half2 dmA[ntx][mma_C::ne/2];
#pragma unroll
for (int n = 0; n < ntx; ++n) {
#pragma unroll
- for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 8) {
- A[n][kvdr/4 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_K + k0, MMQ_MMA_TILE_X_K_Q4_K);
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += 16) {
+ const int k0 = k00 + k01;
+
+ A[n][k01/8 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_K + k0/QR4_K, MMQ_MMA_TILE_X_K_Q4_K);
#pragma unroll
for (int l = 0; l < mma_A::ne; ++l) {
- A[n][kvdr/4 + 1].x[l] = (A[n][kvdr/4 + 0].x[l] >> 4) & 0x0F0F0F0F;
- A[n][kvdr/4 + 0].x[l] &= 0x0F0F0F0F;
+ A[n][k01/8 + 1].x[l] = (A[n][k01/8 + 0].x[l] >> 4) & 0x0F0F0F0F;
+ A[n][k01/8 + 0].x[l] &= 0x0F0F0F0F;
}
}
#pragma unroll
- for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 4) {
-#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
+
+ const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + (k00/32 + 0)];
+ const int m_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + (k00/32 + 2)];
- const uint8_t * sc = ((const uint8_t *) &x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + k0/16]) + 2 * ((k0 % 16) / 8);
- const uint8_t * m = sc + 8;
+ const uint8_t * sc = (const uint8_t *) &sc_packed;
+ const uint8_t * m = (const uint8_t *) &m_packed;
- scA[n][l][kvdr/4] = sc[kvdr/4];
- mA[n][l][kvdr/4] = m[kvdr/4];
+#pragma unroll
+ for (int ksc = 0; ksc < sizeof(int); ++ksc) {
+ scA[n][l][ksc] = sc[ksc];
+ mA[n][l][ksc] = m[ksc];
}
}
for (int l = 0; l < mma_C::ne/2; ++l) {
const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
- dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_K + k0/QI4_K];
+ dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_K];
}
}
float tmpm[ntx][mma_C::ne] = {{0.0f}};
#pragma unroll
- for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 4) {
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
mma_B B;
half2 dsB[mma_C::ne/2];
- B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0 + 2*kvdr) % WARP_SIZE, MMQ_TILE_Y_K);
+ B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
- dsB[l] = y_ds[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
+ dsB[l] = y_ds[j*MMQ_TILE_Y_K + k01/QI8_1];
}
#pragma unroll
for (int n = 0; n < ntx; ++n) {
mma_C C;
- C.mma_K8(A[n][kvdr/4], B);
+ C.mma_K8(A[n][k01/8], B);
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
- tmpd[n][l] += (C.x[l]*scA[n][l/2][kvdr/4]) * __low2float(dsB[l%2]);
- tmpm[n][l] += mA[n][l/2][kvdr/4] * __high2float(dsB[l%2]);
+ tmpd[n][l] += (C.x[l]*scA[n][l/2][k01/8]) * __low2float(dsB[l%2]);
+ tmpm[n][l] += mA[n][l/2][k01/8] * __high2float(dsB[l%2]);
}
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
const int * x_qs = (const int *) x;
const int * y_qs = (const int *) y + 4;
const half2 * y_ds = (const half2 *) y;
+// #pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) {
+ const int k0 = k00 + k01;
+
#pragma unroll
- for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
- const int j = j0 + threadIdx.y;
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
- const int i = i0 + threadIdx.x;
+ for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
- const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
+ const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k00/32]) + 2*(k01/16);
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq(
- &x_qs[i*(QR5_K*WARP_SIZE + 1) + QR5_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR5_K*k0) % WARP_SIZE], sc, sc+8,
- x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[j*MMQ_TILE_Y_K + ((QR5_K*k0) % WARP_SIZE)/QI8_1]);
+ sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq(
+ &x_qs[i*(QR5_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
+ x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+ }
}
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
#ifdef INT8_MMA_AVAILABLE
typedef mma_int_A_I16K8 mma_A;
const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
- mma_A A[ntx][2];
- int scA[ntx][mma_C::ne/2][2];
- int mA[ntx][mma_C::ne/2][2];
+ mma_A A[ntx][4];
+ int scA[ntx][mma_C::ne/2][4];
+ int mA[ntx][mma_C::ne/2][4];
half2 dmA[ntx][mma_C::ne/2];
#pragma unroll
for (int n = 0; n < ntx; ++n) {
#pragma unroll
- for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
- A[n][kvdr/4].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_K + (QR5_K*k0 + QR5_K*kvdr), MMQ_MMA_TILE_X_K_Q5_K);
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
+ const int k0 = k00 + k01;
+
+ A[n][k01/8].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_K + k0, MMQ_MMA_TILE_X_K_Q5_K);
+ }
#pragma unroll
- for (int l = 0; l < mma_C::ne/2; ++l) {
- const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
+
+ const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + (k00/32 + 0)];
+ const int m_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + (k00/32 + 2)];
- const uint8_t * sc = ((const uint8_t *) &x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + k0/16]) + 2 * ((k0 % 16) / 8);
- const uint8_t * m = sc + 8;
+ const uint8_t * sc = (const uint8_t *) &sc_packed;
+ const uint8_t * m = (const uint8_t *) &m_packed;
- scA[n][l][kvdr/4] = sc[kvdr/4];
- mA[n][l][kvdr/4] = m[kvdr/4];
+#pragma unroll
+ for (int ksc = 0; ksc < sizeof(int); ++ksc) {
+ scA[n][l][ksc] = sc[ksc];
+ mA[n][l][ksc] = m[ksc];
}
}
for (int l = 0; l < mma_C::ne/2; ++l) {
const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
- dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q5_K + k0/QI5_K];
+ dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q5_K];
}
}
float tmpm[ntx][mma_C::ne] = {{0.0f}};
#pragma unroll
- for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
+ const int k0 = k00 + k01;
+
mma_B B;
half2 dsB[mma_C::ne/2];
- B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0 + 2*kvdr) % WARP_SIZE, MMQ_TILE_Y_K);
+ B.load(y_qs + j0*MMQ_TILE_Y_K + k0 % WARP_SIZE, MMQ_TILE_Y_K);
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
- dsB[l] = y_ds[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
+ dsB[l] = y_ds[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)];
}
#pragma unroll
for (int n = 0; n < ntx; ++n) {
mma_C C;
- C.mma_K8(A[n][kvdr/4], B);
+ C.mma_K8(A[n][k01/8], B);
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
- tmpd[n][l] += (C.x[l]*scA[n][l/2][kvdr/4]) * __low2float(dsB[l%2]);
- tmpm[n][l] += mA[n][l/2][kvdr/4] * __high2float(dsB[l%2]);
+ tmpd[n][l] += (C.x[l]*scA[n][l/2][k01/8]) * __low2float(dsB[l%2]);
+ tmpm[n][l] += mA[n][l/2][k01/8] * __high2float(dsB[l%2]);
}
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
const int * x_qs = (const int *) x;
const int * y_qs = (const int *) y + 4;
const float * y_df = (const float *) y;
+// #pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QR6_K*VDR_Q6_K_Q8_1_MMQ) {
+ const int k0 = k00 + k01;
+
#pragma unroll
- for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
- const int j = j0 + threadIdx.y;
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
- const int i = i0 + threadIdx.x;
+ for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
- const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/8]);
+ const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]);
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq(
- &x_qs[i*(QR6_K*WARP_SIZE + 1) + QR6_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR6_K*k0) % WARP_SIZE], sc,
- x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + ((QR6_K*k0) % WARP_SIZE)/QI8_1]);
+ sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq(
+ &x_qs[i*(QR6_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc,
+ x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
+ }
}
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
#ifdef INT8_MMA_AVAILABLE
typedef mma_int_A_I16K4 mma_A;
const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
- mma_A A[ntx][4];
- int scA[ntx][mma_C::ne/2][4];
+ mma_A A[ntx][8];
+ int scA[ntx][mma_C::ne/2][8];
float dA[ntx][mma_C::ne/2];
#pragma unroll
for (int n = 0; n < ntx; ++n) {
#pragma unroll
- for (int kvdr = 0; kvdr < VDR_Q6_K_Q8_1_MMQ; kvdr += 4) {
- A[n][kvdr/2 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (QR6_K*k0 + QR6_K*kvdr + 0), MMQ_MMA_TILE_X_K_Q6_K);
- A[n][kvdr/2 + 1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (QR6_K*k0 + QR6_K*kvdr + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K);
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
+ const int k0 = k00 + k01;
+
+ A[n][k01/4 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K);
+ A[n][k01/4 + 1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K);
+ }
+
+#pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += 16) {
+ const int k0 = k00 + k01;
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
- const int8_t * sc = ((const int8_t *) &x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/8]);
+ const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/16];
+ const int8_t * sc = (const int8_t *) &sc_packed;
- scA[n][l][kvdr/2 + 0] = sc[kvdr/2 + 0];
- scA[n][l][kvdr/2 + 1] = sc[kvdr/2 + 1];
+#pragma unroll
+ for (int ksc = 0; ksc < sizeof(int); ++ksc) {
+ scA[n][l][k01/4 + ksc] = sc[ksc];
+ }
}
}
for (int l = 0; l < mma_C::ne/2; ++l) {
const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
- dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K + k0/QI6_K];
+ dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K];
}
}
float tmp[ntx][mma_C::ne] = {{0.0f}};
#pragma unroll
- for (int kvdr = 0; kvdr < VDR_Q6_K_Q8_1_MMQ; kvdr += 4) {
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
mma_B B[2];
float dB[mma_C::ne/2];
- const int k0B = (2*k0 + 2*kvdr) % WARP_SIZE;
- B[0].load(y_qs + j0*MMQ_TILE_Y_K + 0 + k0B, MMQ_TILE_Y_K);
- B[1].load(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k0B, MMQ_TILE_Y_K);
+ B[0].load(y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K);
+ B[1].load(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k01, MMQ_TILE_Y_K);
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
- dB[l] = y_df[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
+ dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
}
#pragma unroll
for (int n = 0; n < ntx; ++n) {
mma_C C[2];
- C[0].mma_K4(A[n][kvdr/2 + 0], B[0]);
- C[1].mma_K4(A[n][kvdr/2 + 1], B[1]);
+ C[0].mma_K4(A[n][k01/4 + 0], B[0]);
+ C[1].mma_K4(A[n][k01/4 + 1], B[1]);
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
- tmp[n][l] += (C[0].x[l]*scA[n][l/2][kvdr/2 + 0] + C[1].x[l]*scA[n][l/2][kvdr/2 + 1])*dB[l%2];
+ tmp[n][l] += (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1])*dB[l%2];
}
}
}
const int2 v = get_int_from_table_16(aux_q4);
const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
#ifdef INT8_MMA_AVAILABLE
- x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + k0 + 0] = v.x;
- x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + k0 + 4] = v.y;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
#else
x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
#ifdef INT8_MMA_AVAILABLE
- x_df[i*MMQ_MMA_TILE_X_K_Q5_0 + kbxd] = __half2float(bxi->d);
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d);
#else
x_df[i*(WARP_SIZE/4) + i/4 + kbxd] = __half2float(bxi->d);
#endif // INT8_MMA_AVAILABLE
const int2 v = get_int_from_table_16(aux_q4);
const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
#ifdef INT8_MMA_AVAILABLE
- x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + k0 + 0] = v.x;
- x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + k0 + 4] = v.y;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
#else
x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
| (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
#ifdef INT8_MMA_AVAILABLE
- x_df[i*MMQ_MMA_TILE_X_K_Q5_0 + threadIdx.x % 8] = d * (ls - 32);
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32);
#else
x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * (ls - 32);
#endif // INT8_MMA_AVAILABLE
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>;
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_NL> {
static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl<mmq_y, nwarps, need_check>;
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_XS> {
static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs<mmq_y, nwarps, need_check>;
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};
-static bool mmq_need_sum(const ggml_type type_x) {
- switch (type_x) {
- case GGML_TYPE_Q4_0:
- case GGML_TYPE_Q4_1:
- return true;
- case GGML_TYPE_Q5_0:
- return false;
- case GGML_TYPE_Q5_1:
- return true;
- case GGML_TYPE_Q8_0:
- case GGML_TYPE_Q2_K:
- case GGML_TYPE_Q3_K:
- return false;
- case GGML_TYPE_Q4_K:
- case GGML_TYPE_Q5_K:
- return true;
- case GGML_TYPE_Q6_K:
- case GGML_TYPE_IQ4_XS:
- case GGML_TYPE_IQ4_NL:
- return false;
- default:
- GGML_ASSERT(false);
- break;
- }
- return false;
-}
-
template <ggml_type type, int mmq_x, int nwarps, bool need_check, bool fixup>
static __device__ void mul_mat_q_process_tile(
const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
const int & it, const int & jt, const int & kb0_start, const int & kb0_stop) {
constexpr int qk = ggml_cuda_type_traits<type>::qk;
- constexpr int qr = ggml_cuda_type_traits<type>::qr;
- constexpr int qi = ggml_cuda_type_traits<type>::qi;
constexpr int mmq_y = get_mmq_y_device();
- constexpr int vdr = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
extern __shared__ char data_mul_mat_q[];
constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
#endif // INT8_MMA_AVAILABLE
- constexpr int blocks_per_warp = WARP_SIZE / qi;
+ constexpr int blocks_per_iter = MMQ_ITER_K / qk;
float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
const int * y = (const int *) yc + jt*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
- for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_warp) {
-
+ for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) {
load_tiles(x, tile_x, stride01*it*mmq_y + kb0, tile_x_max_i, stride01);
-#pragma unroll
- for (int kr = 0; kr < qr; ++kr) {
- const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + kr*sizeof(block_q8_1_mmq)/sizeof(int));
+ {
+ const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int));
#pragma unroll
for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
tile_y[l] = by0[l];
}
+ }
- __syncthreads();
+ __syncthreads();
-// #pragma unroll // unrolling this loop causes too much register pressure
- for (int k0 = kr*WARP_SIZE/qr; k0 < (kr+1)*WARP_SIZE/qr; k0 += vdr) {
- vec_dot(tile_x, tile_y, sum, k0);
- }
+ vec_dot(tile_x, tile_y, sum, 0);
+
+ __syncthreads();
+
+ {
+ const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int));
+#pragma unroll
+ for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
+ int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
- __syncthreads();
+ tile_y[l] = by0[l];
+ }
}
+
+ __syncthreads();
+
+ vec_dot(tile_x, tile_y, sum, WARP_SIZE);
+
+ __syncthreads();
}
if (fixup) {
}
constexpr int qk = ggml_cuda_type_traits<type>::qk;
- constexpr int qi = ggml_cuda_type_traits<type>::qi;
constexpr int mmq_y = get_mmq_y_device();
// On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
#endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA
const int64_t blocks_per_ne00 = ne00 / qk;
- constexpr int blocks_per_warp = WARP_SIZE / qi;
+ constexpr int blocks_per_iter = MMQ_ITER_K / qk;
const int ntx = (ne11 + mmq_x - 1) / mmq_x; // Number of tiles x
const int nty = (ne01 + mmq_y - 1) / mmq_y; // Number of tiles y
int64_t kbc = (int64_t) blockIdx.x *blocks_per_ne00*ntx*nty / gridDim.x;
int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*blocks_per_ne00*ntx*nty / gridDim.x;
- kbc -= (kbc % blocks_per_ne00) % blocks_per_warp;
- kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_warp;
+ kbc -= (kbc % blocks_per_ne00) % blocks_per_iter;
+ kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_iter;
// kb0 == k index when doing the matrix multiplication for an output tile.
int kb0_start = kbc % blocks_per_ne00;
constexpr int mmq_y = get_mmq_y_device();
constexpr int qk = ggml_cuda_type_traits<type>::qk;
- constexpr int qi = ggml_cuda_type_traits<type>::qi;
- constexpr int blocks_per_warp = WARP_SIZE / qi;
+ constexpr int blocks_per_iter = MMQ_ITER_K / qk;
const int64_t blocks_per_ne00 = ne00 / qk;
float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
bool any_fixup = false;
- const int bidx_start = (blockIdx.y*nty + blockIdx.x) * block_num_mmq / (gridDim.y*gridDim.x);
- const int bidx_stop = (blockIdx.y*nty + blockIdx.x + 1) * block_num_mmq / (gridDim.y*gridDim.x) + 1;
+ const int bidx_start = ((blockIdx.y*nty + blockIdx.x) * block_num_mmq) / (gridDim.y*gridDim.x);
+ const int bidx_stop = ((blockIdx.y*nty + blockIdx.x + 1) * block_num_mmq + gridDim.y*gridDim.x - 1) / (gridDim.y*gridDim.x);
+
+ int64_t kbc_0;
+ int64_t kbc_stop_0 = (int64_t) bidx_start*blocks_per_ne00*ntx*nty / block_num_mmq;
for (int bidx = bidx_start; bidx < bidx_stop; ++bidx) {
- int64_t kbc = (int64_t) bidx *blocks_per_ne00*ntx*nty / block_num_mmq;
- int64_t kbc_stop = (int64_t)(bidx + 1)*blocks_per_ne00*ntx*nty / block_num_mmq;
+ kbc_0 = kbc_stop_0;
+ kbc_stop_0 = (int64_t) (bidx + 1)*blocks_per_ne00*ntx*nty / block_num_mmq;
- kbc -= (kbc % blocks_per_ne00) % blocks_per_warp;
- kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_warp;
+ const int64_t kbc = kbc_0 - (kbc_0 % blocks_per_ne00) % blocks_per_iter;
+ const int64_t kbc_stop = kbc_stop_0 - (kbc_stop_0 % blocks_per_ne00) % blocks_per_iter;
// Skip fixup tile if the MMQ CUDA block never wrote anything to it:
if (kbc == kbc_stop || kbc_stop % blocks_per_ne00 == 0) {