return tensor->ne[dim];
}
+template <typename Variant, typename Ret, typename... Args, std::size_t... Is>
+constexpr bool variant_any_invocable_impl(std::index_sequence<Is...>) {
+ using V = std::remove_reference_t<Variant>;
+ return (std::is_invocable_r_v<
+ Ret,
+ std::variant_alternative_t<Is, V>,
+ Args...> || ...);
+}
+
+template <typename Variant, typename Ret, typename... Args>
+constexpr bool variant_any_invocable_v =
+ variant_any_invocable_impl<Variant, Ret, Args...>(
+ std::make_index_sequence<
+ std::variant_size_v<std::remove_reference_t<Variant>>>{});
+
template<typename Ret, typename Variant, typename... Args>
-static Ret variant_call(const Variant & var, Args&&... args) {
- return std::visit([&](auto&& func) -> Ret {
- if constexpr (std::is_invocable_r_v<Ret, decltype(func), Args...>) {
- return func(std::forward<Args>(args)...);
- } else {
- throw std::runtime_error("Invalid function type in variant_call");
- }
- }, var);
+static inline Ret variant_call(Variant && var, Args&&... args) {
+ static_assert(variant_any_invocable_v<std::remove_reference_t<Variant>, Ret, Args...>,
+ "No alternative in Variant is invocable with the provided arguments and return type.");
+
+ return std::visit(
+ [&](auto && f) -> Ret {
+ using F = std::decay_t<decltype(f)>;
+ if constexpr (std::is_invocable_r_v<Ret, F, Args...>) {
+ return std::invoke(std::forward<decltype(f)>(f), std::forward<Args>(args)...);
+ } else {
+ GGML_ABORT("Invalid function type in variant_call");
+ GGML_UNREACHABLE();
+ }
+ },
+ std::forward<Variant>(var)
+ );
}
namespace ggml::cpu::kleidiai {
if (kernels->rhs_type == GGML_TYPE_Q4_0) {
size = variant_call<size_t>(lhs_info->packed_size, m, k, QK4_0, mr, kr, sr);
} else if (kernels->rhs_type == GGML_TYPE_F16) {
- size = variant_call<size_t>(lhs_info->packed_size, m, k, mr, kr, sr) +
+ const int64_t lhs_batch_size0 = op->src[1]->ne[2];
+ const int64_t rhs_batch_size0 = op->src[0]->ne[2];
+ const int64_t r = lhs_batch_size0 / rhs_batch_size0;
+ size = variant_call<size_t>(lhs_info->packed_size, m * r, k, mr, kr, sr) +
variant_call<size_t>(kernels->rhs_info.packed_size, n, k) +
k * n * sizeof(float) + n * sizeof(float);
} else {
return true;
}
-
bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override {
if (dst->op == GGML_OP_MUL_MAT) {
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
}
bool compute_forward_fp16(ggml_compute_params * params, struct ggml_tensor * dst) {
- static std::atomic_flag first_to_arrive = ATOMIC_FLAG_INIT;
-
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
GGML_ASSERT(kernels);
- bool is_gemv = src1->ne[1] == 1;
+ const bool is_gemv = src1->ne[1] == 1;
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
GGML_ASSERT(kernel);
const int64_t lhs_batch_size0 = ne12;
const int64_t rhs_batch_size0 = ne02;
- const int64_t batch_size = rhs_batch_size0;
+ const int64_t batch_size = lhs_batch_size0;
+ GGML_ASSERT(rhs_batch_size0 > 0);
+ GGML_ASSERT(lhs_batch_size0 % rhs_batch_size0 == 0);
const int64_t r = lhs_batch_size0 / rhs_batch_size0;
- const int64_t m = ne11 * r;
- const int64_t n = ne01;
- const int64_t k = ne00;
+ const int64_t m_group = ne11;
+ const int64_t m = m_group;
+ const int64_t n = ne01;
+ const int64_t k = ne00;
const size_t lhs_stride = src1->nb[1];
const size_t rhs_stride = src0->nb[1];
const size_t dst_stride = dst->nb[1];
- const int64_t mr = static_cast<int64_t>(kernel->get_mr());
- const int64_t nr = static_cast<int64_t>(kernel->get_nr());
- const int64_t kr = static_cast<int64_t>(kernel->get_kr());
- const int64_t sr = static_cast<int64_t>(kernel->get_sr());
+ const int64_t mr = (int64_t) kernel->get_mr();
+ const int64_t nr = (int64_t) kernel->get_nr();
+ const int64_t kr = (int64_t) kernel->get_kr();
+ const int64_t sr = (int64_t) kernel->get_sr();
- const size_t lhs_packed_size = variant_call<size_t>(lhs_info->packed_size, m, k, mr, kr, sr);
- const size_t rhs_packed_size = variant_call<size_t>(kernels->rhs_info.packed_size, n, k);
- const size_t kxn_size = k * n * sizeof(float);
- const size_t bias_size = n * sizeof(float);
+ const size_t lhs_packed_size = variant_call<size_t>(lhs_info->packed_size, (size_t)m, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
+ const size_t rhs_packed_size = variant_call<size_t>(kernels->rhs_info.packed_size, (size_t)n, (size_t)k);
+ const size_t kxn_size = (size_t)k * (size_t)n * sizeof(float);
+ const size_t bias_size = (size_t)n * sizeof(float);
const size_t wsize_required = lhs_packed_size + rhs_packed_size + kxn_size + bias_size;
GGML_ASSERT(wsize_required <= params->wsize);
uint8_t * bias = rhs_kxn + kxn_size;
for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
- const uint8_t * lhs_batch = static_cast<const uint8_t *>(src1->data) + batch_idx * m * lhs_stride;
- const uint8_t * rhs_batch = static_cast<const uint8_t *>(src0->data) + batch_idx * n * rhs_stride;
- uint8_t * dst_batch = static_cast<uint8_t *>(dst->data) + batch_idx * m * dst_stride;
+ const int64_t rhs_batch_idx = batch_idx / r;
+ const uint8_t * rhs_batch_base = static_cast<const uint8_t *>(src0->data) + rhs_batch_idx * src0->nb[2];
+ uint8_t * dst_batch_base = static_cast<uint8_t *>(dst->data) + batch_idx * dst->nb[2];
- // LHS packing
+ // LHS packing (threaded over m, honoring mr alignment and KV groups)
{
const int64_t m_roundup_mr = kai_roundup(m, mr);
const int64_t num_threads = KAI_MIN(m_roundup_mr / mr, nth);
if (ith < num_threads) {
- const int64_t num_m_per_thread0 = round_down(m_roundup_mr / num_threads, mr);
+ const int64_t num_m_per_thread0 = round_down((size_t)(m_roundup_mr / num_threads), (size_t)mr);
const int64_t num_m_per_threadN_1 = m - (num_threads - 1) * num_m_per_thread0;
- const int64_t m_start = ith * num_m_per_thread0;
- const int64_t num_m_per_thread = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
+ const int64_t m_start = ith * num_m_per_thread0;
+ const int64_t m_count = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
+
+ // Base packed offset (aligned) and per-row stride in bytes
+ const size_t base_packed_off = variant_call<size_t>(
+ lhs_info->get_packed_offset, (size_t)m_start, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
+ const size_t next_block_off = variant_call<size_t>(
+ lhs_info->get_packed_offset, (size_t)(m_start + mr), (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
+ const size_t row_stride_bytes = (next_block_off - base_packed_off) / (size_t)mr;
+
+ int64_t remaining = m_count;
+ int64_t cur = m_start;
+
+ while (remaining > 0) {
+ const int64_t row_in_group = cur;
+ const int64_t avail = m_group - row_in_group;
+ const int64_t take = std::min(avail, remaining);
- const size_t lhs_offset = variant_call<size_t>(kernels->gemm.get_lhs_offset, m_start, lhs_stride);
- const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, m_start, k, mr, kr, sr);
+ const uint8_t * lhs_batch_base = static_cast<const uint8_t *>(src1->data) + batch_idx * src1->nb[2];
+ const void * src_ptr = lhs_batch_base + (size_t)row_in_group * lhs_stride;
+ const size_t dst_off = base_packed_off + (size_t)(cur - m_start) * row_stride_bytes;
+ void * dst_ptr = lhs_packed + dst_off;
- const void * src_ptr = static_cast<const uint8_t *>(lhs_batch) + lhs_offset;
- void * dst_ptr = static_cast<uint8_t *>(lhs_packed) + lhs_packed_offset;
+ variant_call<void>(lhs_info->pack_func,
+ (size_t)take, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr,
+ /*m_idx_start*/ 0, src_ptr, lhs_stride, dst_ptr);
- variant_call<void>(lhs_info->pack_func, num_m_per_thread, k, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr);
+ cur += take;
+ remaining -= take;
+ }
}
}
- // RHS packing
- if (first_to_arrive.test_and_set(std::memory_order_acquire) == false) {
- // First thread to reach this point handles RHS packing
- memset(bias, 0, n * sizeof(float));
- transpose_f32kxn_f16nxk(n, k, reinterpret_cast<float *>(rhs_kxn),
- reinterpret_cast<const uint16_t *>(rhs_batch), rhs_stride);
-
- variant_call<void>(kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, n * sizeof(float),
- rhs_kxn, bias, nullptr, rhs_packed, 0, nullptr);
+ // RHS packing (single thread), then synchronize
+ if (ith == 0) {
+ memset(bias, 0, (size_t)n * sizeof(float));
+ transpose_f32kxn_f16nxk((size_t)n, (size_t)k,
+ reinterpret_cast<float *>(rhs_kxn),
+ reinterpret_cast<const uint16_t *>(rhs_batch_base),
+ rhs_stride);
+
+ variant_call<void>(kernels->rhs_info.pack_func,
+ /*num_groups*/ 1, (size_t)n, (size_t)k, (size_t)nr, (size_t)kr, (size_t)sr,
+ /*rhs_stride (bytes)*/ (size_t)(n * sizeof(float)),
+ rhs_kxn, bias, nullptr, rhs_packed, /*extra_bytes*/ 0, /*params*/ nullptr);
}
ggml_barrier(params->threadpool);
- first_to_arrive.clear(std::memory_order_release);
-
- // Perform the matmul
+ // Matmul (threaded over n)
{
- const int64_t m_to_process = m;
- const int64_t m_start = 0;
-
- const int64_t n_step = static_cast<int64_t>(kernel->get_n_step());
- int64_t num_threads = KAI_MIN(n / n_step, nth);
- if (num_threads <= 0) {
- num_threads = 1;
+ const int64_t n_step = (int64_t) kernel->get_n_step();
+ int64_t num_threads_n = KAI_MIN(n / n_step, nth);
+ if (num_threads_n <= 0) {
+ num_threads_n = 1;
}
- if (ith < num_threads) {
- const int64_t num_n_per_thread0 = round_down(n / num_threads, n_step);
- const int64_t num_n_per_threadN_1 = n - (num_threads - 1) * num_n_per_thread0;
+ if (ith < num_threads_n) {
+ const int64_t num_n_per_thread0 = round_down((size_t)(n / num_threads_n), (size_t)n_step);
+ const int64_t num_n_per_threadN_1 = n - (num_threads_n - 1) * num_n_per_thread0;
const int64_t n_start = ith * num_n_per_thread0;
- const int64_t n_to_process = (ith == num_threads - 1) ? num_n_per_threadN_1 : num_n_per_thread0;
+ const int64_t n_to_process = (ith == num_threads_n - 1) ? num_n_per_threadN_1 : num_n_per_thread0;
- const size_t lhs_packed_offset = variant_call<size_t>(kernel->get_lhs_offset, m_start, k);
- const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, n_start, k);
- const size_t dst_offset = kernel->get_dst_offset(m_start, n_start, dst_stride);
+ // LHS packed base at row 0 (consistent with packing above)
+ const size_t lhs_packed_offset0 = variant_call<size_t>(
+ lhs_info->get_packed_offset, (size_t)0, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
+ const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, (size_t)n_start, (size_t)k);
+ const size_t dst_offset = kernel->get_dst_offset((size_t)0, (size_t)n_start, dst_stride);
- const void * lhs_ptr = lhs_packed + lhs_packed_offset;
+ const void * lhs_ptr = lhs_packed + lhs_packed_offset0;
const void * rhs_ptr = rhs_packed + rhs_packed_offset;
- float * dst_ptr = reinterpret_cast<float *>(dst_batch + dst_offset);
+ float * dst_ptr = reinterpret_cast<float *>(dst_batch_base + dst_offset);
- variant_call<void>(kernel->run_kernel, m_to_process, n_to_process, k, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
+ variant_call<void>(kernel->run_kernel,
+ (size_t)m, (size_t)n_to_process, (size_t)k,
+ lhs_ptr, rhs_ptr,
+ dst_ptr, dst_stride, sizeof(float),
+ -FLT_MAX, FLT_MAX);
}
}
if (batch_idx != batch_size - 1) {
- // This barrier is necessary when the batch size is larger than 1. While processing a batch,
- // the work data buffer (params->wdata) is used as temporary storage which means that only
- // a single batch can be processed at any given time. No barrier is needed for the last
- // batch since GGML inserts a barrier between the execution of every operator.
ggml_barrier(params->threadpool);
}
}