#include "common.cuh"
#include "vecdotq.cuh"
+#include "mma.cuh"
#include <climits>
#include <cstdint>
typedef void (*vec_dot_mmq_t)(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0);
+typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1);
struct block_q8_1_mmq {
half2 ds[4];
}
template <int mmq_x, int mmq_y, int nwarps>
-static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mul_mat(
+static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
- const float * x_dmf = (const float *) x_dm;
- const int * y_qs = (const int *) y + 4;
- const half2 * y_ds = (const half2 *) y;
+ const float * x_df = (const float *) x_dm;
+ const int * y_qs = (const int *) y + 4;
+ const half2 * y_ds = (const half2 *) y;
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
}
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
- (&x_ql[i*(WARP_SIZE + 1) + k0], u, x_dmf[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0],
+ (&x_ql[i*(WARP_SIZE + 1) + k0], u, x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0],
y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
}
}
}
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma(
+ const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+ const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+
+ GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+
+ typedef mma_int_A_I16K8 mma_A;
+ typedef mma_int_B_J8K8 mma_B;
+ typedef mma_int_C_I16J8 mma_C;
+
+ const float * x_df = (const float *) x_dm;
+ const int * y_qs = (const int *) y + 4;
+ const half2 * y_ds = (const half2 *) y;
+
+ mma_A A;
+ float dA[mma_C::ne/2];
+
+ const int i0 = threadIdx.y*mma_A::I;
+ static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+
+#pragma unroll
+ for (int l = 0; l < mma_A::ne; ++l) {
+ const int i = i0 + mma_A::get_i(l);
+ const int k = k0 + mma_A::get_k(l) % QI4_0;
+ const int shift = 4*(mma_A::get_k(l) / QI4_0);
+
+ A.x[l] = __vsubss4((x_ql[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F, 0x08080808);
+ }
+#pragma unroll
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int i = i0 + mma_C::get_i(2*l);
+
+ dA[l] = x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0];
+ }
+
+ for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
+ mma_C C;
+ mma_B B;
+ half2 dsB[mma_C::ne/2];
+
+#pragma unroll
+ for (int l = 0; l < mma_B::ne; ++l) {
+ const int j = j0 + mma_B::get_j(l);
+ const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
+
+ B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
+ }
+#pragma unroll
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int j = j0 + mma_C::get_j(l);
+
+ dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
+ }
+
+ C.mma_K8(A, B);
+
+#pragma unroll
+ for (int l = 0; l < mma_C::ne; ++l) {
+ sum[(j0/B.J)*C.ne + l] += dA[l/2]*__low2float(dsB[l%2])*C.x[l];
+ }
+ }
+}
+
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
}
template <int mmq_x, int mmq_y, int nwarps>
-static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mul_mat(
+static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
}
}
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma(
+ const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+ const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+
+ GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+
+ typedef mma_int_A_I16K8 mma_A;
+ typedef mma_int_B_J8K8 mma_B;
+ typedef mma_int_C_I16J8 mma_C;
+
+ const int * y_qs = (const int *) y + 4;
+ const half2 * y_ds = (const half2 *) y;
+
+ mma_A A;
+ half2 dmA[mma_C::ne/2];
+
+ const int i0 = threadIdx.y*mma_A::I;
+ static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+
+#pragma unroll
+ for (int l = 0; l < mma_A::ne; ++l) {
+ const int i = i0 + mma_A::get_i(l);
+ const int k = k0 + mma_A::get_k(l) % QI4_0;
+ const int shift = 4*(mma_A::get_k(l) / QI4_0);
+
+ A.x[l] = (x_ql[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F;
+ }
+#pragma unroll
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int i = i0 + mma_C::get_i(2*l);
+
+ dmA[l] = x_dm[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0];
+ }
+
+ for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
+ mma_C C;
+ mma_B B;
+ half2 dsB[mma_C::ne/2];
+
+#pragma unroll
+ for (int l = 0; l < mma_B::ne; ++l) {
+ const int j = j0 + mma_B::get_j(l);
+ const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
+
+ B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
+ }
+#pragma unroll
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int j = j0 + mma_C::get_j(l);
+
+ dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
+ }
+
+ C.mma_K8(A, B);
+
+#pragma unroll
+ for (int l = 0; l < mma_C::ne; ++l) {
+ const half2 dmA_dsB = dmA[l/2]*dsB[l%2];
+ sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
+ }
+ }
+}
+
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
}
template <int mmq_x, int mmq_y, int nwarps>
-static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mul_mat(
+static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
}
}
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma(
+ const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+ const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+
+ GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+
+ typedef mma_int_A_I16K8 mma_A;
+ typedef mma_int_B_J8K8 mma_B;
+ typedef mma_int_C_I16J8 mma_C;
+
+ const float * x_df = (const float *) x_dm;
+ const int * y_qs = (const int *) y + 4;
+ const float * y_df = (const float *) y;
+
+ mma_A A;
+ float dA[mma_C::ne/2];
+
+ const int i0 = threadIdx.y*mma_A::I;
+ static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+
+#pragma unroll
+ for (int l = 0; l < mma_A::ne; ++l) {
+ const int i = i0 + mma_A::get_i(l);
+ const int k = 2*(k0 + mma_A::get_k(l) % QI5_0) + mma_A::get_k(l) / QI5_0;
+
+ A.x[l] = x_ql[i*(2*WARP_SIZE + 1) + k];
+ }
+#pragma unroll
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int i = i0 + mma_C::get_i(2*l);
+
+ dA[l] = x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0];
+ }
+
+ for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
+ mma_C C;
+ mma_B B;
+ float dB[mma_C::ne/2];
+
+#pragma unroll
+ for (int l = 0; l < mma_B::ne; ++l) {
+ const int j = j0 + mma_B::get_j(l);
+ const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
+
+ B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
+ }
+#pragma unroll
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int j = j0 + mma_C::get_j(l);
+
+ dB[l] = y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
+ }
+
+ C.mma_K8(A, B);
+
+#pragma unroll
+ for (int l = 0; l < mma_C::ne; ++l) {
+ sum[(j0/B.J)*C.ne + l] += dA[l/2]*dB[l%2]*C.x[l];
+ }
+ }
+}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
}
template <int mmq_x, int mmq_y, int nwarps>
-static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mul_mat(
+static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
}
}
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma(
+ const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+ const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+
+ GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+
+ typedef mma_int_A_I16K8 mma_A;
+ typedef mma_int_B_J8K8 mma_B;
+ typedef mma_int_C_I16J8 mma_C;
+
+ const int * y_qs = (const int *) y + 4;
+ const half2 * y_ds = (const half2 *) y;
+
+ mma_A A;
+ half2 dmA[mma_C::ne/2];
+
+ const int i0 = threadIdx.y*mma_A::I;
+ static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+
+#pragma unroll
+ for (int l = 0; l < mma_A::ne; ++l) {
+ const int i = i0 + mma_A::get_i(l);
+ const int k = 2*(k0 + mma_A::get_k(l) % QI5_1) + mma_A::get_k(l) / QI5_1;
+
+ A.x[l] = x_ql[i*(2*WARP_SIZE + 1) + k];
+ }
+#pragma unroll
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int i = i0 + mma_C::get_i(2*l);
+
+ dmA[l] = x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI5_1];
+ }
+
+ for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
+ mma_C C;
+ mma_B B;
+ half2 dsB[mma_C::ne/2];
+
+#pragma unroll
+ for (int l = 0; l < mma_B::ne; ++l) {
+ const int j = j0 + mma_B::get_j(l);
+ const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
+
+ B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
+ }
+#pragma unroll
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int j = j0 + mma_C::get_j(l);
+
+ dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
+ }
+
+ C.mma_K8(A, B);
+
+#pragma unroll
+ for (int l = 0; l < mma_C::ne; ++l) {
+ const half2 dmA_dsB = dmA[l/2]*dsB[l%2];
+ sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
+ }
+ }
+}
+
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
}
template <int mmq_x, int mmq_y, int nwarps>
-static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mul_mat(
+static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
}
}
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
+ const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+ const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+
+ GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+
+ typedef mma_int_A_I16K8 mma_A;
+ typedef mma_int_B_J8K8 mma_B;
+ typedef mma_int_C_I16J8 mma_C;
+
+ const float * x_df = (const float *) x_dm;
+ const int * y_qs = (const int *) y + 4;
+ const float * y_df = (const float *) y;
+
+ mma_A A;
+ float dA[mma_C::ne/2];
+
+ const int i0 = threadIdx.y*mma_A::I;
+ static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+
+#pragma unroll
+ for (int l = 0; l < mma_A::ne; ++l) {
+ const int i = i0 + mma_A::get_i(l);
+ const int k = k0 + mma_A::get_k(l);
+
+ A.x[l] = x_ql[i*(WARP_SIZE + 1) + k];
+ }
+#pragma unroll
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int i = i0 + mma_C::get_i(2*l);
+
+ dA[l] = x_df[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0];
+ }
+
+ for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
+ mma_C C;
+ mma_B B;
+ float dB[mma_C::ne/2];
+
+#pragma unroll
+ for (int l = 0; l < mma_B::ne; ++l) {
+ const int j = j0 + mma_B::get_j(l);
+ const int k = k0 + mma_B::get_k(l);
+
+ B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
+ }
+#pragma unroll
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int j = j0 + mma_C::get_j(l);
+
+ dB[l] = y_df[j*MMQ_TILE_Y_K + k0/QI8_1];
+ }
+
+ C.mma_K8(A, B);
+
+#pragma unroll
+ for (int l = 0; l < mma_C::ne; ++l) {
+ sum[(j0/B.J)*C.ne + l] += C.x[l]*dA[l/2]*dB[l%2];
+ }
+ }
+}
+
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
}
}
+template<int mmq_x, int mmq_y, int nwarps, bool need_check>
+static __device__ __forceinline__ void mmq_write_back_dp4a(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) {
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = blockIdx.y*mmq_x + j0 + threadIdx.y;
+
+ if (j >= ne1) {
+ return;
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+ const int i = blockIdx.x*mmq_y + i0 + threadIdx.x;
+
+ if (need_check && i >= ne0) {
+ continue;
+ }
+
+ dst[j*ne0 + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
+ }
+ }
+}
+
+template<int mmq_x, int mmq_y, int nwarps, bool need_check>
+static __device__ __forceinline__ void mmq_write_back_mma(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) {
+ typedef mma_int_C_I16J8 mma_C;
+
+ const int i0 = threadIdx.y*mma_C::I;
+ static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += mma_C::J) {
+#pragma unroll
+ for (int l = 0; l < mma_C::ne; ++l) {
+ const int j = blockIdx.y*mmq_x + j0 + mma_C::get_j(l);
+
+ if (j >= ne1) {
+ continue;
+ }
+
+ const int i = blockIdx.x*mmq_y + i0 + mma_C::get_i(l);
+
+ if (need_check && i >= ne0) {
+ continue;
+ }
+
+ dst[j*ne0 + i] = sum[(j0/mma_C::J)*mma_C::ne + l];
+ }
+ }
+}
+
// -------------------------------------------------------------------------------------------------------------------------------------
template <int mmq_x, int mmq_y, int nwarps, bool need_check, ggml_type type>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, nwarps, need_check>;
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
+#ifdef INT8_MMA_AVAILABLE
+ static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
+ static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
+#else
+ static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+ static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
+#endif // INT8_MMA_AVAILABLE
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, nwarps, need_check>;
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
+#ifdef INT8_MMA_AVAILABLE
+ static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
+ static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
+#else
+ static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+ static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
+#endif // INT8_MMA_AVAILABLE
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
+#ifdef INT8_MMA_AVAILABLE
+ static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
+ static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
+#else
+ static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+ static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
+#endif // INT8_MMA_AVAILABLE
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>;
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
+#ifdef INT8_MMA_AVAILABLE
+ static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
+ static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
+#else
+ static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+ static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
+#endif // INT8_MMA_AVAILABLE
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>;
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
+#ifdef INT8_MMA_AVAILABLE
+ static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
+ static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
+#else
+ static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+ static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
+#endif // INT8_MMA_AVAILABLE
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q2_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
+ static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q3_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
+ static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
+ static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
+ static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q6_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
+ static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
};
static int mmq_need_sum(const ggml_type type_x) {
constexpr int vdr = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot;
+ constexpr mmq_write_back_t write_back = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::write_back;
constexpr tile_x_sizes txs = get_tile_x_sizes_device<mmq_y>(type);
const int * y = (const int *) yc + blockIdx.y*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
- float sum[(mmq_x/nwarps) * (mmq_y/WARP_SIZE)] = {0.0f};
+ float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
for (int kb0 = 0; kb0 < blocks_per_row_x; kb0 += blocks_per_warp) {
}
}
-#pragma unroll
- for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
- const int j = blockIdx.y*mmq_x + j0 + threadIdx.y;
-
- if (j >= ne1) {
- return;
- }
-
-#pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
- const int i = blockIdx.x*mmq_y + i0 + threadIdx.x;
-
- if (need_check && i >= ne0) {
- continue;
- }
-
- dst[j*ne0 + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
- }
- }
+ write_back(sum, dst, ne0, ne1);
}
struct mmq_args {
launch_mul_mat_q<type, 8, 4>(args, stream);
break;
case 16:
- launch_mul_mat_q<type, 16, 8>(args, stream);
+ launch_mul_mat_q<type, 16, 4>(args, stream);
break;
case 24:
- launch_mul_mat_q<type, 24, 8>(args, stream);
+ launch_mul_mat_q<type, 24, 4>(args, stream);
break;
case 32:
launch_mul_mat_q<type, 32, 8>(args, stream);