// not realy a GGML_TYPE_Q8_0 but same size.
switch (op->op) {
case GGML_OP_MUL_MAT:
- size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
- return true;
+ {
+ size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
+ return true;
+ }
case GGML_OP_MUL_MAT_ID:
- size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
- size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc.
- size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2];
- return true;
+ {
+ size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
+ size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc.
+
+ const int64_t ne02 = op->src[0]->ne[2]; // n_as, n_expert
+ const int64_t ne12 = op->src[1]->ne[2]; // n_tokens
+
+ const size_t sizeof_mmid_row_mapping = sizeof(int64_t);
+
+ size += sizeof_mmid_row_mapping*ne02*(ne12 + 1);
+
+ return true;
+ }
default:
// GGML_ABORT("fatal error");
break;
int32_t i2;
};
- GGML_ASSERT(params->wsize >= (GGML_PAD(nbw3, sizeof(int64_t)) + n_as * sizeof(int64_t) +
- n_as * ne12 * sizeof(mmid_row_mapping)));
+ GGML_ASSERT(params->wsize >=
+ (GGML_PAD(nbw3, sizeof(int64_t)) +
+ n_as*(ne12 + 1)*sizeof(mmid_row_mapping))
+ );
- auto * wdata = (char *) params->wdata;
- auto * wdata_src1_end = (char *) wdata + GGML_PAD(nbw3, sizeof(int64_t));
- auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
+ auto * wdata = (char *)params->wdata;
+ auto * wdata_src1_end = (char *)wdata + GGML_PAD(nbw3, sizeof(int64_t));
- struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
+ // total of [n_as][ne12 + 1] elemets of type mmid_row_mapping (2*int32_t = int64_t)
+ auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
+ struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
// src1: float32 => param type
for (int64_t i12 = 0; i12 < ne12; ++i12) {