}
/**
- * @brief Performs expert-specific matrix multiplication (MoE) with
- * quantized precision using the CANN backend.
- *
- * This function executes a matrix multiplication operation tailored for
- * Mixture of Experts (MoE) models, where the input tensor is multiplied
- * with expert-specific quantized weight matrices. It leverages the CANN
- * backend to perform efficient low-precision computations and stores the
- * quantized result in the destination tensor `dst`.
- *
- * Quantization techniques reduce memory footprint and improve performance
- * by using lower-bit representations (e.g., int8) instead of floating-point.
- * This function is designed to work with such formats and may incorporate
- * optimizations like identity-based fast paths or routing masks for sparse
- * expert selection.
- *
- * @param ctx The context for executing CANN backend operations.
- * @param dst The destination tensor where the quantized MoE multiplication result
- * will be stored.
- *
- * @note This function assumes quantized data types and is designed for
- * MoE architectures with potential sparse expert routing.
+ * @brief Performs quantized matrix multiplication for Mixture of Experts (MoE)
+ * models using the CANN backend.
+ *
+ * This function implements MUL_MAT_ID operation for quantized weight matrices
+ * (Q4_0 and Q8_0 formats). It selects expert-specific weight matrices based on
+ * the provided expert indices, and computes matrix multiplication using CANN's
+ * WeightQuantBatchMatmulV2 operator.
+ *
+ * The function performs the following steps:
+ * 1. Converts input/output tensors to F16 format if necessary
+ * 2. Uses IndexSelect to extract expert-specific weights and scales based on indices
+ * 3. Performs quantized matrix multiplication for each expert using WeightQuantBatchMatmulV2
+ * 4. Converts output back to the target type if needed
+ *
+ * Tensor shapes:
+ * - dst: [M, K, N, 1] - output tensor
+ * - src0: [D, M, A, 1] - quantized weight matrices (Q4_0 or Q8_0)
+ * - src1: [D, B, N, 1] - input activations (B = K for per-expert input, or B = 1 for broadcast)
+ * - ids: [K, N] - expert indices for routing
+ *
+ * @param ctx The CANN backend context for operation execution.
+ * @param dst The destination tensor where the multiplication result will be stored.
+ *
+ * @note Only Q4_0 and Q8_0 quantization formats are supported.
+ * @note The function handles automatic type conversion to/from F16 as needed by the hardware.
*/
static void ggml_cann_mul_mat_id_quant(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
- // TODO: Use aclnnGroupedMatMul
- //dst [M, K, N, 1]
- ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1]
- ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1
- ggml_tensor * ids = dst->src[2]; //ids [K, N]
+ // dst: [M, K, N, 1]
+ // src0: [D, M, A, 1] - quantized weights
+ // src1: [D, B, N, 1] - input activations, B = K or B = 1
+ // ids: [K, N] - expert indices
+ ggml_tensor * src0 = dst->src[0];
+ ggml_tensor * src1 = dst->src[1];
+ ggml_tensor * ids = dst->src[2];
- GGML_TENSOR_BINARY_OP_LOCALS
+ GGML_ASSERT(src0->ne[3] == 1);
+ GGML_ASSERT(src1->ne[3] == 1);
+ GGML_ASSERT(dst->ne[3] == 1);
+ GGML_ASSERT(src1->ne[2] == ids->ne[1]);
+
+ const int64_t n_batches = ids->ne[1];
+ const int64_t n_select_experts = ids->ne[0];
+ const enum ggml_type type = src0->type;
+
+ const int32_t group_size = QK8_0; // Both Q4_0 and Q8_0 use group size of 32
+ GGML_ASSERT(group_size == QK4_0);
+
+ // Calculate element size for quantized weights
+ const float weight_elem_size =
+ (type == GGML_TYPE_Q4_0) ? 0.5f :
+ (type == GGML_TYPE_Q8_0) ? 1.0f :
+ (GGML_ABORT("MUL_MAT_ID only supports Q4_0 and Q8_0"), 0.0f);
+
+ // Calculate scale offset in memory
+ const size_t weight_size = src0->ne[0] * src0->ne[1] * src0->ne[2] * weight_elem_size;
+ const size_t scale_elem_size = sizeof(uint16_t);
+ char * scale_data = (char *) src0->data + weight_size;
+
+ // Allocate buffers for selected expert weights and scales
+ const size_t selected_weight_size = src0->ne[0] * src0->ne[1] * n_select_experts * weight_elem_size;
+ ggml_cann_pool_alloc selected_weight_alloc(ctx.pool(), selected_weight_size);
+ void * selected_weight_buffer = selected_weight_alloc.get();
+
+ const size_t selected_scale_size = (src0->ne[0] / group_size) * src0->ne[1] * n_select_experts * scale_elem_size;
+ ggml_cann_pool_alloc selected_scale_alloc(ctx.pool(), selected_scale_size);
+ void * selected_scale_buffer = selected_scale_alloc.get();
+
+ // Helper lambda to allocate and cast tensor to F16 if needed
+ constexpr size_t f16_elem_size = sizeof(uint16_t);
+ auto prepare_f16_buffer = [&](ggml_tensor * tensor, ggml_cann_pool_alloc & allocator,
+ bool need_cast = false) -> void * {
+ if (tensor->type == GGML_TYPE_F16) {
+ return tensor->data;
+ }
- // copy index from npu to cpu
- int64_t n_as = ne02; // A
- int64_t n_ids = ids->ne[0]; // K
+ size_t total_size = f16_elem_size;
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
+ total_size *= tensor->ne[i];
+ }
+ void * buffer = allocator.alloc(total_size);
- std::vector<char> ids_host(ggml_nbytes(ids));
- ACL_CHECK(aclrtMemcpyAsync(ids_host.data(), ggml_nbytes(ids), ids->data, ggml_nbytes(ids),
- ACL_MEMCPY_DEVICE_TO_HOST, ctx.stream()));
- ACL_CHECK(aclrtSynchronizeStream(ctx.stream()));
+ if (need_cast == false) {
+ return buffer;
+ }
- char * src0_original = (char *) src0->data;
- char * src1_original = (char *) src1->data;
- char * dst_original = (char *) dst->data;
+ int64_t ne[GGML_MAX_DIMS];
+ size_t nb[GGML_MAX_DIMS] = { f16_elem_size };
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
+ ne[i] = tensor->ne[i];
+ if (i > 0) {
+ nb[i] = nb[i - 1] * ne[i - 1];
+ }
+ }
- ggml_tensor src0_row = *src0;
- ggml_tensor src1_row = *src1;
- ggml_tensor dst_row = *dst;
+ acl_tensor_ptr src_tensor = ggml_cann_create_tensor(tensor);
+ acl_tensor_ptr f16_tensor = ggml_cann_create_tensor(buffer, ACL_FLOAT16, f16_elem_size, ne, nb, GGML_MAX_DIMS);
+ aclnn_cast(ctx, src_tensor.get(), f16_tensor.get(), ACL_FLOAT16);
- const enum ggml_type type = dst->src[0]->type;
- float weight_elem_size;
- if (type == GGML_TYPE_Q4_0) {
- weight_elem_size = float(sizeof(uint8_t)) / 2;
- } else if (type == GGML_TYPE_Q8_0) {
- weight_elem_size = float(sizeof(uint8_t));
- } else {
- GGML_ABORT("MUL_MAT_ID only support quant type Q4_0 and Q8_0 ");
- }
+ return buffer;
+ };
- // src0_row [D, M, 1, 1] weight without permute
- src0_row.ne[2] = 1;
- src0_row.ne[3] = 1;
- src0_row.nb[0] = weight_elem_size;
- src0_row.nb[1] = weight_elem_size * ne00;
- src0_row.nb[2] = weight_elem_size * ne00;
- src0_row.nb[3] = weight_elem_size * ne00;
- size_t weight_stride = ne00 * ne01 * weight_elem_size;
- size_t weight_size = weight_stride * ne02 * ne03;
+ // Prepare input and output buffers
+ ggml_cann_pool_alloc input_alloc(ctx.pool());
+ void * input_buffer = prepare_f16_buffer(src1, input_alloc, true);
- // scale [D, M, 1, 1] -> scale && permute
- size_t scale_elem_size = sizeof(uint16_t);
- size_t scale_stride = src0->ne[1] * src0->ne[0] / QK8_0 * scale_elem_size;
+ ggml_cann_pool_alloc output_alloc(ctx.pool());
+ void * output_buffer = prepare_f16_buffer(dst, output_alloc, false);
+
+ // Process each batch
+ for (int64_t batch_idx = 0; batch_idx < n_batches; batch_idx++) {
+ // Create index tensor for current batch
+ const size_t index_offset = batch_idx * ids->nb[1];
+ acl_tensor_ptr batch_indices = ggml_cann_create_tensor(ids, ids->ne, ids->nb, 1, ACL_FORMAT_ND, index_offset);
+
+ // Select quantized weights using expert indices
+ // Q4_0 stores 2 values per byte, Q8_0 stores 1 value per byte
+ const int64_t weight_d = (type == GGML_TYPE_Q4_0) ? src0->ne[0] / 2 : src0->ne[0];
+ const int64_t weight_m = src0->ne[1];
+ const int64_t weight_n_experts = src0->ne[2];
+
+ int64_t weight_ne[3] = { weight_d, weight_m, weight_n_experts };
+ size_t weight_nb[3] = { sizeof(int8_t), weight_d * sizeof(int8_t), weight_d * weight_m * sizeof(int8_t) };
+
+ acl_tensor_ptr all_weights =
+ ggml_cann_create_tensor(src0->data, ACL_INT8, sizeof(int8_t), weight_ne, weight_nb, 3);
+
+ int64_t selected_weight_ne[3] = { weight_d, weight_m, n_select_experts };
+ size_t selected_weight_nb[3] = { sizeof(int8_t), weight_d * sizeof(int8_t),
+ weight_d * weight_m * sizeof(int8_t) };
+
+ acl_tensor_ptr selected_weights = ggml_cann_create_tensor(selected_weight_buffer, ACL_INT8, sizeof(int8_t),
+ selected_weight_ne, selected_weight_nb, 3);
+
+ GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, all_weights.get(), 0, batch_indices.get(), selected_weights.get());
+
+ // Select scales using the same expert indices
+ const int64_t scale_d = src0->ne[0] / group_size;
+ int64_t scale_ne[3] = { scale_d, weight_m, weight_n_experts };
+ size_t scale_nb[3] = { scale_elem_size, scale_d * scale_elem_size, scale_d * weight_m * scale_elem_size };
+
+ acl_tensor_ptr all_scales =
+ ggml_cann_create_tensor(scale_data, ACL_FLOAT16, scale_elem_size, scale_ne, scale_nb, 3);
+
+ int64_t selected_scale_ne[3] = { scale_d, weight_m, n_select_experts };
+ size_t selected_scale_nb[3] = { scale_elem_size, scale_d * scale_elem_size,
+ scale_d * weight_m * scale_elem_size };
+
+ acl_tensor_ptr selected_scales = ggml_cann_create_tensor(selected_scale_buffer, ACL_FLOAT16, scale_elem_size,
+ selected_scale_ne, selected_scale_nb, 3);
+
+ GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, all_scales.get(), 0, batch_indices.get(), selected_scales.get());
+
+ // Process each expert for current batch
+ // IndexSelect output layout: [D, M, K] in contiguous format
+ // WeightQuantBatchMatmulV2 expects: [M, D] with row-major stride
+ for (int64_t expert_idx = 0; expert_idx < n_select_experts; expert_idx++) {
+ // Determine input offset: broadcast if src1->ne[1]==1, otherwise use per-expert input
+ const size_t input_offset =
+ (batch_idx * src1->ne[1] + (src1->ne[1] == 1 ? 0 : expert_idx)) * src1->ne[0] * f16_elem_size;
+ const size_t output_offset = (batch_idx * dst->ne[1] + expert_idx) * dst->ne[0] * f16_elem_size;
+
+ // Create weight view for current expert: [D, M, K] -> [M, D]
+ int64_t weight_view_ne[2] = { weight_m, src0->ne[0] };
+ float weight_view_nb[2] = { src0->ne[0] * weight_elem_size, weight_elem_size };
+ const size_t weight_view_offset = expert_idx * selected_weight_nb[2];
- // src1_row [D, 1, 1, 1] -> input
- src1_row.ne[1] = 1;
- src1_row.ne[2] = 1;
- src1_row.ne[3] = 1;
- src1_row.nb[2] = nb11;
- src1_row.nb[3] = nb11;
-
- // dst_row [M, 1, 1, 1] -> out
- dst_row.ne[1] = 1;
- dst_row.ne[2] = 1;
- dst_row.ne[3] = 1;
- dst_row.nb[2] = nb1;
- dst_row.nb[3] = nb1;
-
- //create weight for one row
- ggml_cann_pool_alloc weight_allocator(ctx.pool());
- void * weight_buffer = weight_allocator.alloc(nb02);
- for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
- for (int64_t id = 0; id < n_ids; id++) {
- // expert index
- int32_t i02 = *(int32_t *) (ids_host.data() + iid1 * ids->nb[1] + id * ids->nb[0]);
- GGML_ASSERT(i02 >= 0 && i02 < n_as);
-
- // If B = 1 (broadcast), always use 0; otherwise, use id.
- int64_t i11 = (ne11 == 1 ? 0 : id);
- int64_t i12 = iid1;
-
- int64_t i1 = id;
- int64_t i2 = i12;
-
- void * src0_tmp_ptr = src0_original + i02 * weight_stride;
- void * scale_tmp_ptr = src0_original + weight_size + i02 * scale_stride;
- void * src1_tmp_ptr = src1_original + i11 * nb11 + i12 * nb12;
- void * dst_tmp_ptr = dst_original + i1 * nb1 + i2 * nb2;
-
- // mem cpy
- ACL_CHECK(aclrtMemcpyAsync(weight_buffer, weight_stride, src0_tmp_ptr, weight_stride,
- ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
- void * scale_buffer = (char *) weight_buffer + weight_stride;
- ACL_CHECK(aclrtMemcpyAsync(scale_buffer, scale_stride, scale_tmp_ptr, scale_stride,
- ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
-
- src0_row.data = weight_buffer;
- src1_row.data = src1_tmp_ptr;
- dst_row.data = dst_tmp_ptr;
- dst_row.src[0] = &src0_row;
- dst_row.src[1] = &src1_row;
-
- ggml_cann_mul_mat(ctx, &dst_row);
+ acl_tensor_ptr weight_view =
+ ggml_cann_create_tensor(selected_weight_buffer, ggml_cann_type_mapping(type), weight_elem_size,
+ weight_view_ne, weight_view_nb, 2, ACL_FORMAT_ND, weight_view_offset);
+
+ // Create scale view for current expert: [D, M, K] -> [M, D]
+ int64_t scale_view_ne[2] = { weight_m, scale_d };
+ size_t scale_view_nb[2] = { selected_scale_nb[1], selected_scale_nb[0] };
+ const size_t scale_view_offset = expert_idx * selected_scale_nb[2];
+
+ acl_tensor_ptr scale_view =
+ ggml_cann_create_tensor(selected_scale_buffer, ACL_FLOAT16, scale_elem_size, scale_view_ne,
+ scale_view_nb, 2, ACL_FORMAT_ND, scale_view_offset);
+
+ // Create input activation tensor [D, 1]
+ int64_t input_ne[2] = { src1->ne[0], 1 };
+ size_t input_nb[2] = { f16_elem_size, src1->ne[0] * f16_elem_size };
+
+ acl_tensor_ptr input_tensor = ggml_cann_create_tensor(input_buffer, ACL_FLOAT16, f16_elem_size, input_ne,
+ input_nb, 2, ACL_FORMAT_ND, input_offset);
+
+ // Create output tensor [M, 1]
+ int64_t output_ne[2] = { dst->ne[0], 1 };
+ size_t output_nb[2] = { f16_elem_size, dst->ne[0] * f16_elem_size };
+
+ acl_tensor_ptr output_tensor = ggml_cann_create_tensor(output_buffer, ACL_FLOAT16, f16_elem_size, output_ne,
+ output_nb, 2, ACL_FORMAT_ND, output_offset);
+
+ // Perform quantized matrix multiplication
+ GGML_CANN_CALL_ACLNN_OP(ctx, WeightQuantBatchMatmulV2, input_tensor.get(), weight_view.get(),
+ scale_view.get(), nullptr, nullptr, nullptr, nullptr, group_size,
+ output_tensor.get());
}
}
- return;
+
+ // Cast output back to original type if we used a temporary F16 buffer
+ if (dst->type != GGML_TYPE_F16) {
+ int64_t ne[GGML_MAX_DIMS];
+ size_t nb[GGML_MAX_DIMS] = { f16_elem_size };
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
+ ne[i] = dst->ne[i];
+ if (i > 0) {
+ nb[i] = nb[i - 1] * ne[i - 1];
+ }
+ }
+
+ acl_tensor_ptr f16_output =
+ ggml_cann_create_tensor(output_buffer, ACL_FLOAT16, f16_elem_size, ne, nb, GGML_MAX_DIMS);
+ acl_tensor_ptr dst_tensor = ggml_cann_create_tensor(dst);
+
+ aclnn_cast(ctx, f16_output.get(), dst_tensor.get(), ggml_cann_type_mapping(dst->type));
+ }
}
void ggml_cann_mul_mat_id(ggml_backend_cann_context & ctx, ggml_tensor * dst) {