static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
-static void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
+static void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
assert(QK8_0 == 32);
assert(k % QK8_0 == 0);
const int nb = k / QK8_0;
#endif
}
-static void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
+static void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
assert(QK8_0 == 32);
assert(k % QK8_0 == 0);
const int nb = k / QK8_0;
#endif
}
-static void quantize_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
+static void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
assert(QK_K == 256);
assert(k % QK_K == 0);
const int nb = k / QK_K;
// i.e first four bsums from the first super block, followed by first four bsums from second super block and so on
for (int j = 0; j < QK_K * 4; j++) {
int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
- int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
+ int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
src_offset += (j % blck_size_interleave);
int index = (((j & 31) >> 3) << 2) + ((j >> 8) << 4) + ((j >> 6) & 3);
#endif
}
-static void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) {
+template <int64_t INTER_SIZE, ggml_type PARAM_TYPE>
+void ggml_quantize_mat_t(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row);
+
+template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
assert(nrow == 4);
UNUSED(nrow);
- if (blck_size_interleave == 4) {
- quantize_q8_0_4x4(x, vy, n_per_row);
- } else if (blck_size_interleave == 8) {
- quantize_q8_0_4x8(x, vy, n_per_row);
- } else {
- assert(false);
- }
+ ggml_quantize_mat_q8_0_4x4(x, vy, n_per_row);
}
-static void quantize_mat_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) {
+template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
assert(nrow == 4);
UNUSED(nrow);
- if (blck_size_interleave == 8) {
- quantize_q8_K_4x8(x, vy, n_per_row);
- } else {
- assert(false);
- }
+ ggml_quantize_mat_q8_0_4x8(x, vy, n_per_row);
+}
+
+template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
+ assert(nrow == 4);
+ UNUSED(nrow);
+ ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row);
}
static void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
//}
// gemv
-template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
+template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE>
void gemv(int, float *, size_t, const void *, const void *, int, int);
-template <> void gemv<block_q4_0, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+template <> void gemv<block_q4_0, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemv_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
}
-template <> void gemv<block_q4_0, 8, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+template <> void gemv<block_q4_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemv_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
}
-template <> void gemv<block_q4_0, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+template <> void gemv<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
}
-template <> void gemv<block_q4_K, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+template <> void gemv<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
}
-template <>
-void gemv<block_iq4_nl, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+template <> void gemv<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
}
// gemm
-template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
+template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE>
void gemm(int, float *, size_t, const void *, const void *, int, int);
-template <> void gemm<block_q4_0, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+template <> void gemm<block_q4_0, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
}
-template <> void gemm<block_q4_0, 8, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+template <> void gemm<block_q4_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
}
-template <> void gemm<block_q4_0, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+template <> void gemm<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
}
-template <> void gemm<block_q4_K, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+template <> void gemm<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
}
-template <>
-void gemm<block_iq4_nl, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+template <> void gemm<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
}
bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
// not realy a GGML_TYPE_Q8_0 but same size.
switch (op->op) {
- case GGML_OP_MUL_MAT:
- size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
- return true;
- case GGML_OP_MUL_MAT_ID:
- size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
- size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc.
- size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2];
- return true;
- default:
- // GGML_ABORT("fatal error");
- break;
+ case GGML_OP_MUL_MAT:
+ size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
+ return true;
+ case GGML_OP_MUL_MAT_ID:
+ size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
+ size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc.
+ size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2];
+ return true;
+ default:
+ // GGML_ABORT("fatal error");
+ break;
}
return false;
}
bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override {
switch (op->op) {
- case GGML_OP_MUL_MAT:
- forward_mul_mat(params, op);
- return true;
- case GGML_OP_MUL_MAT_ID:
- forward_mul_mat_id(params, op);
- return true;
- default:
- // GGML_ABORT("fatal error");
- break;
+ case GGML_OP_MUL_MAT:
+ forward_mul_mat(params, op);
+ return true;
+ case GGML_OP_MUL_MAT_ID:
+ forward_mul_mat_id(params, op);
+ return true;
+ default:
+ // GGML_ABORT("fatal error");
+ break;
}
return false;
}
const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float;
int64_t i11_processed = 0;
- if(PARAM_TYPE == GGML_TYPE_Q8_K) {
- for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
- quantize_mat_q8_K((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10,
- INTER_SIZE);
- }
- } else {
- for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
- quantize_mat_q8_0((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10,
- INTER_SIZE);
- }
+ for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
+ ggml_quantize_mat_t<INTER_SIZE, PARAM_TYPE>((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10);
}
+
i11_processed = ne11 - ne11 % 4;
for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
int64_t src0_start = (ith * ne01) / nth;
int64_t src0_end = ((ith + 1) * ne01) / nth;
src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
- src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
+ src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
if (src0_start >= src0_end) {
return;
}
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
if (ne11 > 3) {
- gemm<BLOC_TYPE, INTER_SIZE, NB_COLS>(ne00, (float *) ((char *) dst->data) + src0_start, ne01,
- (const char *) src0->data + src0_start * nb01,
- (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
+ gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
+ (float *) ((char *) dst->data) + src0_start, ne01,
+ (const char *) src0->data + src0_start * nb01,
+ (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
}
for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
- gemv<BLOC_TYPE, INTER_SIZE, NB_COLS>(ne00, (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
- (const char *) src0->data + src0_start * nb01,
- (const char *) src1_wdata + (src1_col_stride * iter), 1,
- src0_end - src0_start);
+ gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
+ (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
+ (const char *) src0->data + src0_start * nb01,
+ (const char *) src1_wdata + (src1_col_stride * iter), 1,
+ src0_end - src0_start);
}
}
const int ith = params->ith;
const int nth = params->nth;
- const ggml_from_float_t from_float = ggml_get_type_traits_cpu(GGML_TYPE_Q8_0)->from_float;
+ const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float;
// we don't support permuted src0 or src1
GGML_ASSERT(nb00 == ggml_type_size(src0->type));
const int n_ids = ids->ne[0]; // n_expert_used
const int n_as = ne02; // n_expert
- const size_t nbw1 = ggml_row_size(GGML_TYPE_Q8_0, ne10);
+ const size_t nbw1 = ggml_row_size(PARAM_TYPE, ne10);
const size_t nbw2 = nbw1*ne11;
const size_t nbw3 = nbw2*ne12;
GGML_ASSERT(params->wsize >= (GGML_PAD(nbw3, sizeof(int64_t)) + n_as * sizeof(int64_t) +
n_as * ne12 * sizeof(mmid_row_mapping)));
- auto wdata = (char *) params->wdata;
- auto wdata_src1_end = (char *) wdata + GGML_PAD(nbw3, sizeof(int64_t));
- int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
+ auto * wdata = (char *) params->wdata;
+ auto * wdata_src1_end = (char *) wdata + GGML_PAD(nbw3, sizeof(int64_t));
+ auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
+
struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
- // src1: float32 => block_q8_0
+ // src1: float32 => param type
for (int64_t i12 = 0; i12 < ne12; ++i12) {
for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
from_float((float *)((char *) src1->data + i12 * nb12 + i11 * nb11),
continue;
}
- auto src0_cur = (const char *) src0->data + cur_a*nb02;
+ const auto * src0_cur = (const char *) src0->data + cur_a*nb02;
//const int64_t nr0 = ne01; // src0 rows
const int64_t nr1 = cne1; // src1 rows
int64_t src0_cur_start = (ith * ne01) / nth;
int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
- src0_cur_start =
- (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
- src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
- if (src0_cur_start >= src0_cur_end) return;
+ src0_cur_start = (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
+ src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
+
+ if (src0_cur_start >= src0_cur_end) {
+ return;
+ }
for (int ir1 = 0; ir1 < nr1; ir1++) {
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
- const int id = row_mapping.i1; // selected expert index
- const int64_t i11 = id % ne11;
- const int64_t i12 = row_mapping.i2; // row index in src1
+ const int id = row_mapping.i1; // selected expert index
+
+ const int64_t i11 = id % ne11;
+ const int64_t i12 = row_mapping.i2; // row index in src1
- const int64_t i1 = id; // selected expert index
- const int64_t i2 = i12; // row
+ const int64_t i1 = id; // selected expert index
+ const int64_t i2 = i12; // row
- auto src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2);
+ const auto * src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2);
- gemv<BLOC_TYPE, INTER_SIZE, NB_COLS>(
- ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start,
- ne01, src0_cur + src0_cur_start * nb01,
+ gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
+ (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01,
+ src0_cur + src0_cur_start * nb01,
src1_col, 1, src0_cur_end - src0_cur_start);
}
}
static const tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
// instance for IQ4
-static const tensor_traits<block_iq4_nl, 4, 4, GGML_TYPE_IQ4_NL> iq4_nl_4x4_q8_0;
+static const tensor_traits<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0> iq4_nl_4x4_q8_0;
} // namespace ggml::cpu::aarch64