* kleidiai: add data type check to get_tensor_traits
* Added check for F16 data type into get_tensor_traits path with input data
not in ggml_backend_cpu_kleidiai_buffer_type format (unsupported for Q4/8)
Signed-off-by: Martin Klacer <redacted>
Change-Id: I9aca4b9b8d669d35db6f1dbcc4e080b1919b1de7
* updated ggml/src/ggml-cpu/kleidiai/kleidiai.cpp
updated kleidiai.cpp file as per suggestion
Co-authored-by: Georgi Gerganov <redacted>
---------
Signed-off-by: Martin Klacer <redacted>
Co-authored-by: Georgi Gerganov <redacted>
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
} else {
+ if (op->src[0]->type != GGML_TYPE_F16) {
+ return nullptr;
+ }
std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
const int slot_total = kleidiai_collect_kernel_chain(op, kernel_chain);
- const bool has_kernel = slot_total > 0;
- if (has_kernel && op->src[1]->ne[1] > 1) {
+ if (slot_total > 0 && op->src[1]->ne[1] > 1) {
if ((op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
(op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
return nullptr;