(char*)output_buffer + batch1 * output_stride, ACL_FLOAT16,
output_elem_size, output_ne, output_nb, 2, ACL_FORMAT_ND,
output_ne_offset);
+ int64_t antiquantGroupSize = 0;
+ if (src0->ne[0] > QK8_0) {
+ antiquantGroupSize = QK8_0;
+ }
ACL_CHECK(aclnnWeightQuantBatchMatmulV2GetWorkspaceSize(
acl_input_tensor, acl_weight_tensor, acl_scale_tensor, nullptr,
- nullptr, nullptr, nullptr, QK8_0, acl_output_tensor,
+ nullptr, nullptr, nullptr, antiquantGroupSize, acl_output_tensor,
&workspaceSize, &executor));
if (workspaceAddr == nullptr) {
workspaceAddr = workspace_allocator.alloc(workspaceSize);
ACL_CHECK(aclnnWeightQuantBatchMatmulV2GetWorkspaceSize(
acl_input_tensor, acl_weight_tensor, acl_scale_tensor,
- nullptr, nullptr, nullptr, nullptr, QK8_0,
+ nullptr, nullptr, nullptr, nullptr, antiquantGroupSize,
acl_output_tensor, &workspaceSize, &executor));
ACL_CHECK(aclnnWeightQuantBatchMatmulV2(
workspaceAddr, workspaceSize, executor, ctx.stream()));