]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
Add experimental ggml-hexagon backend for the Hexagon NPU (llama/16547)
authorMax Krasnyansky <redacted>
Wed, 22 Oct 2025 20:47:09 +0000 (13:47 -0700)
committerGeorgi Gerganov <redacted>
Sat, 1 Nov 2025 07:41:35 +0000 (09:41 +0200)
* model: add support for extra bufs for all devices

* hexagon: add experimental ggml-hexagon backend for the Hexagon NPU

This commit introduces a new experimental backend `ggml-hexagon` with support for the Hexagon NPU.

Highlights:
- Supports Hexagon versions: v73, v75, v79, and v81
- Targets Android devices based on Snapdragon SoCs: Gen3, 8-Elite, and 8-Elite Gen5
- Supports Q4_0, Q8_0, MXFP4, and FP32 data types
- Implements core LLM ops: MUL_MAT/MUL_MAT_ID, ADD/SUB/MUL/ADD_ID, RMS_NORM, ROPE, GLU/SWIGLU, SOFTMAX

**Note:** This backend is experimental and may exhibit instability or limited performance across supported devices.
It is intended for early testing and feedback from llama.cpp/ggml developer and user community.

Co-Authored-By: Rajdeep Ganguly <redacted>
Co-Authored-By: Todor Boinovski <redacted>
* hexagon: fix format checker errors

* hexagon: update readme and cmake presets

* ci: add android-ndk-build jobs that build plain ARM64 and Snapdragon versions

* hexagon: add simple graph optimizer for stacking MUL_MAT ops with the same input

* hexagon: move ADB helper scripts into scripts/snapdragon/adb

* hexagon: replace all f/printfs with GGML_LOG_...

* readme: add hexagon to the list supported backends

* hexagon: stack malmuts with quantized inputs only

* hexagon: add TODO for fixing issues in hexagon_graph_optimize

* hexagon: update to hex-sdk 6.4.0 and add scripts for running on QDC

* scripts: fix lint errors

* scripts: update qdc pytest script to make linter happy

* hexagon: add reduce sum in fp32

* hexagon: reduce number of vector stores in matmul output

* hexagon: remove the need for vdelta in reduce-multiply-x8

* hexagon: consistent use of reduce_sum_fp32 for row_sums

* hexagon: some more matmul optimizations and comments

Optimize cases where tensor dims are not multiple of 1024 (e.g in Qwen models).
We've handled those cases already but at a higher overhead.

* hexagon: update cmake presets

* hexagon: add OPMASK support for run-bench.sh wrapper

* hexagon: update to use GGML_BACKEND_API

* hexagon: remove unused logic for setting tensor flags for the views

* hexagon: add asserts to set/get_tensor to make sure we handle complete tensors

Same asserts as the CPU backend.

* hexagon: use cpy_tensor slow path for non-host buffers

* hexagon: error checks in the buffer allocator

* cmake: move include(extProj) under ggml-hexagon

* hexagon: don't forget to delete the backend on free

* hexagon: set/get_tensor size assert apply only to quantized tensors

* hexagon: reintroduce HEX_VERBOSE wrapper for GGML_LOG_DEBUG for now

GGML_LOG_DEBUG is always enabled for test-backend-ops and the output gets in the way.
Ideally we need a bit more finer log levels.

* docs: typos in hexagon developer docs (libggm-...)

* hexagon: overhaul error handling in the session/device allocation

this should handle all failure paths in the session allocation.

* hexagon: update cmake presets to enable fp16 vectors

* hexagon: remove unused time_usec function

* hexagon: don't forget to release buffer contexts

* hexagon: fixed indents in hvx-utils (missed clang-format auto-format failure)

* hexagon: remove custom can_repeat function and use ggml_can_repeat

---------

Co-authored-by: Rajdeep Ganguly <redacted>
Co-authored-by: Todor Boinovski <redacted>
31 files changed:
CMakeLists.txt
include/ggml-hexagon.h [new file with mode: 0644]
src/CMakeLists.txt
src/ggml-backend-reg.cpp
src/ggml-hexagon/CMakeLists.txt [new file with mode: 0644]
src/ggml-hexagon/ggml-hexagon.cpp [new file with mode: 0644]
src/ggml-hexagon/htp-utils.c [new file with mode: 0644]
src/ggml-hexagon/htp-utils.h [new file with mode: 0644]
src/ggml-hexagon/htp/CMakeLists.txt [new file with mode: 0644]
src/ggml-hexagon/htp/act-ops.c [new file with mode: 0644]
src/ggml-hexagon/htp/binary-ops.c [new file with mode: 0644]
src/ggml-hexagon/htp/cmake-toolchain.cmake [new file with mode: 0644]
src/ggml-hexagon/htp/htp-ctx.h [new file with mode: 0644]
src/ggml-hexagon/htp/htp-dma.c [new file with mode: 0644]
src/ggml-hexagon/htp/htp-dma.h [new file with mode: 0644]
src/ggml-hexagon/htp/htp-msg.h [new file with mode: 0644]
src/ggml-hexagon/htp/htp-ops.h [new file with mode: 0644]
src/ggml-hexagon/htp/htp_iface.idl [new file with mode: 0644]
src/ggml-hexagon/htp/hvx-exp.c [new file with mode: 0644]
src/ggml-hexagon/htp/hvx-inverse.c [new file with mode: 0644]
src/ggml-hexagon/htp/hvx-sigmoid.c [new file with mode: 0644]
src/ggml-hexagon/htp/hvx-utils.c [new file with mode: 0644]
src/ggml-hexagon/htp/hvx-utils.h [new file with mode: 0644]
src/ggml-hexagon/htp/main.c [new file with mode: 0644]
src/ggml-hexagon/htp/matmul-ops.c [new file with mode: 0644]
src/ggml-hexagon/htp/ops-utils.h [new file with mode: 0644]
src/ggml-hexagon/htp/rope-ops.c [new file with mode: 0644]
src/ggml-hexagon/htp/softmax-ops.c [new file with mode: 0644]
src/ggml-hexagon/htp/unary-ops.c [new file with mode: 0644]
src/ggml-hexagon/htp/worker-pool.c [new file with mode: 0644]
src/ggml-hexagon/htp/worker-pool.h [new file with mode: 0644]

index 73032be68e153e73d412acee2d43e645ced5bbbd..181f179ed171c924ae75ad80cdccd444d351a5cf 100644 (file)
@@ -251,6 +251,8 @@ option(GGML_OPENCL_USE_ADRENO_KERNELS       "ggml: use optimized kernels for Adr
 set   (GGML_OPENCL_TARGET_VERSION "300" CACHE STRING
                                             "gmml: OpenCL API version to target")
 
+option(GGML_HEXAGON                         "ggml: enable Hexagon backend"                    OFF)
+
 # toolchain for vulkan-shaders-gen
 set   (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN "" CACHE FILEPATH "ggml: toolchain file for vulkan-shaders-gen")
 
diff --git a/include/ggml-hexagon.h b/include/ggml-hexagon.h
new file mode 100644 (file)
index 0000000..6e07900
--- /dev/null
@@ -0,0 +1,19 @@
+#pragma once
+
+#include "ggml.h"
+#include "ggml-backend.h"
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+// backend API
+GGML_BACKEND_API ggml_backend_t ggml_backend_hexagon_init(void);
+
+GGML_BACKEND_API bool ggml_backend_is_hexagon(ggml_backend_t backend);
+
+GGML_BACKEND_API ggml_backend_reg_t ggml_backend_hexagon_reg(void);
+
+#ifdef  __cplusplus
+}
+#endif
index 3356ef550dec0d26f942eade83d4b15c5bf9ccf4..ba281b8e6d17aeb6d1b93707903729a80774ab00 100644 (file)
@@ -402,6 +402,7 @@ ggml_add_backend(Vulkan)
 ggml_add_backend(WebGPU)
 ggml_add_backend(zDNN)
 ggml_add_backend(OpenCL)
+ggml_add_backend(Hexagon)
 
 foreach (target ggml-base ggml)
     target_include_directories(${target} PUBLIC    $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/../include> $<INSTALL_INTERFACE:include>)
index 136afec748d9631c97a96807d3c5e9db0170731e..e96b5c403dd3f55483ad9cceb91bb48153e19193 100644 (file)
 #include "ggml-opencl.h"
 #endif
 
+#ifdef GGML_USE_HEXAGON
+#include "ggml-hexagon.h"
+#endif
+
 #ifdef GGML_USE_BLAS
 #include "ggml-blas.h"
 #endif
@@ -199,6 +203,9 @@ struct ggml_backend_registry {
 #ifdef GGML_USE_OPENCL
         register_backend(ggml_backend_opencl_reg());
 #endif
+#ifdef GGML_USE_HEXAGON
+        register_backend(ggml_backend_hexagon_reg());
+#endif
 #ifdef GGML_USE_CANN
         register_backend(ggml_backend_cann_reg());
 #endif
@@ -598,6 +605,7 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
     ggml_backend_load_best("sycl", silent, dir_path);
     ggml_backend_load_best("vulkan", silent, dir_path);
     ggml_backend_load_best("opencl", silent, dir_path);
+    ggml_backend_load_best("hexagon", silent, dir_path);
     ggml_backend_load_best("musa", silent, dir_path);
     ggml_backend_load_best("cpu", silent, dir_path);
     // check the environment variable GGML_BACKEND_PATH to load an out-of-tree backend
diff --git a/src/ggml-hexagon/CMakeLists.txt b/src/ggml-hexagon/CMakeLists.txt
new file mode 100644 (file)
index 0000000..166825c
--- /dev/null
@@ -0,0 +1,68 @@
+include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake)
+include(ExternalProject)
+
+option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF)
+
+add_library(htp_iface OBJECT
+    ${CMAKE_CURRENT_BINARY_DIR}/htp_iface_stub.c)
+
+set_target_properties(htp_iface PROPERTIES POSITION_INDEPENDENT_CODE ON)
+target_include_directories(htp_iface PUBLIC
+    ${HEXAGON_SDK_ROOT}/incs
+    ${HEXAGON_SDK_ROOT}/incs/stddef
+    ${HEXAGON_SDK_ROOT}/utils/examples
+    ${CMAKE_CURRENT_SOURCE_DIR}/htp
+    ${CMAKE_CURRENT_BINARY_DIR})
+
+build_idl(htp/htp_iface.idl htp_iface)
+
+if (CMAKE_SYSTEM_NAME MATCHES Android)
+    target_link_options(htp_iface PUBLIC -llog -ldl)
+elseif (CMAKE_SYSTEM_NAME MATCHES Windows)
+    target_precompile_headers(htp_iface PUBLIC <sal.h>)
+else()
+    target_link_options(htp_iface PUBLIC -ldl)
+endif()
+
+link_custom_library(htp_iface cdsprpc)
+link_custom_library(htp_iface rpcmem)
+
+set(TARGET_NAME ggml-hexagon)
+ggml_add_backend_library(${TARGET_NAME}
+    ggml-hexagon.cpp htp-utils.c htp-utils.h ../../include/ggml-hexagon.h)
+
+target_link_libraries(${TARGET_NAME} PRIVATE htp_iface)
+target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/htp ${CMAKE_CURRENT_BINARY_DIR})
+
+# Build HTP bits
+set(HTP_CMAKE_ARGS
+    -DCMAKE_TOOLCHAIN_FILE=${CMAKE_CURRENT_SOURCE_DIR}/htp/cmake-toolchain.cmake
+    -DCMAKE_BUILD_TYPE=Release
+    -DCMAKE_INSTALL_LIBDIR=${CMAKE_CURRENT_BINARY_DIR}
+    -DHEXAGON_SDK_ROOT=$ENV{HEXAGON_SDK_ROOT}
+    -DHEXAGON_TOOLS_ROOT=$ENV{HEXAGON_TOOLS_ROOT}
+    -DHEXAGON_HTP_DEBUG=${GGML_HEXAGON_HTP_DEBUG})
+
+ExternalProject_Add(htp-v73
+    SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
+    CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v73 -DPREBUILT_LIB_DIR="toolv19_v73")
+
+ExternalProject_Add(htp-v75
+    SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
+    CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v75 -DPREBUILT_LIB_DIR="toolv19_v75")
+
+ExternalProject_Add(htp-v79
+    SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
+    CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v79 -DPREBUILT_LIB_DIR="toolv19_v79")
+
+ExternalProject_Add(htp-v81
+    SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
+    CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v81 -DPREBUILT_LIB_DIR="toolv19_v81")
+
+# Install Hexagon skels required at runtime
+install(FILES
+    ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v73.so
+    ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v75.so
+    ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v79.so
+    ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v81.so
+    TYPE LIB)
diff --git a/src/ggml-hexagon/ggml-hexagon.cpp b/src/ggml-hexagon/ggml-hexagon.cpp
new file mode 100644 (file)
index 0000000..ecfc1c8
--- /dev/null
@@ -0,0 +1,3757 @@
+#include <assert.h>
+#include <inttypes.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <time.h>
+
+#include <atomic>
+#include <chrono>
+#include <mutex>
+#include <string>
+
+#ifdef _WIN32
+#    include <sal.h>
+#    ifndef _WINDOWS
+#        define _WINDOWS
+#    endif
+#else
+#    include <semaphore.h>
+#    include <unistd.h>
+#endif
+
+#pragma clang diagnostic ignored "-Wnested-anon-types"
+#pragma clang diagnostic ignored "-Wgnu-anonymous-struct"
+
+#include "htp-utils.h"
+
+#include <AEEStdErr.h>
+#include <dspqueue.h>
+#include <rpcmem.h>
+
+#define GGML_COMMON_IMPL_CPP
+#include "ggml-backend-impl.h"
+#include "ggml-common.h"
+#include "ggml-hexagon.h"
+#include "ggml-impl.h"
+#include "ggml-quants.h"
+#include "htp-msg.h"
+#include "htp_iface.h"
+
+static size_t opt_ndev         = 1;
+static size_t opt_nhvx         = 0;  // use all
+static int    opt_arch         = 0;  // autodetect
+static int    opt_etm          = 0;
+static int    opt_verbose      = 0;
+static int    opt_profile      = 0;
+static int    opt_hostbuf      = 1;
+static int    opt_experimental = 0;
+
+// Enable all stages by default
+static int opt_opmask = HTP_OPMASK_QUEUE | HTP_OPMASK_QUANTIZE | HTP_OPMASK_COMPUTE;
+static int opt_opsync = 0;  // synchronous ops
+
+#define HEX_VERBOSE(...) \
+    if (opt_verbose) GGML_LOG_DEBUG(__VA_ARGS__)
+
+#define HEX_PROFILE(...) \
+    if (opt_profile) GGML_LOG_INFO(__VA_ARGS__)
+
+static inline uint64_t hex_is_aligned(void * addr, uint32_t align) {
+    return ((size_t) addr & (align - 1)) == 0;
+}
+
+static inline size_t hex_round_up(size_t n, size_t m) {
+    return m * ((n + m - 1) / m);
+}
+
+static const char * status_to_str(uint32_t status) {
+    switch (status) {
+        case HTP_STATUS_OK:
+            return "OK";
+        case HTP_STATUS_NO_SUPPORT:
+            return "NO-SUPPORT";
+        case HTP_STATUS_INVAL_PARAMS:
+            return "INVAL-PARAMS";
+        case HTP_STATUS_VTCM_TOO_SMALL:
+            return "VTCM-TOO-SMALL";
+        case HTP_STATUS_INTERNAL_ERR:
+            return "INTERNAL-ERROR";
+        default:
+            return "UNKNOWN";
+    }
+}
+
+// ** debug helpers
+
+static inline int hex_format_tensor_dims(char * str, const struct ggml_tensor * t) {
+    if (t->ne[2] == 1 && t->ne[3] == 1) {
+        return sprintf(str, "%d:%d", (int) t->ne[0], (int) t->ne[1]);
+    } else {
+        return sprintf(str, "%d:%d:%d:%d", (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]);
+    }
+}
+
+static inline void hex_format_op_dims(char * str, const struct ggml_tensor * t) {
+    char * p = str;
+
+    // append src0 and src1 (if any)
+    if (t->src[0]) {
+        p += hex_format_tensor_dims(p, t->src[0]);
+
+        for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) {
+            p += sprintf(p, " x ");
+            p += hex_format_tensor_dims(p, t->src[i]);
+        }
+
+        p += sprintf(p, " -> ");
+    }
+
+    // format self dims separately for better visual alignment
+    char self[64];
+    hex_format_tensor_dims(self, t);
+
+    p += sprintf(p, "%s", self);
+}
+
+static inline int hex_format_tensor_strides(char * str, const struct ggml_tensor * t) {
+    const char * c = ggml_is_contiguous(t) ? "" : "!";
+
+    if (t->ne[2] == 1 && t->ne[3] == 1) {
+        return sprintf(str, "%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], c);
+    } else {
+        return sprintf(str, "%zu:%zu:%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], (size_t) t->nb[2],
+                       (size_t) t->nb[3], c);
+    }
+}
+
+static inline void hex_format_op_strides(char * str, const struct ggml_tensor * t) {
+    char * p = str;
+
+    // append src0 and src1 (if any)
+    if (t->src[0]) {
+        p += hex_format_tensor_strides(p, t->src[0]);
+
+        for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) {
+            p += sprintf(p, " x ");
+            p += hex_format_tensor_strides(p, t->src[i]);
+        }
+
+        p += sprintf(p, " -> ");
+    }
+
+    // format self dims separately for better visual alignment
+    char self[64];
+    hex_format_tensor_strides(self, t);
+
+    p += sprintf(p, "%s", self);
+}
+
+static inline void hex_format_op_types(char * str, const struct ggml_tensor * t) {
+    char * p = str;
+
+    // append src0 and src1 (if any)
+    if (t->src[0]) {
+        p += sprintf(p, "%s", ggml_type_name(t->src[0]->type));
+
+        for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) {
+            p += sprintf(p, " x ");
+            p += sprintf(p, "%s", ggml_type_name(t->src[i]->type));
+        }
+
+        p += sprintf(p, " -> ");
+    }
+
+    p += sprintf(p, "%s", ggml_type_name(t->type));
+}
+
+static inline const char * hex_tensor_buff_name(const struct ggml_tensor * t) {
+    if (t->buffer) {
+        return ggml_backend_buffer_name(t->buffer);
+    }
+    return "NONE";
+}
+
+static inline void hex_format_op_buffs(char * str, const struct ggml_tensor * t) {
+    char * p = str;
+
+    // append src0 and src1 (if any)
+    if (t->src[0]) {
+        p += sprintf(p, "%s", hex_tensor_buff_name(t->src[0]));
+
+        for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) {
+            p += sprintf(p, " x ");
+            p += sprintf(p, "%s", hex_tensor_buff_name(t->src[i]));
+        }
+
+        p += sprintf(p, " -> ");
+    }
+
+    p += sprintf(p, "%s", hex_tensor_buff_name(t));
+}
+
+static inline void hex_format_op_names(char * str, const struct ggml_tensor * t) {
+    char * p = str;
+
+    // append src0 and src1 (if any)
+    if (t->src[0]) {
+        p += sprintf(p, "%s", t->src[0]->name);
+
+        for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) {
+            p += sprintf(p, " x ");
+            p += sprintf(p, "%s", t->src[i]->name);
+        }
+
+        p += sprintf(p, " -> ");
+    }
+
+    p += sprintf(p, "%s", t->name);
+}
+
+// ** backend sessions
+
+struct ggml_hexagon_session {
+    ggml_hexagon_session(int dev_id) noexcept(false);
+    ~ggml_hexagon_session() noexcept(true);
+
+    void allocate(int dev_id) noexcept(false);
+    void release() noexcept(true);
+
+    ggml_backend_buffer_type buffer_type;
+    ggml_backend_buffer_type repack_buffer_type;
+
+    std::string      name;
+    remote_handle64  handle;
+    dspqueue_t       queue;
+    uint32_t         session_id;
+    uint32_t         domain_id;
+    uint64_t         queue_id;
+    int              dev_id;
+    bool             valid_session;
+    bool             valid_handle;
+    bool             valid_queue;
+    bool             valid_iface;
+    std::atomic<int> op_pending;
+    uint32_t         prof_usecs;
+    uint32_t         prof_cycles;
+    uint32_t         prof_pkts;
+};
+
+// Packet callback
+static void htp_packet_callback(dspqueue_t queue, AEEResult error, void * context) {
+    auto sess = static_cast<ggml_hexagon_session *>(context);
+
+    // Repeatedly read packets from the queue until it's empty. We don't
+    // necessarily get a separate callback for each packet, and new packets
+    // may arrive while we're processing the previous one.
+
+    while (1) {
+        struct htp_general_rsp rsp;
+        uint32_t               rsp_size;
+        uint32_t               flags;
+
+        struct dspqueue_buffer bufs[HTP_MAX_PACKET_BUFFERS];
+        uint32_t               n_bufs;
+
+        // Read packet from queue
+        int err = dspqueue_read_noblock(queue, &flags,
+                                        HTP_MAX_PACKET_BUFFERS,  // Maximum number of buffer references
+                                        &n_bufs,                 // Number of buffer references
+                                        bufs,                    // Buffer references
+                                        sizeof(rsp),             // Max message length
+                                        &rsp_size,               // Message length
+                                        (uint8_t *) &rsp);
+
+        if (err == AEE_EWOULDBLOCK) {
+            // Consumed all packets available for now
+            return;
+        }
+
+        if (err != 0) {
+            GGML_ABORT("ggml-hex: dspqueue_read_noblock failed: 0x%08x\n", (unsigned) err);
+        }
+
+        // Basic sanity checks
+        if (rsp_size != sizeof(rsp)) {
+            GGML_ABORT("ggml-hex: dspcall : bad response (size)\n");
+        }
+
+        if (rsp.status != HTP_STATUS_OK) {
+            GGML_LOG_ERROR("ggml-hex: dspcall : dsp-rsp: %s\n", status_to_str(rsp.status));
+            // TODO: handle errors
+        }
+
+        // FIXME: update profiling implementation
+        sess->prof_usecs  = rsp.prof_usecs;
+        sess->prof_cycles = rsp.prof_cycles;
+        sess->prof_pkts   = rsp.prof_pkts;
+
+        sess->op_pending--;  // atomic dec
+    }
+}
+
+// Error callback - simply terminates with an error. Used where we don't
+// expect errors.
+[[noreturn]] static void htp_error_callback(dspqueue_t queue, AEEResult error, void * context) {
+    GGML_ABORT("ggml-hex: dspcall general error 0x%x: for queue %p\n", error, (void *) queue);
+}
+
+// ** backend buffers
+
+struct ggml_backend_hexagon_buffer_type_context {
+    ggml_backend_hexagon_buffer_type_context(const std::string & name, ggml_hexagon_session * sess) {
+        this->sess = sess;
+        this->name = name;
+    }
+
+    ggml_hexagon_session * sess;
+    std::string            name;
+};
+
+struct ggml_backend_hexagon_buffer_context {
+    bool mmap_to(ggml_hexagon_session * s) {
+        HEX_VERBOSE("ggml-hex: %s mmaping buffer: base %p domain-id %d session-id %d size %zu fd %d repack %d\n",
+                    s->name.c_str(), (void *) this->base, s->domain_id, s->session_id, this->size, this->fd,
+                    (int) this->repack);
+
+        int err = fastrpc_mmap(s->domain_id, this->fd, (void *) this->base, 0, this->size, FASTRPC_MAP_FD);
+        if (err != 0) {
+            GGML_LOG_ERROR("ggml-hex: buffer mapping failed : domain_id %d size %zu fd %d error 0x%08x\n",
+                    s->domain_id, this->size, this->fd, (unsigned) err);
+            return false;
+        }
+
+        return true;
+    }
+
+    bool mmap() {
+        if (this->mapped) {
+            return true;
+        }
+        if (!mmap_to(this->sess)) {
+            return false;
+        }
+        this->mapped = true;
+        return true;
+    }
+
+    void munmap() {
+        if (!this->mapped) {
+            return;
+        }
+
+        fastrpc_munmap(this->sess->domain_id, this->fd, this->base, this->size);
+        this->mapped = false;
+    }
+
+    ggml_backend_hexagon_buffer_context(ggml_hexagon_session * sess, size_t size, bool repack) {
+        size += 4 * 1024;  // extra page for padding
+
+        this->base = (uint8_t *) rpcmem_alloc2(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS | RPCMEM_HEAP_NOREG, size);
+        if (!this->base) {
+            GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer : size %zu\n", sess->name.c_str(), size);
+            throw std::runtime_error("ggml-hex: rpcmem_alloc failed (see log for details)");
+        }
+
+        this->fd = rpcmem_to_fd(this->base);
+        if (this->fd < 0) {
+            GGML_LOG_ERROR("ggml-hex: %s failed to get FD for buffer %p\n", sess->name.c_str(), (void *) this->base);
+            rpcmem_free(this->base);
+            this->base = NULL;
+            throw std::runtime_error("ggml-hex: rpcmem_to_fd failed (see log for details)");
+        }
+
+        HEX_VERBOSE("ggml-hex: %s allocated buffer: base %p size %zu fd %d repack %d\n", sess->name.c_str(),
+                    (void *) this->base, size, this->fd, (int) repack);
+
+        this->sess   = sess;
+        this->size   = size;
+        this->mapped = false;
+        this->repack = repack;
+    }
+
+    ~ggml_backend_hexagon_buffer_context() {
+        munmap();
+        if (this->base) {
+            rpcmem_free(this->base);
+            this->base = NULL;
+        }
+    }
+
+    ggml_hexagon_session * sess;  // primary session
+    uint8_t *              base;
+    size_t                 size;
+    int                    fd;
+    bool                   mapped;  // mmap is done
+    bool                   repack;  // repacked buffer
+};
+
+static ggml_hexagon_session * ggml_backend_hexagon_buffer_get_sess(ggml_backend_buffer_t buffer) {
+    return static_cast<ggml_backend_hexagon_buffer_type_context *>(buffer->buft->context)->sess;
+}
+
+static void ggml_backend_hexagon_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+    auto ctx = static_cast<ggml_backend_hexagon_buffer_context *>(buffer->context);
+    delete ctx;
+}
+
+static void * ggml_backend_hexagon_buffer_get_base(ggml_backend_buffer_t buffer) {
+    auto ctx = static_cast<ggml_backend_hexagon_buffer_context *>(buffer->context);
+    return ctx->base;
+}
+
+static enum ggml_status ggml_backend_hexagon_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
+    auto ctx  = static_cast<ggml_backend_hexagon_buffer_context *>(buffer->context);
+    auto sess = ctx->sess;
+
+    HEX_VERBOSE("ggml-hex: %s init-tensor %s : base %p data %p nbytes %zu usage %d repack %d\n", sess->name.c_str(),
+                tensor->name, (void *) ctx->base, tensor->data, ggml_nbytes(tensor), (int) buffer->usage,
+                (int) ctx->repack);
+
+    if (tensor->view_src != NULL && tensor->view_offs == 0) {
+        ; // nothing to do for the view
+    } else {
+        if (!ctx->mapped) {
+            ctx->mmap();
+        }
+    }
+    return GGML_STATUS_SUCCESS;
+}
+
+// ======== Q4x4x2 ====================
+struct x2_q4 {
+    int v[2];
+};
+
+static x2_q4 unpack_q4(uint8_t v) {
+    x2_q4 x = { (int) (v & 0x0f) - 8, (int) (v >> 4) - 8 };
+    return x;
+}
+
+static void dump_block_q4_0(const block_q4_0 * b, int i) {
+    HEX_VERBOSE("ggml-hex: repack q4_0 %d: %d %d %d %d ... %d %d %d %d : %.6f\n", i, unpack_q4(b->qs[0]).v[0],
+                unpack_q4(b->qs[1]).v[0], unpack_q4(b->qs[2]).v[0], unpack_q4(b->qs[3]).v[0], unpack_q4(b->qs[12]).v[1],
+                unpack_q4(b->qs[13]).v[1], unpack_q4(b->qs[14]).v[1], unpack_q4(b->qs[15]).v[1],
+                GGML_FP16_TO_FP32(b->d));
+}
+
+static void dump_packed_block_q4x4x2(const uint8_t * v, unsigned int i, size_t k) {
+    static const int qk        = QK_Q4_0x4x2;
+    const int        dblk_size = 8 * 2;   // 8x __fp16
+    const int        qblk_size = qk / 2;  // int4
+    const int        qrow_size = k / 2;   // int4 (not padded)
+
+    const uint8_t * v_q = v + 0;          // quants first
+    const uint8_t * v_d = v + qrow_size;  // then scales
+
+    const uint8_t *   q = v_q + i * qblk_size;
+    const ggml_half * d = (const ggml_half *) (v_d + i * dblk_size);
+
+    HEX_VERBOSE("ggml-hex: repack q4x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n", i,
+                unpack_q4(q[0]).v[0], unpack_q4(q[1]).v[0], unpack_q4(q[2]).v[0], unpack_q4(q[3]).v[0],
+                unpack_q4(q[60]).v[0], unpack_q4(q[61]).v[0], unpack_q4(q[62]).v[0], unpack_q4(q[63]).v[0],
+                unpack_q4(q[124]).v[0], unpack_q4(q[125]).v[0], unpack_q4(q[126]).v[0], unpack_q4(q[127]).v[0],
+                GGML_FP16_TO_FP32(d[0]), GGML_FP16_TO_FP32(d[1]), GGML_FP16_TO_FP32(d[2]), GGML_FP16_TO_FP32(d[3]));
+
+    HEX_VERBOSE("ggml-hex: repack q4x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n",
+                i + 1, unpack_q4(q[0]).v[1], unpack_q4(q[1]).v[1], unpack_q4(q[2]).v[1], unpack_q4(q[3]).v[1],
+                unpack_q4(q[60]).v[1], unpack_q4(q[61]).v[1], unpack_q4(q[62]).v[1], unpack_q4(q[63]).v[1],
+                unpack_q4(q[124]).v[1], unpack_q4(q[125]).v[1], unpack_q4(q[126]).v[1], unpack_q4(q[127]).v[1],
+                GGML_FP16_TO_FP32(d[4]), GGML_FP16_TO_FP32(d[5]), GGML_FP16_TO_FP32(d[6]), GGML_FP16_TO_FP32(d[7]));
+}
+
+static void unpack_q4_0_quants(uint8_t * qs, const block_q4_0 * x, unsigned int bi) {
+    static const int qk = QK4_0;
+
+    for (unsigned int i = 0; i < qk / 2; ++i) {
+        const int x0             = (x->qs[i] & 0x0F);
+        const int x1             = (x->qs[i] >> 4);
+        qs[bi * qk + i + 0]      = x0;
+        qs[bi * qk + i + qk / 2] = x1;
+    }
+}
+
+static void pack_q4_0_quants(block_q4_0 * x, const uint8_t * qs, unsigned int bi) {
+    static const int qk = QK4_0;
+
+    for (unsigned int i = 0; i < qk / 2; ++i) {
+        const uint8_t x0 = qs[bi * qk + i + 0];
+        const uint8_t x1 = qs[bi * qk + i + qk / 2];
+        x->qs[i]         = x0 | (x1 << 4);
+    }
+}
+
+static void repack_row_q4x4x2(uint8_t * y, const block_q4_0 * x, int64_t k) {
+    static const int qk = QK_Q4_0x4x2;
+    const int        nb = (k + qk - 1) / qk;  // number of blocks (padded)
+
+    const int dblk_size = 8 * 2;              // 8x __fp16
+    const int qblk_size = qk / 2;             // int4
+    const int qrow_size = k / 2;              // int4 (not padded to blocks)
+
+    uint8_t * y_q = y + 0;                    // quants first
+    uint8_t * y_d = y + qrow_size;            // then scales
+
+    if (opt_verbose > 2) {
+        for (int i = 0; i < nb; i++) {
+            dump_block_q4_0(&x[i * 8 + 0], 0);
+            dump_block_q4_0(&x[i * 8 + 1], 1);
+            dump_block_q4_0(&x[i * 8 + 2], 2);
+            dump_block_q4_0(&x[i * 8 + 3], 3);
+            dump_block_q4_0(&x[i * 8 + 4], 4);
+            dump_block_q4_0(&x[i * 8 + 5], 5);
+            dump_block_q4_0(&x[i * 8 + 6], 6);
+            dump_block_q4_0(&x[i * 8 + 7], 7);
+        }
+    }
+
+    // Repack the quants
+    for (int i = 0; i < nb; i++) {
+        uint8_t qs[QK_Q4_0x4x2];  // unpacked quants
+        unpack_q4_0_quants(qs, &x[i * 8 + 0], 0);
+        unpack_q4_0_quants(qs, &x[i * 8 + 1], 1);
+        unpack_q4_0_quants(qs, &x[i * 8 + 2], 2);
+        unpack_q4_0_quants(qs, &x[i * 8 + 3], 3);
+        unpack_q4_0_quants(qs, &x[i * 8 + 4], 4);
+        unpack_q4_0_quants(qs, &x[i * 8 + 5], 5);
+        unpack_q4_0_quants(qs, &x[i * 8 + 6], 6);
+        unpack_q4_0_quants(qs, &x[i * 8 + 7], 7);
+
+        uint8_t * q = y_q + (i * qblk_size);
+        for (int j = 0; j < qk / 2; j++) {
+            q[j] = (qs[j + 128] << 4) | qs[j];
+        }
+    }
+
+    // Repack the scales
+    // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2)
+    // the last block is truncated and overriden by the scales.
+    for (int i = 0; i < nb; i++) {
+        // Repack the scales
+        ggml_half * d = (ggml_half *) (y_d + i * dblk_size);
+        d[0]          = x[i * 8 + 0].d;
+        d[1]          = x[i * 8 + 1].d;
+        d[2]          = x[i * 8 + 2].d;
+        d[3]          = x[i * 8 + 3].d;
+        d[4]          = x[i * 8 + 4].d;
+        d[5]          = x[i * 8 + 5].d;
+        d[6]          = x[i * 8 + 6].d;
+        d[7]          = x[i * 8 + 7].d;
+    }
+
+    if (opt_verbose > 1) {
+        for (int i = 0; i < nb; i++) {
+            dump_packed_block_q4x4x2(y, i, k);
+        }
+    }
+}
+
+static void unpack_row_q4x4x2(block_q4_0 * x, const uint8_t * y, int64_t k) {
+    static const int qk = QK_Q4_0x4x2;
+    const int        nb = (k + qk - 1) / qk;  // number of blocks (padded)
+
+    const int dblk_size = 8 * 2;              // 8x __fp16
+    const int qblk_size = qk / 2;             // int4
+    const int qrow_size = k / 2;              // int4 (not padded to blocks)
+
+    const uint8_t * y_q = y + 0;              // quants first
+    const uint8_t * y_d = y + qrow_size;      // then scales
+
+    if (opt_verbose > 1) {
+        for (int i = 0; i < nb; i++) {
+            dump_packed_block_q4x4x2(y, i, k);
+        }
+    }
+
+    // Unpack the quants
+    for (int i = 0; i < nb; i++) {
+        uint8_t qs[QK_Q4_0x4x2];  // unpacked quants
+
+        const uint8_t * q = y_q + (i * qblk_size);
+        for (int j = 0; j < qk / 2; j++) {
+            qs[j]       = q[j] & 0xf;
+            qs[j + 128] = q[j] >> 4;
+        }
+
+        pack_q4_0_quants(&x[i * 8 + 0], qs, 0);
+        pack_q4_0_quants(&x[i * 8 + 1], qs, 1);
+        pack_q4_0_quants(&x[i * 8 + 2], qs, 2);
+        pack_q4_0_quants(&x[i * 8 + 3], qs, 3);
+        pack_q4_0_quants(&x[i * 8 + 4], qs, 4);
+        pack_q4_0_quants(&x[i * 8 + 5], qs, 5);
+        pack_q4_0_quants(&x[i * 8 + 6], qs, 6);
+        pack_q4_0_quants(&x[i * 8 + 7], qs, 7);
+    }
+
+    // Repack the scales
+    // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2)
+    // the last block is truncated and overriden by the scales.
+    for (int i = 0; i < nb; i++) {
+        // Unpack the scales
+        const ggml_half * d = (const ggml_half *) (y_d + i * dblk_size);
+        x[i * 8 + 0].d      = d[0];
+        x[i * 8 + 1].d      = d[1];
+        x[i * 8 + 2].d      = d[2];
+        x[i * 8 + 3].d      = d[3];
+        x[i * 8 + 4].d      = d[4];
+        x[i * 8 + 5].d      = d[5];
+        x[i * 8 + 6].d      = d[6];
+        x[i * 8 + 7].d      = d[7];
+    }
+
+    if (opt_verbose > 2) {
+        for (int i = 0; i < nb; i++) {
+            dump_block_q4_0(&x[i * 8 + 0], 0);
+            dump_block_q4_0(&x[i * 8 + 1], 1);
+            dump_block_q4_0(&x[i * 8 + 2], 2);
+            dump_block_q4_0(&x[i * 8 + 3], 3);
+            dump_block_q4_0(&x[i * 8 + 4], 4);
+            dump_block_q4_0(&x[i * 8 + 5], 5);
+            dump_block_q4_0(&x[i * 8 + 6], 6);
+            dump_block_q4_0(&x[i * 8 + 7], 7);
+        }
+    }
+}
+
+static void init_row_q4x4x2(block_q4_0 * x, int64_t k) {
+    static const int qk = QK_Q4_0x4x2;
+    const int        nb = (k + qk - 1) / qk;  // number of blocks (padded)
+
+    // Init the quants such that they unpack into zeros
+    uint8_t qs[QK_Q4_0x4x2];  // unpacked quants
+    memset(qs, 8, sizeof(qs));
+
+    for (int i = 0; i < nb; i++) {
+        pack_q4_0_quants(&x[i * 8 + 0], qs, 0);
+        pack_q4_0_quants(&x[i * 8 + 1], qs, 1);
+        pack_q4_0_quants(&x[i * 8 + 2], qs, 2);
+        pack_q4_0_quants(&x[i * 8 + 3], qs, 3);
+        pack_q4_0_quants(&x[i * 8 + 4], qs, 4);
+        pack_q4_0_quants(&x[i * 8 + 5], qs, 5);
+        pack_q4_0_quants(&x[i * 8 + 6], qs, 6);
+        pack_q4_0_quants(&x[i * 8 + 7], qs, 7);
+    }
+
+    // Init the scales
+    // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2)
+    // the last block is truncated and overriden by the scales.
+    for (int i = 0; i < nb; i++) {
+        // Unpack the scales
+        x[i * 8 + 0].d = 0;
+        x[i * 8 + 1].d = 0;
+        x[i * 8 + 2].d = 0;
+        x[i * 8 + 3].d = 0;
+        x[i * 8 + 4].d = 0;
+        x[i * 8 + 5].d = 0;
+        x[i * 8 + 6].d = 0;
+        x[i * 8 + 7].d = 0;
+    }
+}
+
+// repack q4_0 data into q4x4x2 tensor
+static void repack_q4_0_q4x4x2(ggml_tensor * t, const void * data, size_t size) {
+    int64_t nrows = ggml_nrows(t);
+
+    size_t row_size    = ggml_row_size(t->type, t->ne[0]);
+    size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2));  // extra elements for the pad
+    size_t row_size_rp = row_size * 2;  // extra space for tmp pad (if any)
+
+    void * buf_pd = ggml_aligned_malloc(row_size_pd);
+    GGML_ASSERT(buf_pd != NULL);
+
+    void * buf_rp = ggml_aligned_malloc(row_size_rp);
+    GGML_ASSERT(buf_rp != NULL);
+
+    HEX_VERBOSE("ggml-hex: repack-q4_0-q4x4x2 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size,
+                t->ne[0], nrows, row_size);
+
+    init_row_q4x4x2((block_q4_0 *) buf_pd, t->ne[0]);  // init padded buffer to make sure the tail is all zeros
+
+    for (int64_t i = 0; i < nrows; i++) {
+        const uint8_t * src = (const uint8_t *) data + (i * row_size);
+        uint8_t *       dst = (uint8_t *) t->data + (i * row_size);
+
+        memcpy(buf_pd, src, row_size);
+        repack_row_q4x4x2((uint8_t *) buf_rp, (const block_q4_0 *) buf_pd, t->ne[0]);
+        memcpy(dst, buf_rp, row_size);
+    }
+
+    ggml_aligned_free(buf_pd, row_size_pd);
+    ggml_aligned_free(buf_rp, row_size_rp);
+}
+
+// repack q4x4x2 tensor into q4_0 data
+static void repack_q4x4x2_q4_0(void * data, const ggml_tensor * t, size_t size) {
+    int64_t nrows = ggml_nrows(t);
+
+    size_t row_size    = ggml_row_size(t->type, t->ne[0]);
+    size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2));  // extra elements for the pad
+    size_t row_size_rp = row_size * 2;  // extra space for tmp pad (if any)
+
+    void * buf_pd = ggml_aligned_malloc(row_size_pd);
+    GGML_ASSERT(buf_pd != NULL);
+
+    void * buf_rp = ggml_aligned_malloc(row_size_rp);
+    GGML_ASSERT(buf_rp != NULL);
+
+    HEX_VERBOSE("ggml-hex: repack-q4x4x2-q4_0 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size,
+                t->ne[0], nrows, row_size);
+
+    memset(buf_pd, 0, row_size_pd);  // clear-out padded buffer to make sure the tail is all zeros
+
+    for (int64_t i = 0; i < nrows; i++) {
+        const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
+        uint8_t *       dst = (uint8_t *) data + (i * row_size);
+
+        memcpy(buf_pd, src, row_size);
+        unpack_row_q4x4x2((block_q4_0 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]);
+        memcpy(dst, buf_rp, row_size);
+    }
+
+    ggml_aligned_free(buf_pd, row_size_pd);
+    ggml_aligned_free(buf_rp, row_size_rp);
+}
+
+// ======== Q8x4x2 ====================
+static void dump_block_q8_0(const block_q8_0 * b, int i) {
+    HEX_VERBOSE("ggml-hex: repack q8_0 %d: %d %d %d %d ... %d %d %d %d : %.6f\n", i, b->qs[0], b->qs[1], b->qs[2],
+                b->qs[3], b->qs[28], b->qs[29], b->qs[30], b->qs[31], GGML_FP16_TO_FP32(b->d));
+}
+
+static void dump_packed_block_q8x4x2(const uint8_t * v, unsigned int i, size_t k) {
+    static const int qk        = QK_Q8_0x4x2;
+    const int        dblk_size = 8 * 2;   // 8x __fp16
+    const int        qblk_size = qk;      // int8
+    const int        qrow_size = k;       // int8 (not padded)
+
+    const uint8_t * v_q = v + 0;          // quants first
+    const uint8_t * v_d = v + qrow_size;  // then scales
+
+    const uint8_t *   q = v_q + i * qblk_size;
+    const ggml_half * d = (const ggml_half *) (v_d + i * dblk_size);
+
+    HEX_VERBOSE("ggml-hex: repack q8x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n", i,
+                q[0], q[1], q[2], q[3], q[60], q[61], q[62], q[63], q[124], q[125], q[126], q[127],
+                GGML_FP16_TO_FP32(d[0]), GGML_FP16_TO_FP32(d[1]), GGML_FP16_TO_FP32(d[2]), GGML_FP16_TO_FP32(d[3]));
+
+    HEX_VERBOSE("ggml-hex: repack q8x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n",
+                i + 1, q[128], q[129], q[130], q[131], q[192], q[193], q[194], q[195], q[252], q[253], q[254], q[255],
+                GGML_FP16_TO_FP32(d[4]), GGML_FP16_TO_FP32(d[5]), GGML_FP16_TO_FP32(d[6]), GGML_FP16_TO_FP32(d[7]));
+}
+
+static void unpack_q8_0_quants(uint8_t * qs, const block_q8_0 * x, unsigned int bi) {
+    static const int qk = QK8_0;
+
+    for (unsigned int i = 0; i < qk; ++i) {
+        qs[bi * qk + i] = x->qs[i];
+    }
+}
+
+static void pack_q8_0_quants(block_q8_0 * x, const uint8_t * qs, unsigned int bi) {
+    static const int qk = QK8_0;
+
+    for (unsigned int i = 0; i < qk; ++i) {
+        x->qs[i] = qs[bi * qk + i];
+    }
+}
+
+static void repack_row_q8x4x2(uint8_t * y, const block_q8_0 * x, int64_t k) {
+    static const int qk = QK_Q8_0x4x2;
+    const int        nb = (k + qk - 1) / qk;  // number of blocks (padded)
+
+    const int dblk_size = 8 * 2;              // 8x __fp16
+    const int qblk_size = qk;                 // int8
+    const int qrow_size = k;                  // int8 (not padded to blocks)
+
+    uint8_t * y_q = y + 0;                    // quants first
+    uint8_t * y_d = y + qrow_size;            // then scales
+
+    if (opt_verbose > 2) {
+        for (int i = 0; i < nb; i++) {
+            dump_block_q8_0(&x[i * 8 + 0], 0);
+            dump_block_q8_0(&x[i * 8 + 1], 1);
+            dump_block_q8_0(&x[i * 8 + 2], 2);
+            dump_block_q8_0(&x[i * 8 + 3], 3);
+            dump_block_q8_0(&x[i * 8 + 4], 4);
+            dump_block_q8_0(&x[i * 8 + 5], 5);
+            dump_block_q8_0(&x[i * 8 + 6], 6);
+            dump_block_q8_0(&x[i * 8 + 7], 7);
+        }
+    }
+
+    // Repack the quants
+    for (int i = 0; i < nb; i++) {
+        uint8_t qs[QK_Q8_0x4x2];  // unpacked quants
+
+        unpack_q8_0_quants(qs, &x[i * 8 + 0], 0);
+        unpack_q8_0_quants(qs, &x[i * 8 + 1], 1);
+        unpack_q8_0_quants(qs, &x[i * 8 + 2], 2);
+        unpack_q8_0_quants(qs, &x[i * 8 + 3], 3);
+        unpack_q8_0_quants(qs, &x[i * 8 + 4], 4);
+        unpack_q8_0_quants(qs, &x[i * 8 + 5], 5);
+        unpack_q8_0_quants(qs, &x[i * 8 + 6], 6);
+        unpack_q8_0_quants(qs, &x[i * 8 + 7], 7);
+
+        uint8_t * q = y_q + (i * qblk_size);
+        for (int j = 0; j < qk; j++) {
+            q[j] = qs[j];
+        }
+    }
+
+    // Repack the scales
+    // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2)
+    // the last block is truncated and overriden by the scales.
+    for (int i = 0; i < nb; i++) {
+        // Repack the scales
+        ggml_half * d = (ggml_half *) (y_d + i * dblk_size);
+        d[0]          = x[i * 8 + 0].d;
+        d[1]          = x[i * 8 + 1].d;
+        d[2]          = x[i * 8 + 2].d;
+        d[3]          = x[i * 8 + 3].d;
+        d[4]          = x[i * 8 + 4].d;
+        d[5]          = x[i * 8 + 5].d;
+        d[6]          = x[i * 8 + 6].d;
+        d[7]          = x[i * 8 + 7].d;
+    }
+
+    if (opt_verbose > 1) {
+        for (int i = 0; i < nb; i++) {
+            dump_packed_block_q8x4x2(y, i, k);
+        }
+    }
+}
+
+static void unpack_row_q8x4x2(block_q8_0 * x, const uint8_t * y, int64_t k) {
+    static const int qk = QK_Q8_0x4x2;
+    const int        nb = (k + qk - 1) / qk;  // number of blocks (padded)
+
+    const int dblk_size = 8 * 2;              // 8x __fp16
+    const int qblk_size = qk;                 // int8
+    const int qrow_size = k;                  // int8 (not padded to blocks)
+
+    const uint8_t * y_q = y + 0;              // quants first
+    const uint8_t * y_d = y + qrow_size;      // then scales
+
+    if (opt_verbose > 1) {
+        for (int i = 0; i < nb; i++) {
+            dump_packed_block_q8x4x2(y, i, k);
+        }
+    }
+
+    // Unpack the quants
+    for (int i = 0; i < nb; i++) {
+        uint8_t qs[QK_Q4_0x4x2];  // unpacked quants
+
+        const uint8_t * q = y_q + (i * qblk_size);
+        for (int j = 0; j < qk; j++) {
+            qs[j] = q[j];
+        }
+
+        pack_q8_0_quants(&x[i * 8 + 0], qs, 0);
+        pack_q8_0_quants(&x[i * 8 + 1], qs, 1);
+        pack_q8_0_quants(&x[i * 8 + 2], qs, 2);
+        pack_q8_0_quants(&x[i * 8 + 3], qs, 3);
+        pack_q8_0_quants(&x[i * 8 + 4], qs, 4);
+        pack_q8_0_quants(&x[i * 8 + 5], qs, 5);
+        pack_q8_0_quants(&x[i * 8 + 6], qs, 6);
+        pack_q8_0_quants(&x[i * 8 + 7], qs, 7);
+    }
+
+    // Repack the scales
+    // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2)
+    // the last block is truncated and overriden by the scales.
+    for (int i = 0; i < nb; i++) {
+        // Unpack the scales
+        const ggml_half * d = (const ggml_half *) (y_d + i * dblk_size);
+        x[i * 8 + 0].d      = d[0];
+        x[i * 8 + 1].d      = d[1];
+        x[i * 8 + 2].d      = d[2];
+        x[i * 8 + 3].d      = d[3];
+        x[i * 8 + 4].d      = d[4];
+        x[i * 8 + 5].d      = d[5];
+        x[i * 8 + 6].d      = d[6];
+        x[i * 8 + 7].d      = d[7];
+    }
+
+    if (opt_verbose > 2) {
+        for (int i = 0; i < nb; i++) {
+            dump_block_q8_0(&x[i * 8 + 0], 0);
+            dump_block_q8_0(&x[i * 8 + 1], 1);
+            dump_block_q8_0(&x[i * 8 + 2], 2);
+            dump_block_q8_0(&x[i * 8 + 3], 3);
+            dump_block_q8_0(&x[i * 8 + 4], 4);
+            dump_block_q8_0(&x[i * 8 + 5], 5);
+            dump_block_q8_0(&x[i * 8 + 6], 6);
+            dump_block_q8_0(&x[i * 8 + 7], 7);
+        }
+    }
+}
+
+static void init_row_q8x4x2(block_q8_0 * x, int64_t k) {
+    static const int qk = QK_Q8_0x4x2;
+    const int        nb = (k + qk - 1) / qk;  // number of blocks (padded)
+
+    // Init the quants such that they unpack into zeros
+    uint8_t qs[QK_Q8_0x4x2];  // unpacked quants
+    memset(qs, 0, sizeof(qs));
+
+    for (int i = 0; i < nb; i++) {
+        pack_q8_0_quants(&x[i * 8 + 0], qs, 0);
+        pack_q8_0_quants(&x[i * 8 + 1], qs, 1);
+        pack_q8_0_quants(&x[i * 8 + 2], qs, 2);
+        pack_q8_0_quants(&x[i * 8 + 3], qs, 3);
+        pack_q8_0_quants(&x[i * 8 + 4], qs, 4);
+        pack_q8_0_quants(&x[i * 8 + 5], qs, 5);
+        pack_q8_0_quants(&x[i * 8 + 6], qs, 6);
+        pack_q8_0_quants(&x[i * 8 + 7], qs, 7);
+    }
+
+    // Init the scales
+    // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q8_0x4x2)
+    // the last block is truncated and overriden by the scales.
+    for (int i = 0; i < nb; i++) {
+        // Unpack the scales
+        x[i * 8 + 0].d = 0;
+        x[i * 8 + 1].d = 0;
+        x[i * 8 + 2].d = 0;
+        x[i * 8 + 3].d = 0;
+        x[i * 8 + 4].d = 0;
+        x[i * 8 + 5].d = 0;
+        x[i * 8 + 6].d = 0;
+        x[i * 8 + 7].d = 0;
+    }
+}
+
+// repack q8_0 data into q8x4x2 tensor
+static void repack_q8_0_q8x4x2(ggml_tensor * t, const void * data, size_t size) {
+    int64_t nrows = ggml_nrows(t);
+
+    size_t row_size    = ggml_row_size(t->type, t->ne[0]);
+    size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q8_0x4x2));  // extra elements for the pad
+    size_t row_size_rp = row_size * 2;  // extra space for tmp pad (if any)
+
+    void * buf_pd = ggml_aligned_malloc(row_size_pd);
+    GGML_ASSERT(buf_pd != NULL);
+
+    void * buf_rp = ggml_aligned_malloc(row_size_rp);
+    GGML_ASSERT(buf_rp != NULL);
+
+    HEX_VERBOSE("ggml-hex: repack-q8_0-q8x4x2 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size,
+                t->ne[0], nrows, row_size);
+
+    init_row_q8x4x2((block_q8_0 *) buf_pd, t->ne[0]);  // init padded buffer to make sure the tail is all zeros
+
+    for (int64_t i = 0; i < nrows; i++) {
+        const uint8_t * src = (const uint8_t *) data + (i * row_size);
+        uint8_t *       dst = (uint8_t *) t->data + (i * row_size);
+
+        memcpy(buf_pd, src, row_size);
+        repack_row_q8x4x2((uint8_t *) buf_rp, (const block_q8_0 *) buf_pd, t->ne[0]);
+        memcpy(dst, buf_rp, row_size);
+    }
+
+    ggml_aligned_free(buf_pd, row_size_pd);
+    ggml_aligned_free(buf_rp, row_size_rp);
+}
+
+// repack q8x4x2 tensor into q8_0 data
+static void repack_q8x4x2_q8_0(void * data, const ggml_tensor * t, size_t size) {
+    int64_t nrows = ggml_nrows(t);
+
+    size_t row_size    = ggml_row_size(t->type, t->ne[0]);
+    size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q8_0x4x2));  // extra elements for the pad
+    size_t row_size_rp = row_size * 2;  // extra space for tmp pad (if any)
+
+    void * buf_pd = ggml_aligned_malloc(row_size_pd);
+    GGML_ASSERT(buf_pd != NULL);
+
+    void * buf_rp = ggml_aligned_malloc(row_size_rp);
+    GGML_ASSERT(buf_rp != NULL);
+
+    HEX_VERBOSE("ggml-hex: repack-q8x4x2-q8_0 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size,
+                t->ne[0], nrows, row_size);
+
+    memset(buf_pd, 0, row_size_pd);  // clear-out padded buffer to make sure the tail is all zeros
+
+    for (int64_t i = 0; i < nrows; i++) {
+        const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
+        uint8_t *       dst = (uint8_t *) data + (i * row_size);
+
+        memcpy(buf_pd, src, row_size);
+        unpack_row_q8x4x2((block_q8_0 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]);
+        memcpy(dst, buf_rp, row_size);
+    }
+
+    ggml_aligned_free(buf_pd, row_size_pd);
+    ggml_aligned_free(buf_rp, row_size_rp);
+}
+
+// ======== MXFP4x4x2 ====================
+struct x2_mxfp4 {
+    int v[2];
+};
+
+static x2_mxfp4 unpack_mxfp4(uint8_t v) {
+    x2_mxfp4 x;
+    x.v[0] = kvalues_mxfp4[(v & 0x0f)];
+    x.v[1] = kvalues_mxfp4[(v >> 4)];
+    return x;
+}
+
+static void dump_block_mxfp4(const block_mxfp4 * b, int i) {
+    HEX_VERBOSE("ggml-hex: repack mxfp4 %d: %d %d %d %d ... %d %d %d %d : %.6f\n", i, unpack_mxfp4(b->qs[0]).v[0],
+                unpack_mxfp4(b->qs[1]).v[0], unpack_mxfp4(b->qs[2]).v[0], unpack_mxfp4(b->qs[3]).v[0],
+                unpack_mxfp4(b->qs[12]).v[1], unpack_mxfp4(b->qs[13]).v[1], unpack_mxfp4(b->qs[14]).v[1],
+                unpack_mxfp4(b->qs[15]).v[1], GGML_E8M0_TO_FP32_HALF(b->e));
+}
+
+static void dump_packed_block_mxfp4x4x2(const uint8_t * v, unsigned int i, size_t k) {
+    static const int qk        = QK_MXFP4x4x2;
+    const int        eblk_size = 8 * 1;   // 8x E8M0
+    const int        qblk_size = qk / 2;  // int4
+    const int        qrow_size = k / 2;   // int4 (not padded)
+
+    const uint8_t * v_q = v + 0;          // quants first
+    const uint8_t * v_e = v + qrow_size;  // then scales
+
+    const uint8_t * q = v_q + i * qblk_size;
+    const uint8_t * e = (const uint8_t *) (v_e + i * eblk_size);
+
+    HEX_VERBOSE("ggml-hex: repack mxfp4x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n", i,
+                unpack_mxfp4(q[0]).v[0], unpack_mxfp4(q[1]).v[0], unpack_mxfp4(q[2]).v[0], unpack_mxfp4(q[3]).v[0],
+                unpack_mxfp4(q[60]).v[0], unpack_mxfp4(q[61]).v[0], unpack_mxfp4(q[62]).v[0], unpack_mxfp4(q[63]).v[0],
+                unpack_mxfp4(q[124]).v[0], unpack_mxfp4(q[125]).v[0], unpack_mxfp4(q[126]).v[0],
+                unpack_mxfp4(q[127]).v[0], GGML_E8M0_TO_FP32_HALF(e[0]), GGML_E8M0_TO_FP32_HALF(e[1]),
+                GGML_E8M0_TO_FP32_HALF(e[2]), GGML_E8M0_TO_FP32_HALF(e[3]));
+
+    HEX_VERBOSE("ggml-hex: repack mxfp4x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n",
+                i + 1, unpack_mxfp4(q[0]).v[1], unpack_mxfp4(q[1]).v[1], unpack_mxfp4(q[2]).v[1],
+                unpack_mxfp4(q[3]).v[1], unpack_mxfp4(q[60]).v[1], unpack_mxfp4(q[61]).v[1], unpack_mxfp4(q[62]).v[1],
+                unpack_mxfp4(q[63]).v[1], unpack_mxfp4(q[124]).v[1], unpack_mxfp4(q[125]).v[1],
+                unpack_mxfp4(q[126]).v[1], unpack_mxfp4(q[127]).v[1], GGML_E8M0_TO_FP32_HALF(e[4]),
+                GGML_E8M0_TO_FP32_HALF(e[5]), GGML_E8M0_TO_FP32_HALF(e[6]), GGML_E8M0_TO_FP32_HALF(e[7]));
+}
+
+static void unpack_mxfp4_quants(uint8_t * qs, const block_mxfp4 * x, unsigned int bi) {
+    static const int qk = QK_MXFP4;
+
+    for (unsigned int i = 0; i < qk / 2; ++i) {
+        const uint8_t x0         = (x->qs[i] & 0x0F);
+        const uint8_t x1         = (x->qs[i] >> 4);
+        qs[bi * qk + i + 0]      = x0;
+        qs[bi * qk + i + qk / 2] = x1;
+    }
+}
+
+static void pack_mxfp4_quants(block_mxfp4 * x, const uint8_t * qs, unsigned int bi) {
+    static const int qk = QK4_0;
+
+    for (unsigned int i = 0; i < qk / 2; ++i) {
+        const uint8_t x0 = qs[bi * qk + i + 0];
+        const uint8_t x1 = qs[bi * qk + i + qk / 2];
+        x->qs[i]         = x0 | (x1 << 4);
+    }
+}
+
+static void repack_row_mxfp4x4x2(uint8_t * y, const block_mxfp4 * x, int64_t k) {
+    static const int qk = QK_MXFP4x4x2;
+    const int        nb = (k + qk - 1) / qk;  // number of blocks (padded)
+
+    const int eblk_size = 8 * 1;              // 8x E8M0
+    const int qblk_size = qk / 2;             // int4
+    const int qrow_size = k / 2;              // int4 (not padded to blocks)
+
+    uint8_t * y_q = y + 0;                    // quants first
+    uint8_t * y_e = y + qrow_size;            // then scales
+
+    if (opt_verbose > 2) {
+        for (int i = 0; i < nb; i++) {
+            dump_block_mxfp4(&x[i * 8 + 0], 0);
+            dump_block_mxfp4(&x[i * 8 + 1], 1);
+            dump_block_mxfp4(&x[i * 8 + 2], 2);
+            dump_block_mxfp4(&x[i * 8 + 3], 3);
+            dump_block_mxfp4(&x[i * 8 + 4], 4);
+            dump_block_mxfp4(&x[i * 8 + 5], 5);
+            dump_block_mxfp4(&x[i * 8 + 6], 6);
+            dump_block_mxfp4(&x[i * 8 + 7], 7);
+        }
+    }
+
+    // Repack the quants
+    for (int i = 0; i < nb; i++) {
+        uint8_t qs[QK_MXFP4x4x2];  // unpacked quants
+
+        unpack_mxfp4_quants(qs, &x[i * 8 + 0], 0);
+        unpack_mxfp4_quants(qs, &x[i * 8 + 1], 1);
+        unpack_mxfp4_quants(qs, &x[i * 8 + 2], 2);
+        unpack_mxfp4_quants(qs, &x[i * 8 + 3], 3);
+        unpack_mxfp4_quants(qs, &x[i * 8 + 4], 4);
+        unpack_mxfp4_quants(qs, &x[i * 8 + 5], 5);
+        unpack_mxfp4_quants(qs, &x[i * 8 + 6], 6);
+        unpack_mxfp4_quants(qs, &x[i * 8 + 7], 7);
+
+        uint8_t * q = y_q + (i * qblk_size);
+        for (int j = 0; j < qk / 2; j++) {
+            q[j] = (qs[j + 128] << 4) | qs[j];
+        }
+    }
+
+    // Repack the scales
+    // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_MXFP4x4x2)
+    // the last block is truncated and overriden by the scales.
+    for (int i = 0; i < nb; i++) {
+        // Repack the scales
+        uint8_t * e = (uint8_t *) (y_e + i * eblk_size);
+        e[0]        = x[i * 8 + 0].e;
+        e[1]        = x[i * 8 + 1].e;
+        e[2]        = x[i * 8 + 2].e;
+        e[3]        = x[i * 8 + 3].e;
+        e[4]        = x[i * 8 + 4].e;
+        e[5]        = x[i * 8 + 5].e;
+        e[6]        = x[i * 8 + 6].e;
+        e[7]        = x[i * 8 + 7].e;
+    }
+
+    if (opt_verbose > 1) {
+        for (int i = 0; i < nb; i++) {
+            dump_packed_block_mxfp4x4x2(y, i, k);
+        }
+    }
+}
+
+static void unpack_row_mxfp4x4x2(block_mxfp4 * x, const uint8_t * y, int64_t k) {
+    static const int qk = QK_MXFP4x4x2;
+    const int        nb = (k + qk - 1) / qk;  // number of blocks (padded)
+
+    const int eblk_size = 8 * 1;              // 8x E8M0
+    const int qblk_size = qk / 2;             // int4
+    const int qrow_size = k / 2;              // int4 (not padded to blocks)
+
+    const uint8_t * y_q = y + 0;              // quants first
+    const uint8_t * y_e = y + qrow_size;      // then scales
+
+    if (opt_verbose > 1) {
+        for (int i = 0; i < nb; i++) {
+            dump_packed_block_mxfp4x4x2(y, i, k);
+        }
+    }
+
+    // Unpack the quants
+    for (int i = 0; i < nb; i++) {
+        uint8_t qs[QK_MXFP4x4x2];  // unpacked quants
+
+        const uint8_t * q = y_q + (i * qblk_size);
+        for (int j = 0; j < qk / 2; j++) {
+            qs[j]       = q[j] & 0xf;
+            qs[j + 128] = q[j] >> 4;
+        }
+
+        pack_mxfp4_quants(&x[i * 8 + 0], qs, 0);
+        pack_mxfp4_quants(&x[i * 8 + 1], qs, 1);
+        pack_mxfp4_quants(&x[i * 8 + 2], qs, 2);
+        pack_mxfp4_quants(&x[i * 8 + 3], qs, 3);
+        pack_mxfp4_quants(&x[i * 8 + 4], qs, 4);
+        pack_mxfp4_quants(&x[i * 8 + 5], qs, 5);
+        pack_mxfp4_quants(&x[i * 8 + 6], qs, 6);
+        pack_mxfp4_quants(&x[i * 8 + 7], qs, 7);
+    }
+
+    // Repack the scales
+    // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_MXFP4_0x4x2)
+    // the last block is truncated and overriden by the scales.
+    for (int i = 0; i < nb; i++) {
+        // Unpack the scales
+        const uint8_t * e = (const uint8_t *) (y_e + i * eblk_size);
+        x[i * 8 + 0].e    = e[0];
+        x[i * 8 + 1].e    = e[1];
+        x[i * 8 + 2].e    = e[2];
+        x[i * 8 + 3].e    = e[3];
+        x[i * 8 + 4].e    = e[4];
+        x[i * 8 + 5].e    = e[5];
+        x[i * 8 + 6].e    = e[6];
+        x[i * 8 + 7].e    = e[7];
+    }
+
+    if (opt_verbose > 2) {
+        for (int i = 0; i < nb; i++) {
+            dump_block_mxfp4(&x[i * 8 + 0], 0);
+            dump_block_mxfp4(&x[i * 8 + 1], 1);
+            dump_block_mxfp4(&x[i * 8 + 2], 2);
+            dump_block_mxfp4(&x[i * 8 + 3], 3);
+            dump_block_mxfp4(&x[i * 8 + 4], 4);
+            dump_block_mxfp4(&x[i * 8 + 5], 5);
+            dump_block_mxfp4(&x[i * 8 + 6], 6);
+            dump_block_mxfp4(&x[i * 8 + 7], 7);
+        }
+    }
+}
+
+static void init_row_mxfp4x4x2(block_mxfp4 * x, int64_t k) {
+    static const int qk = QK_MXFP4x4x2;
+    const int        nb = (k + qk - 1) / qk;  // number of blocks (padded)
+
+    // Init the quants such that they unpack into zeros
+    uint8_t qs[QK_MXFP4x4x2];  // unpacked quants
+    memset(qs, 0, sizeof(qs));
+
+    for (int i = 0; i < nb; i++) {
+        pack_mxfp4_quants(&x[i * 8 + 0], qs, 0);
+        pack_mxfp4_quants(&x[i * 8 + 1], qs, 1);
+        pack_mxfp4_quants(&x[i * 8 + 2], qs, 2);
+        pack_mxfp4_quants(&x[i * 8 + 3], qs, 3);
+        pack_mxfp4_quants(&x[i * 8 + 4], qs, 4);
+        pack_mxfp4_quants(&x[i * 8 + 5], qs, 5);
+        pack_mxfp4_quants(&x[i * 8 + 6], qs, 6);
+        pack_mxfp4_quants(&x[i * 8 + 7], qs, 7);
+    }
+
+    // Init the scales
+    // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_MXFP4x4x2)
+    // the last block is truncated and overriden by the scales.
+    for (int i = 0; i < nb; i++) {
+        // Unpack the scales
+        x[i * 8 + 0].e = 0;
+        x[i * 8 + 1].e = 0;
+        x[i * 8 + 2].e = 0;
+        x[i * 8 + 3].e = 0;
+        x[i * 8 + 4].e = 0;
+        x[i * 8 + 5].e = 0;
+        x[i * 8 + 6].e = 0;
+        x[i * 8 + 7].e = 0;
+    }
+}
+
+// repack mxfp4 data into mxfp4x4x2 tensor
+static void repack_mxfp4_mxfp4x4x2(ggml_tensor * t, const void * data, size_t size) {
+    int64_t nrows = ggml_nrows(t);
+
+    size_t row_size    = ggml_row_size(t->type, t->ne[0]);
+    size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_MXFP4x4x2));  // extra elements for the pad
+    size_t row_size_rp = row_size * 2;  // extra space for tmp pad (if any)
+
+    void * buf_pd = ggml_aligned_malloc(row_size_pd);
+    GGML_ASSERT(buf_pd != NULL);
+
+    void * buf_rp = ggml_aligned_malloc(row_size_rp);
+    GGML_ASSERT(buf_rp != NULL);
+
+    HEX_VERBOSE("ggml-hex: repack-mxfp4-mxfp4x4x2 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data,
+                size, t->ne[0], nrows, row_size);
+
+    init_row_mxfp4x4x2((block_mxfp4 *) buf_pd, t->ne[0]);  // init padded buffer to make sure the tail is all zeros
+
+    for (int64_t i = 0; i < nrows; i++) {
+        const uint8_t * src = (const uint8_t *) data + (i * row_size);
+        uint8_t *       dst = (uint8_t *) t->data + (i * row_size);
+
+        memcpy(buf_pd, src, row_size);
+        repack_row_mxfp4x4x2((uint8_t *) buf_rp, (const block_mxfp4 *) buf_pd, t->ne[0]);
+        memcpy(dst, buf_rp, row_size);
+    }
+
+    ggml_aligned_free(buf_pd, row_size_pd);
+    ggml_aligned_free(buf_rp, row_size_rp);
+}
+
+// repack mxfp4x4x2 tensor into mxfp4 data
+static void repack_mxfp4x4x2_mxfp4(void * data, const ggml_tensor * t, size_t size) {
+    int64_t nrows = ggml_nrows(t);
+
+    size_t row_size    = ggml_row_size(t->type, t->ne[0]);
+    size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_MXFP4x4x2));  // extra elements for the pad
+    size_t row_size_rp = row_size * 2;  // extra space for tmp pad (if any)
+
+    void * buf_pd = ggml_aligned_malloc(row_size_pd);
+    GGML_ASSERT(buf_pd != NULL);
+
+    void * buf_rp = ggml_aligned_malloc(row_size_rp);
+    GGML_ASSERT(buf_rp != NULL);
+
+    HEX_VERBOSE("ggml-hex: repack-mxfp4x4x2-mxfp4 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data,
+                size, t->ne[0], nrows, row_size);
+
+    memset(buf_pd, 0, row_size_pd);  // clear-out padded buffer to make sure the tail is all zeros
+
+    for (int64_t i = 0; i < nrows; i++) {
+        const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
+        uint8_t *       dst = (uint8_t *) data + (i * row_size);
+
+        memcpy(buf_pd, src, row_size);
+        unpack_row_mxfp4x4x2((block_mxfp4 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]);
+        memcpy(dst, buf_rp, row_size);
+    }
+
+    ggml_aligned_free(buf_pd, row_size_pd);
+    ggml_aligned_free(buf_rp, row_size_rp);
+}
+
+static void ggml_backend_hexagon_buffer_set_tensor(ggml_backend_buffer_t buffer,
+                                                   ggml_tensor *         tensor,
+                                                   const void *          data,
+                                                   size_t                offset,
+                                                   size_t                size) {
+    auto ctx  = (ggml_backend_hexagon_buffer_context *) buffer->context;
+    auto sess = ctx->sess;
+
+    HEX_VERBOSE("ggml-hex: %s set-tensor %s : data %p offset %zu size %zu\n", sess->name.c_str(), tensor->name, data,
+                offset, size);
+
+    switch (tensor->type) {
+        case GGML_TYPE_Q4_0:
+            GGML_ASSERT(offset == 0);
+            GGML_ASSERT(size == ggml_nbytes(tensor));
+            repack_q4_0_q4x4x2(tensor, data, size);
+            break;
+
+        case GGML_TYPE_Q8_0:
+            GGML_ASSERT(offset == 0);
+            GGML_ASSERT(size == ggml_nbytes(tensor));
+            repack_q8_0_q8x4x2(tensor, data, size);
+            break;
+
+        case GGML_TYPE_MXFP4:
+            GGML_ASSERT(offset == 0);
+            GGML_ASSERT(size == ggml_nbytes(tensor));
+            repack_mxfp4_mxfp4x4x2(tensor, data, size);
+            break;
+
+        default:
+            memcpy((char *) tensor->data + offset, data, size);
+            break;
+    }
+}
+
+static void ggml_backend_hexagon_buffer_get_tensor(ggml_backend_buffer_t buffer,
+                                                   const ggml_tensor *   tensor,
+                                                   void *                data,
+                                                   size_t                offset,
+                                                   size_t                size) {
+    auto ctx  = (ggml_backend_hexagon_buffer_context *) buffer->context;
+    auto sess = ctx->sess;
+
+    HEX_VERBOSE("ggml-hex: %s get-tensor %s : data %p offset %zu size %zu\n", sess->name.c_str(), tensor->name, data,
+                offset, size);
+
+    switch (tensor->type) {
+        case GGML_TYPE_Q4_0:
+            GGML_ASSERT(offset == 0);
+            GGML_ASSERT(size == ggml_nbytes(tensor));
+            repack_q4x4x2_q4_0(data, tensor, size);
+            break;
+
+        case GGML_TYPE_Q8_0:
+            GGML_ASSERT(offset == 0);
+            GGML_ASSERT(size == ggml_nbytes(tensor));
+            repack_q8x4x2_q8_0(data, tensor, size);
+            break;
+
+        case GGML_TYPE_MXFP4:
+            GGML_ASSERT(offset == 0);
+            GGML_ASSERT(size == ggml_nbytes(tensor));
+            repack_mxfp4x4x2_mxfp4(data, tensor, size);
+            break;
+
+        default:
+            memcpy(data, (const char *) tensor->data + offset, size);
+            break;
+    }
+}
+
+static bool ggml_backend_hexagon_buffer_cpy_tensor(ggml_backend_buffer_t      buffer,
+                                                   const struct ggml_tensor * src,
+                                                   struct ggml_tensor *       dst) {
+    GGML_UNUSED(buffer);
+    GGML_UNUSED(src);
+    GGML_UNUSED(dst);
+    // we might optimize this later, for now take the slow path (ie get/set_tensor)
+    return false;
+}
+
+static void ggml_backend_hexagon_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+    auto ctx  = (ggml_backend_hexagon_buffer_context *) buffer->context;
+    auto sess = ctx->sess;
+    HEX_VERBOSE("ggml-hex: %s clear-buff base %p size %zu\n", sess->name.c_str(), (void *) ctx->base, ctx->size);
+    memset(ctx->base, value, ctx->size);
+}
+
+static ggml_backend_buffer_i ggml_backend_hexagon_buffer_interface = {
+    /* .free_buffer     = */ ggml_backend_hexagon_buffer_free_buffer,
+    /* .get_base        = */ ggml_backend_hexagon_buffer_get_base,
+    /* .init_tensor     = */ ggml_backend_hexagon_buffer_init_tensor,
+    /* .memset_tensor   = */ NULL,
+    /* .set_tensor      = */ ggml_backend_hexagon_buffer_set_tensor,
+    /* .get_tensor      = */ ggml_backend_hexagon_buffer_get_tensor,
+    /* .cpy_tensor      = */ ggml_backend_hexagon_buffer_cpy_tensor,
+    /* .clear           = */ ggml_backend_hexagon_buffer_clear,
+    /* .reset           = */ NULL,
+};
+
+// ** backend buffer type
+
+static const char * ggml_backend_hexagon_buffer_type_name(ggml_backend_buffer_type_t buffer_type) {
+    return static_cast<ggml_backend_hexagon_buffer_type_context *>(buffer_type->context)->name.c_str();
+}
+
+static ggml_backend_buffer_t ggml_backend_hexagon_buffer_type_alloc_buffer(
+            ggml_backend_buffer_type_t buffer_type, size_t size) {
+    auto sess = static_cast<ggml_backend_hexagon_buffer_type_context *>(buffer_type->context)->sess;
+    try {
+        ggml_backend_hexagon_buffer_context * ctx = new ggml_backend_hexagon_buffer_context(sess, size, false /*repack*/);
+        return ggml_backend_buffer_init(buffer_type, ggml_backend_hexagon_buffer_interface, ctx, size);
+    } catch (std::exception const &exc) {
+        GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer context: %s\n", sess->name.c_str(), exc.what());
+        return nullptr;
+    }
+}
+
+static ggml_backend_buffer_t ggml_backend_hexagon_repack_buffer_type_alloc_buffer(
+            ggml_backend_buffer_type_t buffer_type, size_t size) {
+    auto sess = static_cast<ggml_backend_hexagon_buffer_type_context *>(buffer_type->context)->sess;
+    try {
+        ggml_backend_hexagon_buffer_context * ctx = new ggml_backend_hexagon_buffer_context(sess, size, true /*repack*/);
+        return ggml_backend_buffer_init(buffer_type, ggml_backend_hexagon_buffer_interface, ctx, size);
+    } catch (std::exception const &exc) {
+        GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer context: %s\n", sess->name.c_str(), exc.what());
+        return nullptr;
+    }
+}
+
+static size_t ggml_backend_hexagon_buffer_type_get_alignment(ggml_backend_buffer_type_t buffer_type) {
+    return 128;  // HVX alignment
+    GGML_UNUSED(buffer_type);
+}
+
+static size_t ggml_backend_hexagon_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * t) {
+    return ggml_nbytes(t);
+}
+
+static size_t ggml_backend_hexagon_buffer_type_get_max_size(ggml_backend_buffer_type_t buffer_type) {
+    return 1 * 1024 * 1024 * 1024;  // 1GB per buffer
+    GGML_UNUSED(buffer_type);
+}
+
+static bool ggml_backend_hexagon_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
+    return opt_hostbuf;
+    GGML_UNUSED(buft);
+}
+
+static bool ggml_backend_hexagon_repack_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
+    return false;
+    GGML_UNUSED(buft);
+}
+
+static ggml_backend_buffer_type_i ggml_backend_hexagon_buffer_type_interface = {
+    /* .get_name         = */ ggml_backend_hexagon_buffer_type_name,
+    /* .alloc_buffer     = */ ggml_backend_hexagon_buffer_type_alloc_buffer,
+    /* .get_alignment    = */ ggml_backend_hexagon_buffer_type_get_alignment,
+    /* .get_max_size     = */ ggml_backend_hexagon_buffer_type_get_max_size,
+    /* .get_alloc_size   = */ ggml_backend_hexagon_buffer_type_get_alloc_size,
+    /* .is_host          = */ ggml_backend_hexagon_buffer_type_is_host,
+};
+
+static ggml_backend_buffer_type_i ggml_backend_hexagon_repack_buffer_type_interface = {
+    /* .get_name         = */ ggml_backend_hexagon_buffer_type_name,
+    /* .alloc_buffer     = */ ggml_backend_hexagon_repack_buffer_type_alloc_buffer,
+    /* .get_alignment    = */ ggml_backend_hexagon_buffer_type_get_alignment,
+    /* .get_max_size     = */ ggml_backend_hexagon_buffer_type_get_max_size,
+    /* .get_alloc_size   = */ ggml_backend_hexagon_buffer_type_get_alloc_size,
+    /* .is_host          = */ ggml_backend_hexagon_repack_buffer_type_is_host,
+};
+
+void ggml_hexagon_session::allocate(int dev_id) noexcept(false) {
+    this->valid_session = false;
+    this->valid_handle  = false;
+    this->valid_queue   = false;
+    this->valid_iface   = false;
+
+    this->domain_id  = 3;  // Default for CDSP, updated after the session is created
+    this->session_id = 0;  // Default for CDSP, updated after the session is created
+    this->dev_id     = dev_id;
+    this->name       = std::string("HTP") + std::to_string(dev_id);
+
+    this->op_pending  = 0;
+    this->prof_usecs  = 0;
+    this->prof_cycles = 0;
+    this->prof_pkts   = 0;
+
+    GGML_LOG_INFO("ggml-hex: allocating new session: %s\n", this->name.c_str());
+
+    domain * my_domain = get_domain(this->domain_id);
+    if (my_domain == NULL) {
+        GGML_LOG_ERROR("ggml-hex: unable to get domain struct for CDSP\n");
+        throw std::runtime_error("ggml-hex: failed to get CDSP domain (see log for details)");
+    }
+
+    // Create new session
+    if (dev_id != 0) {
+        struct remote_rpc_reserve_new_session n;
+        n.domain_name_len  = strlen(CDSP_DOMAIN_NAME);
+        n.domain_name      = const_cast<char *>(CDSP_DOMAIN_NAME);
+        n.session_name     = const_cast<char *>(this->name.c_str());
+        n.session_name_len = this->name.size();
+
+        int err = remote_session_control(FASTRPC_RESERVE_NEW_SESSION, (void *) &n, sizeof(n));
+        if (err != AEE_SUCCESS) {
+            GGML_LOG_ERROR("ggml-hex: failed to reserve new session %d : error 0x%x\n", dev_id, err);
+            throw std::runtime_error("ggml-hex: remote_session_control(new-sess) failed (see log for details)");
+        }
+
+        // Save the IDs
+        this->session_id = n.session_id;
+        this->domain_id  = n.effective_domain_id;
+        this->valid_session = true;
+    }
+
+    // Get session URI
+    char htp_uri[256];
+    sprintf(htp_uri, "file:///libggml-htp-v%u.so?htp_iface_skel_handle_invoke&_modver=1.0", opt_arch);
+
+    char session_uri[256];
+    {
+        struct remote_rpc_get_uri u;
+        u.session_id      = this->session_id;
+        u.domain_name     = const_cast<char *>(CDSP_DOMAIN_NAME);
+        u.domain_name_len = strlen(CDSP_DOMAIN_NAME);
+        u.module_uri      = const_cast<char *>(htp_uri);
+        u.module_uri_len  = strlen(htp_uri);
+        u.uri             = session_uri;
+        u.uri_len         = sizeof(session_uri);
+
+        int err = remote_session_control(FASTRPC_GET_URI, (void *) &u, sizeof(u));
+        if (err != AEE_SUCCESS) {
+            GGML_LOG_ERROR("ggml-hex: failed to get URI for session %d : error 0x%x\n", dev_id, err);
+            throw std::runtime_error("ggml-hex: remote_session_control(get-uri) failed (see log for details)");
+        }
+    }
+
+    // Enable Unsigned PD
+    {
+        struct remote_rpc_control_unsigned_module u;
+        u.domain = this->domain_id;
+        u.enable = 1;
+        int err  = remote_session_control(DSPRPC_CONTROL_UNSIGNED_MODULE, (void *) &u, sizeof(u));
+        if (err != AEE_SUCCESS) {
+            GGML_LOG_ERROR("ggml-hex: failed to enable unsigned PD for session %d : error 0x%x\n", dev_id, err);
+            throw std::runtime_error("ggml-hex: remote_session_control(unsign) failed (see log for details)");
+        }
+    }
+
+    // Open session
+    int err = htp_iface_open(session_uri, &this->handle);
+    if (err != AEE_SUCCESS) {
+        GGML_LOG_ERROR("ggml-hex: failed to open session %d : error 0x%x\n", dev_id, err);
+        throw std::runtime_error("ggml-hex: failed to open session (see log for details)");
+    }
+
+    this->valid_handle = true;
+
+    GGML_LOG_INFO("ggml-hex: new session: %s : session-id %d domain-id %d uri %s handle 0x%lx\n", this->name.c_str(),
+            this->session_id, this->domain_id, session_uri, (unsigned long) this->handle);
+
+    // Enable FastRPC QoS mode
+    {
+        struct remote_rpc_control_latency l;
+        l.enable = 1;
+
+        int err = remote_handle64_control(this->handle, DSPRPC_CONTROL_LATENCY, (void *) &l, sizeof(l));
+        if (err != 0) {
+            GGML_LOG_WARN("ggml-hex: failed to enable fastrpc QOS mode: 0x%08x\n", (unsigned) err);
+        }
+    }
+
+    // Now let's setup the DSP queue
+    err = dspqueue_create(this->domain_id,
+                          0,              // Flags
+                          128 * 1024,     // Request  queue size (in bytes)
+                          64 * 1024,      // Response queue size (in bytes)
+                          htp_packet_callback, htp_error_callback,
+                          (void *) this,  // Callback context
+                          &queue);
+    if (err != 0) {
+        GGML_LOG_ERROR("ggml-hex: %s dspqueue_create failed: 0x%08x\n", this->name.c_str(), (unsigned) err);
+        throw std::runtime_error("ggml-hex: failed to create dspqueue (see log for details)");
+    }
+
+    this->valid_queue = true;
+
+    // Export queue for use on the DSP
+    err = dspqueue_export(queue, &this->queue_id);
+    if (err != 0) {
+        GGML_LOG_ERROR("ggml-hex: dspqueue_export failed: 0x%08x\n", (unsigned) err);
+        throw std::runtime_error("ggml-hex: dspqueue export failed (see log for details)");
+    }
+
+    if (opt_etm) {
+        err = htp_iface_enable_etm(this->handle);
+        if (err != 0) {
+            GGML_LOG_ERROR("ggml-hex: failed to enable ETM tracing: 0x%08x\n", (unsigned) err);
+        }
+    }
+
+    // Start the DSP-side service. We need to pass the queue ID to the
+    // DSP in a FastRPC call; the DSP side will import the queue and start
+    // listening for packets in a callback.
+    err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx);
+    if (err != 0) {
+        GGML_LOG_ERROR("ggml-hex: failed to start session: 0x%08x\n", (unsigned) err);
+        throw std::runtime_error("ggml-hex: iface start failed (see log for details)");
+    }
+    this->valid_iface = true;
+}
+
+void ggml_hexagon_session::release() noexcept(true) {
+    GGML_LOG_INFO("ggml-hex: releasing session: %s\n", this->name.c_str());
+
+    int err;
+
+    // Stop the DSP-side service and close the queue
+    if (this->valid_iface) {
+        err = htp_iface_stop(this->handle);
+        if (err != 0) {
+            GGML_ABORT("ggml-hex: htp_iface_stop failed: 0x%08x\n", (unsigned) err);
+        }
+    }
+
+    if (opt_etm) {
+        err = htp_iface_disable_etm(this->handle);
+        if (err != 0) {
+            GGML_LOG_ERROR("ggml-hex: warn : failed to disable ETM tracing: 0x%08x\n", (unsigned) err);
+        }
+    }
+
+    if (this->valid_queue) {
+        err = dspqueue_close(queue);
+        if (err != 0) {
+            GGML_ABORT("ggml-hex: dspqueue_close failed: 0x%08x\n", (unsigned) err);
+        }
+    }
+
+    if (this->valid_handle) {
+        htp_iface_close(this->handle);
+    }
+}
+
+ggml_hexagon_session::ggml_hexagon_session(int dev_id) noexcept(false) {
+    buffer_type.context        = nullptr;
+    repack_buffer_type.context = nullptr;
+
+    try {
+        allocate(dev_id);
+
+        buffer_type.iface   = ggml_backend_hexagon_buffer_type_interface;
+        buffer_type.context = new ggml_backend_hexagon_buffer_type_context(this->name, this);
+
+        repack_buffer_type.iface   = ggml_backend_hexagon_repack_buffer_type_interface;
+        repack_buffer_type.context = new ggml_backend_hexagon_buffer_type_context(this->name + "-REPACK", this);
+    } catch (std::exception const &exc) {
+        release();
+        throw;
+    }
+}
+
+ggml_hexagon_session::~ggml_hexagon_session() noexcept(true) {
+    release();
+
+    delete static_cast<ggml_backend_hexagon_buffer_type_context*>(buffer_type.context);
+    delete static_cast<ggml_backend_hexagon_buffer_type_context*>(repack_buffer_type.context);
+}
+
+// ** backend interface
+
+static bool ggml_backend_buffer_is_hexagon(const struct ggml_backend_buffer * b) {
+    return b->buft->iface.get_alignment == ggml_backend_hexagon_buffer_type_get_alignment;
+}
+
+static inline bool ggml_backend_buffer_is_hexagon_repack(const struct ggml_backend_buffer * b) {
+    return b->buft->iface.alloc_buffer == ggml_backend_hexagon_repack_buffer_type_alloc_buffer;
+}
+
+static bool hex_supported_dims2(const struct ggml_tensor * x, const struct ggml_tensor * y) {
+    if (x->ne[0] != y->ne[0]) {
+        return false;
+    }
+    if (x->ne[1] != y->ne[1]) {
+        return false;
+    }
+    if (x->ne[2] != y->ne[2]) {
+        return false;
+    }
+    if (x->ne[3] != y->ne[3]) {
+        return false;
+    }
+
+    return true;
+}
+
+static bool hex_supported_src0_type(ggml_type t) {
+    return t == GGML_TYPE_F32;
+}
+
+static bool hex_supported_src1_type(ggml_type t) {
+    return t == GGML_TYPE_F32;
+}
+
+static bool hex_supported_src2_type(ggml_type t) {
+    return t == GGML_TYPE_F32;
+}
+
+static bool hex_supported_src1_type2(ggml_type t) {
+    return t == GGML_TYPE_F16;
+}
+
+static bool hex_supported_src1_type3(ggml_type t) {
+    return t == GGML_TYPE_I32;
+}
+
+static bool hex_supported_dst_type(ggml_type t) {
+    return t == GGML_TYPE_F32;
+}
+
+static bool hex_supported_dims(const struct ggml_tensor * x, const struct ggml_tensor * y) {
+    // TODO: support broadcast for ne[2 and 3]
+    if (x->ne[0] != y->ne[0]) {
+        return false;
+    }
+    if (x->ne[2] != y->ne[2]) {
+        return false;
+    }
+    if (x->ne[3] != y->ne[3]) {
+        return false;
+    }
+    return true;
+}
+
+static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * dst) {
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    if (src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) {
+        return false;
+    }
+
+    // TODO: add support for non-cont tensors
+    if (!ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
+        return false;
+    }
+
+    switch (src0->type) {
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q8_0:
+        case GGML_TYPE_MXFP4:
+            if (src0->ne[0] % 32) {
+                return false;
+            }
+
+            if (src0->ne[1] > 16 * 1024) {
+                return false;  // typically the lm-head which would be too large for VTCM
+            }
+
+            // if ((src0->ne[2] != src1->ne[2] || src0->ne[3] != src1->ne[3])) return false;
+            if ((src1->ne[2] != 1 || src1->ne[3] != 1)) {
+                return false;
+            }
+
+            // src0 (weights) must be repacked
+            if (src0->buffer && !ggml_backend_buffer_is_hexagon_repack(src0->buffer)) {
+                return false;
+            }
+            break;
+
+        case GGML_TYPE_F16:
+            if (!opt_experimental) {
+                return false;
+            }
+            break;
+
+        default:
+            return false;
+    }
+
+    // src0 & src1 & dst must be mapped to the same session
+    if (src0->buffer &&
+        (!ggml_backend_buffer_is_hexagon(src0->buffer) || ggml_backend_hexagon_buffer_get_sess(src0->buffer) != sess)) {
+        return false;
+    }
+    if (src1->buffer &&
+        (!ggml_backend_buffer_is_hexagon(src1->buffer) || ggml_backend_hexagon_buffer_get_sess(src1->buffer) != sess)) {
+        return false;
+    }
+    if (dst->buffer &&
+        (!ggml_backend_buffer_is_hexagon(dst->buffer) || ggml_backend_hexagon_buffer_get_sess(dst->buffer) != sess)) {
+        return false;
+    }
+
+    return true;
+}
+
+static bool ggml_hexagon_supported_mul_mat_id(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
+    const struct ggml_tensor * src0 = op->src[0];
+    const struct ggml_tensor * src1 = op->src[1];
+    const struct ggml_tensor * src2 = op->src[2];
+    const struct ggml_tensor * dst  = op;
+
+    if (src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32 || src2->type != GGML_TYPE_I32) {
+        return false;
+    }
+
+    switch (src0->type) {
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q8_0:
+        case GGML_TYPE_MXFP4:
+            if ((src0->ne[0] % 32)) {
+                return false;
+            }
+
+            // src0 (weights) must be repacked
+            if (src0->buffer && !ggml_backend_buffer_is_hexagon_repack(src0->buffer)) {
+                return false;
+            }
+            break;
+
+        case GGML_TYPE_F16:
+            if (!opt_experimental) {
+                return false;
+            }
+            break;
+
+        default:
+            return false;
+    }
+
+    // TODO: add support for non-cont tensors
+    if (!ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
+        return false;
+    }
+
+    // src0 (weights) must be repacked and mapped to the same session
+    // src1 & sr2 & dst must be mapped to the same session
+    if (src0->buffer &&
+        (!ggml_backend_buffer_is_hexagon(src0->buffer) || ggml_backend_hexagon_buffer_get_sess(src0->buffer) != sess)) {
+        return false;
+    }
+    if (src1->buffer &&
+        (!ggml_backend_buffer_is_hexagon(src1->buffer) || ggml_backend_hexagon_buffer_get_sess(src1->buffer) != sess)) {
+        return false;
+    }
+    if (src2->buffer &&
+        (!ggml_backend_buffer_is_hexagon(src2->buffer) || ggml_backend_hexagon_buffer_get_sess(src2->buffer) != sess)) {
+        return false;
+    }
+    if (dst->buffer &&
+        (!ggml_backend_buffer_is_hexagon(dst->buffer) || ggml_backend_hexagon_buffer_get_sess(dst->buffer) != sess)) {
+        return false;
+    }
+
+    return true;
+}
+
+static bool ggml_hexagon_supported_binary(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
+    const struct ggml_tensor * src0 = op->src[0];
+    const struct ggml_tensor * src1 = op->src[1];
+    const struct ggml_tensor * dst  = op;
+
+    if (!hex_supported_src0_type(src0->type)) {
+        return false;
+    }
+    if (!hex_supported_src1_type(src1->type)) {
+        return false;
+    }
+    if (!hex_supported_dst_type(dst->type)) {
+        return false;
+    }
+    if (!hex_supported_dims2(src0, dst)) {
+        return false;
+    }
+    if (!ggml_can_repeat(src1, src0)) {
+        return false;
+    }
+
+    // TODO: add support for non-contigiuos tensors
+    if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
+        return false;
+    }
+
+    // src0, src1 & dst must be mapped to the same session
+    if (src0->buffer &&
+        (!ggml_backend_buffer_is_hexagon(src0->buffer) || ggml_backend_hexagon_buffer_get_sess(src0->buffer) != sess)) {
+        return false;
+    }
+    if (src1->buffer &&
+        (!ggml_backend_buffer_is_hexagon(src1->buffer) || ggml_backend_hexagon_buffer_get_sess(src1->buffer) != sess)) {
+        return false;
+    }
+    if (dst->buffer &&
+        (!ggml_backend_buffer_is_hexagon(dst->buffer) || ggml_backend_hexagon_buffer_get_sess(dst->buffer) != sess)) {
+        return false;
+    }
+
+    return true;
+}
+
+static bool ggml_hexagon_supported_add_id(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
+    const struct ggml_tensor * src0 = op->src[0];
+    const struct ggml_tensor * src1 = op->src[1];
+    const struct ggml_tensor * src2 = op->src[2];
+    const struct ggml_tensor * dst  = op;
+
+    if (!hex_supported_src0_type(src0->type)) {
+        return false;
+    }
+    if (!hex_supported_src1_type(src1->type)) {
+        return false;
+    }
+    if (!hex_supported_dst_type(dst->type)) {
+        return false;
+    }
+    if (!hex_supported_dims2(src0, dst)) {
+        return false;
+    }
+
+    // REVISIT: add support for non-contigiuos tensors
+    if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
+        return false;
+    }
+
+    // src0, src1 & dst must be mapped to the same session
+    if (src0->buffer &&
+        (!ggml_backend_buffer_is_hexagon(src0->buffer) || ggml_backend_hexagon_buffer_get_sess(src0->buffer) != sess)) {
+        return false;
+    }
+    if (src1->buffer &&
+        (!ggml_backend_buffer_is_hexagon(src1->buffer) || ggml_backend_hexagon_buffer_get_sess(src1->buffer) != sess)) {
+        return false;
+    }
+    if (src2->buffer &&
+        (!ggml_backend_buffer_is_hexagon(src2->buffer) || ggml_backend_hexagon_buffer_get_sess(src2->buffer) != sess)) {
+        return false;
+    }
+    if (dst->buffer &&
+        (!ggml_backend_buffer_is_hexagon(dst->buffer) || ggml_backend_hexagon_buffer_get_sess(dst->buffer) != sess)) {
+        return false;
+    }
+
+    return true;
+}
+
+static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
+    const struct ggml_tensor * src0 = op->src[0];
+    const struct ggml_tensor * dst  = op;
+
+    if (!hex_supported_src0_type(src0->type)) {
+        return false;
+    }
+    if (!hex_supported_dst_type(dst->type)) {
+        return false;
+    }
+    if (!hex_supported_dims2(src0, dst)) {
+        return false;
+    }
+
+    // TODO: add support for non-contigiuos tensors
+    if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) {
+        return false;
+    }
+
+    // src0 & dst must be mapped to the same session
+    if (src0->buffer &&
+        (!ggml_backend_buffer_is_hexagon(src0->buffer) || ggml_backend_hexagon_buffer_get_sess(src0->buffer) != sess)) {
+        return false;
+    }
+    if (dst->buffer &&
+        (!ggml_backend_buffer_is_hexagon(dst->buffer) || ggml_backend_hexagon_buffer_get_sess(dst->buffer) != sess)) {
+        return false;
+    }
+
+    return true;
+}
+
+static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session * sess,
+                                               const struct ggml_tensor *          op) {
+    const struct ggml_tensor * src0 = op->src[0];
+    const struct ggml_tensor * src1 = op->src[1];
+    const struct ggml_tensor * dst  = op;
+
+    if (!hex_supported_src0_type(src0->type)) {
+        return false;
+    }
+    if (!hex_supported_dst_type(dst->type)) {
+        return false;
+    }
+
+    if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) {
+        return false;
+    }
+
+    if (src1) {
+        if (!hex_supported_src1_type(src1->type)) {
+            return false;
+        }
+        if (!hex_supported_dims2(src0, src1)) {
+            return false;
+        }
+        if (!ggml_is_contiguous(src1)) {
+            return false;
+        }
+    }
+
+    // src0, src1 & dst must be mapped to the same session
+    if (src0->buffer &&
+        (!ggml_backend_buffer_is_hexagon(src0->buffer) || ggml_backend_hexagon_buffer_get_sess(src0->buffer) != sess)) {
+        return false;
+    }
+    if (src1 && src1->buffer &&
+        (!ggml_backend_buffer_is_hexagon(src1->buffer) || ggml_backend_hexagon_buffer_get_sess(src1->buffer) != sess)) {
+        return false;
+    }
+    if (dst->buffer &&
+        (!ggml_backend_buffer_is_hexagon(dst->buffer) || ggml_backend_hexagon_buffer_get_sess(dst->buffer) != sess)) {
+        return false;
+    }
+
+    return true;
+}
+
+static bool ggml_hexagon_supported_softmax(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
+    const struct ggml_tensor * src0 = op->src[0];
+    const struct ggml_tensor * src1 = op->src[1];
+    const struct ggml_tensor * src2 = op->src[2];
+    const struct ggml_tensor * dst  = op;
+
+    if (src2) {
+        return false;  // FIXME: add support for sinks
+    }
+
+    if (!hex_supported_src0_type(src0->type)) {
+        return false;
+    }
+    if (!hex_supported_dst_type(dst->type)) {
+        return false;
+    }
+
+    if (src1) {
+        if (!hex_supported_src1_type(src1->type) && !hex_supported_src1_type2(src1->type)) {
+            return false;
+        }
+        if (src0->ne[0] != src1->ne[0]) {
+            return false;
+        }
+        if (src1->ne[1] < src0->ne[1]) {
+            return false;
+        }
+        if (src0->ne[2] % src1->ne[2] != 0) {
+            return false;
+        }
+        if (src0->ne[3] % src1->ne[3] != 0) {
+            return false;
+        }
+    }
+
+    if (src1) {
+        if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
+            return false;
+        }
+    } else {
+        if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) {
+            return false;
+        }
+    }
+
+    // src0, src1 & dst must be mapped to the same session
+    if (src0->buffer &&
+        (!ggml_backend_buffer_is_hexagon(src0->buffer) || ggml_backend_hexagon_buffer_get_sess(src0->buffer) != sess)) {
+        return false;
+    }
+    if (src1 && src1->buffer &&
+        (!ggml_backend_buffer_is_hexagon(src1->buffer) || ggml_backend_hexagon_buffer_get_sess(src1->buffer) != sess)) {
+        return false;
+    }
+    if (dst->buffer &&
+        (!ggml_backend_buffer_is_hexagon(dst->buffer) || ggml_backend_hexagon_buffer_get_sess(dst->buffer) != sess)) {
+        return false;
+    }
+
+    return true;
+}
+
+static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
+    const int32_t * op_params = &op->op_params[0];
+
+    int mode = op_params[2];
+
+    if ((mode & GGML_ROPE_TYPE_NEOX) || (mode & GGML_ROPE_TYPE_MROPE) || (mode & GGML_ROPE_TYPE_VISION)) {
+        return false;
+    }
+    if (mode & 1) {
+        return false;
+    }
+
+    const struct ggml_tensor * src0 = op->src[0];
+    const struct ggml_tensor * src1 = op->src[1];
+    const struct ggml_tensor * src2 = op->src[2];
+    const struct ggml_tensor * dst  = op;
+
+    if (!hex_supported_src0_type(src0->type)) {
+        return false;  // FIXME: add support for GGML_TYPE_F16 for src0
+    }
+    if (!hex_supported_dst_type(dst->type)) {
+        return false;
+    }
+    if (!hex_supported_src1_type3(src1->type)) {
+        return false;
+    }
+    if (src2) {
+        if (!hex_supported_src2_type(src2->type)) {
+            return false;
+        }
+        int n_dims = op_params[1];
+        if (src2->ne[0] < (n_dims / 2)) {
+            return false;
+        }
+    }
+
+    if (src2) {
+        if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(src2) ||
+            !ggml_is_contiguous(dst)) {
+            return false;
+        }
+    } else {
+        if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
+            return false;
+        }
+    }
+
+    // src0, src1, src2 & dst must be mapped to the same session
+    if (src0->buffer &&
+        (!ggml_backend_buffer_is_hexagon(src0->buffer) || ggml_backend_hexagon_buffer_get_sess(src0->buffer) != sess)) {
+        return false;
+    }
+    if (src1->buffer &&
+        (!ggml_backend_buffer_is_hexagon(src1->buffer) || ggml_backend_hexagon_buffer_get_sess(src1->buffer) != sess)) {
+        return false;
+    }
+    if (src2 && src2->buffer &&
+        (!ggml_backend_buffer_is_hexagon(src2->buffer) || ggml_backend_hexagon_buffer_get_sess(src2->buffer) != sess)) {
+        return false;
+    }
+    if (dst->buffer &&
+        (!ggml_backend_buffer_is_hexagon(dst->buffer) || ggml_backend_hexagon_buffer_get_sess(dst->buffer) != sess)) {
+        return false;
+    }
+
+    return true;
+}
+
+// Init hexagon tensor from GGML tensor and Hexagon buffer
+static void init_htp_tensor(htp_tensor * h, const ggml_tensor * t) {
+    h->data  = 0;  // updated by the receiver
+    h->type  = t->type;
+    h->ne[0] = t->ne[0];
+    h->ne[1] = t->ne[1];
+    h->ne[2] = t->ne[2];
+    h->ne[3] = t->ne[3];
+    h->nb[0] = t->nb[0];
+    h->nb[1] = t->nb[1];
+    h->nb[2] = t->nb[2];
+    h->nb[3] = t->nb[3];
+}
+
+static void hex_dump_dspbuf(const struct ggml_tensor * t, const dspqueue_buffer * d) {
+    auto buf  = static_cast<ggml_backend_hexagon_buffer_context *>(t->buffer->context);
+    auto sess = buf->sess;
+
+    HEX_VERBOSE("ggml-hex: %s dspqbuf : %s base-addr %p base-size %zu data %p offset %u size %u\n", sess->name.c_str(),
+                t->name, (void *) buf->base, buf->size, (void *) d->ptr, (unsigned int) d->offset,
+                (unsigned int) d->size);
+}
+
+static void ggml_hexagon_mul_mat(const struct ggml_tensor * op, uint32_t flags) {
+    const struct ggml_tensor * src0 = op->src[0];
+    const struct ggml_tensor * src1 = op->src[1];
+    const struct ggml_tensor * dst  = op;
+
+    auto src0_buf = static_cast<ggml_backend_hexagon_buffer_context *>(src0->buffer->context);
+    auto src1_buf = static_cast<ggml_backend_hexagon_buffer_context *>(src1->buffer->context);
+    auto dst_buf  = static_cast<ggml_backend_hexagon_buffer_context *>(dst->buffer->context);
+
+    uint64_t t1, t2;
+    t1 = ggml_time_us();
+
+    // Construct HTP message
+    htp_general_req req;
+    req.op    = HTP_OP_MUL_MAT;
+    req.flags = flags;
+
+    init_htp_tensor(&req.src0, src0);
+    init_htp_tensor(&req.src1, src1);
+    init_htp_tensor(&req.dst, dst);
+
+    // Use opmask to override flags
+    if (!(opt_opmask & HTP_OPMASK_QUANTIZE)) {
+        req.flags |= HTP_OPFLAGS_SKIP_QUANTIZE;
+    }
+    if (!(opt_opmask & HTP_OPMASK_COMPUTE)) {
+        req.flags |= HTP_OPFLAGS_SKIP_COMPUTE;
+    }
+
+    dspqueue_buffer bufs[3];
+    memset(bufs, 0, sizeof(bufs));
+
+    // First buffer Weights.
+    // The content is static, there is no need to do any cache management
+    bufs[0].fd     = src0_buf->fd;
+    bufs[0].ptr    = src0->data;
+    bufs[0].offset = (uint8_t *) src0->data - src0_buf->base;
+    bufs[0].size   = ggml_nbytes(src0);
+    bufs[0].flags  = DSPQUEUE_BUFFER_FLAG_REF;
+
+    // Second buffer Input Activations. This is a buffer that the CPU
+    // writes and the DSP reads, so we'll need to flush CPU caches and
+    // invalidate DSP ones. On platforms with I/O coherency support the
+    // framework will automatically skip cache operations where possible.
+    bufs[1].fd     = src1_buf->fd;
+    bufs[1].ptr    = src1->data;
+    bufs[1].offset = (uint8_t *) src1->data - src1_buf->base;
+    bufs[1].size   = ggml_nbytes(src1);
+    bufs[1].flags  = (DSPQUEUE_BUFFER_FLAG_REF |                   // Take a reference
+                     DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |          // Flush CPU
+                     DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);  // Invalidate DSP
+
+    // Third buffer Output Activations. We'll handle DSP
+    // cache maintenance in the response message but need to flush
+    // CPU caches to ensure any previously written dirty lines are
+    // written out before writes from the DSP start.
+    bufs[2].fd     = dst_buf->fd;
+    bufs[2].ptr    = dst->data;
+    bufs[2].offset = (uint8_t *) dst->data - dst_buf->base;
+    bufs[2].size   = ggml_nbytes(dst);
+    bufs[2].flags  = (DSPQUEUE_BUFFER_FLAG_REF | DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
+
+    // Primary DSP session from the src0 (normally weight) tensor
+    auto sess = src0_buf->sess;
+
+    if (opt_verbose) {
+        char dims[64 * GGML_MAX_SRC];
+        char strides[64 * GGML_MAX_SRC];
+        char types[16 * GGML_MAX_SRC];
+        char buffs[64 * GGML_MAX_SRC];
+        char names[64 * GGML_MAX_SRC];
+
+        hex_format_op_dims(dims, op);
+        hex_format_op_strides(strides, op);
+        hex_format_op_types(types, op);
+        hex_format_op_buffs(buffs, op);
+        hex_format_op_names(names, op);
+
+        HEX_VERBOSE("ggml-hex: %s %s: %s : %s : %s : %s : %s: flags 0x%x\n", sess->name.c_str(), ggml_op_name(op->op),
+                    names, dims, types, strides, buffs, req.flags);
+        if (opt_verbose > 1) {
+            hex_dump_dspbuf(src0, &bufs[0]);
+            hex_dump_dspbuf(src1, &bufs[1]);
+            hex_dump_dspbuf(dst, &bufs[2]);
+        }
+    }
+
+    if ((opt_opmask & HTP_OPMASK_QUEUE)) {
+        // Bump pending flag (cleared in the callback once we get the responce)
+        sess->op_pending++;  // atomic inc
+
+        int err = dspqueue_write(sess->queue,
+                                 0,                       // flags - the framework will autoset this
+                                 3,                       // number of buffers
+                                 bufs,                    // buffer references
+                                 sizeof(req),
+                                 (const uint8_t *) &req,  // Message
+                                 1000000                  // Timeout
+        );
+
+        if (err != 0) {
+            GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", sess->name.c_str(), (unsigned) err);
+        }
+    }
+
+    if (opt_opsync) {
+        while (sess->op_pending) {
+            ;
+        }
+    }
+
+    t2 = ggml_time_us();
+
+    HEX_PROFILE(
+        "ggml-hex: %s %s %s %u:%u:%u:%u x %s %u:%u:%u:%u -> %s %u:%u:%u:%u : op-usec %u op-cycles %u op-pkts %u (%f) "
+        "call-usec %llu\n",
+        sess->name.c_str(), ggml_op_name(op->op), src0->name, (uint32_t) src0->ne[0], (uint32_t) src0->ne[1],
+        (uint32_t) src0->ne[2], (uint32_t) src0->ne[3], src1->name, (uint32_t) src1->ne[0], (uint32_t) src1->ne[1],
+        (uint32_t) src1->ne[2], (uint32_t) src1->ne[3], dst->name, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1],
+        (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], sess->prof_usecs, sess->prof_cycles, sess->prof_pkts,
+        (float) sess->prof_cycles / sess->prof_pkts, (unsigned long long) t2 - t1);
+}
+
+static void ggml_hexagon_mul_mat_id(const struct ggml_tensor * op, uint32_t flags) {
+    const struct ggml_tensor * src0 = op->src[0];
+    const struct ggml_tensor * src1 = op->src[1];
+    const struct ggml_tensor * src2 = op->src[2];
+    const struct ggml_tensor * dst  = op;
+
+    auto src0_buf = static_cast<ggml_backend_hexagon_buffer_context *>(src0->buffer->context);
+    auto src1_buf = static_cast<ggml_backend_hexagon_buffer_context *>(src1->buffer->context);
+    auto src2_buf = static_cast<ggml_backend_hexagon_buffer_context *>(src2->buffer->context);
+    auto dst_buf  = static_cast<ggml_backend_hexagon_buffer_context *>(dst->buffer->context);
+
+    uint64_t t1, t2;
+    t1 = ggml_time_us();
+
+    // Construct HTP message
+    htp_general_req req;
+    req.op    = HTP_OP_MUL_MAT_ID;
+    req.flags = flags;
+
+    init_htp_tensor(&req.src0, src0);
+    init_htp_tensor(&req.src1, src1);
+    init_htp_tensor(&req.src2, src2);
+    init_htp_tensor(&req.dst, dst);
+
+    // Use opmask to override flags
+    if (!(opt_opmask & HTP_OPMASK_QUANTIZE)) {
+        req.flags |= HTP_OPFLAGS_SKIP_QUANTIZE;
+    }
+    if (!(opt_opmask & HTP_OPMASK_COMPUTE)) {
+        req.flags |= HTP_OPFLAGS_SKIP_COMPUTE;
+    }
+
+    dspqueue_buffer bufs[4];
+    memset(bufs, 0, sizeof(bufs));
+
+    // First buffer Weights.
+    // The content is static, there is no need to do any cache management
+    bufs[0].fd     = src0_buf->fd;
+    bufs[0].ptr    = src0->data;
+    bufs[0].offset = (uint8_t *) src0->data - src0_buf->base;
+    bufs[0].size   = ggml_nbytes(src0);
+    bufs[0].flags  = DSPQUEUE_BUFFER_FLAG_REF;
+
+    // Second buffer Input Activations. This is a buffer that the CPU
+    // writes and the DSP reads, so we'll need to flush CPU caches and
+    // invalidate DSP ones. On platforms with I/O coherency support the
+    // framework will automatically skip cache operations where possible.
+    bufs[1].fd     = src1_buf->fd;
+    bufs[1].ptr    = src1->data;
+    bufs[1].offset = (uint8_t *) src1->data - src1_buf->base;
+    bufs[1].size   = ggml_nbytes(src1);
+    bufs[1].flags  = (DSPQUEUE_BUFFER_FLAG_REF |                   // Take a reference
+                     DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |          // Flush CPU
+                     DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);  // Invalidate DSP
+
+    // Third buffer expert IDs. This is a buffer that the CPU
+    // writes and the DSP reads, so we'll need to flush CPU caches and
+    // invalidate DSP ones. On platforms with I/O coherency support the
+    // framework will automatically skip cache operations where possible.
+    bufs[2].fd     = src2_buf->fd;
+    bufs[2].ptr    = src2->data;
+    bufs[2].offset = (uint8_t *) src2->data - src2_buf->base;
+    bufs[2].size   = ggml_nbytes(src2);
+    bufs[2].flags  = (DSPQUEUE_BUFFER_FLAG_REF |                   // Take a reference
+                     DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |          // Flush CPU
+                     DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);  // Invalidate DSP
+
+    // Forth buffer Output Activations. We'll handle DSP
+    // cache maintenance in the response message but need to flush
+    // CPU caches to ensure any previously written dirty lines are
+    // written out before writes from the DSP start.
+    bufs[3].fd     = dst_buf->fd;
+    bufs[3].ptr    = dst->data;
+    bufs[3].offset = (uint8_t *) dst->data - dst_buf->base;
+    bufs[3].size   = ggml_nbytes(dst);
+    bufs[3].flags  = (DSPQUEUE_BUFFER_FLAG_REF | DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
+
+    // Primary DSP session from the src0 (normally weight) tensor
+    auto sess = src0_buf->sess;
+
+    if (opt_verbose) {
+        char dims[64 * GGML_MAX_SRC];
+        char strides[64 * GGML_MAX_SRC];
+        char types[16 * GGML_MAX_SRC];
+        char buffs[64 * GGML_MAX_SRC];
+        char names[64 * GGML_MAX_SRC];
+
+        hex_format_op_dims(dims, op);
+        hex_format_op_types(types, op);
+        hex_format_op_buffs(buffs, op);
+        hex_format_op_names(names, op);
+
+        HEX_VERBOSE("ggml-hex: %s %s: %s : %s : %s : %s : %s: flags 0x%x\n", sess->name.c_str(), ggml_op_name(op->op),
+                    names, dims, types, strides, buffs, req.flags);
+
+        if (opt_verbose > 1) {
+            hex_dump_dspbuf(src0, &bufs[0]);
+            hex_dump_dspbuf(src1, &bufs[1]);
+            hex_dump_dspbuf(src2, &bufs[2]);
+            hex_dump_dspbuf(dst, &bufs[3]);
+        }
+    }
+
+    if ((opt_opmask & HTP_OPMASK_QUEUE)) {
+        // Bump pending flag (cleared in the callback once we get the responce)
+        sess->op_pending++;  // atomic inc
+
+        int err = dspqueue_write(sess->queue,
+                                 0,                       // flags - the framework will autoset this
+                                 4,                       // number of buffers
+                                 bufs,                    // buffer references
+                                 sizeof(req),
+                                 (const uint8_t *) &req,  // Message
+                                 1000000                  // Timeout
+        );
+
+        if (err != 0) {
+            GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", sess->name.c_str(), (unsigned) err);
+        }
+    }
+
+    if (opt_opsync) {
+        while (sess->op_pending) {
+            ;
+        }
+    }
+
+    t2 = ggml_time_us();
+
+    HEX_PROFILE(
+        "ggml-hex: %s matmul-id %s %u:%u:%u:%u x %s %u:%u:%u:%u (%s %u:%u:%u:%u) -> %s %u:%u:%u:%u : op-usec %u "
+        "op-cycles %u op-pkts %u (%f) call-usec %llu\n",
+        sess->name.c_str(), src0->name, (uint32_t) src0->ne[0], (uint32_t) src0->ne[1], (uint32_t) src0->ne[2],
+        (uint32_t) src0->ne[3], src1->name, (uint32_t) src1->ne[0], (uint32_t) src1->ne[1], (uint32_t) src1->ne[2],
+        (uint32_t) src1->ne[3], src2->name, (uint32_t) src2->ne[0], (uint32_t) src2->ne[1], (uint32_t) src2->ne[2],
+        (uint32_t) src2->ne[3], dst->name, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],
+        (uint32_t) dst->ne[3], sess->prof_usecs, sess->prof_cycles, sess->prof_pkts,
+        (float) sess->prof_cycles / sess->prof_pkts, (unsigned long long) t2 - t1);
+}
+
+static void ggml_hexagon_binary(const struct ggml_tensor * op, uint32_t flags) {
+    const struct ggml_tensor * node = op;
+    const struct ggml_tensor * src0 = node->src[0];
+    const struct ggml_tensor * src1 = node->src[1];
+    const struct ggml_tensor * dst  = node;
+
+    auto src0_buf = static_cast<ggml_backend_hexagon_buffer_context *>(src0->buffer->context);
+    auto src1_buf = static_cast<ggml_backend_hexagon_buffer_context *>(src1->buffer->context);
+    auto dst_buf  = static_cast<ggml_backend_hexagon_buffer_context *>(dst->buffer->context);
+
+    uint64_t t1 = 0;
+    uint64_t t2 = 0;
+
+    t1 = ggml_time_us();
+
+    // Construct HTP message
+    htp_general_req req;
+    req.flags = flags;
+
+    // Use opmask to override flags
+    if (!(opt_opmask & HTP_OPMASK_QUANTIZE)) {
+        req.flags |= HTP_OPFLAGS_SKIP_QUANTIZE;
+    }
+    if (!(opt_opmask & HTP_OPMASK_COMPUTE)) {
+        req.flags |= HTP_OPFLAGS_SKIP_COMPUTE;
+    }
+
+    switch (node->op) {
+        case GGML_OP_MUL:
+            req.op = HTP_OP_MUL;
+            break;
+        case GGML_OP_ADD:
+            req.op = HTP_OP_ADD;
+            break;
+        case GGML_OP_SUB:
+            req.op = HTP_OP_SUB;
+            break;
+        default:
+            GGML_ABORT("ggml-hex: binary : unsupported op:%d\n", node->op);
+    }
+
+    init_htp_tensor(&req.src0, src0);
+    init_htp_tensor(&req.src1, src1);
+    init_htp_tensor(&req.dst, dst);
+
+    dspqueue_buffer bufs[3];
+    memset(bufs, 0, sizeof(bufs));
+
+    // First buffer = First Operand of Binary op
+    // This is a buffer that the CPU writes and the DSP reads, so we'll
+    // need to flush CPU caches and invalidate DSP ones. On platforms
+    // with I/O coherency support the framework will automatically skip
+    // cache operations where possible.
+    bufs[0].fd     = src0_buf->fd;
+    bufs[0].ptr    = src0->data;
+    bufs[0].offset = (uint8_t *) src0->data - src0_buf->base;
+    bufs[0].size   = ggml_nbytes(src0);
+    bufs[0].flags  = (DSPQUEUE_BUFFER_FLAG_REF |                   // Take a reference
+                     DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |          // Flush CPU
+                     DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);  // Invalidate DSP;
+
+    // Second buffer = Second Operand of Binary op
+    // This is a buffer that the CPU writes and the DSP reads, so we'll
+    // need to flush CPU caches and invalidate DSP ones. On platforms
+    // with I/O coherency support the framework will automatically skip
+    // cache operations where possible.
+    bufs[1].fd     = src1_buf->fd;
+    bufs[1].ptr    = src1->data;
+    bufs[1].offset = (uint8_t *) src1->data - src1_buf->base;
+    bufs[1].size   = ggml_nbytes(src1);
+    bufs[1].flags  = (DSPQUEUE_BUFFER_FLAG_REF |                   // Take a reference
+                     DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |          // Flush CPU
+                     DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);  // Invalidate DSP
+
+    // Third buffer = Output Activations. We'll handle DSP
+    // cache maintenance in the response message but need to flush
+    // CPU caches to ensure any previously written dirty lines are
+    // written out before writes from the DSP start.
+    bufs[2].fd     = dst_buf->fd;
+    bufs[2].ptr    = dst->data;
+    bufs[2].offset = (uint8_t *) dst->data - dst_buf->base;
+    bufs[2].size   = ggml_nbytes(dst);
+    bufs[2].flags  = (DSPQUEUE_BUFFER_FLAG_REF | DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
+
+    // Primary DSP session from the src0 tensor
+    ggml_hexagon_session * sess = src0_buf->sess;
+
+    if (opt_verbose) {
+        char dims[64 * GGML_MAX_SRC];
+        char strides[16 * GGML_MAX_SRC];
+        char types[16 * GGML_MAX_SRC];
+        char buffs[64 * GGML_MAX_SRC];
+        char names[64 * GGML_MAX_SRC];
+
+        hex_format_op_dims(dims, op);
+        hex_format_op_strides(strides, op);
+        hex_format_op_types(types, op);
+        hex_format_op_buffs(buffs, op);
+        hex_format_op_names(names, op);
+
+        HEX_VERBOSE("ggml-hex: %s %s : %s : %s : %s : %s : %s : flags 0x%x\n", sess->name.c_str(),
+                    ggml_op_name(node->op), names, dims, types, strides, buffs, req.flags);
+        if (opt_verbose > 1) {
+            hex_dump_dspbuf(src0, &bufs[0]);
+            hex_dump_dspbuf(src1, &bufs[1]);
+            hex_dump_dspbuf(dst, &bufs[2]);
+        }
+    }
+
+    if ((opt_opmask & HTP_OPMASK_QUEUE)) {
+        // Bump pending flag (cleared in the callback once we get the responce)
+        sess->op_pending++;  // atomic inc
+
+        int err = dspqueue_write(sess->queue,
+                                 0,                       // flags - the framework will autoset this
+                                 3,                       // number of buffers
+                                 bufs,                    // buffer references
+                                 sizeof(req),
+                                 (const uint8_t *) &req,  // Message
+                                 1000000);                // Timeout
+
+        if (0 != err) {
+            GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", sess->name.c_str(), (unsigned) err);
+        }
+    }
+
+    if (opt_opsync) {
+        while (sess->op_pending) {
+            ;
+        }
+    }
+
+    t2 = ggml_time_us();
+
+    HEX_PROFILE(
+        "ggml-hex: %s %s %s %u:%u:%u:%u x %s %u:%u:%u:%u -> %s %u:%u:%u:%u : op-usec %u op-cycles %u op-pkts %u (%f) "
+        "call-usec %llu\n",
+        sess->name.c_str(), ggml_op_name(node->op), src0->name, (uint32_t) src0->ne[0], (uint32_t) src0->ne[1],
+        (uint32_t) src0->ne[2], (uint32_t) src0->ne[3], src1->name, (uint32_t) src1->ne[0], (uint32_t) src1->ne[1],
+        (uint32_t) src1->ne[2], (uint32_t) src1->ne[3], dst->name, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1],
+        (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], sess->prof_usecs, sess->prof_cycles, sess->prof_pkts,
+        (float) sess->prof_cycles / sess->prof_pkts, (unsigned long long) t2 - t1);
+}
+
+static void ggml_hexagon_add_id(const struct ggml_tensor * op, uint32_t flags) {
+    const struct ggml_tensor * node = op;
+    const struct ggml_tensor * src0 = node->src[0];
+    const struct ggml_tensor * src1 = node->src[1];
+    const struct ggml_tensor * src2 = node->src[2];
+    const struct ggml_tensor * dst  = node;
+
+    auto src0_buf = static_cast<ggml_backend_hexagon_buffer_context *>(src0->buffer->context);
+    auto src1_buf = static_cast<ggml_backend_hexagon_buffer_context *>(src1->buffer->context);
+    auto src2_buf = static_cast<ggml_backend_hexagon_buffer_context *>(src2->buffer->context);
+    auto dst_buf  = static_cast<ggml_backend_hexagon_buffer_context *>(dst->buffer->context);
+
+    uint64_t t1 = 0;
+    uint64_t t2 = 0;
+
+    t1 = ggml_time_us();
+
+    // Construct HTP message
+    htp_general_req req;
+    req.flags = flags;
+
+    // Use opmask to override flags
+    if (!(opt_opmask & HTP_OPMASK_QUANTIZE)) {
+        req.flags |= HTP_OPFLAGS_SKIP_QUANTIZE;
+    }
+    if (!(opt_opmask & HTP_OPMASK_COMPUTE)) {
+        req.flags |= HTP_OPFLAGS_SKIP_COMPUTE;
+    }
+
+    switch (node->op) {
+        case GGML_OP_ADD_ID:
+            req.op = HTP_OP_ADD_ID;
+            break;
+        default:
+            GGML_ABORT("ggml-hex: unsupported op:%d\n", node->op);
+    }
+
+    init_htp_tensor(&req.src0, src0);
+    init_htp_tensor(&req.src1, src1);
+    init_htp_tensor(&req.src2, src2);
+    init_htp_tensor(&req.dst, dst);
+
+    dspqueue_buffer bufs[4];
+    memset(bufs, 0, sizeof(bufs));
+
+    // First buffer = input activations
+    bufs[0].fd     = src0_buf->fd;
+    bufs[0].ptr    = src0->data;
+    bufs[0].offset = (uint8_t *) src0->data - src0_buf->base;
+    bufs[0].size   = ggml_nbytes(src0);
+    bufs[0].flags  = (DSPQUEUE_BUFFER_FLAG_REF |                   // Take a reference
+                     DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |          // Flush CPU
+                     DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);  // Invalidate DSP;
+
+    // Second buffer = experts bias
+    bufs[1].fd     = src1_buf->fd;
+    bufs[1].ptr    = src1->data;
+    bufs[1].offset = (uint8_t *) src1->data - src1_buf->base;
+    bufs[1].size   = ggml_nbytes(src1);
+    bufs[1].flags  = (DSPQUEUE_BUFFER_FLAG_REF |                   // Take a reference
+                     DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |          // Flush CPU
+                     DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);  // Invalidate DSP
+
+    // Third buffer = activated experts
+    bufs[2].fd     = src2_buf->fd;
+    bufs[2].ptr    = src2->data;
+    bufs[2].offset = (uint8_t *) src2->data - src2_buf->base;
+    bufs[2].size   = ggml_nbytes(src2);
+    bufs[2].flags  = (DSPQUEUE_BUFFER_FLAG_REF |                   // Take a reference
+                     DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |          // Flush CPU
+                     DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);  // Invalidate DSP
+
+    // Forth buffer = output activations
+    bufs[3].fd     = dst_buf->fd;
+    bufs[3].ptr    = dst->data;
+    bufs[3].offset = (uint8_t *) dst->data - dst_buf->base;
+    bufs[3].size   = ggml_nbytes(dst);
+    bufs[3].flags  = (DSPQUEUE_BUFFER_FLAG_REF | DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
+
+    // Primary DSP session from the src0 tensor
+    ggml_hexagon_session * sess = src0_buf->sess;
+
+    if (opt_verbose) {
+        char dims[64 * GGML_MAX_SRC];
+        char strides[16 * GGML_MAX_SRC];
+        char types[16 * GGML_MAX_SRC];
+        char buffs[64 * GGML_MAX_SRC];
+        char names[64 * GGML_MAX_SRC];
+
+        hex_format_op_dims(dims, op);
+        hex_format_op_strides(strides, op);
+        hex_format_op_types(types, op);
+        hex_format_op_buffs(buffs, op);
+        hex_format_op_names(names, op);
+
+        HEX_VERBOSE("ggml-hex: %s %s : %s : %s : %s : %s : %s : flags 0x%x\n", sess->name.c_str(),
+                    ggml_op_name(node->op), names, dims, types, strides, buffs, req.flags);
+
+        if (opt_verbose > 1) {
+            hex_dump_dspbuf(src0, &bufs[0]);
+            hex_dump_dspbuf(src1, &bufs[1]);
+            hex_dump_dspbuf(src2, &bufs[2]);
+            hex_dump_dspbuf(dst, &bufs[3]);
+        }
+    }
+
+    if ((opt_opmask & HTP_OPMASK_QUEUE)) {
+        // Bump pending flag (cleared in the callback once we get the responce)
+        sess->op_pending++;  // atomic inc
+
+        int err = dspqueue_write(sess->queue,
+                                 0,                       // flags - the framework will autoset this
+                                 4,                       // number of buffers
+                                 bufs,                    // buffer references
+                                 sizeof(req),
+                                 (const uint8_t *) &req,  // Message
+                                 1000000);                // Timeout
+
+        if (0 != err) {
+            GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", sess->name.c_str(), (unsigned) err);
+        }
+    }
+
+    if (opt_opsync) {
+        while (sess->op_pending) {
+            ;
+        }
+    }
+
+    t2 = ggml_time_us();
+
+    HEX_PROFILE(
+        "ggml-hex: %s %s %s %u:%u:%u:%u x %s %u:%u:%u:%u -> %s %u:%u:%u:%u : op-usec %u op-cycles %u op-pkts %u (%f) "
+        "call-usec %llu\n",
+        sess->name.c_str(), ggml_op_name(node->op), src0->name, (uint32_t) src0->ne[0], (uint32_t) src0->ne[1],
+        (uint32_t) src0->ne[2], (uint32_t) src0->ne[3], src1->name, (uint32_t) src1->ne[0], (uint32_t) src1->ne[1],
+        (uint32_t) src1->ne[2], (uint32_t) src1->ne[3], dst->name, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1],
+        (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], sess->prof_usecs, sess->prof_cycles, sess->prof_pkts,
+        (float) sess->prof_cycles / sess->prof_pkts, (unsigned long long) t2 - t1);
+}
+
+static void ggml_hexagon_unary(const struct ggml_tensor * op, uint32_t flags) {
+    const struct ggml_tensor * src0 = op->src[0];
+    const struct ggml_tensor * src1 = op->src[1];
+    const struct ggml_tensor * dst  = op;
+
+    uint64_t t1 = 0;
+    uint64_t t2 = 0;
+
+    t1 = ggml_time_us();
+
+    // Construct HTP message
+    htp_general_req req;
+
+    memset(&req, 0, sizeof(htp_general_req));
+    memcpy(&req.op_params, &op->op_params, sizeof(op->op_params));
+    req.flags = flags;
+
+    bool supported = false;
+
+    switch (op->op) {
+        case GGML_OP_RMS_NORM:
+            req.op    = HTP_OP_RMS_NORM;
+            supported = true;
+            break;
+
+        case GGML_OP_UNARY:
+            if (ggml_get_unary_op(dst) == GGML_UNARY_OP_SILU) {
+                req.op    = HTP_OP_UNARY_SILU;
+                supported = true;
+            }
+            break;
+
+        case GGML_OP_GLU:
+            if (ggml_get_glu_op(dst) == GGML_GLU_OP_SWIGLU) {
+                req.op    = HTP_OP_GLU_SWIGLU;
+                supported = true;
+            } else if (ggml_get_glu_op(dst) == GGML_GLU_OP_SWIGLU_OAI) {
+                req.op    = HTP_OP_GLU_SWIGLU_OAI;
+                supported = true;
+            }
+            break;
+
+        case GGML_OP_SOFT_MAX:
+            req.op    = HTP_OP_SOFTMAX;
+            supported = true;
+
+        default:
+            break;
+    }
+
+    if (!supported) {
+        GGML_ABORT("ggml-hex: unary : unsupported op:%d\n", op->op);
+    }
+
+    init_htp_tensor(&req.dst, dst);
+    init_htp_tensor(&req.src0, src0);
+    if (src1) {
+        init_htp_tensor(&req.src1, src1);
+    }
+
+    // Use opmask to override flags
+    if (!(opt_opmask & HTP_OPMASK_QUANTIZE)) {
+        req.flags |= HTP_OPFLAGS_SKIP_QUANTIZE;
+    }
+    if (!(opt_opmask & HTP_OPMASK_COMPUTE)) {
+        req.flags |= HTP_OPFLAGS_SKIP_COMPUTE;
+    }
+
+    dspqueue_buffer bufs[3];
+    int             n_bufs = 0;
+
+    memset(bufs, 0, sizeof(bufs));
+
+    // First buffer = Only Operand of Unary op
+    // This is a buffer that the CPU writes and the DSP reads, so we'll
+    // need to flush CPU caches and invalidate DSP ones. On platforms
+    // with I/O coherency support the framework will automatically skip
+    // cache operations where possible.
+    auto src0_buf       = static_cast<ggml_backend_hexagon_buffer_context *>(src0->buffer->context);
+    bufs[n_bufs].fd     = src0_buf->fd;
+    bufs[n_bufs].ptr    = src0->data;
+    bufs[n_bufs].offset = (uint8_t *) src0->data - src0_buf->base;
+    bufs[n_bufs].size   = ggml_nbytes(src0);
+    bufs[n_bufs].flags  = (DSPQUEUE_BUFFER_FLAG_REF |                   // Take a reference
+                          DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |          // Flush CPU
+                          DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);  // Invalidate DSP;
+    ++n_bufs;
+
+    if (src1) {
+        // Second buffer = Second Operand of Binary op
+        // This is a buffer that the CPU writes and the DSP reads, so we'll
+        // need to flush CPU caches and invalidate DSP ones. On platforms
+        // with I/O coherency support the framework will automatically skip
+        // cache operations where possible.
+        auto src1_buf       = static_cast<ggml_backend_hexagon_buffer_context *>(src1->buffer->context);
+        bufs[n_bufs].fd     = src1_buf->fd;
+        bufs[n_bufs].ptr    = src1->data;
+        bufs[n_bufs].offset = (uint8_t *) src1->data - src1_buf->base;
+        bufs[n_bufs].size   = ggml_nbytes(src1);
+        bufs[n_bufs].flags  = (DSPQUEUE_BUFFER_FLAG_REF |                   // Take a reference
+                              DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |          // Flush CPU
+                              DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);  // Invalidate DSP
+        ++n_bufs;
+    }
+
+    // Second or third buffer = Output Activations. We'll handle DSP
+    // Second buffer = Output Activations. We'll handle DSP
+    // cache maintenance in the response message but need to flush
+    // CPU caches to ensure any previously written dirty lines are
+    // written out before writes from the DSP start.
+    auto dst_buf        = static_cast<ggml_backend_hexagon_buffer_context *>(dst->buffer->context);
+    bufs[n_bufs].fd     = dst_buf->fd;
+    bufs[n_bufs].ptr    = dst->data;
+    bufs[n_bufs].offset = (uint8_t *) dst->data - dst_buf->base;
+    bufs[n_bufs].size   = ggml_nbytes(dst);
+    bufs[n_bufs].flags  = (DSPQUEUE_BUFFER_FLAG_REF | DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
+    ++n_bufs;
+
+    // Primary DSP session from the src0 tensor
+    ggml_hexagon_session * sess = src0_buf->sess;
+
+    if (opt_verbose) {
+        char dims[64 * GGML_MAX_SRC];
+        char strides[64 * GGML_MAX_SRC];
+        char types[16 * GGML_MAX_SRC];
+        char buffs[64 * GGML_MAX_SRC];
+        char names[64 * GGML_MAX_SRC];
+
+        hex_format_op_dims(dims, op);
+        hex_format_op_strides(strides, op);
+        hex_format_op_types(types, op);
+        hex_format_op_buffs(buffs, op);
+        hex_format_op_names(names, op);
+
+        HEX_VERBOSE("ggml-hex: %s %s : %s : %s : %s : %s : %s : flags 0x%x\n", sess->name.c_str(), ggml_op_name(op->op),
+                    names, dims, types, strides, buffs, req.flags);
+        if (opt_verbose > 1) {
+            hex_dump_dspbuf(src0, &bufs[0]);
+            if (src1) {
+                hex_dump_dspbuf(src1, &bufs[1]);
+                hex_dump_dspbuf(dst, &bufs[2]);
+            } else {
+                hex_dump_dspbuf(dst, &bufs[1]);
+            }
+        }
+    }
+
+    if ((opt_opmask & HTP_OPMASK_QUEUE)) {
+        // Bump pending flag (cleared in the callback once we get the responce)
+        sess->op_pending++;  // atomic inc
+
+        int err = dspqueue_write(sess->queue,
+                                 0,                       // flags - the framework will autoset this
+                                 n_bufs,                  // number of buffers
+                                 bufs,                    // buffer references
+                                 sizeof(req),
+                                 (const uint8_t *) &req,  // Message
+                                 1000000);                // Timeout
+
+        if (0 != err) {
+            GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", sess->name.c_str(), (unsigned) err);
+        }
+    }
+
+    if (opt_opsync) {
+        while (sess->op_pending) {
+            ;
+        }
+    }
+
+    t2 = ggml_time_us();
+
+    if (src1) {
+        HEX_PROFILE(
+            "ggml-hex: %s %s %s %u:%u:%u:%u x %s %u:%u:%u:%u -> %s %u:%u:%u:%u : op-usec %u op-cycles %u op-pkts %u "
+            "(%f) call-usec %llu\n",
+            sess->name.c_str(), ggml_op_name(op->op), src0->name, (uint32_t) src0->ne[0], (uint32_t) src0->ne[1],
+            (uint32_t) src0->ne[2], (uint32_t) src0->ne[3], src1->name, (uint32_t) src1->ne[0], (uint32_t) src1->ne[1],
+            (uint32_t) src1->ne[2], (uint32_t) src1->ne[3], dst->name, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1],
+            (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], sess->prof_usecs, sess->prof_cycles, sess->prof_pkts,
+            (float) sess->prof_cycles / sess->prof_pkts, (unsigned long long) t2 - t1);
+    } else {
+        HEX_PROFILE(
+            "ggml-hex: %s %s %s %u:%u:%u:%u -> %s %u:%u:%u:%u : op-usec %u op-cycles %u op-pkts %u (%f) call-usec "
+            "%llu\n",
+            sess->name.c_str(), ggml_op_name(op->op), src0->name, (uint32_t) src0->ne[0], (uint32_t) src0->ne[1],
+            (uint32_t) src0->ne[2], (uint32_t) src0->ne[3], dst->name, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1],
+            (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], sess->prof_usecs, sess->prof_cycles, sess->prof_pkts,
+            (float) sess->prof_cycles / sess->prof_pkts, (unsigned long long) t2 - t1);
+    }
+}
+
+static void ggml_hexagon_rope(const struct ggml_tensor * op, uint32_t flags) {
+    const struct ggml_tensor * src0 = op->src[0];
+    const struct ggml_tensor * src1 = op->src[1];
+    const struct ggml_tensor * src2 = op->src[2];
+    const struct ggml_tensor * dst  = op;
+
+    uint64_t t1 = 0;
+    uint64_t t2 = 0;
+
+    t1 = ggml_time_us();
+
+    // Construct HTP message
+    htp_general_req req;
+
+    memset(&req, 0, sizeof(htp_general_req));
+    memcpy(&req.op_params, &op->op_params, sizeof(op->op_params));
+    req.flags = flags;
+    req.op    = HTP_OP_ROPE;
+
+    init_htp_tensor(&req.dst, dst);
+    init_htp_tensor(&req.src0, src0);
+    init_htp_tensor(&req.src1, src1);
+    if (src2) {
+        init_htp_tensor(&req.src2, src2);
+    }
+
+    // Use opmask to override flags
+    if (!(opt_opmask & HTP_OPMASK_QUANTIZE)) {
+        req.flags |= HTP_OPFLAGS_SKIP_QUANTIZE;
+    }
+    if (!(opt_opmask & HTP_OPMASK_COMPUTE)) {
+        req.flags |= HTP_OPFLAGS_SKIP_COMPUTE;
+    }
+
+    dspqueue_buffer bufs[4];
+    int             n_bufs = 0;
+
+    memset(bufs, 0, sizeof(bufs));
+
+    // First buffer
+    // This is a buffer that the CPU writes and the DSP reads, so we'll
+    // need to flush CPU caches and invalidate DSP ones. On platforms
+    // with I/O coherency support the framework will automatically skip
+    // cache operations where possible.
+    auto src0_buf       = static_cast<ggml_backend_hexagon_buffer_context *>(src0->buffer->context);
+    bufs[n_bufs].fd     = src0_buf->fd;
+    bufs[n_bufs].ptr    = src0->data;
+    bufs[n_bufs].offset = (uint8_t *) src0->data - src0_buf->base;
+    bufs[n_bufs].size   = ggml_nbytes(src0);
+    bufs[n_bufs].flags  = (DSPQUEUE_BUFFER_FLAG_REF |                   // Take a reference
+                          DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |          // Flush CPU
+                          DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);  // Invalidate DSP;
+    ++n_bufs;
+
+    // Second buffer
+    // This is a buffer that the CPU writes and the DSP reads, so we'll
+    // need to flush CPU caches and invalidate DSP ones. On platforms
+    // with I/O coherency support the framework will automatically skip
+    // cache operations where possible.
+    auto src1_buf       = static_cast<ggml_backend_hexagon_buffer_context *>(src1->buffer->context);
+    bufs[n_bufs].fd     = src1_buf->fd;
+    bufs[n_bufs].ptr    = src1->data;
+    bufs[n_bufs].offset = (uint8_t *) src1->data - src1_buf->base;
+    bufs[n_bufs].size   = ggml_nbytes(src1);
+    bufs[n_bufs].flags  = (DSPQUEUE_BUFFER_FLAG_REF |                   // Take a reference
+                          DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |          // Flush CPU
+                          DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);  // Invalidate DSP
+    ++n_bufs;
+
+    if (src2) {
+        // Third buffer
+        // This is a buffer that the CPU writes and the DSP reads, so we'll
+        // need to flush CPU caches and invalidate DSP ones. On platforms
+        // with I/O coherency support the framework will automatically skip
+        // cache operations where possible.
+        auto src2_buf       = static_cast<ggml_backend_hexagon_buffer_context *>(src2->buffer->context);
+        bufs[n_bufs].fd     = src2_buf->fd;
+        bufs[n_bufs].ptr    = src2->data;
+        bufs[n_bufs].offset = (uint8_t *) src2->data - src2_buf->base;
+        bufs[n_bufs].size   = ggml_nbytes(src2);
+        bufs[n_bufs].flags  = (DSPQUEUE_BUFFER_FLAG_REF |                   // Take a reference
+                              DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |          // Flush CPU
+                              DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);  // Invalidate DSP
+        ++n_bufs;
+    }
+
+    // Final buffer = Output Activations. We'll handle DSP
+    // Second buffer = Output Activations. We'll handle DSP
+    // cache maintenance in the response message but need to flush
+    // CPU caches to ensure any previously written dirty lines are
+    // written out before writes from the DSP start.
+    auto dst_buf        = static_cast<ggml_backend_hexagon_buffer_context *>(dst->buffer->context);
+    bufs[n_bufs].fd     = dst_buf->fd;
+    bufs[n_bufs].ptr    = dst->data;
+    bufs[n_bufs].offset = (uint8_t *) dst->data - dst_buf->base;
+    bufs[n_bufs].size   = ggml_nbytes(dst);
+    bufs[n_bufs].flags  = (DSPQUEUE_BUFFER_FLAG_REF | DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
+    ++n_bufs;
+
+    // Primary DSP session from the src0 tensor
+    ggml_hexagon_session * sess = src0_buf->sess;
+
+    if (opt_verbose) {
+        char dims[64 * GGML_MAX_SRC];
+        char strides[64 * GGML_MAX_SRC];
+        char types[16 * GGML_MAX_SRC];
+        char buffs[64 * GGML_MAX_SRC];
+        char names[64 * GGML_MAX_SRC];
+
+        hex_format_op_dims(dims, op);
+        hex_format_op_strides(strides, op);
+        hex_format_op_types(types, op);
+        hex_format_op_buffs(buffs, op);
+        hex_format_op_names(names, op);
+
+        HEX_VERBOSE("ggml-hex: %s %s : %s : %s : %s : %s : %s : flags 0x%x\n", sess->name.c_str(), ggml_op_name(op->op),
+                    names, dims, types, strides, buffs, req.flags);
+        if (opt_verbose > 1) {
+            hex_dump_dspbuf(src0, &bufs[0]);
+            if (src1) {
+                hex_dump_dspbuf(src1, &bufs[1]);
+                hex_dump_dspbuf(dst, &bufs[2]);
+            } else {
+                hex_dump_dspbuf(dst, &bufs[1]);
+            }
+        }
+    }
+
+    if ((opt_opmask & HTP_OPMASK_QUEUE)) {
+        // Bump pending flag (cleared in the callback once we get the responce)
+        sess->op_pending++;  // atomic inc
+
+        int err = dspqueue_write(sess->queue,
+                                 0,                       // flags - the framework will autoset this
+                                 n_bufs,                  // number of buffers
+                                 bufs,                    // buffer references
+                                 sizeof(req),
+                                 (const uint8_t *) &req,  // Message
+                                 1000000);                // Timeout
+
+        if (0 != err) {
+            GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", sess->name.c_str(), (unsigned) err);
+        }
+    }
+
+    if (opt_opsync) {
+        while (sess->op_pending) {
+            ;
+        }
+    }
+
+    t2 = ggml_time_us();
+
+    if (src2) {
+        HEX_PROFILE(
+            "ggml-hex: %s %s %s %u:%u:%u:%u x %s %u:%u:%u:%u x %s %u:%u:%u:%u -> %s %u:%u:%u:%u : op-usec %u op-cycles "
+            "%u op-pkts %u (%f) call-usec %llu\n",
+            sess->name.c_str(), ggml_op_name(op->op), src0->name, (uint32_t) src0->ne[0], (uint32_t) src0->ne[1],
+            (uint32_t) src0->ne[2], (uint32_t) src0->ne[3], src1->name, (uint32_t) src1->ne[0], (uint32_t) src1->ne[1],
+            (uint32_t) src1->ne[2], (uint32_t) src1->ne[3], src2->name, (uint32_t) src2->ne[0], (uint32_t) src2->ne[1],
+            (uint32_t) src2->ne[2], (uint32_t) src2->ne[3], dst->name, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1],
+            (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], sess->prof_usecs, sess->prof_cycles, sess->prof_pkts,
+            (float) sess->prof_cycles / sess->prof_pkts, (unsigned long long) t2 - t1);
+    } else {
+        HEX_PROFILE(
+            "ggml-hex: %s %s %s %u:%u:%u:%u x %s %u:%u:%u:%u -> %s %u:%u:%u:%u : op-usec %u op-cycles %u op-pkts %u "
+            "(%f) call-usec %llu\n",
+            sess->name.c_str(), ggml_op_name(op->op), src0->name, (uint32_t) src0->ne[0], (uint32_t) src0->ne[1],
+            (uint32_t) src0->ne[2], (uint32_t) src0->ne[3], src1->name, (uint32_t) src1->ne[0], (uint32_t) src1->ne[1],
+            (uint32_t) src1->ne[2], (uint32_t) src1->ne[3], dst->name, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1],
+            (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], sess->prof_usecs, sess->prof_cycles, sess->prof_pkts,
+            (float) sess->prof_cycles / sess->prof_pkts, (unsigned long long) t2 - t1);
+    }
+}
+
+static const char * ggml_backend_hexagon_name(ggml_backend_t backend) {
+    auto sess = static_cast<ggml_hexagon_session *>(backend->context);
+    return sess->name.c_str();
+}
+
+static void ggml_backend_hexagon_free(ggml_backend_t backend) {
+    // we just need to delete the backend here
+    // the sessions are allocated & freed as part of the registry
+    delete backend;
+}
+
+static inline bool op_reuse_src1(const ggml_tensor * op1, const ggml_tensor * op0) {
+    return (op0 && op0->src[1] == op1->src[1]);
+}
+
+// scan the graph and figure out last compute op index
+static inline int last_compute_op(ggml_cgraph * graph) {
+    int last;
+    for (int i = 0; i < graph->n_nodes; ++i) {
+        ggml_tensor * node = graph->nodes[i];
+
+        switch (node->op) {
+            case GGML_OP_MUL_MAT:
+            case GGML_OP_MUL_MAT_ID:
+            case GGML_OP_MUL:
+            case GGML_OP_ADD:
+            case GGML_OP_SUB:
+            case GGML_OP_RMS_NORM:
+            case GGML_OP_GLU:
+            case GGML_OP_ADD_ID:
+                last = i;
+                break;
+
+            default:
+                break;
+        }
+    }
+
+    return last;
+}
+
+static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, ggml_cgraph * graph) {
+    auto sess = static_cast<ggml_hexagon_session *>(backend->context);
+
+    HEX_VERBOSE("ggml-hex: %s graph-compute n_nodes %d\n", sess->name.c_str(), graph->n_nodes);
+
+    const int last = last_compute_op(graph);
+
+    const struct ggml_tensor * prev_quant_op = nullptr;  // prev executed op with quantizer
+
+    for (int i = 0; i < graph->n_nodes; ++i) {
+        ggml_tensor * node = graph->nodes[i];
+
+        uint32_t flags = 0;
+
+        // skip quantizer if src1 is reused
+        if (op_reuse_src1(node, prev_quant_op)) {
+            flags |= HTP_OPFLAGS_SKIP_QUANTIZE;
+        }
+
+        // ask for early notification for the last Op
+        if (i == last) {
+            flags |= HTP_OPFLAGS_EARLY_WAKEUP;
+        }
+
+        switch (node->op) {
+            case GGML_OP_MUL_MAT:
+                ggml_hexagon_mul_mat(node, flags);
+                prev_quant_op = node;
+                break;
+            case GGML_OP_MUL_MAT_ID:
+                ggml_hexagon_mul_mat_id(node, flags);
+                prev_quant_op = node;
+                break;
+            case GGML_OP_MUL:
+            case GGML_OP_ADD:
+            case GGML_OP_SUB:
+                ggml_hexagon_binary(node, flags);
+                break;
+            case GGML_OP_ADD_ID:
+                ggml_hexagon_add_id(node, flags);
+                break;
+            case GGML_OP_RMS_NORM:
+                ggml_hexagon_unary(node, flags);
+                break;
+            case GGML_OP_UNARY:
+                if (ggml_get_unary_op(node) == GGML_UNARY_OP_SILU) {
+                    ggml_hexagon_unary(node, flags);
+                }
+                break;
+            case GGML_OP_GLU:
+                if ((ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU) ||
+                    (ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI)) {
+                    ggml_hexagon_unary(node, flags);
+                }
+                break;
+            case GGML_OP_SOFT_MAX:
+                ggml_hexagon_unary(node, flags);
+                break;
+
+            case GGML_OP_ROPE:
+                ggml_hexagon_rope(node, flags);
+                break;
+
+            // non-compute ops
+            case GGML_OP_NONE:
+            case GGML_OP_RESHAPE:
+            case GGML_OP_VIEW:
+            case GGML_OP_PERMUTE:
+            case GGML_OP_TRANSPOSE:
+                break;
+
+            default:
+                GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node));
+        }
+    }
+
+    // Wait until all pending ops complete
+    while (sess->op_pending) {
+        ;
+    }
+
+    return GGML_STATUS_SUCCESS;
+}
+
+static void ggml_backend_hexagon_synchronize(ggml_backend_t backend) {
+    auto sess = static_cast<ggml_hexagon_session *>(backend->context);
+
+    HEX_VERBOSE("ggml-hex: %s synchronize\n", sess->name.c_str());
+
+    // Wait until all pending ops complete
+    while (sess->op_pending) {
+        ;
+    }
+}
+
+struct node_info {
+    ggml_tensor * node;
+
+    std::vector<ggml_tensor *> fused;
+
+    ggml_op op() const {
+        return node->op;
+    }
+
+    const ggml_tensor * dst() const {
+        return fused.empty() ? node : fused.back();
+    }
+
+    const ggml_tensor * src0() const {
+        return node->src[0];
+    }
+
+    const ggml_tensor * src1() const {
+        return node->src[1];
+    }
+
+    bool is_empty() const {
+        return ggml_op_is_empty(node->op);
+    }
+
+    void add_fused(ggml_tensor * t) {
+        fused.push_back(t);
+    }
+
+    bool stackable() const {
+        switch (this->op()) {
+            case GGML_OP_MUL_MAT:
+            case GGML_OP_MUL_MAT_ID:
+                return ggml_is_quantized(this->src0()->type);
+            default:
+                return false;
+        }
+    }
+
+    bool same_input(const node_info& n) const {
+        return n.src1() == this->src1();
+    }
+};
+
+static std::vector<int> ggml_hexagon_graph_optimize_reorder(const std::vector<node_info> & nodes) {
+    const int n = nodes.size();
+
+    std::vector<int> res;
+    res.reserve(n);
+
+    std::vector<bool> used(n, false);
+
+    // The main goal here is to stack the MUL_MAT ops with the same src1 input.
+    // This allows use to reuse dynamically quantized src1 in VTCM.
+
+    // TODO: the current version might do incorrect reodering in cases where quantized src0
+    //       input is an output of another Op.
+
+    for (int i0 = 0; i0 < n; i0++) {
+        if (used[i0]) {
+            continue;
+        }
+
+        res.push_back(i0);
+
+        const auto & node0 = nodes[i0];
+
+        if (!node0.stackable()) {
+            continue;
+        }
+
+        // that many nodes forward to search for stackable nodes that can reuse VTCM
+        constexpr int N_FORWARD = 8;
+
+        for (int i1 = i0 + 1; i1 < i0 + N_FORWARD && i1 < n; i1++) {
+            if (used[i1]) {
+                continue;
+            }
+
+            const auto & node1 = nodes[i1];
+
+            if (node1.stackable() && node1.same_input(node0)) {
+                res.push_back(i1);
+                used[i1] = true;
+            }
+        }
+    }
+
+    return res;
+}
+
+static void ggml_backend_hexagon_graph_optimize(ggml_backend_t backend, ggml_cgraph * gf) {
+    const int n = gf->n_nodes;
+
+    constexpr int MAX_FUSE = 16;
+
+    enum ggml_op ops[MAX_FUSE];
+
+    std::vector<node_info> nodes;
+    nodes.reserve(gf->n_nodes);
+
+    // fuse nodes:
+    // we don't want to make reorders that break fusing, so we first pack all fusable tensors
+    //   and perform the reorder over the fused nodes. after the reorder is done, we unfuse
+    for (int i = 0; i < n; i++) {
+        node_info node = {
+            /*.node =*/ gf->nodes[i],
+            /*.fused =*/ {},
+        };
+
+        // fuse only ops that start with these operations
+        // can be expanded when needed
+        if (node.op() == GGML_OP_ADD ||
+            node.op() == GGML_OP_NORM ||
+            node.op() == GGML_OP_RMS_NORM) {
+            ops[0] = node.op();
+
+            int f = i + 1;
+            while (f < n && f < i + MAX_FUSE) {
+                // conservatively allow fusing only these ops
+                // can be expanded when needed
+                if (gf->nodes[f]->op != GGML_OP_ADD &&
+                    gf->nodes[f]->op != GGML_OP_MUL &&
+                    gf->nodes[f]->op != GGML_OP_NORM &&
+                    gf->nodes[f]->op != GGML_OP_RMS_NORM) {
+                    break;
+                }
+                ops[f - i] = gf->nodes[f]->op;
+                f++;
+            }
+
+            f -= i;
+            for (; f > 1; f--) {
+                if (ggml_can_fuse(gf, i, ops, f)) {
+                    break;
+                }
+            }
+
+            // add the fused tensors into the node info so we can unfuse them later
+            for (int k = 1; k < f; k++) {
+                ++i;
+
+                // the .dst() becomes the last fused tensor
+                node.add_fused(gf->nodes[i]);
+            }
+        }
+
+        nodes.push_back(std::move(node));
+    }
+
+    const auto order = ggml_hexagon_graph_optimize_reorder(nodes);
+
+    // unfuse
+    {
+        int j = 0;
+        for (const auto i : order) {
+            const auto & node = nodes[i];
+
+            gf->nodes[j++] = node.node;
+
+            for (auto * fused : node.fused) {
+                gf->nodes[j++] = fused;
+            }
+        }
+    }
+}
+
+static struct ggml_backend_i hexagon_backend_i = {
+    /* .get_name                = */ ggml_backend_hexagon_name,
+    /* .free                    = */ ggml_backend_hexagon_free,
+    /* .set_tensor_async        = */ NULL,
+    /* .get_tensor_async        = */ NULL,
+    /* .cpy_tensor_async        = */ NULL,
+    /* .synchronize             = */ ggml_backend_hexagon_synchronize,
+    /* .graph_plan_create       = */ NULL,
+    /* .graph_plan_free         = */ NULL,
+    /* .graph_plan_update       = */ NULL,
+    /* .graph_plan_compute      = */ NULL,
+    /* .graph_compute           = */ ggml_backend_hexagon_graph_compute,
+    /* .event_record            = */ NULL,
+    /* .event_wait              = */ NULL,
+    /* .graph_optimize          = */ ggml_backend_hexagon_graph_optimize,
+};
+
+static ggml_guid_t ggml_backend_hexagon_guid() {
+    static ggml_guid guid = { 0x7b, 0x57, 0xdc, 0xaf, 0xde, 0x12, 0x1d, 0x49,
+                              0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11 };
+    return &guid;
+}
+
+bool ggml_backend_is_hexagon(ggml_backend_t backend) {
+    return backend && backend->iface.get_name == ggml_backend_hexagon_name;
+}
+
+// device interface
+
+static ggml_backend_t ggml_backend_hexagon_device_init(ggml_backend_dev_t dev, const char * params) {
+    auto sess = static_cast<ggml_hexagon_session *>(dev->context);
+
+    return new ggml_backend{
+        /* .guid      = */ ggml_backend_hexagon_guid(),
+        /* .interface = */ hexagon_backend_i,
+        /* .device    = */ dev,
+        /* .context   = */ sess,
+    };
+
+    GGML_UNUSED(params);
+}
+
+static const char * ggml_backend_hexagon_device_get_name(ggml_backend_dev_t dev) {
+    auto sess = static_cast<ggml_hexagon_session *>(dev->context);
+    return sess->name.c_str();
+
+    GGML_UNUSED(dev);
+}
+
+static const char * ggml_backend_hexagon_device_get_description(ggml_backend_dev_t dev) {
+    return "Hexagon";
+    GGML_UNUSED(dev);
+}
+
+static void ggml_backend_hexagon_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
+    // ~2GB per session for now
+    *free  = 2ULL * 1024 * 1024 * 1024;
+    *total = *free;
+
+    GGML_UNUSED(dev);
+}
+
+static enum ggml_backend_dev_type ggml_backend_hexagon_device_get_type(ggml_backend_dev_t dev) {
+    return GGML_BACKEND_DEVICE_TYPE_GPU;
+
+    GGML_UNUSED(dev);
+}
+
+static void ggml_backend_hexagon_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
+    props->name        = ggml_backend_hexagon_device_get_name(dev);
+    props->description = ggml_backend_hexagon_device_get_description(dev);
+    props->type        = ggml_backend_hexagon_device_get_type(dev);
+    ggml_backend_hexagon_device_get_memory(dev, &props->memory_free, &props->memory_total);
+    props->caps = {
+        /* .async                 = */ true,
+        /* .host_buffer           = */ (bool) opt_hostbuf,
+        /* .buffer_from_host_ptr  = */ false,
+        /* .events                = */ false,
+    };
+}
+
+static ggml_backend_buffer_type_t ggml_backend_hexagon_device_get_buffer_type(ggml_backend_dev_t dev) {
+    auto sess = static_cast<ggml_hexagon_session *>(dev->context);
+    return &sess->buffer_type;
+}
+
+static ggml_backend_buffer_type_t ggml_backend_hexagon_device_get_repack_buffer_type(ggml_backend_dev_t dev) {
+    auto sess = static_cast<ggml_hexagon_session *>(dev->context);
+    return &sess->repack_buffer_type;
+}
+
+static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
+    auto sess = static_cast<ggml_hexagon_session *>(dev->context);
+
+    bool supp = false;
+
+    switch (op->op) {
+        case GGML_OP_NONE:
+        case GGML_OP_RESHAPE:
+        case GGML_OP_VIEW:
+        case GGML_OP_PERMUTE:
+        case GGML_OP_TRANSPOSE:
+            supp = true;
+            break;
+
+        case GGML_OP_MUL_MAT:
+            supp = ggml_hexagon_supported_mul_mat(sess, op);
+            break;
+
+        case GGML_OP_MUL_MAT_ID:
+            supp = ggml_hexagon_supported_mul_mat_id(sess, op);
+            break;
+
+        case GGML_OP_MUL:
+        case GGML_OP_ADD:
+        case GGML_OP_SUB:
+            supp = ggml_hexagon_supported_binary(sess, op);
+            break;
+
+        case GGML_OP_ADD_ID:
+            supp = ggml_hexagon_supported_add_id(sess, op);
+            break;
+
+        case GGML_OP_RMS_NORM:
+            supp = ggml_hexagon_supported_unary(sess, op);
+            break;
+
+        case GGML_OP_SOFT_MAX:
+            supp = ggml_hexagon_supported_softmax(sess, op);
+            break;
+
+        case GGML_OP_UNARY:
+            if (ggml_get_unary_op(op) == GGML_UNARY_OP_SILU) {
+                supp = ggml_hexagon_supported_activations(sess, op);
+            }
+            break;
+
+        case GGML_OP_GLU:
+            if ((ggml_get_glu_op(op) == GGML_GLU_OP_SWIGLU) /* || (ggml_get_glu_op(op) == GGML_GLU_OP_SWIGLU_OAI) */) {
+                supp = ggml_hexagon_supported_activations(sess, op);
+            }
+            break;
+
+        case GGML_OP_ROPE:
+            supp = ggml_hexagon_supported_rope(sess, op);
+            break;
+
+        default:
+            break;
+    }
+
+    if (opt_verbose) {
+        char dims[64 * GGML_MAX_SRC];
+        char strides[64 * GGML_MAX_SRC];
+        char types[16 * GGML_MAX_SRC];
+        char buffs[64 * GGML_MAX_SRC];
+        char names[64 * GGML_MAX_SRC];
+
+        hex_format_op_dims(dims, op);
+        hex_format_op_strides(strides, op);
+        hex_format_op_types(types, op);
+        hex_format_op_buffs(buffs, op);
+        hex_format_op_names(names, op);
+
+        HEX_VERBOSE("ggml-hex: %s device-supports-op %s : %s : %s : %s : %s : %s : (%d)\n", sess->name.c_str(),
+                    ggml_op_name(op->op), names, dims, types, strides, buffs, (int) supp);
+    }
+
+    return supp;
+
+    GGML_UNUSED(dev);
+}
+
+static bool ggml_backend_hexagon_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
+    if (buft->iface.get_alignment != ggml_backend_hexagon_buffer_type_get_alignment) {
+        return false;
+    }
+
+    auto s0 = static_cast<ggml_hexagon_session *>(dev->context);
+    auto s1 = static_cast<ggml_backend_hexagon_buffer_type_context *>(buft->context)->sess;
+
+    // Need session/domain-id for buffers to be compatible
+    bool supp = (s0->session_id == s1->session_id);
+
+    HEX_VERBOSE("ggml-hex: %s device-supports-buft %s (%d)\n", s0->name.c_str(), s1->name.c_str(), (int) supp);
+
+    return supp;
+}
+
+static ggml_backend_buffer_type_t * ggml_backend_hexagon_device_get_extra_buffers_type(ggml_backend_dev_t dev) {
+    auto s0 = static_cast<ggml_hexagon_session *>(dev->context);
+    HEX_VERBOSE("ggml-hex: device-get-extra-buft : %s \n", s0->name.c_str());
+
+    static ggml_backend_buffer_type_t bufts[2];
+    bufts[0] = ggml_backend_hexagon_device_get_repack_buffer_type(dev);
+    bufts[1] = NULL;
+    return bufts;
+}
+
+static const struct ggml_backend_device_i ggml_backend_hexagon_device_i = {
+    /* .get_name             = */ ggml_backend_hexagon_device_get_name,
+    /* .get_description      = */ ggml_backend_hexagon_device_get_description,
+    /* .get_memory           = */ ggml_backend_hexagon_device_get_memory,
+    /* .get_type             = */ ggml_backend_hexagon_device_get_type,
+    /* .get_props            = */ ggml_backend_hexagon_device_get_props,
+    /* .init_backend         = */ ggml_backend_hexagon_device_init,
+    /* .get_buffer_type      = */ ggml_backend_hexagon_device_get_buffer_type,
+    /* .get_host_buffer_type = */ NULL,  // ggml_backend_hexagon_device_get_host_buffer_type,
+    /* .buffer_from_host_ptr = */ NULL,  // ggml_backend_hexagon_device_buffer_from_ptr,
+    /* .supports_op          = */ ggml_backend_hexagon_device_supports_op,
+    /* .supports_buft        = */ ggml_backend_hexagon_device_supports_buft,
+    /* .offload_op           = */ NULL,  // ggml_backend_hexagon_device_offload_op,
+    /* .event_new            = */ NULL,
+    /* .event_free           = */ NULL,
+    /* .event_synchronize    = */ NULL,
+};
+
+//** backend registry
+
+#define GGML_HEXAGON_MAX_SESSIONS 16
+
+struct ggml_hexagon_registry {
+    ggml_hexagon_registry(ggml_backend_reg_t reg);
+    ~ggml_hexagon_registry();
+
+    ggml_backend_device devices[GGML_HEXAGON_MAX_SESSIONS];
+};
+
+ggml_hexagon_registry::ggml_hexagon_registry(ggml_backend_reg_t reg) {
+    GGML_LOG_INFO("ggml-hex: Hexagon backend (experimental) : allocating new registry : ndev %zu\n", opt_ndev);
+
+    if (!opt_arch) {
+        int err = get_hex_arch_ver(CDSP_DOMAIN_ID, &opt_arch);
+        if (err != 0) {
+            GGML_LOG_ERROR("ggml-hex: failed to query HTP version (err %d) defaulting to v73\n", err);
+            opt_arch = 73;
+        }
+    }
+
+    GGML_LOG_INFO("ggml-hex: Hexagon Arch version v%d\n", opt_arch);
+
+    // Create devices / sessions
+    for (size_t i = 0; i < opt_ndev; i++) {
+        devices[i].iface   = ggml_backend_hexagon_device_i;
+        devices[i].reg     = reg;
+        try {
+            devices[i].context = new ggml_hexagon_session(i);
+        } catch (std::exception const &exc) {
+            GGML_LOG_ERROR("ggml-hex: failed to create device/session %zu\n", i);
+            devices[i].context = nullptr;
+        }
+    }
+}
+
+ggml_hexagon_registry::~ggml_hexagon_registry() {
+    GGML_LOG_INFO("ggml-hex: releasing registry\n");
+
+    // Release devices / sessions
+    for (size_t i = 0; i < opt_ndev; i++) {
+        auto sess = static_cast<ggml_hexagon_session *>(devices[i].context);
+        delete sess;
+    }
+}
+
+static const char * ggml_backend_hexagon_reg_get_name(ggml_backend_reg_t reg) {
+    return "HTP";
+    GGML_UNUSED(reg);
+}
+
+static size_t ggml_backend_hexagon_reg_get_device_count(ggml_backend_reg_t reg) {
+    return opt_ndev;
+    GGML_UNUSED(reg);
+}
+
+static ggml_backend_dev_t ggml_backend_hexagon_reg_get_device(ggml_backend_reg_t reg, size_t index) {
+    auto hreg = static_cast<ggml_hexagon_registry *>(reg->context);
+
+    if (index >= opt_ndev || !hreg->devices[index].context) {
+        return nullptr;
+    }
+
+    return &hreg->devices[index];
+}
+
+static void * ggml_backend_hexagon_get_proc_address(ggml_backend_reg_t reg, const char * name) {
+    if (strcmp(name, "ggml_backend_dev_get_extra_bufts") == 0) {
+        ggml_backend_dev_get_extra_bufts_t fct = ggml_backend_hexagon_device_get_extra_buffers_type;
+        return (void *) fct;
+    }
+
+    return NULL;
+}
+
+static void ggml_hexagon_init(ggml_backend_reg * reg) {
+    // Basic sanity checks to make sure definitions match
+    static_assert((unsigned int) HTP_TYPE_Q4_0 == (unsigned int) GGML_TYPE_Q4_0,
+                  "please update hexagon_type to match ggml_type");
+    static_assert((unsigned int) HTP_TYPE_Q8_0 == (unsigned int) GGML_TYPE_Q8_0,
+                  "please update hexagon_type to match ggml_type");
+    static_assert((unsigned int) HTP_TYPE_MXFP4 == (unsigned int) GGML_TYPE_MXFP4,
+                  "please update hexagon_type to match ggml_type");
+
+    const char * str_verbose = getenv("GGML_HEXAGON_VERBOSE");
+    const char * str_hostbuf = getenv("GGML_HEXAGON_HOSTBUF");
+
+    opt_verbose      = str_verbose ? atoi(str_verbose) : 0;
+    opt_profile      = getenv("GGML_HEXAGON_PROFILE") != nullptr;
+    opt_etm          = getenv("GGML_HEXAGON_ETM") != nullptr;
+    opt_experimental = getenv("GGML_HEXAGON_EXPERIMENTAL") != nullptr;
+
+    const char * str_opmask = getenv("GGML_HEXAGON_OPMASK");
+    if (str_opmask != nullptr) {
+        opt_opmask = strtoul(str_opmask, NULL, 0);
+    }
+    opt_opsync = getenv("GGML_HEXAGON_OPSYNC") != nullptr;
+
+    const char * str_ndev = getenv("GGML_HEXAGON_NDEV");
+    if (str_ndev) {
+        opt_ndev = strtoul(str_ndev, NULL, 0);
+        if (opt_ndev > GGML_HEXAGON_MAX_SESSIONS) {
+            opt_ndev = GGML_HEXAGON_MAX_SESSIONS;
+        }
+    }
+
+    const char * str_nhvx = getenv("GGML_HEXAGON_NHVX");
+    if (str_nhvx) {
+        opt_nhvx = strtoul(str_nhvx, NULL, 0);
+    }
+
+    const char * str_arch = getenv("GGML_HEXAGON_ARCH");
+    if (str_arch) {
+        if (str_arch[0] == 'v') {
+            str_arch++;
+        }
+        opt_arch = strtoul(str_arch, NULL, 0);
+    }
+
+    opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : 1;
+
+    reg->context = new ggml_hexagon_registry(reg);
+
+    HEX_VERBOSE("ggml-hex: size-of-general-req %zu size-of-general-rsp %zu\n", sizeof(struct htp_general_req),
+                sizeof(struct htp_general_rsp));
+}
+
+static const struct ggml_backend_reg_i ggml_backend_hexagon_reg_i = {
+    /* .get_name         = */ ggml_backend_hexagon_reg_get_name,
+    /* .get_device_count = */ ggml_backend_hexagon_reg_get_device_count,
+    /* .get_device       = */ ggml_backend_hexagon_reg_get_device,
+    /* .get_proc_address = */ ggml_backend_hexagon_get_proc_address,
+};
+
+ggml_backend_reg_t ggml_backend_hexagon_reg(void) {
+    static bool initialized = false;
+
+    static ggml_backend_reg reg = { /* .api_version = */ GGML_BACKEND_API_VERSION,
+                                    /* .iface       = */ ggml_backend_hexagon_reg_i,
+                                    /* .context     = */ NULL };
+
+    {
+        static std::mutex           mutex;
+        std::lock_guard<std::mutex> lock(mutex);
+        if (!initialized) {
+            ggml_hexagon_init(&reg);
+        }
+
+        initialized = true;
+    }
+
+    return &reg;
+}
+
+GGML_BACKEND_DL_IMPL(ggml_backend_hexagon_reg)
diff --git a/src/ggml-hexagon/htp-utils.c b/src/ggml-hexagon/htp-utils.c
new file mode 100644 (file)
index 0000000..e8a035a
--- /dev/null
@@ -0,0 +1,448 @@
+
+#pragma clang diagnostic ignored "-Wgnu-anonymous-struct"
+#pragma clang diagnostic ignored "-Wmissing-prototypes"
+#pragma clang diagnostic ignored "-Wsign-compare"
+
+#define GGML_COMMON_IMPL_C
+#include "ggml-backend-impl.h"
+#include "ggml-common.h"
+#include "ggml-hexagon.h"
+#include "ggml-impl.h"
+
+#include "htp-utils.h"
+
+#include <domain.h>
+#include <remote.h>
+#include <stdbool.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+
+domain * get_domain(int domain_id) {
+    int i    = 0;
+    int size = sizeof(supported_domains) / sizeof(domain);
+
+    for (i = 0; i < size; i++) {
+        if (supported_domains[i].id == domain_id) {
+            return &supported_domains[i];
+        }
+    }
+
+    return NULL;
+}
+
+bool is_valid_domain_id(int domain_id, int compute_only) {
+    int i    = 0;
+    int size = sizeof(supported_domains) / sizeof(domain);
+
+    if (compute_only) {
+        return is_CDSP(domain_id);
+    }
+
+    for (i = 0; i < size; i++) {
+        if (supported_domains[i].id == domain_id) {
+            return true;
+        }
+    }
+
+    return false;
+}
+
+int get_domains_info(char * domain_type, int * num_domains, fastrpc_domain ** domains_info) {
+    int nErr    = AEE_SUCCESS;
+    int ss_info = 0;
+    if (domain_type != NULL) {
+        if (strcmp(domain_type, "LPASS") == 0) {
+            ss_info = FASTRPC_LPASS;
+        } else if (strcmp(domain_type, "HPASS") == 0) {
+            ss_info = FASTRPC_HPASS;
+        } else {
+            ss_info = FASTRPC_NSP;
+        }
+    }
+    system_req_payload req  = { 0 };
+    req.id                  = FASTRPC_GET_DOMAINS;
+    req.sys.domains         = NULL;
+    fastrpc_domain * domain = NULL;
+    if (ss_info != 0) {
+        req.sys.flags = DOMAINS_LIST_FLAGS_SET_TYPE(req.sys.flags, ss_info);
+    } else {
+        req.sys.flags = 0;
+    }
+#ifdef _WIN32
+    nErr = AEE_EUNSUPPORTED;
+    goto bail;
+#endif
+    if (remote_system_request) {
+        nErr = remote_system_request(&req);
+        if (nErr != AEE_SUCCESS) {
+            GGML_LOG_ERROR("Failure in remote_system_request call: %d.\n", nErr);
+            goto bail;
+        }
+        // Allocate memory for domain-info array
+        req.sys.max_domains = req.sys.num_domains;
+        if ((req.sys.domains = calloc(req.sys.num_domains, sizeof(fastrpc_domain))) == NULL) {
+            nErr = AEE_ENOMEMORY;
+            GGML_LOG_ERROR("Unable to allocate memory for req.sys.domains");
+            goto bail;
+        }
+
+        nErr = remote_system_request(&req);
+        if (nErr != AEE_SUCCESS) {
+            GGML_LOG_ERROR("Failure in remote_system_request call: %d.\n", nErr);
+            goto bail;
+        }
+
+        for (int i = 0; i < req.sys.num_domains; i++) {
+            // Verify that only requested type domains were returned
+            domain = &req.sys.domains[i];
+            if (domain->type != ss_info && domain_type != NULL) {
+                nErr = -1;
+                GGML_LOG_ERROR("Incorrect data received from remote_system_request.\n");
+                goto bail;
+            }
+        }
+        *domains_info = req.sys.domains;
+        *num_domains  = req.sys.num_domains;
+    } else {
+        nErr = AEE_EUNSUPPORTED;
+        goto bail;
+    }
+bail:
+    if (nErr && !req.sys.domains) {
+        free(req.sys.domains);
+    }
+    return nErr;
+}
+
+int get_effective_domain_id(char * domain_name, int session_id, int * effec_domain_id) {
+    int                              err  = 0;
+    remote_rpc_effective_domain_id_t sess = { 0 };
+
+    sess.domain_name     = domain_name;
+    sess.domain_name_len = strlen(domain_name);
+    sess.session_id      = session_id;
+
+    err = remote_session_control(FASTRPC_GET_EFFECTIVE_DOMAIN_ID, &sess, sizeof(sess));
+    if (err) {
+        GGML_LOG_ERROR("Error 0x%x: failed to get effective domain id for %s, session id %d\n", err, sess.domain_name,
+               session_id);
+        return err;
+    }
+
+    *effec_domain_id = sess.effective_domain_id;
+    return err;
+}
+
+int get_dsp_support(int * domain) {
+    int nErr = AEE_SUCCESS;
+    *domain  = CDSP_DOMAIN_ID;  // DSP domain default value is CDSP_DOMAIN_ID
+
+    if (remote_handle_control) {
+        struct remote_dsp_capability dsp_capability_domain = { CDSP_DOMAIN_ID, DOMAIN_SUPPORT, 0 };
+        nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_domain, sizeof(struct remote_dsp_capability));
+        if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
+            GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
+            goto bail;
+        }
+
+        if (dsp_capability_domain.capability == 0) {
+            dsp_capability_domain.domain       = ADSP_DOMAIN_ID;  // Check for ADSP support.
+            dsp_capability_domain.attribute_ID = DOMAIN_SUPPORT;
+            dsp_capability_domain.capability   = 0;
+            nErr                               = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_domain,
+                                                                       sizeof(struct remote_dsp_capability));
+            if (dsp_capability_domain.capability) {
+                *domain = ADSP_DOMAIN_ID;  // For targets like Agatti (not having cDSP), domain is ADSP_DOMAIN_ID
+            }
+        }
+
+        if (nErr != AEE_SUCCESS) {
+            GGML_LOG_ERROR("\nget_dsp_support failed with Error 0x%x\n", nErr);
+            goto bail;
+        }
+    } else {
+        nErr = AEE_EUNSUPPORTEDAPI;
+        GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
+    }
+
+bail:
+    return nErr;
+}
+
+int get_vtcm_info(int domain, uint32_t * capability, uint32_t attr) {
+    int nErr    = AEE_SUCCESS;
+    *capability = 0;
+
+    if (attr == VTCM_PAGE || attr == VTCM_COUNT) {
+    } else {
+        nErr = AEE_EBADPARM;
+        GGML_LOG_ERROR("Unsupported attr. Only VTCM_PAGE and VTCM_COUNT supported\n");
+        goto bail;
+    }
+    if (remote_handle_control) {
+        if (domain == ADSP_DOMAIN_ID || domain == CDSP_DOMAIN_ID) {
+            /*
+            * Query the DSP for VTCM information
+            * Since the ADSP does not have a dedicated VTCM, we expect the output to be 0
+            */
+            struct remote_dsp_capability dsp_capability_vtcm_dsp;
+            dsp_capability_vtcm_dsp.domain       = (uint32_t) domain;
+            dsp_capability_vtcm_dsp.attribute_ID = attr;
+            dsp_capability_vtcm_dsp.capability   = (uint32_t) 0;
+            nErr                                 = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_vtcm_dsp,
+                                                                         sizeof(struct remote_dsp_capability));
+            if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
+                GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
+                GGML_LOG_ERROR("Running the usecase without checking the capability\n");
+                nErr = AEE_SUCCESS;
+                goto bail;
+            } else if (nErr == AEE_SUCCESS) {
+                *capability = dsp_capability_vtcm_dsp.capability;
+            } else {
+                GGML_LOG_ERROR("\nget_vtcm_info failed with Error 0x%x\n", nErr);
+                goto bail;
+            }
+        } else {
+            nErr = AEE_EUNSUPPORTED;
+            GGML_LOG_ERROR("Unsupported domain %d\n", domain);
+            goto bail;
+        }
+    } else {
+        nErr = AEE_EUNSUPPORTEDAPI;
+        GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
+    }
+
+bail:
+    return nErr;
+}
+
+bool is_unsignedpd_supported(int domain_id) {
+    int nErr = AEE_SUCCESS;
+    if (remote_handle_control) {
+        struct remote_dsp_capability dsp_capability_domain = { domain_id, UNSIGNED_PD_SUPPORT, 0 };
+        nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_domain, sizeof(struct remote_dsp_capability));
+        if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
+            GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device. Falling back to signed pd.\n");
+            return false;
+        }
+        if (nErr) {
+            GGML_LOG_ERROR("\nERROR 0x%x: FastRPC Capability API failed. Falling back to signed pd.", nErr);
+            return false;
+        }
+        if (dsp_capability_domain.capability == 1) {
+            return true;
+        }
+    } else {
+        nErr = AEE_EUNSUPPORTEDAPI;
+        GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device. Falling back to signed pd.\n");
+        return false;
+    }
+    return false;
+}
+
+bool get_unsignedpd_support(void) {
+    return is_unsignedpd_supported(CDSP_DOMAIN_ID);
+}
+
+bool is_async_fastrpc_supported(int domain) {
+    int nErr = AEE_SUCCESS;
+    if (remote_handle_control) {
+        if (domain == CDSP_DOMAIN_ID) {
+            /*
+            * Query the DSP for ASYNC_FASTRPC_SUPPORT information
+            * Async fastrpc is supported only on CDSP
+            */
+            struct remote_dsp_capability dsp_capability_async_support;
+            dsp_capability_async_support.domain       = (uint32_t) domain;
+            dsp_capability_async_support.attribute_ID = ASYNC_FASTRPC_SUPPORT;
+            dsp_capability_async_support.capability   = (uint32_t) 0;
+            nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_async_support,
+                                         sizeof(struct remote_dsp_capability));
+            if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
+                GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
+                GGML_LOG_ERROR("Running the usecase without checking the capability\n");
+                nErr = AEE_SUCCESS;
+                goto bail;
+            } else if (dsp_capability_async_support.capability == 1) {
+                return true;
+            }
+            if (nErr != AEE_SUCCESS) {
+                GGML_LOG_ERROR("\nis_async_fastrpc_supported failed with Error 0x%x\n", nErr);
+                goto bail;
+            }
+        } else {
+            nErr = AEE_EUNSUPPORTED;
+            GGML_LOG_ERROR("Async fastrpc is not supported on domain %d\n", domain);
+            goto bail;
+        }
+    } else {
+        nErr = AEE_EUNSUPPORTEDAPI;
+        GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
+    }
+
+bail:
+    return false;
+}
+
+bool is_status_notification_supported(int domain) {
+    int nErr = AEE_SUCCESS;
+
+    if (remote_handle_control) {
+        /*
+        * Query the DSP for STATUS_NOTIFICATION_SUPPORT information
+        * DSP User PD status notification Support
+        */
+        struct remote_dsp_capability dsp_capability_status_notification_support;
+        dsp_capability_status_notification_support.domain       = (uint32_t) domain;
+        dsp_capability_status_notification_support.attribute_ID = STATUS_NOTIFICATION_SUPPORT;
+        dsp_capability_status_notification_support.capability   = (uint32_t) 0;
+        nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_status_notification_support,
+                                     sizeof(struct remote_dsp_capability));
+        if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
+            GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
+            GGML_LOG_ERROR("Running the usecase without checking the capability\n");
+            nErr = AEE_SUCCESS;
+            goto bail;
+        } else if (dsp_capability_status_notification_support.capability == 1) {
+            return true;
+        }
+        if (nErr != AEE_SUCCESS) {
+            GGML_LOG_ERROR("\nis_status_notification_supported failed with Error 0x%x\n", nErr);
+            goto bail;
+        }
+    } else {
+        nErr = AEE_EUNSUPPORTEDAPI;
+        GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
+    }
+
+bail:
+    return false;
+}
+
+int get_hmx_support_info(int domain, uint32_t * capability, uint32_t attr) {
+    int nErr    = AEE_SUCCESS;
+    *capability = 0;
+
+    if (attr != HMX_SUPPORT_SPATIAL && attr != HMX_SUPPORT_DEPTH) {
+        nErr = AEE_EBADPARM;
+        GGML_LOG_ERROR("Unsupported attr. Only HMX_SUPPORT_SPATIAL and HMX_SUPPORT_DEPTH supported\n");
+        goto bail;
+    }
+    if (remote_handle_control) {
+        if (domain == CDSP_DOMAIN_ID) {
+            /*
+            * Query the DSP for HMX SUPPORT information
+            * HMX is supported on CDSP only
+            */
+            struct remote_dsp_capability dsp_capability_hmx_dsp;
+            dsp_capability_hmx_dsp.domain       = (uint32_t) domain;
+            dsp_capability_hmx_dsp.attribute_ID = attr;
+            dsp_capability_hmx_dsp.capability   = (uint32_t) 0;
+            nErr                                = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_hmx_dsp,
+                                                                        sizeof(struct remote_dsp_capability));
+            if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
+                GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
+                GGML_LOG_ERROR("Running the usecase without checking the capability\n");
+                nErr = AEE_SUCCESS;
+                goto bail;
+            } else if (nErr == AEE_SUCCESS) {
+                *capability = dsp_capability_hmx_dsp.capability;
+            } else {
+                GGML_LOG_ERROR("\nget_hmx_support_info failed with Error 0x%x\n", nErr);
+                goto bail;
+            }
+        } else {
+            nErr = AEE_EUNSUPPORTED;
+            GGML_LOG_ERROR("HMX support is not there for domain %d\n", domain);
+            goto bail;
+        }
+    } else {
+        nErr = AEE_EUNSUPPORTEDAPI;
+        GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
+    }
+
+bail:
+    return nErr;
+}
+
+int get_hex_arch_ver(int domain, int * arch) {
+    if (!remote_handle_control) {
+        GGML_LOG_ERROR("ggml-hex: remote_handle_control is not supported on this device\n");
+        return AEE_EUNSUPPORTEDAPI;
+    }
+
+    struct remote_dsp_capability arch_ver;
+    arch_ver.domain       = (uint32_t) domain;
+    arch_ver.attribute_ID = ARCH_VER;
+    arch_ver.capability   = (uint32_t) 0;
+
+    int err = remote_handle_control(DSPRPC_GET_DSP_INFO, &arch_ver, sizeof(arch_ver));
+    if ((err & 0xff) == (AEE_EUNSUPPORTEDAPI & 0xff)) {
+        GGML_LOG_ERROR("ggml-hex: FastRPC capability API is not supported on this device\n");
+        return AEE_EUNSUPPORTEDAPI;
+    }
+
+    if (err != AEE_SUCCESS) {
+        GGML_LOG_ERROR("ggml-hex: FastRPC capability query failed (err %d)\n", err);
+        return err;
+    }
+
+    switch (arch_ver.capability & 0xff) {
+        case 0x73:
+            *arch = 73;
+            return 0;
+        case 0x75:
+            *arch = 75;
+            return 0;
+        case 0x79:
+            *arch = 79;
+            return 0;
+        case 0x81:
+            *arch = 81;
+            return 0;
+    }
+    return -1;
+}
+
+int get_hvx_support_info(int domain, uint32_t * capability, uint32_t attr) {
+    int nErr    = AEE_SUCCESS;
+    *capability = 0;
+
+    if (remote_handle_control) {
+        if (domain == CDSP_DOMAIN_ID) {
+            /*
+            * Query the DSP for HVX SUPPORT information
+            * HVX is supported on CDSP only
+            */
+            struct remote_dsp_capability dsp_capability_hvx_dsp;
+            dsp_capability_hvx_dsp.domain       = (uint32_t) domain;
+            dsp_capability_hvx_dsp.attribute_ID = attr;
+            dsp_capability_hvx_dsp.capability   = (uint32_t) 0;
+            nErr                                = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_hvx_dsp,
+                                                                        sizeof(struct remote_dsp_capability));
+            if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
+                GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
+                GGML_LOG_ERROR("Running the usecase without checking the capability\n");
+                nErr = AEE_SUCCESS;
+                goto bail;
+            } else if (nErr == AEE_SUCCESS) {
+                *capability = dsp_capability_hvx_dsp.capability;
+            } else {
+                GGML_LOG_ERROR("\nget_hvx_support_info failed with Error 0x%x\n", nErr);
+                goto bail;
+            }
+        } else {
+            nErr = AEE_EUNSUPPORTED;
+            GGML_LOG_ERROR("HVX support is not available on domain %d\n", domain);
+            goto bail;
+        }
+    } else {
+        nErr = AEE_EUNSUPPORTEDAPI;
+        GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
+    }
+
+bail:
+    return nErr;
+}
diff --git a/src/ggml-hexagon/htp-utils.h b/src/ggml-hexagon/htp-utils.h
new file mode 100644 (file)
index 0000000..66f9fd3
--- /dev/null
@@ -0,0 +1,219 @@
+#ifndef HTP_UTILS_H
+#define HTP_UTILS_H
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+#include <AEEStdErr.h>
+#include <inttypes.h>
+#include <remote.h>
+#include <stdbool.h>
+
+/* Offset to differentiate HLOS and Hexagon error codes.
+   Stores the value of AEE_EOFFSET for Hexagon. */
+#ifndef DSP_OFFSET
+#    define DSP_OFFSET 0x80000400
+#endif
+
+/* Errno for connection reset by peer. */
+#ifndef ECONNRESET
+#    ifdef __hexagon__
+#        define ECONNRESET 104
+#    endif
+#endif
+
+/* Abstraction of different OS specific sleep APIs.
+   SLEEP accepts input in seconds. */
+#ifndef SLEEP
+#    ifdef __hexagon__
+#        define SLEEP(x)                      \
+            { /* Do nothing for simulator. */ \
+            }
+#    else
+#        ifdef _WINDOWS
+#            define SLEEP(x) Sleep(1000 * x) /* Sleep accepts input in milliseconds. */
+#        else
+#            define SLEEP(x) sleep(x)        /* sleep accepts input in seconds. */
+#        endif
+#    endif
+#endif
+
+/* Include windows specific header files. */
+#ifdef _WINDOWS
+#    include <sysinfoapi.h>
+#    include <windows.h>
+#    define _CRT_SECURE_NO_WARNINGS         1
+#    define _WINSOCK_DEPRECATED_NO_WARNINGS 1
+/* Including this file for custom implementation of getopt function. */
+#    include "getopt_custom.h"
+#endif
+
+/* Includes and defines for all HLOS except windows */
+#if !defined(__hexagon__) && !defined(_WINDOWS)
+#    include "unistd.h"
+
+#    include <sys/time.h>
+#endif
+
+/* Includes and defines for Hexagon and all HLOS except Windows. */
+#if !defined(_WINDOWS)
+/* Weak reference to remote symbol for compilation. */
+#    pragma weak remote_session_control
+#    pragma weak remote_handle_control
+#    pragma weak remote_handle64_control
+#    pragma weak fastrpc_mmap
+#    pragma weak fastrpc_munmap
+#endif
+
+#if !defined(_WINDOWS)
+#    pragma weak remote_system_request
+#endif
+/**
+ * Wrapper for FastRPC Capability API: query DSP support.
+ *
+ * @param[out]  domain pointer to supported domain.
+ * @return      0          if query is successful.
+ *              non-zero   if error, return value points to the error.
+ */
+int get_dsp_support(int * domain);
+
+/**
+ * Wrapper for FastRPC Capability API: query VTCM information.
+ *
+ * @param[in]   domain value of domain in the queried.
+ * @param[out]  capability capability value of the attribute queried.
+ * @param[in]   attr value of the attribute to the queried.
+ * @return      0          if query is successful.
+ *              non-zero   if error, return value points to the error.
+ */
+int get_vtcm_info(int domain, uint32_t * capability, uint32_t attr);
+
+/**
+ * Wrapper for FastRPC Capability API: query unsigned pd support on CDSP domain.
+ *
+ * @return      true          if unsigned pd is supported.
+ *              false         if unsigned pd is not supported, capability query failed.
+ */
+
+bool get_unsignedpd_support(void);
+
+/**
+ * Wrapper for FastRPC Capability API: query unsigned pd support.
+ *
+ * @param[in]   domain value of domain in the queried.
+ * @return      true          if unsigned pd is supported.
+ *              false         if unsigned pd is not supported, capability query failed.
+ */
+
+bool is_unsignedpd_supported(int domain_id);
+
+/**
+ * is_valid_domain_id API: query a domain id is valid.
+ *
+ * @param[in]   domain value of domain in the queried.
+ * @param[in]   compute_only value of domain is only compared with CDSP domains supported by the target when enabled.
+ * @return      true          if value of domain is valid.
+ *              false         if value of domain is not valid.
+ */
+
+bool is_valid_domain_id(int domain_id, int compute_only);
+
+/**
+ * get_domain API: get domain struct from domain value.
+ *
+ * @param[in]  domain value of a domain
+ * @return     Returns domain struct of the domain if it is supported or else
+ *             returns NULL.
+ *
+ */
+
+domain * get_domain(int domain_id);
+
+/**
+ * get_domains_info API: get information for all the domains available on the device
+ *
+ * @param[in]  domain_type pointer to domain type
+ * @param[in]  num_domains pointer to number of domains
+ * @param[in]  domains_info pointer to save discovered domains information.
+ * @return     0 if query is successful.
+ *              non-zero if error, return value points to the error.
+ *
+ * It is user's responsibility to free the memory used to store the domains info whose address is present in domains_info before closing the application.
+ *
+ */
+
+int get_domains_info(char * domain_type, int * num_domains, fastrpc_domain ** domains_info);
+
+/**
+ * get_effective_domain_id API: get effective domain id for given session id
+ *
+ * @param[in]  domain_name pointer to domain name
+ * @param[in]  session_id
+ * @param[in]  effec_domain_id pointer to save obtained effective domain id.
+ * @return     0 if query is successful.
+ *              non-zero if error, return value points to the error.
+ *
+ */
+
+int get_effective_domain_id(char * domain_name, int session_id, int * effec_domain_id);
+
+/**
+ * is_async_fastrpc_supported API: query a domain id has async fastrpc supported or not
+ *
+ * @param[in]  domain_id value of a domain
+ * @return     Returns true or false stating support of Async FastRPC
+ *
+ */
+
+bool is_async_fastrpc_supported(int domain_id);
+
+/**
+ * is_status_notification_supported API: query the DSP for STATUS_NOTIFICATION_SUPPORT information
+ *
+ * @param[in]  domain_id value of a domain
+ * @return     Returns true or false stating status notification support information
+ *
+ */
+bool is_status_notification_supported(int domain_id);
+
+/**
+ * get_hmx_support_info API: query the DSP for HMX SUPPORT information
+ *
+ * @param[in]   domain_id value of a domain
+ * @param[out]  capability capability value of the attribute queried.
+ * @param[in]   attr value of the attribute to the queried.
+ * @return      0 if query is successful.
+ *              non-zero if error, return value points to the error.
+ *
+ */
+int get_hmx_support_info(int domain, uint32_t * capability, uint32_t attr);
+
+/**
+ * get_hex_arch_ver API: query the Hexagon processor architecture version information
+ *
+ * @param[in]   domain_id value of a domain
+ * @param[out]  Arch version (73, 75, ...)
+ * @return      0 if query is successful.
+ *              non-zero if error, return value points to the error.
+ *
+ */
+int get_hex_arch_ver(int domain, int * arch);
+
+/**
+ * get_hvx_support_info API: query the DSP for HVX SUPPORT information
+ *
+ * @param[in]   domain_id value of a domain
+ * @param[out]  capability capability value of the attribute queried.
+ * @param[in]   attr value of the attribute to the queried.
+ * @return      0 if query is successful.
+ *              non-zero if error, return value points to the error.
+ *
+ */
+int get_hvx_support_info(int domain, uint32_t * capability, uint32_t attr);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif  //DSP_CAPABILITIES_UTILS_H
diff --git a/src/ggml-hexagon/htp/CMakeLists.txt b/src/ggml-hexagon/htp/CMakeLists.txt
new file mode 100644 (file)
index 0000000..22e3fea
--- /dev/null
@@ -0,0 +1,40 @@
+cmake_minimum_required(VERSION 3.22.2)
+project(ggml-htp C CXX ASM)
+
+include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake)
+
+include_directories(
+    ${HEXAGON_SDK_ROOT}/incs
+    ${HEXAGON_SDK_ROOT}/incs/stddef
+    ${CMAKE_CURRENT_SOURCE_DIR}/../..
+    ${CMAKE_CURRENT_SOURCE_DIR}/..
+    ${CMAKE_CURRENT_SOURCE_DIR}
+    ${CMAKE_CURRENT_BINARY_DIR})
+
+set(HTP_LIB ggml-htp-${DSP_VERSION})
+
+add_library(${HTP_LIB} SHARED
+    main.c
+    htp_iface_skel.c
+    worker-pool.c
+    htp-dma.c
+    hvx-sigmoid.c
+    hvx-inverse.c
+    hvx-exp.c
+    hvx-utils.c
+    matmul-ops.c
+    binary-ops.c
+    unary-ops.c
+    softmax-ops.c
+    act-ops.c
+    rope-ops.c
+)
+
+target_compile_definitions(${HTP_LIB} PRIVATE
+    $<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,HTP_DEBUG=1,NDEBUG=1>)
+
+build_idl(htp_iface.idl ${HTP_LIB})
+
+set_target_properties(${HTP_LIB} PROPERTIES EXPORT_COMPILE_COMMANDS ON)
+
+install(TARGETS ${HTP_LIB})
diff --git a/src/ggml-hexagon/htp/act-ops.c b/src/ggml-hexagon/htp/act-ops.c
new file mode 100644 (file)
index 0000000..1604497
--- /dev/null
@@ -0,0 +1,448 @@
+#pragma clang diagnostic ignored "-Wunused-variable"
+#pragma clang diagnostic ignored "-Wunused-function"
+#pragma clang diagnostic ignored "-Wunused-but-set-variable"
+
+#ifdef HTP_DEBUG
+#    define FARF_HIGH 1
+#endif
+#include <HAP_farf.h>
+#include <HAP_mem.h>
+#include <HAP_perf.h>
+#include <HAP_ps.h>
+#include <hexagon_protos.h>
+#include <hexagon_types.h>
+#include <math.h>
+#include <qurt_thread.h>
+#include <string.h>
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+#include "htp-ctx.h"
+#include "htp-dma.h"
+#include "htp-msg.h"
+#include "htp-ops.h"
+#include "hvx-utils.h"
+#include "ops-utils.h"
+
+#define htp_act_preamble3              \
+    const uint32_t ne00 = src0->ne[0]; \
+    const uint32_t ne01 = src0->ne[1]; \
+    const uint32_t ne02 = src0->ne[2]; \
+    const uint32_t ne03 = src0->ne[3]; \
+                                       \
+    const uint32_t ne10 = src1->ne[0]; \
+    const uint32_t ne11 = src1->ne[1]; \
+    const uint32_t ne12 = src1->ne[2]; \
+    const uint32_t ne13 = src1->ne[3]; \
+                                       \
+    const uint32_t ne0 = dst->ne[0];   \
+    const uint32_t ne1 = dst->ne[1];   \
+    const uint32_t ne2 = dst->ne[2];   \
+    const uint32_t ne3 = dst->ne[3];   \
+                                       \
+    const uint32_t nb00 = src0->nb[0]; \
+    const uint32_t nb01 = src0->nb[1]; \
+    const uint32_t nb02 = src0->nb[2]; \
+    const uint32_t nb03 = src0->nb[3]; \
+                                       \
+    const uint32_t nb10 = src1->nb[0]; \
+    const uint32_t nb11 = src1->nb[1]; \
+    const uint32_t nb12 = src1->nb[2]; \
+    const uint32_t nb13 = src1->nb[3]; \
+                                       \
+    const uint32_t nb0 = dst->nb[0];   \
+    const uint32_t nb1 = dst->nb[1];   \
+    const uint32_t nb2 = dst->nb[2];   \
+    const uint32_t nb3 = dst->nb[3];
+
+#define htp_act_preamble2              \
+    const uint32_t ne00 = src0->ne[0]; \
+    const uint32_t ne01 = src0->ne[1]; \
+    const uint32_t ne02 = src0->ne[2]; \
+    const uint32_t ne03 = src0->ne[3]; \
+                                       \
+    const uint32_t ne0 = dst->ne[0];   \
+    const uint32_t ne1 = dst->ne[1];   \
+    const uint32_t ne2 = dst->ne[2];   \
+    const uint32_t ne3 = dst->ne[3];   \
+                                       \
+    const uint32_t nb00 = src0->nb[0]; \
+    const uint32_t nb01 = src0->nb[1]; \
+    const uint32_t nb02 = src0->nb[2]; \
+    const uint32_t nb03 = src0->nb[3]; \
+                                       \
+    const uint32_t nb0 = dst->nb[0];   \
+    const uint32_t nb1 = dst->nb[1];   \
+    const uint32_t nb2 = dst->nb[2];   \
+    const uint32_t nb3 = dst->nb[3];
+
+static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0,
+                                       const struct htp_tensor * src1,
+                                       struct htp_tensor *       dst,
+                                       const int32_t *           op_params,
+                                       struct htp_spad *         src0_spad,
+                                       struct htp_spad *         src1_spad,
+                                       struct htp_spad *         dst_spad,
+                                       uint32_t                  nth,
+                                       uint32_t                  ith,
+                                       uint32_t                  src0_nrows_per_thread) {
+    htp_act_preamble3;
+
+    size_t src0_row_size = nb01;
+    size_t src1_row_size = nb11;
+    size_t dst_row_size  = nb1;
+
+    const uint32_t src0_nrows = ne01 * ne02 * ne03;  // src0 rows
+
+    const uint32_t src0_start_row = src0_nrows_per_thread * ith;
+    const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
+
+    // no work for this thread
+    if (src0_start_row >= src0_end_row) {
+        return;
+    }
+
+    uint64_t t1, t2;
+    t1 = HAP_perf_get_qtimer_count();
+
+    int is_aligned = 1;
+    int opt_path   = 0;
+    if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) {
+        is_aligned = 0;
+        FARF(HIGH, "swiglu-f32: unaligned addresses in elementwise op, possibly slower execution\n");
+    }
+    if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
+        opt_path = 1;
+    }
+
+    const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
+    const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
+    uint8_t * restrict data_dst        = (uint8_t *) dst->data;
+
+    bool src1_valid = src1->ne[0];
+    if (!src1_valid) {
+        data_src1     = data_src0;
+        src1_row_size = src0_row_size;
+    }
+
+    uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_row_size);
+    uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_row_size);
+    uint8_t * restrict dst_spad_data  = dst_spad->data + (ith * dst_row_size);
+
+    const int32_t swapped = op_params[1];
+
+    const int nc = (src1_valid) ? ne0 : ne0 / 2;
+
+    for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
+        const float * restrict src0 = (float *) (data_src0 + (ir * src0_row_size));
+        const float * restrict src1 = (float *) (data_src1 + (ir * src1_row_size));
+        float * restrict dst        = (float *) (data_dst + (ir * dst_row_size));
+
+        if (ir + 1 < src0_end_row) {
+            htp_l2fetch(src0 + src0_row_size, 1, src0_row_size, src0_row_size);
+        }
+
+        if (!src1_valid) {
+            src0 += swapped ? nc : 0;
+            src1 += swapped ? 0 : nc;
+        }
+
+        if (1 == opt_path) {
+            hvx_fast_sigmoid_f32((const uint8_t *) src0, (uint8_t *) src0_spad_data, nc);
+            hvx_mul_mul_f32_opt((const uint8_t *) src0, (const uint8_t *) src0_spad_data, (const uint8_t *) src1,
+                                (uint8_t *) dst, nc);
+        } else {
+            hvx_exp_f32((const uint8_t *) src0, src0_spad_data, nc, true);
+            hvx_add_scalar_f32(src0_spad_data, 1.0, src1_spad_data, nc);
+            hvx_inverse_f32(src1_spad_data, src0_spad_data, nc);
+
+            hvx_mul_f32((const uint8_t *) src0, src0_spad_data, dst_spad_data, nc);
+            hvx_mul_f32(dst_spad_data, (const uint8_t *) src1, (uint8_t *) dst, nc);
+        }
+    }
+
+    t2 = HAP_perf_get_qtimer_count();
+
+    FARF(HIGH, "swiglu-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path,
+         ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3,
+         (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0,
+                                           const struct htp_tensor * src1,
+                                           struct htp_tensor *       dst,
+                                           const int32_t *           op_params,
+                                           struct htp_spad *         src0_spad,
+                                           struct htp_spad *         src1_spad,
+                                           struct htp_spad *         dst_spad,
+                                           uint32_t                  nth,
+                                           uint32_t                  ith,
+                                           uint32_t                  src0_nrows_per_thread) {
+    htp_act_preamble3;
+
+    uint64_t t1, t2;
+    t1 = HAP_perf_get_qtimer_count();
+
+    const size_t src0_row_size = nb01;
+    const size_t src1_row_size = nb11;
+    const size_t dst_row_size  = nb1;
+
+    const uint32_t src0_nrows = ne01 * ne02 * ne03;  // src0 rows
+
+    const uint32_t src0_start_row = src0_nrows_per_thread * ith;
+    const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
+
+    // no work for this thread
+    if (src0_start_row >= src0_end_row) {
+        return;
+    }
+
+    if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) {
+        FARF(HIGH, "act-f32: unaligned addresses in activations op, possibly slower execution\n");
+    }
+
+    const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
+    const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
+    uint8_t * restrict data_dst        = (uint8_t *) dst->data;
+
+    bool src1_valid = src1->ne[0];
+    if (!src1_valid) {
+        data_src1 = data_src0;
+    }
+
+    uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_row_size);
+    uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_row_size);
+    uint8_t * restrict dst_spad_data  = dst_spad->data + (ith * dst_row_size);
+
+    const int32_t swapped = op_params[1];
+    const float   alpha   = ((const float *) (op_params))[2];
+    const float   limit   = ((const float *) (op_params))[3];
+
+    const int nc = (src1_valid) ? ne0 : ne0 / 2;
+
+    for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
+        const float * restrict src0 = (float *) (data_src0 + (ir * src0_row_size));
+        const float * restrict src1 = (float *) (data_src1 + (ir * src1_row_size));
+        float * restrict dst        = (float *) (data_dst + (ir * dst_row_size));
+
+        if (ir + 1 < src0_end_row) {
+            htp_l2fetch(src0 + src0_row_size, 1, src0_row_size, src0_row_size);
+        }
+
+        if (!src1) {
+            src0 += swapped ? nc : 0;
+            src1 += swapped ? 0 : nc;
+        }
+
+        // x (src0_spad_data) = std::min(src0_p[k], limit);
+        hvx_min_scalar_f32((const uint8_t *) src0, limit, src0_spad_data, nc);
+        // y1 (src1_spad_data) = std::clamp(src1_p[k], -limit, limit);
+        hvx_clamp_scalar_f32((const uint8_t *) src1, limit, limit, src1_spad_data, nc);
+        // y (src1_spad_data)  = y1 + 1.f
+        hvx_add_scalar_f32(src1_spad_data, 1.0, src1_spad_data, nc);
+        // x1 (dst_spad_data) = alpha * (x)
+        hvx_mul_scalar_f32(src0_spad_data, alpha, dst_spad_data, nc);
+        // x2 (dst_spad_data) = expf(-x1)
+        hvx_exp_f32(dst_spad_data, dst_spad_data, nc, true);
+        // x3 (dst_spad_data) = x2 + 1.f
+        hvx_add_scalar_f32(dst_spad_data, 1.0, dst_spad_data, nc);
+        // x4 (dst_spad_data) = 1 / x3
+        hvx_inverse_f32(dst_spad_data, dst_spad_data, nc);
+        // out_glu(dst_spad_data) = x * x4
+        hvx_mul_f32(src0_spad_data, dst_spad_data, dst_spad_data, nc);
+        // out = out_glu * (y + 1.f);
+        hvx_mul_f32(dst_spad_data, src1_spad_data, (uint8_t *) dst, nc);
+    }
+
+    t2 = HAP_perf_get_qtimer_count();
+
+    FARF(HIGH, "swiglu-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, src0->ne[0],
+         src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], src1->ne[2],
+         src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+static void unary_silu_fp32_per_thread(const struct htp_tensor * src0,
+                                       struct htp_tensor *       dst,
+                                       const int32_t *           op_params,
+                                       struct htp_spad *         src0_spad,
+                                       struct htp_spad *         dst_spad,
+                                       uint32_t                  nth,
+                                       uint32_t                  ith,
+                                       uint32_t                  src0_nrows_per_thread) {
+    htp_act_preamble2;
+
+    uint64_t t1, t2;
+    t1 = HAP_perf_get_qtimer_count();
+
+    const size_t src0_row_size = nb01;
+    const size_t dst_row_size  = nb1;
+
+    const uint32_t src0_nrows = ne01 * ne02 * ne03;
+
+    const uint32_t src0_start_row = src0_nrows_per_thread * ith;
+    const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
+
+    // no work for this thread
+    if (src0_start_row >= src0_end_row) {
+        return;
+    }
+
+    int is_aligned = 1;
+    int opt_path   = 0;
+    if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) {
+        is_aligned = 0;
+        FARF(HIGH, "silu-f32: unaligned addresses in elementwise op, possibly slower execution\n");
+    }
+    if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
+        opt_path = 1;
+    }
+
+    const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
+    uint8_t * restrict data_dst        = (uint8_t *) dst->data;
+
+    uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_row_size);
+    uint8_t * restrict dst_spad_data  = dst_spad->data + (ith * dst_row_size);
+
+    for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
+        const float * restrict src0 = (float *) (data_src0 + (ir * src0_row_size));
+        float * restrict dst        = (float *) (data_dst + (ir * dst_row_size));
+
+        if (ir + 1 < src0_end_row) {
+            htp_l2fetch(src0 + src0_row_size, 1, src0_row_size, src0_row_size);
+        }
+
+        if (1 == opt_path) {
+            hvx_fast_sigmoid_f32((const uint8_t *) src0, (uint8_t *) src0_spad_data, ne0);
+            hvx_mul_f32_opt((const uint8_t *) src0, src0_spad_data, (uint8_t *) dst, ne0);
+        } else {
+            hvx_exp_f32((const uint8_t *) src0, src0_spad_data, ne0, true);
+            hvx_add_scalar_f32(src0_spad_data, 1.0, dst_spad_data, ne0);
+            hvx_inverse_f32(dst_spad_data, src0_spad_data, ne0);
+
+            hvx_mul_f32((const uint8_t *) src0, src0_spad_data, (uint8_t *) dst, ne0);
+        }
+    }
+
+    t2 = HAP_perf_get_qtimer_count();
+
+    FARF(HIGH, "silu-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, ne00, ne01, ne02,
+         ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+static void unary_silu_fp32(unsigned int n, unsigned int i, void * data) {
+    struct htp_ops_context * octx = (struct htp_ops_context *) data;
+    unary_silu_fp32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
+                               octx->src0_nrows_per_thread);
+}
+
+static void glu_swiglu_fp32(unsigned int n, unsigned int i, void * data) {
+    struct htp_ops_context * octx = (struct htp_ops_context *) data;
+    glu_swiglu_fp32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
+                               &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread);
+}
+
+static void glu_swiglu_oai_fp32(unsigned int n, unsigned int i, void * data) {
+    struct htp_ops_context * octx = (struct htp_ops_context *) data;
+    glu_swiglu_oai_fp32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
+                                   &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread);
+}
+
+static int execute_op_activations_fp32(struct htp_ops_context * octx) {
+    int err = HTP_STATUS_OK;
+
+    const struct htp_tensor * src0 = &octx->src0;
+    const struct htp_tensor * src1 = &octx->src1;
+    struct htp_tensor *       dst  = &octx->dst;
+
+    if (((src0->ne[0] * SIZEOF_FP32) != src0->nb[1]) || ((dst->ne[0] * SIZEOF_FP32) != dst->nb[1])) {
+        FARF(ERROR, "Non-contiguous tensors are not supported at this time \n");
+        return HTP_STATUS_NO_SUPPORT;
+    }
+
+    worker_callback_t act_op_func;
+    const char *      op_type = NULL;
+
+    switch (octx->op) {
+        case HTP_OP_UNARY_SILU:
+            act_op_func = unary_silu_fp32;
+            op_type     = "silu-f32";
+            break;
+
+        case HTP_OP_GLU_SWIGLU:
+            act_op_func = glu_swiglu_fp32;
+            op_type     = "swiglu-f32";
+            break;
+
+        case HTP_OP_GLU_SWIGLU_OAI:
+            act_op_func = glu_swiglu_oai_fp32;
+            op_type     = "swiglu-oai-f32";
+            break;
+
+        default:
+            FARF(ERROR, "Unsupported activations Op %u\n", octx->op);
+            return HTP_STATUS_NO_SUPPORT;
+    }
+
+    const uint32_t n_threads  = octx->n_threads;
+    const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
+
+    const size_t src0_row_size = src0->nb[1];
+    const size_t src1_row_size = src1->ne[0] ? src1->nb[1] : src0->nb[1];
+    const size_t dst_row_size  = dst->nb[1];
+
+    // VTCM scratchpads for all tensors
+    // N rows per thread, padded to HVX vector size
+    octx->dst_spad.size  = htp_round_up(dst_row_size, 128) * octx->n_threads;
+    octx->src0_spad.size = htp_round_up(src0_row_size, 128) * octx->n_threads;
+    octx->src1_spad.size = htp_round_up(src1_row_size, 128) * octx->n_threads;
+
+    size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
+
+    if (src1->ne[0]) {
+        FARF(HIGH,
+             "%s: %ux%ux%ux%u x %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n",
+             op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
+             src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size,
+             octx->dst_spad.size);
+    } else {
+        FARF(HIGH, "%s: %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type,
+             src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
+             octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
+    }
+
+    // Make sure the reserved vtcm size is sufficient
+    if (octx->ctx->vtcm_size < spad_size) {
+        FARF(ERROR, "act-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
+             spad_size);
+        return HTP_STATUS_VTCM_TOO_SMALL;
+    }
+
+    octx->src0_spad.data = octx->ctx->vtcm_base;
+    octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
+    octx->dst_spad.data  = octx->src1_spad.data + octx->src1_spad.size;
+
+    if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
+        uint32_t n_jobs = MIN(n_threads, src0_nrows);
+
+        octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
+        worker_pool_run_func(octx->ctx->worker_pool, act_op_func, octx, n_jobs);
+    }
+
+    return err;
+}
+
+int op_activations(struct htp_ops_context * octx) {
+    int err = HTP_STATUS_OK;
+
+    switch (octx->src0.type) {
+        case HTP_TYPE_F32:
+            err = execute_op_activations_fp32(octx);
+            break;
+
+        default:
+            err = HTP_STATUS_NO_SUPPORT;
+            break;
+    }
+
+    return err;
+}
diff --git a/src/ggml-hexagon/htp/binary-ops.c b/src/ggml-hexagon/htp/binary-ops.c
new file mode 100644 (file)
index 0000000..92c0109
--- /dev/null
@@ -0,0 +1,344 @@
+#pragma clang diagnostic ignored "-Wunused-variable"
+#pragma clang diagnostic ignored "-Wunused-function"
+#pragma clang diagnostic ignored "-Wunused-but-set-variable"
+
+#ifdef HTP_DEBUG
+#    define FARF_HIGH 1
+#endif
+
+#include <HAP_farf.h>
+#include <HAP_mem.h>
+#include <HAP_perf.h>
+#include <HAP_ps.h>
+#include <hexagon_protos.h>
+#include <hexagon_types.h>
+#include <math.h>
+#include <qurt_thread.h>
+#include <string.h>
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+#include "htp-ctx.h"
+#include "htp-dma.h"
+#include "htp-msg.h"
+#include "htp-ops.h"
+#include "hvx-utils.h"
+#include "ops-utils.h"
+
+typedef void (*hvx_elemwise_f32_func)(const uint8_t * src0,
+                                      const uint8_t * src1,
+                                      uint8_t *       data_dst,
+                                      const int       num_elems);
+
+static hvx_elemwise_f32_func func_table_HVX[]     = { hvx_mul_f32, hvx_add_f32, hvx_sub_f32 };
+static hvx_elemwise_f32_func func_table_HVX_opt[] = { hvx_mul_f32_opt, hvx_add_f32_opt, hvx_sub_f32_opt };
+
+#define htp_binary_preamble            \
+    const uint32_t ne00 = src0->ne[0]; \
+    const uint32_t ne01 = src0->ne[1]; \
+    const uint32_t ne02 = src0->ne[2]; \
+    const uint32_t ne03 = src0->ne[3]; \
+                                       \
+    const uint32_t ne10 = src1->ne[0]; \
+    const uint32_t ne11 = src1->ne[1]; \
+    const uint32_t ne12 = src1->ne[2]; \
+    const uint32_t ne13 = src1->ne[3]; \
+                                       \
+    const uint32_t ne0 = dst->ne[0];   \
+    const uint32_t ne1 = dst->ne[1];   \
+    const uint32_t ne2 = dst->ne[2];   \
+    const uint32_t ne3 = dst->ne[3];   \
+                                       \
+    const uint32_t nb00 = src0->nb[0]; \
+    const uint32_t nb01 = src0->nb[1]; \
+    const uint32_t nb02 = src0->nb[2]; \
+    const uint32_t nb03 = src0->nb[3]; \
+                                       \
+    const uint32_t nb10 = src1->nb[0]; \
+    const uint32_t nb11 = src1->nb[1]; \
+    const uint32_t nb12 = src1->nb[2]; \
+    const uint32_t nb13 = src1->nb[3]; \
+                                       \
+    const uint32_t nb0 = dst->nb[0];   \
+    const uint32_t nb1 = dst->nb[1];   \
+    const uint32_t nb2 = dst->nb[2];   \
+    const uint32_t nb3 = dst->nb[3];
+
+static void binary_job_f32_per_thread(const struct htp_tensor * src0,
+                                      const struct htp_tensor * src1,
+                                      struct htp_tensor *       dst,
+                                      uint8_t *                 spad_data,
+                                      uint32_t                  nth,
+                                      uint32_t                  ith,
+                                      uint32_t                  src0_nrows_per_thread,
+                                      enum htp_op               op) {
+    htp_binary_preamble;
+
+    const size_t src0_row_size = nb01;
+    const size_t src1_row_size = nb11;
+    const size_t dst_row_size  = nb1;
+
+    const uint32_t src0_nrows = ne01 * ne02 * ne03;  // src0 rows
+    const uint32_t src1_nrows = ne11 * ne12 * ne13;  // src1 rows
+
+    const uint32_t src0_start_row = src0_nrows_per_thread * ith;
+    const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
+
+    // no work for this thread
+    if (src0_start_row >= src0_end_row) {
+        return;
+    }
+
+    uint64_t t1, t2;
+    t1 = HAP_perf_get_qtimer_count();
+
+    int is_aligned = 1;
+    int opt_path   = 0;
+    if ((0 == htp_is_aligned((void *) src0->data, VLEN)) || (0 == htp_is_aligned((void *) src1->data, VLEN)) ||
+        (0 == htp_is_aligned((void *) dst->data, VLEN))) {
+        FARF(HIGH, "binary-f32: unaligned addresses in elementwise op, possibly slower execution\n");
+        is_aligned = 0;
+    }
+    if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
+        opt_path = 1;
+    }
+
+    hvx_elemwise_f32_func func_HVX = (1 == opt_path) ? func_table_HVX_opt[op] : func_table_HVX[op];
+
+    uint8_t * restrict spad_data_th = spad_data + (ith * src0_row_size);
+
+    const uint32_t nr0 = ne00 / ne10;
+
+    const uint8_t * restrict src0_ptr = (const uint8_t *) src0->data + (src0_start_row * src0_row_size);
+    uint8_t * restrict dst_ptr        = (uint8_t *) dst->data + (src0_start_row * dst_row_size);
+
+    const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
+    const uint8_t * restrict src1_ptr  = NULL;
+
+    for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
+        src1_ptr = data_src1 + (ir % src1_nrows) * src1_row_size;
+
+        if (ir + 1 < src0_end_row) {
+            htp_l2fetch(src0_ptr + ne00, 1, src0_row_size, src0_row_size);
+            if (src1_row_size == src0_row_size) {
+                htp_l2fetch(src1_ptr, 1, src1_row_size, src1_row_size);
+            }
+        }
+
+        if (nr0 > 1) {
+            if ((1 == is_aligned) && (nr0 == ne00)) {
+                hvx_bcast_fp32_a(spad_data_th, *(float *) src1_ptr, nr0);
+            } else {
+                for (uint32_t r = 0; r < nr0; r++) {
+                    memcpy(spad_data_th + r * nb11, (const uint8_t *) src1_ptr, nb11);
+                }
+            }
+            func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) spad_data_th, (uint8_t *) dst_ptr, ne00);
+        } else {
+            func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) src1_ptr, (uint8_t *) dst_ptr, ne00);
+        }
+
+        src0_ptr += src0_row_size;
+        dst_ptr += dst_row_size;
+    }
+
+    t2 = HAP_perf_get_qtimer_count();
+
+    FARF(HIGH, "binary-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path,
+         ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3,
+         (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+static void binary_add_id_job_f32_per_thread(const struct htp_tensor * src0,
+                                             const struct htp_tensor * src1,
+                                             const struct htp_tensor * src2,
+                                             struct htp_tensor *       dst,
+                                             uint8_t *                 spad_data,
+                                             uint32_t                  nth,
+                                             uint32_t                  ith,
+                                             uint32_t                  src0_nrows_per_thread,
+                                             hvx_elemwise_f32_func     func_HVX) {
+    htp_binary_preamble;
+
+    const size_t src0_row_size = nb01;
+    const size_t src1_row_size = nb11;
+    const size_t dst_row_size  = nb1;
+
+    const uint32_t ne02_ne01  = ne02 * ne01;
+    const uint32_t src0_nrows = ne01 * ne02 * ne03;  // src0 rows
+
+    const uint32_t src0_start_row = src0_nrows_per_thread * ith;
+    const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
+
+    // no work for this thread
+    if (src0_start_row >= src0_end_row) {
+        return;
+    }
+
+    uint64_t t1, t2;
+    t1 = HAP_perf_get_qtimer_count();
+
+    if ((0 == htp_is_aligned((void *) src0->data, VLEN)) || (0 == htp_is_aligned((void *) src1->data, VLEN)) ||
+        (0 == htp_is_aligned((void *) dst->data, VLEN))) {
+        FARF(HIGH, "add-id-f32: unaligned addresses, possibly slower execution\n");
+    }
+
+    const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
+    const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
+    uint8_t * restrict data_dst        = (uint8_t *) dst->data;
+
+    for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
+        // src0 indices
+        const uint32_t i03 = ir / ne02_ne01;
+        const uint32_t i02 = (ir - i03 * ne02_ne01) / ne01;
+        const uint32_t i01 = (ir - i03 * ne02_ne01 - i02 * ne01);
+
+        // src1 indices
+        const int i11 = *(int32_t *) ((char *) src2->data + i01 * src2->nb[0] + i02 * src2->nb[1]);
+        assert(i11 >= 0 && i11 < ne11);
+
+        float * restrict dst_ptr        = (float *) (data_dst + i03 * nb3 + i02 * nb2 + i01 * nb1);
+        const float * restrict src0_ptr = (const float *) (data_src0 + i03 * nb03 + i02 * nb02 + i01 * nb01);
+        const float * restrict src1_ptr = (const float *) (data_src1 + 0 + 0 + i11 * nb11);
+
+        if (ir + 1 < src0_end_row) {
+            htp_l2fetch(src0_ptr + ne00, 1, src0_row_size, src0_row_size);
+            if (src1_row_size == src0_row_size) {
+                htp_l2fetch(src1_ptr + ne10, 1, src1_row_size, src1_row_size);
+            }
+        }
+
+        const uint32_t nr0 = ne00 / ne10;
+        if (nr0 > 1) {
+            for (uint32_t r = 0; r < nr0; r++) {
+                memcpy(spad_data + r * nb10, (const uint8_t *) src1_ptr, nb10);
+            }
+            func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) spad_data, (uint8_t *) dst_ptr, ne00);
+        } else {
+            func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) src1_ptr, (uint8_t *) dst_ptr, ne00);
+        }
+    }
+
+    t2 = HAP_perf_get_qtimer_count();
+
+    FARF(HIGH, "add-id-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", ith, nth,
+         src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1],
+         src1->ne[2], src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0], dst->ne[1],
+         dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+static void binary_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) {
+    struct htp_ops_context * octx = (struct htp_ops_context *) data;
+
+    switch (octx->op) {
+        case HTP_OP_MUL:
+        case HTP_OP_ADD:
+        case HTP_OP_SUB:
+            binary_job_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->src1_spad.data, n, i,
+                                      octx->src0_nrows_per_thread, octx->op);
+            break;
+
+        case HTP_OP_ADD_ID:
+            binary_add_id_job_f32_per_thread(&octx->src0, &octx->src1, &octx->src2, &octx->dst, octx->src0_spad.data, n,
+                                             i, octx->src0_nrows_per_thread, hvx_add_f32);
+            break;
+
+        default:
+            FARF(ERROR, "Unknown Binary Op %u", octx->op);
+            break;
+    }
+}
+
+static int execute_op_binary_f32(struct htp_ops_context * octx) {
+    int err = HTP_STATUS_OK;
+
+    const struct htp_tensor * src0 = &octx->src0;
+    const struct htp_tensor * src1 = &octx->src1;
+    struct htp_tensor *       dst  = &octx->dst;
+
+    worker_callback_t binary_op_func;
+    const char *      op_type = NULL;
+
+    switch (octx->op) {
+        case HTP_OP_MUL:
+            binary_op_func = binary_job_dispatcher_f32;
+            op_type        = "mul-f32";
+            break;
+
+        case HTP_OP_ADD:
+            binary_op_func = binary_job_dispatcher_f32;
+            op_type        = "add-f32";
+            break;
+
+        case HTP_OP_SUB:
+            binary_op_func = binary_job_dispatcher_f32;
+            op_type        = "sub-f32";
+            break;
+
+        case HTP_OP_ADD_ID:
+            binary_op_func = binary_job_dispatcher_f32;
+            op_type        = "add-id-f32";
+            break;
+
+        default:
+            FARF(ERROR, "Unsupported binary-Op %u\n", octx->op);
+            return HTP_STATUS_NO_SUPPORT;
+    }
+
+    const int      n_threads  = octx->n_threads;
+    const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
+
+    const size_t src0_row_size = src0->nb[1];
+    const size_t src1_row_size = src1->nb[1];
+    const size_t dst_row_size  = dst->nb[1];
+
+    // VTCM scratchpads for all tensors
+    octx->dst_spad.size  = htp_round_up(dst_row_size, 128) * n_threads;
+    octx->src0_spad.size = htp_round_up(src0_row_size, 128) * n_threads;
+    octx->src1_spad.size = htp_round_up(src1_row_size, 128) * n_threads;
+
+    size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
+
+    FARF(HIGH,
+         "%s: (%ux%ux%ux%u) * (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n",
+         op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
+         src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size,
+         octx->dst_spad.size);
+
+    // Make sure the reserved vtcm size is sufficient
+    if (octx->ctx->vtcm_size < spad_size) {
+        FARF(ERROR, "binary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type,
+             octx->ctx->vtcm_size, spad_size);
+        return HTP_STATUS_VTCM_TOO_SMALL;
+    }
+
+    octx->src0_spad.data = octx->ctx->vtcm_base;
+    octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
+    octx->dst_spad.data  = octx->src1_spad.data + octx->src1_spad.size;
+
+    if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
+        uint32_t n_jobs = MIN(n_threads, src0_nrows);
+
+        octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
+
+        worker_pool_run_func(octx->ctx->worker_pool, binary_op_func, octx, n_jobs);
+    }
+
+    return err;
+}
+
+int op_binary(struct htp_ops_context * octx) {
+    int err = HTP_STATUS_OK;
+
+    switch (octx->src0.type) {
+        case HTP_TYPE_F32:
+            err = execute_op_binary_f32(octx);
+            break;
+
+        default:
+            err = HTP_STATUS_NO_SUPPORT;
+            break;
+    }
+
+    return err;
+}
diff --git a/src/ggml-hexagon/htp/cmake-toolchain.cmake b/src/ggml-hexagon/htp/cmake-toolchain.cmake
new file mode 100644 (file)
index 0000000..7fa236e
--- /dev/null
@@ -0,0 +1,157 @@
+if (HEXAGON_TOOLCHAIN_INCLUDED)
+  return()
+endif()
+set(HEXAGON_TOOLCHAIN_INCLUDED true)
+
+#Cross Compiling for Hexagon
+set(HEXAGON TRUE)
+set(CMAKE_SYSTEM_NAME QURT)
+set(CMAKE_SYSTEM_PROCESSOR Hexagon)
+set(CMAKE_SYSTEM_VERSION "1") #${HEXAGON_PLATFORM_LEVEL})
+set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER)
+set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY)
+set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY)
+set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY)
+set(CUSTOM_RUNELF_PATH "")
+
+#To fix backward compatibility with EAI addon.
+if (NOT HEXAGON_SDK_ROOT)
+    set(HEXAGON_SDK_ROOT $ENV{HEXAGON_SDK_ROOT})
+endif()
+
+if (NOT HEXAGON_TOOLS_ROOT)
+    if (DEFINED ENV{HEXAGON_TOOLS_ROOT})
+        set(HEXAGON_TOOLS_ROOT $ENV{HEXAGON_TOOLS_ROOT})
+    endif()
+    if(NOT HEXAGON_TOOLS_ROOT)
+        set(HEXAGON_TOOLS_ROOT $ENV{DEFAULT_HEXAGON_TOOLS_ROOT})
+    endif()
+endif()
+
+file(TO_CMAKE_PATH "${HEXAGON_TOOLS_ROOT}" HEXAGON_TOOLS_ROOT)
+file(TO_CMAKE_PATH "${HEXAGON_SDK_ROOT}"   HEXAGON_SDK_ROOT)
+
+#Get the Binary extension of the Hexagon Toolchain
+if(CMAKE_HOST_SYSTEM_NAME STREQUAL Windows)
+    set(HEXAGON_TOOLCHAIN_SUFFIX .exe)
+endif()
+message(DEBUG "CMAKE_HOST_SYSTEM_NAME:${CMAKE_HOST_SYSTEM_NAME}")
+
+include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_arch.cmake)
+
+set(HEXAGON_TOOLCHAIN ${HEXAGON_TOOLS_ROOT})
+set(HEXAGON_LIB_DIR "${HEXAGON_TOOLCHAIN}/Tools/target/hexagon/lib")
+set(HEXAGON_ISS_DIR ${HEXAGON_TOOLCHAIN}/Tools/lib/iss)
+
+set(CMAKE_TRY_COMPILE_PLATFORM_VARIABLES
+    HEXAGON_SDK_ROOT
+    HEXAGON_TOOLS_ROOT
+)
+
+#QURT Related includes and linker flags
+set(V_ARCH ${HEXAGON_ARCH})
+set(_QURT_INSTALL_DIR "${HEXAGON_SDK_ROOT}/rtos/qurt/ADSP${V_ARCH}MP${V_ARCH_EXTN}")
+set(_QURT_INSTALL_DIR "${HEXAGON_SDK_ROOT}/rtos/qurt/compute${V_ARCH}${V_ARCH_EXTN}")
+
+if( ${TREE} MATCHES PAKMAN )
+    set(_QURT_INSTALL_DIR "${QURT_IMAGE_DIR}/compute${V_ARCH}${V_ARCH_EXTN}")
+endif()
+message(DEBUG "_QURT_INSTALL_DIR:${_QURT_INSTALL_DIR}")
+set(RTOS_DIR ${_QURT_INSTALL_DIR})
+set(QCC_DIR "${HEXAGON_QCC_DIR}/${V_ARCH}/G0")
+set(TARGET_DIR "${HEXAGON_LIB_DIR}/${V_ARCH}/G0")
+
+include_directories(
+    ${_QURT_INSTALL_DIR}/include
+    ${_QURT_INSTALL_DIR}/include/qurt
+    ${_QURT_INSTALL_DIR}/include/posix
+    )
+
+set(QURT_START_LINK_LIBS)
+set(QURT_START_LINK_LIBS
+    "${TARGET_DIR}/init.o"
+    "${RTOS_DIR}/lib/crt1.o"
+    "${RTOS_DIR}/lib/debugmon.o"
+    "${RTOS_DIR}/lib/libqurt.a"
+    "${TARGET_DIR}/libc.a"
+    "${TARGET_DIR}/libqcc.a"
+    "${TARGET_DIR}/libhexagon.a"
+    "${RTOS_DIR}/lib/libqurtcfs.a"
+    "${RTOS_DIR}/lib/libtimer_island.a"
+    "${RTOS_DIR}/lib/libtimer_main.a"
+    "${RTOS_DIR}/lib/libposix.a"
+    )
+STRING(REPLACE ";" " " QURT_START_LINK_LIBS "${QURT_START_LINK_LIBS}")
+
+set(QURT_END_LINK_LIBS
+    ${TARGET_DIR}/fini.o
+    )
+
+#Non QURT related includes and linker flags
+
+set(TARGET_DIR_NOOS "${HEXAGON_TOOLCHAIN}/Tools/target/hexagon/lib/${HEXAGON_ARCH}")
+
+if (NOT NO_WRAP_MEM_API)
+    set(WRAP_MALLOC   -Wl,--wrap=malloc)
+    set(WRAP_CALLOC   -Wl,--wrap=calloc)
+    set(WRAP_FREE     -Wl,--wrap=free)
+    set(WRAP_REALLOC  -Wl,--wrap=realloc)
+    set(WRAP_MEMALIGN -Wl,--wrap=memalign)
+endif()
+
+set(PIC_SHARED_LD_FLAGS
+    -mcpu=${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH}
+    -G0
+    -fpic
+    -Wl,-Bsymbolic
+    -Wl,-L${TARGET_DIR_NOOS}/G0/pic
+    -Wl,-L${HEXAGON_TOOLCHAIN}/Tools/target/hexagon/lib/
+    -Wl,--no-threads ${WRAP_MALLOC} ${WRAP_CALLOC} ${WRAP_FREE} ${WRAP_REALLOC} ${WRAP_MEMALIGN}
+    -shared
+    "-o <TARGET> <SONAME_FLAG><TARGET_SONAME>"
+    "<LINK_FLAGS>"
+    -Wl,--start-group
+    "<OBJECTS>"
+    "<LINK_LIBRARIES>"
+    -Wl,--end-group
+    -lc
+    )
+STRING(REPLACE ";" " " PIC_SHARED_LD_FLAGS "${PIC_SHARED_LD_FLAGS}")
+
+set(HEXAGON_PIC_SHARED_LINK_OPTIONS "${PIC_SHARED_LD_FLAGS}")
+
+#System include paths
+include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/incs)
+include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/incs/stddef)
+include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/ipc/fastrpc/incs)
+
+#LLVM toolchain setup
+#Compiler paths, options and architecture
+set(CMAKE_C_COMPILER ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang${HEXAGON_TOOLCHAIN_SUFFIX})
+set(CMAKE_CXX_COMPILER ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang++${HEXAGON_TOOLCHAIN_SUFFIX})
+set(CMAKE_AR ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-ar${HEXAGON_TOOLCHAIN_SUFFIX})
+set(CMAKE_ASM_COMPILER ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang++${HEXAGON_TOOLCHAIN_SUFFIX})
+set(HEXAGON_LINKER ${CMAKE_C_COMPILER})
+set(CMAKE_PREFIX_PATH ${HEXAGON_TOOLCHAIN}/Tools/target/hexagon)
+
+set(CMAKE_SHARED_LIBRARY_SONAME_C_FLAG   "-Wl,-soname,")
+set(CMAKE_SHARED_LIBRARY_SONAME_CXX_FLAG "-Wl,-soname,")
+
+#Compiler Options
+set(COMMON_FLAGS "-mcpu=hexagon${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -fvectorize -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}")
+
+set(CMAKE_CXX_FLAGS_DEBUG          "${COMMON_FLAGS} -O0 -D_DEBUG -g")
+set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O3 -g")
+set(CMAKE_CXX_FLAGS_RELEASE        "${COMMON_FLAGS} -O3")
+
+set(CMAKE_C_FLAGS_DEBUG            "${COMMON_FLAGS} -O0 -D_DEBUG -g")
+set(CMAKE_C_FLAGS_RELWITHDEBINFO   "${COMMON_FLAGS} -O3 -g")
+set(CMAKE_C_FLAGS_RELEASE          "${COMMON_FLAGS} -O3")
+
+set(CMAKE_ASM_FLAGS_DEBUG          "${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_DEBUG}")
+set(CMAKE_ASM_FLAGS_RELEASE        "${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_RELEASE}")
+set(CMAKE_ASM_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_RELWITHDEBINFO}" )
+
+#Linker Options
+set(CMAKE_C_CREATE_SHARED_LIBRARY   "${HEXAGON_LINKER} ${HEXAGON_PIC_SHARED_LINK_OPTIONS}")
+set(CMAKE_CXX_CREATE_SHARED_LIBRARY "${HEXAGON_LINKER} ${HEXAGON_PIC_SHARED_LINK_OPTIONS}")
diff --git a/src/ggml-hexagon/htp/htp-ctx.h b/src/ggml-hexagon/htp/htp-ctx.h
new file mode 100644 (file)
index 0000000..5c3d217
--- /dev/null
@@ -0,0 +1,40 @@
+#ifndef HTP_CTX_H
+#define HTP_CTX_H
+
+#include "htp-dma.h"
+#include "worker-pool.h"
+
+#include <assert.h>
+#include <dspqueue.h>
+#include <stdatomic.h>
+#include <stdint.h>
+
+#define HTP_MAX_NTHREADS 10
+
+// FIXME: move these into matmul-ops
+#define HTP_SPAD_SRC0_NROWS 16
+#define HTP_SPAD_SRC1_NROWS 16
+#define HTP_SPAD_DST_NROWS  2
+
+// Main context for htp DSP backend
+struct htp_context {
+    dspqueue_t            queue;
+    dma_queue *           dma[HTP_MAX_NTHREADS];
+    worker_pool_context_t worker_pool;
+    uint32_t              n_threads;
+
+    int thread_id;
+    int thread_prio;
+
+    uint8_t * vtcm_base;
+    size_t    vtcm_size;
+    uint32_t  vtcm_rctx;
+
+    atomic_bool vtcm_valid;
+    atomic_bool vtcm_inuse;
+    atomic_bool vtcm_needs_release;
+
+    uint32_t opmask;
+};
+
+#endif /* HTP_CTX_H */
diff --git a/src/ggml-hexagon/htp/htp-dma.c b/src/ggml-hexagon/htp/htp-dma.c
new file mode 100644 (file)
index 0000000..10c54b4
--- /dev/null
@@ -0,0 +1,69 @@
+#include "htp-dma.h"
+
+#include <stdbool.h>
+#include <stdlib.h>
+#include <string.h>
+
+#pragma clang diagnostic ignored "-Wunused-function"
+
+static inline uint32_t pow2_ceil(uint32_t x) {
+    if (x <= 1) {
+        return 1;
+    }
+    int p = 2;
+    x--;
+    while (x >>= 1) {
+        p <<= 1;
+    }
+    return p;
+}
+
+dma_queue * dma_queue_create(size_t capacity) {
+    dma_queue * q = (dma_queue *) memalign(32, sizeof(dma_queue));
+    if (q == NULL) {
+        FARF(ERROR, "%s: failed to allocate DMA queue\n", __FUNCTION__);
+        return NULL;
+    }
+
+    capacity = pow2_ceil(capacity);
+
+    memset(q, 0, sizeof(dma_queue));
+    q->capacity = capacity;
+    q->idx_mask = capacity - 1;
+
+    q->desc = (hexagon_udma_descriptor_type1_t *) memalign(64, capacity * sizeof(hexagon_udma_descriptor_type1_t));
+    memset(q->desc, 0, capacity * sizeof(hexagon_udma_descriptor_type1_t));
+
+    q->dst = (void **) memalign(4, capacity * sizeof(void *));
+    memset(q->dst, 0, capacity * sizeof(void *));
+
+    q->tail = &q->desc[capacity - 1];
+
+    if (!q->desc && !q->dst) {
+        FARF(ERROR, "%s: failed to allocate DMA queue items\n", __FUNCTION__);
+        return NULL;
+    }
+
+    FARF(HIGH, "dma-queue: capacity %u\n", capacity);
+
+    return q;
+}
+
+void dma_queue_delete(dma_queue * q) {
+    if (!q) {
+        return;
+    }
+    free(q->desc);
+    free(q->dst);
+    free(q);
+}
+
+void dma_queue_flush(dma_queue * q) {
+    while (1) {
+        uint32_t s = dmwait() & 0x3;
+        if (s == HEXAGON_UDMA_DM0_STATUS_IDLE) {
+            break;
+        }
+    }
+    q->tail = NULL;
+}
diff --git a/src/ggml-hexagon/htp/htp-dma.h b/src/ggml-hexagon/htp/htp-dma.h
new file mode 100644 (file)
index 0000000..4d0d54c
--- /dev/null
@@ -0,0 +1,119 @@
+#ifndef HTP_DMA_H
+#define HTP_DMA_H
+
+#include <HAP_farf.h>
+#include <hexagon_protos.h>
+#include <hexagon_types.h>
+#include <stdbool.h>
+#include <stdint.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+typedef struct {
+    hexagon_udma_descriptor_type1_t * desc;  // descriptor pointers
+    hexagon_udma_descriptor_type1_t * tail;  // tail pointer
+    void **                           dst;   // dst pointers
+    uint32_t                          push_idx;
+    uint32_t                          pop_idx;
+    uint32_t                          capacity;
+    uint32_t                          idx_mask;
+} dma_queue;
+
+dma_queue * dma_queue_create(size_t capacity);
+void        dma_queue_delete(dma_queue * q);
+void        dma_queue_flush(dma_queue * q);
+
+// TODO: technically we don't need these and could use Q6_dmstart/wait/etc instead
+// but those do not seem to always compiler properly.
+static inline void dmstart(void * next) {
+    asm volatile(" release(%0):at" : : "r"(next));
+    asm volatile(" dmstart(%0)" : : "r"(next));
+}
+
+static inline void dmlink(void * cur, void * next) {
+    asm volatile(" release(%0):at" : : "r"(next));
+    asm volatile(" dmlink(%0, %1)" : : "r"(cur), "r"(next));
+}
+
+static inline unsigned int dmpoll(void) {
+    unsigned int ret = 0;
+    asm volatile(" %0 = dmpoll" : "=r"(ret) : : "memory");
+    return ret;
+}
+
+static inline unsigned int dmwait(void) {
+    unsigned int ret = 0;
+    asm volatile(" %0 = dmwait" : "=r"(ret) : : "memory");
+    return ret;
+}
+
+static inline bool dma_queue_push(dma_queue *  q,
+                                  void *       dst,
+                                  const void * src,
+                                  size_t       dst_row_size,
+                                  size_t       src_row_size,
+                                  size_t       nrows) {
+    if (((q->push_idx + 1) & q->idx_mask) == q->pop_idx) {
+        return false;
+    }
+
+    hexagon_udma_descriptor_type1_t * desc = &q->desc[q->push_idx];
+
+    desc->next           = NULL;
+    desc->length         = 0;
+    desc->desctype       = HEXAGON_UDMA_DESC_DESCTYPE_TYPE1;
+    desc->dstbypass      = 1;
+    desc->srcbypass      = 1;
+    desc->order          = 0;
+    desc->dstate         = HEXAGON_UDMA_DESC_DSTATE_INCOMPLETE;
+    desc->src            = (void *) src;
+    desc->dst            = (void *) dst;
+    desc->allocation     = 0;
+    desc->padding        = 0;
+    desc->roiwidth       = src_row_size;
+    desc->roiheight      = nrows;
+    desc->srcstride      = src_row_size;
+    desc->dststride      = dst_row_size;
+    desc->srcwidthoffset = 0;
+    desc->dstwidthoffset = 0;
+
+    q->dst[q->push_idx] = dst;
+
+    dmlink(q->tail, desc);
+    q->tail = desc;
+
+    // FARF(ERROR, "dma-push: i %u len %u dst %p src %p\n", q->push_idx, len, dst, src);
+    q->push_idx = (q->push_idx + 1) & q->idx_mask;
+    return true;
+}
+
+static inline uint8_t * dma_queue_pop(dma_queue * q) {
+    if (q->push_idx == q->pop_idx) {
+        return NULL;
+    }
+
+    hexagon_udma_descriptor_type1_t * desc = &q->desc[q->pop_idx];
+
+    // Wait for desc to complete
+    while (1) {
+        dmpoll();
+        if (desc->dstate == HEXAGON_UDMA_DESC_DSTATE_COMPLETE) {
+            break;
+        }
+        // FARF(ERROR, "dma-pop: waiting for DMA : %u\n", q->pop_idx);
+    }
+
+    uint8_t * dst = (uint8_t *) q->dst[q->pop_idx];
+
+    // FARF(ERROR, "dma-pop: i %u dst %p\n", q->pop_idx, dst);
+    q->pop_idx = (q->pop_idx + 1) & q->idx_mask;
+    return dst;
+}
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif
+
+#endif /* HTP_DMA_H */
diff --git a/src/ggml-hexagon/htp/htp-msg.h b/src/ggml-hexagon/htp/htp-msg.h
new file mode 100644 (file)
index 0000000..f23d578
--- /dev/null
@@ -0,0 +1,156 @@
+#ifndef HTP_MSG_H
+#define HTP_MSG_H
+
+#include <assert.h>
+
+// ggml-common.h must be included prio to this header
+
+// Mask to enable various stages of the Ops.
+// Used for debugging and profiling.
+enum {
+    HTP_OPMASK_QUEUE    = (1 << 0),  // Enable Queueing (ie calls into the DSP)
+    HTP_OPMASK_QUANTIZE = (1 << 1),  // Enable Quantize
+    HTP_OPMASK_COMPUTE  = (1 << 2),  // Enable Compute
+};
+
+// Op flags
+enum {
+    HTP_OPFLAGS_SKIP_QUANTIZE = (1 << 0),  // Skip dynamic quantization (reuse quantized tensors)
+    HTP_OPFLAGS_SKIP_COMPUTE  = (1 << 1),  // Skip actual computation (used for profiling)
+    HTP_OPFLAGS_EARLY_WAKEUP  = (1 << 2)   // Send early wakeup notification
+};
+
+enum htp_status {
+    HTP_STATUS_OK             = 1,
+    HTP_STATUS_INTERNAL_ERR   = 2,
+    HTP_STATUS_NO_SUPPORT     = 3,
+    HTP_STATUS_INVAL_PARAMS   = 4,
+    HTP_STATUS_VTCM_TOO_SMALL = 5,
+};
+
+// The values must match the ggml_type.
+// Duplicated here because we can't include full ggml.h in the htp build.
+// We have some static_asserts in the cpp code to ensure things are in sync.
+enum htp_data_type {
+    HTP_TYPE_F32   = 0,
+    HTP_TYPE_F16   = 1,
+    HTP_TYPE_Q4_0  = 2,
+    HTP_TYPE_Q8_0  = 8,
+    HTP_TYPE_MXFP4 = 39,
+    HTP_TYPE_COUNT
+};
+
+// These values are manually translated over to HTP
+// !!!! DO NOT ALTER THE ORDER OF THE FIRST FOUR ENUMS !!!!
+enum htp_op {
+    HTP_OP_MUL            = 0,
+    HTP_OP_ADD            = 1,
+    HTP_OP_SUB            = 2,
+    HTP_OP_DIV            = 3,
+    HTP_OP_MUL_MAT        = 4,
+    HTP_OP_MUL_MAT_ID     = 5,
+    HTP_OP_RMS_NORM       = 6,
+    HTP_OP_UNARY_SILU     = 7,
+    HTP_OP_GLU_SWIGLU     = 8,
+    HTP_OP_GLU_SWIGLU_OAI = 9,
+    HTP_OP_SOFTMAX        = 10,
+    HTP_OP_ADD_ID         = 11,
+    HTP_OP_ROPE           = 12,
+    INVALID
+};
+
+static inline size_t htp_type_block_size(uint32_t t) {
+    switch (t) {
+        case HTP_TYPE_F32:
+            return 1;
+        case HTP_TYPE_F16:
+            return 1;
+        case HTP_TYPE_Q4_0:
+            return QK4_0;
+        case HTP_TYPE_Q8_0:
+            return QK8_0;
+        case HTP_TYPE_MXFP4:
+            return QK_MXFP4;
+        default:
+            assert(0 && "unsupported HTP data type");
+    }
+    return 0;
+}
+
+static inline size_t htp_type_nbytes(uint32_t t) {
+    switch (t) {
+        case HTP_TYPE_F32:
+            return 4;
+        case HTP_TYPE_F16:
+            return 2;
+        case HTP_TYPE_Q4_0:
+            return sizeof(block_q4_0);
+        case HTP_TYPE_Q8_0:
+            return sizeof(block_q8_0);
+        case HTP_TYPE_MXFP4:
+            return sizeof(block_mxfp4);
+        default:
+            assert(0 && "unsupported HTP data type");
+    }
+    return 0;
+}
+
+static const char * htp_type_name(uint32_t t) {
+    switch (t) {
+        case HTP_TYPE_F32:
+            return "fp32";
+        case HTP_TYPE_F16:
+            return "fp16";
+        case HTP_TYPE_Q4_0:
+            return "q4_0";
+        case HTP_TYPE_Q8_0:
+            return "q8_0";
+        case HTP_TYPE_MXFP4:
+            return "mxfp4";
+    }
+    return 0;
+}
+
+// Internal types
+#define QK_Q4_0x4x2  256  // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128)
+#define QK_Q8_0x4x2  256  // 4x Q8_0 blocks concat with next 4x Q8_0 blocks
+#define QK_MXFP4x4x2 256  // 4x MXFP4 blocks concat with next 4x MXFP4 blocks
+
+#define HTP_MAX_DIMS 4
+
+struct htp_tensor {
+    uint32_t data;              // Buffer offset in the messages, and data pointer on the NSP
+    uint32_t type;              // Data type
+    uint32_t ne[HTP_MAX_DIMS];  // Number of elements
+    uint32_t nb[HTP_MAX_DIMS];  // Stride in bytes (see ggml.h ggml_tensor)
+};
+
+#define HTP_MAX_OP_PARAMS 64
+
+struct htp_general_req {
+    uint32_t op;  // GGML/HTP Op
+    int32_t  op_params[HTP_MAX_OP_PARAMS / sizeof(int32_t)];
+    // Params for the op, e.g. epsilon of RMS norm
+    uint32_t flags;          // Request flags
+
+    struct htp_tensor src0;  // Input0 tensor
+    struct htp_tensor src1;  // Input1 tensor
+    struct htp_tensor src2;  // Input2 tensor
+    struct htp_tensor dst;   // Output tensor
+
+    // should be multiple of 64 bytes (cacheline)
+};
+
+struct htp_general_rsp {
+    uint32_t op;           // GGML/HTP Op
+    uint32_t status;       // HTP_STATUS_...
+    uint32_t prof_usecs;   // Number of usec per request
+    uint32_t prof_cycles;  // Number of cycles per request
+    uint32_t prof_pkts;    // Number of instruction packets per request
+    uint8_t  unused[44];   // Pad to 64 bytes
+};
+
+#define HTP_MAX_MESSAGE_SIZE   sizeof(struct htp_general_req)
+#define HTP_MAX_PACKET_BUFFERS 4
+
+#endif /* HTP_MSG_H */
diff --git a/src/ggml-hexagon/htp/htp-ops.h b/src/ggml-hexagon/htp/htp-ops.h
new file mode 100644 (file)
index 0000000..4572319
--- /dev/null
@@ -0,0 +1,53 @@
+#ifndef HTP_OPS_H
+#define HTP_OPS_H
+
+#include "htp-ctx.h"
+#include "htp-msg.h"
+#include "worker-pool.h"
+
+#include <assert.h>
+#include <stdint.h>
+
+// ggml-common.h must be included prior to this header
+
+struct htp_spad {
+    uint8_t * data;
+    size_t    size;
+    size_t    size_per_thread;
+};
+
+struct htp_ops_context {
+    struct htp_context * ctx;
+
+    enum htp_op op;
+    int32_t     op_params[HTP_MAX_OP_PARAMS / sizeof(int32_t)];
+
+    struct htp_tensor src0;
+    struct htp_tensor src1;
+    struct htp_tensor src2;
+    struct htp_tensor dst;
+
+    struct htp_spad src0_spad;
+    struct htp_spad src1_spad;
+    struct htp_spad src2_spad;
+    struct htp_spad dst_spad;
+
+    worker_pool_context_t * wpool;      // worker pool
+    uint32_t                n_threads;  // num threads
+
+    uint32_t src0_nrows_per_thread;
+    uint32_t src1_nrows_per_thread;
+
+    uint32_t flags;
+};
+
+int op_matmul(struct htp_ops_context * octx);
+int op_matmul_id(struct htp_ops_context * octx);
+int op_binary(struct htp_ops_context * octx);
+int op_unary(struct htp_ops_context * octx);
+int op_activations(struct htp_ops_context * octx);
+int op_softmax(struct htp_ops_context * octx);
+int op_add_id(struct htp_ops_context * octx);
+int op_rope(struct htp_ops_context * octx);
+
+#endif /* HTP_OPS_H */
diff --git a/src/ggml-hexagon/htp/htp_iface.idl b/src/ggml-hexagon/htp/htp_iface.idl
new file mode 100644 (file)
index 0000000..9ebd937
--- /dev/null
@@ -0,0 +1,16 @@
+// FastRPC IDL interface for GGML HTP
+
+#ifndef HTP_IDL
+#define HTP_IDL
+
+#include "AEEStdDef.idl"
+#include "remote.idl"
+
+interface htp_iface : remote_handle64 {
+    AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx);
+    AEEResult stop();
+    AEEResult enable_etm();
+    AEEResult disable_etm();
+};
+
+#endif /* HTP_IDL */
diff --git a/src/ggml-hexagon/htp/hvx-exp.c b/src/ggml-hexagon/htp/hvx-exp.c
new file mode 100644 (file)
index 0000000..19f6795
--- /dev/null
@@ -0,0 +1,80 @@
+#pragma clang diagnostic ignored "-Wunused-variable"
+#pragma clang diagnostic ignored "-Wunused-function"
+#pragma clang diagnostic ignored "-Wunused-but-set-variable"
+
+#include <hexagon_protos.h>
+#include <hexagon_types.h>
+#include <math.h>
+#include <string.h>
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+#include "htp-ctx.h"
+#include "htp-dma.h"
+#include "htp-msg.h"
+#include "htp-ops.h"
+#include "hvx-utils.h"
+#include "ops-utils.h"
+
+void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate) {
+    int left_over       = num_elems & (VLEN_FP32 - 1);
+    int num_elems_whole = num_elems - left_over;
+
+    int unaligned_addr = 0;
+    int unaligned_loop = 0;
+    if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
+        FARF(HIGH, "hvx_exp_f32: unaligned address in hvx op, possibly slower execution\n");
+        unaligned_addr = 1;
+    }
+    // assert((0 == unaligned_addr) || (0 == num_elems_whole));
+    if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
+        unaligned_loop = 1;
+        FARF(HIGH, "hvx_exp_f32: unaligned loop in hvx op, possibly slower execution\n");
+    }
+
+    HVX_Vector vec_out = Q6_V_vzero();
+
+    if (0 == unaligned_loop) {
+        HVX_Vector * p_vec_in1 = (HVX_Vector *) src;
+        HVX_Vector * p_vec_out = (HVX_Vector *) dst;
+
+        #pragma unroll(4)
+        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
+            if (true == negate) {
+                HVX_Vector neg_vec_in = hvx_vec_neg_fp32(*p_vec_in1++);
+                *p_vec_out++          = hvx_vec_exp_fp32(neg_vec_in);
+            } else {
+                *p_vec_out++ = hvx_vec_exp_fp32(*p_vec_in1++);
+            }
+        }
+    } else {
+        #pragma unroll(4)
+        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
+            HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
+
+            if (true == negate) {
+                HVX_Vector neg_vec_in                    = hvx_vec_neg_fp32(in);
+                *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32(neg_vec_in);
+            } else {
+                *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32(in);
+            }
+        }
+    }
+
+    if (left_over > 0) {
+        const float * srcf = (float *) src + num_elems_whole;
+        float *       dstf = (float *) dst + num_elems_whole;
+
+        HVX_Vector in = *(HVX_UVector *) srcf;
+
+        if (true == negate) {
+            HVX_Vector neg_vec_in = hvx_vec_neg_fp32(in);
+
+            vec_out = hvx_vec_exp_fp32(neg_vec_in);
+        } else {
+            vec_out = hvx_vec_exp_fp32(in);
+        }
+
+        hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, vec_out);
+    }
+}
diff --git a/src/ggml-hexagon/htp/hvx-inverse.c b/src/ggml-hexagon/htp/hvx-inverse.c
new file mode 100644 (file)
index 0000000..4cf588a
--- /dev/null
@@ -0,0 +1,60 @@
+#pragma clang diagnostic ignored "-Wunused-variable"
+#pragma clang diagnostic ignored "-Wunused-function"
+#pragma clang diagnostic ignored "-Wunused-but-set-variable"
+
+#include <hexagon_protos.h>
+#include <hexagon_types.h>
+#include <math.h>
+#include <string.h>
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+#include "htp-ctx.h"
+#include "htp-dma.h"
+#include "htp-msg.h"
+#include "htp-ops.h"
+#include "hvx-utils.h"
+#include "ops-utils.h"
+
+void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems) {
+    int left_over       = num_elems & (VLEN_FP32 - 1);
+    int num_elems_whole = num_elems - left_over;
+
+    int unaligned_addr = 0;
+    int unaligned_loop = 0;
+    if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
+        FARF(HIGH, "hvx_inverse_f32: unaligned address in hvx op, possibly slower execution\n");
+        unaligned_addr = 1;
+    }
+    // assert((0 == unaligned_addr) || (0 == num_elems_whole));
+    if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
+        unaligned_loop = 1;
+        FARF(HIGH, "hvx_inverse_f32: unaligned loop in hvx op, possibly slower execution\n");
+    }
+
+    if (0 == unaligned_loop) {
+        HVX_Vector * p_vec_in  = (HVX_Vector *) src;
+        HVX_Vector * p_vec_out = (HVX_Vector *) dst;
+
+        #pragma unroll(4)
+        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
+            *p_vec_out++ = hvx_vec_inverse_fp32(*p_vec_in++);
+        }
+    } else {
+        #pragma unroll(4)
+        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
+            HVX_Vector in                            = *(HVX_UVector *) (src + i * SIZEOF_FP32);
+            *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_inverse_fp32(in);
+        }
+    }
+
+    if (left_over > 0) {
+        const float * srcf = (float *) src + num_elems_whole;
+        float *       dstf = (float *) dst + num_elems_whole;
+
+        HVX_Vector in  = *(HVX_UVector *) srcf;
+        HVX_Vector out = hvx_vec_inverse_fp32(in);
+
+        hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, out);
+    }
+}
diff --git a/src/ggml-hexagon/htp/hvx-sigmoid.c b/src/ggml-hexagon/htp/hvx-sigmoid.c
new file mode 100644 (file)
index 0000000..15ac646
--- /dev/null
@@ -0,0 +1,49 @@
+#pragma clang diagnostic ignored "-Wunused-variable"
+#pragma clang diagnostic ignored "-Wunused-function"
+#pragma clang diagnostic ignored "-Wunused-but-set-variable"
+
+#include <hexagon_protos.h>
+#include <hexagon_types.h>
+#include <math.h>
+#include <string.h>
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+#include "htp-ctx.h"
+#include "htp-dma.h"
+#include "htp-msg.h"
+#include "htp-ops.h"
+#include "hvx-utils.h"
+#include "ops-utils.h"
+
+#if 0
+// Reference algo used in hvx-utils
+static void fast_sigmoid_f32(const float*  restrict src, float* restrict dst, const int num_elems)
+{
+    const float c1 = 0.03138777;
+    const float c2 = 0.276281267;
+    const float c_log2f = 1.442695022;
+
+    int32_t store_ints[32];
+    float store_floats[3][32];
+
+    for (int i = 0; i < num_elems; i++)
+    {
+        float v = src0[i];
+
+        v *= c_log2f*0.5;
+        int intPart = (int)v;
+        float x = (v - intPart);
+        float xx = x * x;
+        float v1 = c_log2f + c2 * xx;
+        float v2 = x + xx * c1 * x;
+        float v3 = (v2 + v1);
+        *((int*)&v3) += intPart << 24;
+        float v4 = v2 - v1;
+        float v5 = v3 - v4;
+        float res = v3 / v5;
+
+        dst[i] = res;
+    }
+}
+#endif
diff --git a/src/ggml-hexagon/htp/hvx-utils.c b/src/ggml-hexagon/htp/hvx-utils.c
new file mode 100644 (file)
index 0000000..d3599bc
--- /dev/null
@@ -0,0 +1,947 @@
+#pragma clang diagnostic ignored "-Wunused-variable"
+#pragma clang diagnostic ignored "-Wunused-function"
+#pragma clang diagnostic ignored "-Wunused-but-set-variable"
+
+#ifdef HTP_DEBUG
+#    define FARF_HIGH 1
+#endif
+
+#include <HAP_farf.h>
+#include <HAP_mem.h>
+#include <HAP_perf.h>
+#include <HAP_ps.h>
+#include <hexagon_protos.h>
+#include <hexagon_types.h>
+#include <math.h>
+#include <string.h>
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+#include "hvx-utils.h"
+
+#define htp_binary_ops_preamble                                                                                \
+    int step_of_4 = num_elems >> 7;                                                                            \
+    int step_of_2 = (num_elems - step_of_4 * VLEN_FP32 * 4) >> 6;                                              \
+    int step_of_1 = (num_elems - step_of_4 * VLEN_FP32 * 4 - step_of_2 * VLEN_FP32 * 2) >> 5;                  \
+    int remaining = num_elems - step_of_4 * VLEN_FP32 * 4 - step_of_2 * VLEN_FP32 * 2 - step_of_1 * VLEN_FP32; \
+                                                                                                               \
+    const uint8_t * restrict src0_curr = src0;                                                                 \
+    const uint8_t * restrict src1_curr = src1;                                                                 \
+    uint8_t * restrict dst_curr        = dst;
+
+void hvx_mul_f32(const uint8_t * restrict src0,
+                 const uint8_t * restrict src1,
+                 uint8_t * restrict dst,
+                 const int num_elems) {
+    int left_over       = num_elems & (VLEN_FP32 - 1);
+    int num_elems_whole = num_elems - left_over;
+
+    int unaligned_addr = 0;
+    int unaligned_loop = 0;
+    if ((0 == htp_is_aligned((void *) src0, VLEN)) || (0 == htp_is_aligned((void *) src1, VLEN)) ||
+        (0 == htp_is_aligned((void *) dst, VLEN))) {
+        FARF(HIGH, "hvx_mul_f32: unaligned address in hvx op, possibly slower execution\n");
+        unaligned_addr = 1;
+    }
+
+    if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
+        unaligned_loop = 1;
+        FARF(HIGH, "hvx_mul_f32: unaligned loop in hvx op, possibly slower execution\n");
+    }
+
+    if (0 == unaligned_loop) {
+        HVX_Vector * restrict vec_in1 = (HVX_Vector *) src0;
+        HVX_Vector * restrict vec_in2 = (HVX_Vector *) src1;
+        HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
+
+        #pragma unroll(4)
+        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
+            HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(*vec_in1++, *vec_in2++);
+            *vec_out++   = Q6_Vsf_equals_Vqf32(v);
+        }
+    } else {
+        #pragma unroll(4)
+        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
+            HVX_Vector in1 = *(HVX_UVector *) (src0 + i * SIZEOF_FP32);
+            HVX_Vector in2 = *(HVX_UVector *) (src1 + i * SIZEOF_FP32);
+
+            HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in1, in2);
+
+            *(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out);
+        }
+    }
+
+    if (left_over > 0) {
+        const float * src0f = (const float *) src0 + num_elems_whole;
+        const float * src1f = (const float *) src1 + num_elems_whole;
+        float *       dstf  = (float *) dst + num_elems_whole;
+
+        HVX_Vector in1 = *(HVX_UVector *) src0f;
+        HVX_Vector in2 = *(HVX_UVector *) src1f;
+
+        HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in1, in2);
+        hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out));
+    }
+}
+
+void hvx_mul_f32_opt(const uint8_t * restrict src0,
+                     const uint8_t * restrict src1,
+                     uint8_t * restrict dst,
+                     const int num_elems) {
+    htp_binary_ops_preamble;
+
+    for (int i = 0; i < step_of_4; i++) {
+        HVX_Vector v1a = *(HVX_Vector *) src0_curr;
+
+        HVX_Vector v1b = *(HVX_Vector *) src1_curr;
+
+        HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN);
+
+        HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(v1a, v1b);
+
+        HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN);
+
+        HVX_Vector v3a = *(HVX_Vector *) (src0_curr + 2 * VLEN);
+
+        HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v2a, v2b);
+
+        *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1);
+
+        HVX_Vector v3b = *(HVX_Vector *) (src1_curr + 2 * VLEN);
+
+        HVX_Vector v4a = *(HVX_Vector *) (src0_curr + 3 * VLEN);
+
+        src0_curr += 4 * VLEN;
+
+        HVX_Vector v3 = Q6_Vqf32_vmpy_VsfVsf(v3a, v3b);
+
+        *(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2);
+
+        HVX_Vector v4b = *(HVX_Vector *) (src1_curr + 3 * VLEN);
+
+        *(HVX_Vector *) (dst_curr + 2 * VLEN) = Q6_Vsf_equals_Vqf32(v3);
+
+        HVX_Vector v4 = Q6_Vqf32_vmpy_VsfVsf(v4a, v4b);
+
+        src1_curr += 4 * VLEN;
+
+        *(HVX_Vector *) (dst_curr + 3 * VLEN) = Q6_Vsf_equals_Vqf32(v4);
+
+        dst_curr += 4 * VLEN;
+    }
+
+    for (int i = 0; i < step_of_2; i++) {
+        HVX_Vector v1a = *(HVX_Vector *) src0_curr;
+
+        HVX_Vector v1b = *(HVX_Vector *) src1_curr;
+
+        HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN);
+
+        HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(v1a, v1b);
+
+        HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN);
+
+        *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1);
+
+        src0_curr += 2 * VLEN;
+
+        HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v2a, v2b);
+
+        src1_curr += 2 * VLEN;
+
+        *(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2);
+
+        dst_curr += 2 * VLEN;
+    }
+
+    for (int i = 0; i < step_of_1; i++) {
+        HVX_Vector va = *(HVX_Vector *) src0_curr;
+
+        src0_curr += VLEN;
+
+        HVX_Vector vb = *(HVX_Vector *) src1_curr;
+
+        src1_curr += VLEN;
+
+        HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(va, vb);
+
+        *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v);
+
+        dst_curr += VLEN;
+    }
+
+    if (remaining > 0) {
+        HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(*(HVX_Vector *) src0_curr, *(HVX_Vector *) src1_curr);
+        hvx_vec_store_u((void *) dst_curr, remaining * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(v));
+    }
+}
+
+void hvx_mul_mul_f32_opt(const uint8_t * restrict src0,
+                         const uint8_t * restrict src1,
+                         const uint8_t * restrict src2,
+                         uint8_t * restrict dst,
+                         const int num_elems) {
+    const uint8_t * restrict src0_curr = src0;
+    const uint8_t * restrict src1_curr = src1;
+    const uint8_t * restrict src2_curr = src2;
+    uint8_t * restrict dst_curr        = dst;
+
+    int step_of_2 = num_elems >> 6;
+    int step_of_1 = (num_elems - step_of_2 * VLEN_FP32 * 2) >> 5;
+    int remaining = num_elems - step_of_2 * VLEN_FP32 * 2 - step_of_1 * VLEN_FP32;
+
+    for (int i = 0; i < step_of_2; i++) {
+        HVX_Vector v1a = *(HVX_Vector *) src0_curr;
+        HVX_Vector v1b = *(HVX_Vector *) src1_curr;
+        HVX_Vector v1c = *(HVX_Vector *) src2_curr;
+
+        HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN);
+
+        HVX_Vector v1_ = Q6_Vqf32_vmpy_VsfVsf(v1a, v1b);
+        HVX_Vector v1  = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v1_), v1c);
+
+        HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN);
+
+        *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1);
+
+        HVX_Vector v2c = *(HVX_Vector *) (src2_curr + VLEN);
+
+        src0_curr += 2 * VLEN;
+
+        HVX_Vector v2_ = Q6_Vqf32_vmpy_VsfVsf(v2a, v2b);
+        HVX_Vector v2  = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v2_), v2c);
+
+        src1_curr += 2 * VLEN;
+        src2_curr += 2 * VLEN;
+
+        *(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2);
+
+        dst_curr += 2 * VLEN;
+    }
+    for (int i = 0; i < step_of_1; i++) {
+        HVX_Vector va = *(HVX_Vector *) src0_curr;
+        src0_curr += VLEN;
+
+        HVX_Vector vb = *(HVX_Vector *) src1_curr;
+        src1_curr += VLEN;
+
+        HVX_Vector vc = *(HVX_Vector *) src2_curr;
+        src2_curr += VLEN;
+
+        HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(va, vb);
+        HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v1), vc);
+
+        *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v2);
+        dst_curr += VLEN;
+    }
+    if (remaining > 0) {
+        HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(*(HVX_Vector *) src0_curr, *(HVX_Vector *) src1_curr);
+        HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v1), *(HVX_Vector *) src2_curr);
+        hvx_vec_store_u((void *) dst_curr, remaining * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(v2));
+    }
+}
+
+void hvx_add_f32(const uint8_t * restrict src0,
+                 const uint8_t * restrict src1,
+                 uint8_t * restrict dst,
+                 const int num_elems) {
+    int left_over       = num_elems & (VLEN_FP32 - 1);
+    int num_elems_whole = num_elems - left_over;
+
+    int unaligned_addr = 0;
+    int unaligned_loop = 0;
+    if ((0 == htp_is_aligned((void *) src0, VLEN)) || (0 == htp_is_aligned((void *) src1, VLEN)) ||
+        (0 == htp_is_aligned((void *) dst, VLEN))) {
+        FARF(HIGH, "hvx_add_f32: unaligned address in hvx op, possibly slower execution\n");
+        unaligned_addr = 1;
+    }
+
+    if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
+        unaligned_loop = 1;
+        FARF(HIGH, "hvx_add_f32: unaligned loop in hvx op, possibly slower execution\n");
+    }
+
+    if (0 == unaligned_loop) {
+        HVX_Vector * restrict vec_in1 = (HVX_Vector *) src0;
+        HVX_Vector * restrict vec_in2 = (HVX_Vector *) src1;
+        HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
+
+        #pragma unroll(4)
+        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
+            HVX_Vector v = Q6_Vqf32_vadd_VsfVsf(*vec_in1++, *vec_in2++);
+            *vec_out++   = Q6_Vsf_equals_Vqf32(v);
+        }
+    } else {
+        #pragma unroll(4)
+        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
+            HVX_Vector in1 = *(HVX_UVector *) (src0 + i * SIZEOF_FP32);
+            HVX_Vector in2 = *(HVX_UVector *) (src1 + i * SIZEOF_FP32);
+
+            HVX_Vector out = Q6_Vqf32_vadd_VsfVsf(in1, in2);
+
+            *(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out);
+        }
+    }
+
+    if (left_over > 0) {
+        const float * src0f = (const float *) src0 + num_elems_whole;
+        const float * src1f = (const float *) src1 + num_elems_whole;
+        float *       dstf  = (float *) dst + num_elems_whole;
+
+        HVX_Vector in1 = *(HVX_UVector *) src0f;
+        HVX_Vector in2 = *(HVX_UVector *) src1f;
+
+        HVX_Vector out = Q6_Vqf32_vadd_VsfVsf(in1, in2);
+        hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out));
+    }
+}
+
+void hvx_add_f32_opt(const uint8_t * restrict src0,
+                     const uint8_t * restrict src1,
+                     uint8_t * restrict dst,
+                     const int num_elems) {
+    htp_binary_ops_preamble;
+
+    for (int i = 0; i < step_of_4; i++) {
+        HVX_Vector v1a = *(HVX_Vector *) src0_curr;
+
+        HVX_Vector v1b = *(HVX_Vector *) src1_curr;
+
+        HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN);
+
+        HVX_Vector v1 = Q6_Vqf32_vadd_VsfVsf(v1a, v1b);
+
+        HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN);
+
+        HVX_Vector v3a = *(HVX_Vector *) (src0_curr + 2 * VLEN);
+
+        HVX_Vector v2 = Q6_Vqf32_vadd_VsfVsf(v2a, v2b);
+
+        *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1);
+
+        HVX_Vector v3b = *(HVX_Vector *) (src1_curr + 2 * VLEN);
+
+        HVX_Vector v4a = *(HVX_Vector *) (src0_curr + 3 * VLEN);
+
+        src0_curr += 4 * VLEN;
+
+        HVX_Vector v3 = Q6_Vqf32_vadd_VsfVsf(v3a, v3b);
+
+        *(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2);
+
+        HVX_Vector v4b = *(HVX_Vector *) (src1_curr + 3 * VLEN);
+
+        *(HVX_Vector *) (dst_curr + 2 * VLEN) = Q6_Vsf_equals_Vqf32(v3);
+
+        HVX_Vector v4 = Q6_Vqf32_vadd_VsfVsf(v4a, v4b);
+
+        src1_curr += 4 * VLEN;
+
+        *(HVX_Vector *) (dst_curr + 3 * VLEN) = Q6_Vsf_equals_Vqf32(v4);
+
+        dst_curr += 4 * VLEN;
+    }
+    for (int i = 0; i < step_of_2; i++) {
+        HVX_Vector v1a = *(HVX_Vector *) src0_curr;
+
+        HVX_Vector v1b = *(HVX_Vector *) src1_curr;
+
+        HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN);
+
+        HVX_Vector v1 = Q6_Vqf32_vadd_VsfVsf(v1a, v1b);
+
+        HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN);
+
+        *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1);
+
+        src0_curr += 2 * VLEN;
+
+        HVX_Vector v2 = Q6_Vqf32_vadd_VsfVsf(v2a, v2b);
+
+        src1_curr += 2 * VLEN;
+
+        *(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2);
+
+        dst_curr += 2 * VLEN;
+    }
+    for (int i = 0; i < step_of_1; i++) {
+        HVX_Vector va = *(HVX_Vector *) src0_curr;
+
+        src0_curr += VLEN;
+
+        HVX_Vector vb = *(HVX_Vector *) src1_curr;
+
+        src1_curr += VLEN;
+
+        HVX_Vector v = Q6_Vqf32_vadd_VsfVsf(va, vb);
+
+        *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v);
+
+        dst_curr += VLEN;
+    }
+    if (remaining > 0) {
+        HVX_Vector v = Q6_Vqf32_vadd_VsfVsf(*(HVX_Vector *) src0_curr, *(HVX_Vector *) src1_curr);
+        hvx_vec_store_u((void *) dst_curr, remaining * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(v));
+    }
+}
+
+void hvx_add_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems) {
+    size_t left_over       = num_elems & (VLEN_FP32 - 1);
+    size_t num_elems_whole = num_elems - left_over;
+
+    int unaligned_addr = 0;
+    int unaligned_loop = 0;
+    if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
+        FARF(HIGH, "hvx_add_scalar_f32: unaligned address in hvx op, possibly slower execution\n");
+        unaligned_addr = 1;
+    }
+
+    if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
+        unaligned_loop = 1;
+        FARF(HIGH, "hvx_add_scalar_f32: unaligned loop in hvx op, possibly slower execution\n");
+    }
+
+    HVX_Vector val_vec = hvx_vec_splat_fp32(val);
+
+    if (0 == unaligned_loop) {
+        HVX_Vector * restrict vec_in1 = (HVX_Vector *) src;
+        HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
+
+        #pragma unroll(4)
+        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
+            HVX_Vector v = Q6_Vqf32_vadd_VsfVsf(*vec_in1++, val_vec);
+            *vec_out++   = Q6_Vsf_equals_Vqf32(v);
+        }
+    } else {
+        #pragma unroll(4)
+        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
+            HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
+
+            HVX_Vector out = Q6_Vqf32_vadd_VsfVsf(in, val_vec);
+
+            *(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out);
+        }
+    }
+
+    if (left_over > 0) {
+        const float * srcf = (const float *) src + num_elems_whole;
+        float *       dstf = (float *) dst + num_elems_whole;
+
+        HVX_Vector in = *(HVX_UVector *) srcf;
+
+        HVX_Vector out = Q6_Vqf32_vadd_VsfVsf(in, val_vec);
+        hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out));
+    }
+}
+
+void hvx_mul_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems) {
+    size_t left_over       = num_elems & (VLEN_FP32 - 1);
+    size_t num_elems_whole = num_elems - left_over;
+
+    int unaligned_addr = 0;
+    int unaligned_loop = 0;
+    if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
+        FARF(HIGH, "hvx_mul_scalar_f32: unaligned address in hvx op, possibly slower execution\n");
+        unaligned_addr = 1;
+    }
+
+    if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
+        unaligned_loop = 1;
+        FARF(HIGH, "hvx_mul_scalar_f32: unaligned loop in hvx op, possibly slower execution\n");
+    }
+
+    HVX_Vector val_vec = hvx_vec_splat_fp32(val);
+
+    if (0 == unaligned_loop) {
+        HVX_Vector * restrict vec_in1 = (HVX_Vector *) src;
+        HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
+
+        #pragma unroll(4)
+        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
+            HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(*vec_in1++, val_vec);
+            *vec_out++   = Q6_Vsf_equals_Vqf32(v);
+        }
+    } else {
+        #pragma unroll(4)
+        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
+            HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
+
+            HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, val_vec);
+
+            *(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out);
+        }
+    }
+
+    if (left_over > 0) {
+        const float * srcf = (const float *) src + num_elems_whole;
+        float *       dstf = (float *) dst + num_elems_whole;
+
+        HVX_Vector in = *(HVX_UVector *) srcf;
+
+        HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, val_vec);
+        hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out));
+    }
+}
+
+void hvx_sub_f32(const uint8_t * restrict src0,
+                 const uint8_t * restrict src1,
+                 uint8_t * restrict dst,
+                 const int num_elems) {
+    size_t left_over       = num_elems & (VLEN_FP32 - 1);
+    size_t num_elems_whole = num_elems - left_over;
+
+    int unaligned_addr = 0;
+    int unaligned_loop = 0;
+    if ((0 == htp_is_aligned((void *) src0, VLEN)) || (0 == htp_is_aligned((void *) src1, VLEN)) ||
+        (0 == htp_is_aligned((void *) dst, VLEN))) {
+        FARF(HIGH, "hvx_sub_f32: unaligned address in hvx op, possibly slower execution\n");
+        unaligned_addr = 1;
+    }
+
+    if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
+        unaligned_loop = 1;
+        FARF(HIGH, "hvx_sub_f32: unaligned loop in hvx op, possibly slower execution\n");
+    }
+
+    if (0 == unaligned_loop) {
+        HVX_Vector * restrict vec_in1 = (HVX_Vector *) src0;
+        HVX_Vector * restrict vec_in2 = (HVX_Vector *) src1;
+        HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
+
+        #pragma unroll(4)
+        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
+            HVX_Vector v = Q6_Vqf32_vsub_VsfVsf(*vec_in1++, *vec_in2++);
+            *vec_out++   = Q6_Vsf_equals_Vqf32(v);
+        }
+    } else {
+        #pragma unroll(4)
+        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
+            HVX_Vector in1 = *(HVX_UVector *) (src0 + i * SIZEOF_FP32);
+            HVX_Vector in2 = *(HVX_UVector *) (src1 + i * SIZEOF_FP32);
+
+            HVX_Vector out = Q6_Vqf32_vsub_VsfVsf(in1, in2);
+
+            *(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out);
+        }
+    }
+
+    if (left_over > 0) {
+        const float * src0f = (const float *) src0 + num_elems_whole;
+        const float * src1f = (const float *) src1 + num_elems_whole;
+        float *       dstf  = (float *) dst + num_elems_whole;
+
+        HVX_Vector in1 = *(HVX_UVector *) src0f;
+        HVX_Vector in2 = *(HVX_UVector *) src1f;
+
+        HVX_Vector out = Q6_Vqf32_vsub_VsfVsf(in1, in2);
+        hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out));
+    }
+}
+
+void hvx_sub_f32_opt(const uint8_t * restrict src0,
+                     const uint8_t * restrict src1,
+                     uint8_t * restrict dst,
+                     const int num_elems) {
+    htp_binary_ops_preamble;
+
+    for (int i = 0; i < step_of_4; i++) {
+        HVX_Vector v1a = *(HVX_Vector *) src0_curr;
+
+        HVX_Vector v1b = *(HVX_Vector *) src1_curr;
+
+        HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN);
+
+        HVX_Vector v1 = Q6_Vqf32_vsub_VsfVsf(v1a, v1b);
+
+        HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN);
+
+        HVX_Vector v3a = *(HVX_Vector *) (src0_curr + 2 * VLEN);
+
+        HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v2a, v2b);
+
+        *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1);
+
+        HVX_Vector v3b = *(HVX_Vector *) (src1_curr + 2 * VLEN);
+
+        HVX_Vector v4a = *(HVX_Vector *) (src0_curr + 3 * VLEN);
+
+        src0_curr += 4 * VLEN;
+
+        HVX_Vector v3 = Q6_Vqf32_vsub_VsfVsf(v3a, v3b);
+
+        *(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2);
+
+        HVX_Vector v4b = *(HVX_Vector *) (src1_curr + 3 * VLEN);
+
+        *(HVX_Vector *) (dst_curr + 2 * VLEN) = Q6_Vsf_equals_Vqf32(v3);
+
+        HVX_Vector v4 = Q6_Vqf32_vsub_VsfVsf(v4a, v4b);
+
+        src1_curr += 4 * VLEN;
+
+        *(HVX_Vector *) (dst_curr + 3 * VLEN) = Q6_Vsf_equals_Vqf32(v4);
+
+        dst_curr += 4 * VLEN;
+    }
+    for (int i = 0; i < step_of_2; i++) {
+        HVX_Vector v1a = *(HVX_Vector *) src0_curr;
+
+        HVX_Vector v1b = *(HVX_Vector *) src1_curr;
+
+        HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN);
+
+        HVX_Vector v1 = Q6_Vqf32_vsub_VsfVsf(v1a, v1b);
+
+        HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN);
+
+        *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1);
+
+        src0_curr += 2 * VLEN;
+
+        HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v2a, v2b);
+
+        src1_curr += 2 * VLEN;
+
+        *(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2);
+
+        dst_curr += 2 * VLEN;
+    }
+    for (int i = 0; i < step_of_1; i++) {
+        HVX_Vector va = *(HVX_Vector *) src0_curr;
+
+        src0_curr += VLEN;
+
+        HVX_Vector vb = *(HVX_Vector *) src1_curr;
+
+        src1_curr += VLEN;
+
+        HVX_Vector v = Q6_Vqf32_vsub_VsfVsf(va, vb);
+
+        *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v);
+
+        dst_curr += VLEN;
+    }
+    if (remaining > 0) {
+        HVX_Vector v = Q6_Vqf32_vsub_VsfVsf(*(HVX_Vector *) src0_curr, *(HVX_Vector *) src1_curr);
+        hvx_vec_store_u((void *) dst_curr, remaining * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(v));
+    }
+}
+
+void hvx_sub_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems) {
+    size_t left_over       = num_elems & (VLEN_FP32 - 1);
+    size_t num_elems_whole = num_elems - left_over;
+
+    int unaligned_addr = 0;
+    int unaligned_loop = 0;
+    if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
+        FARF(HIGH, "hvx_sub_scalar_f32: unaligned address in hvx op, possibly slower execution\n");
+        unaligned_addr = 1;
+    }
+
+    if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
+        unaligned_loop = 1;
+        FARF(HIGH, "hvx_sub_scalar_f32: unaligned loop in hvx op, possibly slower execution\n");
+    }
+
+    HVX_Vector val_vec = hvx_vec_splat_fp32(val);
+
+    if (0 == unaligned_loop) {
+        HVX_Vector * restrict vec_in1 = (HVX_Vector *) src;
+        HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
+
+        #pragma unroll(4)
+        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
+            HVX_Vector v = Q6_Vqf32_vsub_VsfVsf(*vec_in1++, val_vec);
+            *vec_out++   = Q6_Vsf_equals_Vqf32(v);
+        }
+    } else {
+        #pragma unroll(4)
+        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
+            HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
+
+            HVX_Vector out = Q6_Vqf32_vsub_VsfVsf(in, val_vec);
+
+            *(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out);
+        }
+    }
+
+    if (left_over > 0) {
+        const float * srcf = (const float *) src + num_elems_whole;
+        float *       dstf = (float *) dst + num_elems_whole;
+
+        HVX_Vector in = *(HVX_UVector *) srcf;
+
+        HVX_Vector out = Q6_Vqf32_vsub_VsfVsf(in, val_vec);
+        hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out));
+    }
+}
+
+float hvx_sum_of_squares_f32(const uint8_t * restrict src, const int num_elems) {
+    int left_over       = num_elems & (VLEN_FP32 - 1);
+    int num_elems_whole = num_elems - left_over;
+
+    if (0 == htp_is_aligned((void *) src, VLEN)) {
+        FARF(HIGH, "hvx_sum_of_squares_f32: unaligned address in hvx op, possibly slower execution\n");
+    }
+
+    assert((1 == htp_is_aligned((void *) src, VLEN)) || (0 == num_elems_whole));
+
+    HVX_Vector * restrict vec_in1 = (HVX_Vector *) src;
+
+    HVX_Vector sum_vec_acc = Q6_V_vsplat_R(0x00000000);
+    HVX_Vector zero_vec    = Q6_V_vsplat_R(0x00000000);
+
+    #pragma unroll(4)
+    for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
+        HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(*vec_in1, *vec_in1);
+        sum_vec_acc  = Q6_Vqf32_vadd_Vqf32Vqf32(sum_vec_acc, v);
+        vec_in1++;
+    }
+
+    if (left_over > 0) {
+        const float * srcf = (const float *) src + num_elems_whole;
+
+        HVX_Vector vec_left = *(HVX_UVector *) srcf;
+
+        HVX_Vector vec_left_sq = Q6_Vqf32_vmpy_VsfVsf(vec_left, vec_left);
+        HVX_Vector vec_tmp     = Q6_V_valign_VVR(vec_left_sq, zero_vec, left_over * SIZEOF_FP32);
+
+        sum_vec_acc = Q6_Vqf32_vadd_Vqf32Vqf32(sum_vec_acc, vec_tmp);
+    }
+
+    HVX_Vector v = hvx_vec_qf32_reduce_sum(sum_vec_acc);
+    return hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(v));
+}
+
+float hvx_self_sum_f32(const uint8_t * restrict src, const int num_elems) {
+    int left_over       = num_elems & (VLEN_FP32 - 1);
+    int num_elems_whole = num_elems - left_over;
+
+    int unaligned_addr = 0;
+    int unaligned_loop = 0;
+    if (0 == htp_is_aligned((void *) src, VLEN)) {
+        FARF(HIGH, "hvx_self_sum_f32: unaligned address in hvx op, possibly slower execution\n");
+        unaligned_addr = 1;
+    }
+
+    if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
+        unaligned_loop = 1;
+        FARF(HIGH, "hvx_self_sum_f32: unaligned loop in hvx op, possibly slower execution\n");
+    }
+
+    HVX_Vector sum_vec  = Q6_V_vsplat_R(0x00000000);
+    HVX_Vector zero_vec = Q6_V_vsplat_R(0x00000000);
+
+    if (0 == unaligned_loop) {
+        HVX_Vector * vec_in = (HVX_Vector *) src;
+
+        #pragma unroll(4)
+        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
+            // sum_vec = Q6_Vqf32_vadd_Vqf32Vsf(sum_vec, *vec_in++);
+            sum_vec = Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(sum_vec), *vec_in++);
+        }
+    } else {
+        #pragma unroll(4)
+        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
+            HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
+
+            sum_vec = Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(sum_vec), in);
+        }
+    }
+
+    if (left_over > 0) {
+        const float * srcf = (const float *) src + num_elems_whole;
+
+        HVX_Vector vec_left = *(HVX_UVector *) srcf;
+        HVX_Vector vec_tmp  = Q6_V_valign_VVR(vec_left, zero_vec, left_over * SIZEOF_FP32);
+        // sum_vec = Q6_Vqf32_vadd_Vqf32Vsf(sum_vec, vec_tmp);
+        sum_vec             = Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(sum_vec), vec_tmp);
+    }
+
+    HVX_Vector v = hvx_vec_qf32_reduce_sum(sum_vec);
+    return hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(v));
+}
+
+void hvx_scale_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, const float scale) {
+    int left_over       = num_elems & (VLEN_FP32 - 1);
+    int num_elems_whole = num_elems - left_over;
+
+    int unaligned_addr = 0;
+    int unaligned_loop = 0;
+    if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
+        FARF(HIGH, "hvx_scale_f32: unaligned address in hvx op, possibly slower execution\n");
+        unaligned_addr = 1;
+    }
+
+    if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
+        unaligned_loop = 1;
+        FARF(HIGH, "hvx_scale_f32: unaligned loop in hvx op, possibly slower execution\n");
+    }
+
+    HVX_Vector scale_vec = hvx_vec_splat_fp32(scale);
+
+    if (0 == unaligned_loop) {
+        HVX_Vector * vec_in1 = (HVX_Vector *) src;
+        HVX_Vector * vec_out = (HVX_Vector *) dst;
+
+        #pragma unroll(4)
+        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
+            HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(*vec_in1++, scale_vec);
+            *vec_out++   = Q6_Vsf_equals_Vqf32(v);
+        }
+    } else {
+        #pragma unroll(4)
+        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
+            HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
+
+            HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, scale_vec);
+
+            *(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out);
+        }
+    }
+
+    if (left_over > 0) {
+        const float * srcf = (const float *) src + num_elems_whole;
+        float *       dstf = (float *) dst + num_elems_whole;
+
+        HVX_Vector in = *(HVX_UVector *) srcf;
+
+        HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, scale_vec);
+        hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out));
+    }
+}
+
+float hvx_self_max_f32(const uint8_t * restrict src, const int num_elems) {
+    int left_over       = num_elems & (VLEN_FP32 - 1);
+    int num_elems_whole = num_elems - left_over;
+
+    int unaligned_addr = 0;
+    int unaligned_loop = 0;
+    if (0 == htp_is_aligned((void *) src, VLEN)) {
+        FARF(HIGH, "hvx_self_max_f32: unaligned address in hvx op, possibly slower execution\n");
+        unaligned_addr = 1;
+    }
+
+    if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
+        unaligned_loop = 1;
+        FARF(HIGH, "hvx_self_max_f32: unaligned loop in hvx op, possibly slower execution\n");
+    }
+
+    HVX_Vector vec_max   = hvx_vec_splat_fp32(((const float *) src)[0]);
+    HVX_Vector vec_first = hvx_vec_splat_fp32(((const float *) src)[0]);
+
+    if (0 == unaligned_loop) {
+        HVX_Vector * restrict vec_in = (HVX_Vector *) src;
+
+        #pragma unroll(4)
+        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
+            vec_max = Q6_Vsf_vmax_VsfVsf(vec_max, *vec_in++);
+        }
+    } else {
+        #pragma unroll(4)
+        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
+            HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
+
+            vec_max = Q6_Vsf_vmax_VsfVsf(vec_max, in);
+        }
+    }
+
+    if (left_over > 0) {
+        const float * srcf = (const float *) src + num_elems_whole;
+
+        HVX_Vector in = *(HVX_UVector *) srcf;
+
+        HVX_Vector temp = Q6_V_valign_VVR(in, vec_first, left_over * SIZEOF_FP32);
+        vec_max         = Q6_Vsf_vmax_VsfVsf(vec_max, temp);
+    }
+
+    HVX_Vector v = hvx_vec_reduce_max_fp32(vec_max);
+    return hvx_vec_get_fp32(v);
+}
+
+void hvx_min_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems) {
+    size_t left_over       = num_elems & (VLEN_FP32 - 1);
+    size_t num_elems_whole = num_elems - left_over;
+
+    if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
+        FARF(HIGH, "hvx_min_scalar_f32: unaligned address in hvx op, possibly slower execution\n");
+    }
+
+    assert((1 == htp_is_aligned((void *) src, VLEN)) || (0 == num_elems_whole));
+
+    const float * src_f = (const float *) src;
+
+    HVX_Vector vec_min = Q6_V_vsplat_R(val);
+
+    HVX_Vector * restrict vec_in  = (HVX_Vector *) src;
+    HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
+
+    #pragma unroll(4)
+    for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
+        vec_min    = Q6_Vsf_vmin_VsfVsf(vec_min, *vec_in++);
+        *vec_out++ = Q6_Vsf_equals_Vqf32(vec_min);
+    }
+
+    if (left_over > 0) {
+        const float * srcf = (const float *) src + num_elems_whole;
+        float *       dstf = (float *) dst + num_elems_whole;
+
+        HVX_Vector in = *(HVX_UVector *) srcf;
+
+        vec_min = Q6_Vsf_vmin_VsfVsf(vec_min, in);
+
+        hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(vec_min));
+    }
+}
+
+void hvx_clamp_scalar_f32(const uint8_t * restrict src,
+                          const float limit_left,
+                          const float limit_right,
+                          uint8_t * restrict dst,
+                          const int num_elems) {
+    size_t left_over       = num_elems & (VLEN_FP32 - 1);
+    size_t num_elems_whole = num_elems - left_over;
+
+    if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
+        FARF(HIGH, "hvx_clamp_scalar_f32: unaligned address in hvx op, possibly slower execution\n");
+    }
+
+    assert((1 == htp_is_aligned((void *) src, VLEN)) || (0 == num_elems_whole));
+
+    HVX_Vector * restrict vec_in  = (HVX_Vector *) src;
+    HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
+
+    HVX_Vector range_left  = hvx_vec_splat_fp32(limit_left);
+    HVX_Vector range_right = hvx_vec_splat_fp32(limit_right);
+
+    #pragma unroll(4)
+    for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
+        HVX_Vector in_vec = *vec_in++;
+        HVX_Vector temp_v = in_vec;
+
+        HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, range_right);
+        HVX_VectorPred pred_cap_left  = Q6_Q_vcmp_gt_VsfVsf(range_left, in_vec);
+
+        in_vec = Q6_V_vmux_QVV(pred_cap_right, range_right, temp_v);
+        in_vec = Q6_V_vmux_QVV(pred_cap_left, range_left, temp_v);
+
+        *vec_out++ = Q6_Vsf_equals_Vqf32(in_vec);
+    }
+
+    if (left_over > 0) {
+        const float * srcf = (const float *) src + num_elems_whole;
+        float *       dstf = (float *) dst + num_elems_whole;
+
+        HVX_Vector in = *(HVX_UVector *) srcf;
+
+        HVX_Vector temp_v = in;
+
+        HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in, range_right);
+        HVX_VectorPred pred_cap_left  = Q6_Q_vcmp_gt_VsfVsf(range_left, in);
+
+        in = Q6_V_vmux_QVV(pred_cap_right, range_right, temp_v);
+        in = Q6_V_vmux_QVV(pred_cap_left, range_left, temp_v);
+
+        hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(in));
+    }
+}
diff --git a/src/ggml-hexagon/htp/hvx-utils.h b/src/ggml-hexagon/htp/hvx-utils.h
new file mode 100644 (file)
index 0000000..b2ca8e8
--- /dev/null
@@ -0,0 +1,998 @@
+#ifndef HVX_UTILS_H
+#define HVX_UTILS_H
+
+#include "ops-utils.h"
+
+#include <stdbool.h>
+#include <stdint.h>
+
+#define SIZEOF_FP32 (4)
+#define SIZEOF_FP16 (2)
+#define VLEN        (128)
+#define VLEN_FP32   (VLEN / SIZEOF_FP32)
+#define VLEN_FP16   (VLEN / SIZEOF_FP16)
+
+static inline HVX_Vector hvx_vec_splat_fp32(float i) {
+    union {
+        float   f;
+        int32_t i;
+    } fp32 = { .f = i };
+
+    return Q6_V_vsplat_R(fp32.i);
+}
+
+static inline void hvx_vec_store_u(void * addr, uint32_t n, HVX_Vector v) {
+    // Rotate as needed.
+    v = Q6_V_vlalign_VVR(v, v, (size_t) addr);
+
+    uint32_t left_off  = (size_t) addr & 127;
+    uint32_t right_off = left_off + n;
+
+    HVX_VectorPred ql_not = Q6_Q_vsetq_R((size_t) addr);
+    HVX_VectorPred qr     = Q6_Q_vsetq2_R(right_off);
+
+    if (right_off > 128) {
+        Q6_vmem_QRIV(qr, (HVX_Vector *) addr + 1, v);
+        // all 1's
+        qr = Q6_Q_vcmp_eq_VbVb(v, v);
+    }
+
+    ql_not = Q6_Q_or_QQn(ql_not, qr);
+    Q6_vmem_QnRIV(ql_not, (HVX_Vector *) addr, v);
+}
+
+static inline void hvx_vec_store_a(void * ptr, size_t n, HVX_Vector v) {
+    assert((unsigned long) ptr % 128 == 0);
+
+    HVX_VectorPred ql_not = Q6_Q_vsetq_R((size_t) ptr);
+    HVX_VectorPred qr     = Q6_Q_vsetq2_R(n);
+    ql_not                = Q6_Q_or_QQn(ql_not, qr);
+    Q6_vmem_QnRIV(ql_not, (HVX_Vector *) ptr, v);
+}
+
+static inline HVX_Vector hvx_vec_repl4(HVX_Vector v) {
+    // vdelta control to replicate first 4 bytes across all elements
+    static const uint8_t __attribute__((aligned(128))) repl[128] = {
+        0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
+        0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
+        0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
+        0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
+        0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
+        0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
+        0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
+        0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
+    };
+
+    HVX_Vector ctrl = *(HVX_Vector *) repl;
+    return Q6_V_vdelta_VV(v, ctrl);
+}
+
+// copy n fp16 elements : source and destination are aligned to HVX Vector (128)
+static inline void hvx_copy_fp16_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    HVX_Vector * restrict vdst = (HVX_Vector *) dst;
+    HVX_Vector * restrict vsrc = (HVX_Vector *) src;
+
+    assert((unsigned long) dst % 128 == 0);
+    assert((unsigned long) src % 128 == 0);
+
+    uint32_t nvec = n / 64;
+    uint32_t nloe = n % 64;
+
+    uint32_t i = 0;
+
+    #pragma unroll(4)
+    for (; i < nvec; i++) {
+        HVX_Vector v = vsrc[i];
+        vdst[i]      = v;
+    }
+
+    if (nloe) {
+        HVX_Vector v = vsrc[i];
+        hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), v);
+    }
+}
+
+// copy n fp16 elements : source is aligned, destination is potentially unaligned
+static inline void hvx_copy_fp16_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    HVX_UVector * restrict vdst = (HVX_UVector *) dst;
+    HVX_Vector * restrict vsrc  = (HVX_Vector *) src;
+
+    assert((unsigned long) src % 128 == 0);
+
+    uint32_t nvec = n / 64;
+    uint32_t nloe = n % 64;
+
+    uint32_t i = 0;
+
+    #pragma unroll(4)
+    for (; i < nvec; i++) {
+        HVX_Vector v = vsrc[i];
+        vdst[i]      = v;
+    }
+
+    if (nloe) {
+        HVX_Vector v = vsrc[i];
+        hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), v);
+    }
+}
+
+// copy n fp16 elements : source is aligned, destination is potentially unaligned
+static inline void hvx_copy_fp16_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    HVX_Vector * restrict vdst  = (HVX_Vector *) dst;
+    HVX_UVector * restrict vsrc = (HVX_UVector *) src;
+
+    assert((unsigned long) dst % 128 == 0);
+
+    uint32_t nvec = n / 64;
+    uint32_t nloe = n % 64;
+
+    uint32_t i = 0;
+
+    #pragma unroll(4)
+    for (; i < nvec; i++) {
+        HVX_Vector v = vsrc[i];
+        vdst[i]      = v;
+    }
+
+    if (nloe) {
+        HVX_Vector v = vsrc[i];
+        hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), v);
+    }
+}
+
+// copy n fp32 elements : source and destination are aligned to HVX Vector (128)
+static inline void hvx_copy_fp32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    HVX_Vector * restrict vdst = (HVX_Vector *) dst;
+    HVX_Vector * restrict vsrc = (HVX_Vector *) src;
+
+    assert((unsigned long) dst % 128 == 0);
+    assert((unsigned long) src % 128 == 0);
+
+    uint32_t nvec = n / 32;
+    uint32_t nloe = n % 32;
+
+    uint32_t i = 0;
+
+    #pragma unroll(4)
+    for (; i < nvec; i++) {
+        HVX_Vector v = vsrc[i];
+        vdst[i]      = v;
+    }
+
+    if (nloe) {
+        HVX_Vector v = vsrc[i];
+        hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(float), v);
+    }
+}
+
+// copy n fp32 elements : source is aligned, destination is unaligned
+static inline void hvx_copy_fp32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    HVX_UVector * restrict vdst = (HVX_UVector *) dst;
+    HVX_Vector * restrict vsrc  = (HVX_Vector *) src;
+
+    assert((unsigned long) src % 128 == 0);
+
+    uint32_t nvec = n / 32;
+    uint32_t nloe = n % 32;
+
+    uint32_t i = 0;
+
+    #pragma unroll(4)
+    for (; i < nvec; i++) {
+        HVX_Vector v = vsrc[i];
+        vdst[i]      = v;
+    }
+
+    if (nloe) {
+        HVX_Vector v = vsrc[i];
+        hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(float), v);
+    }
+}
+
+// copy n fp32 elements : source is unaligned, destination is aligned
+static inline void hvx_copy_fp32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    HVX_Vector * restrict vdst  = (HVX_Vector *) dst;
+    HVX_UVector * restrict vsrc = (HVX_UVector *) src;
+
+    assert((unsigned long) dst % 128 == 0);
+
+    uint32_t nvec = n / 32;
+    uint32_t nloe = n % 32;
+
+    uint32_t i = 0;
+
+    #pragma unroll(4)
+    for (; i < nvec; i++) {
+        HVX_Vector v = vsrc[i];
+        vdst[i]      = v;
+    }
+
+    if (nloe) {
+        HVX_Vector v = vsrc[i];
+        hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(float), v);
+    }
+}
+
+// bcast 1 fp32 element from source to n fp32 elements in destination : destination is aligned
+static inline void hvx_bcast_fp32_a(uint8_t * restrict dst, float elem, uint32_t n) {
+    HVX_Vector * restrict vdst = (HVX_Vector *) dst;
+
+    HVX_Vector velem = hvx_vec_splat_fp32(elem);
+
+    assert((unsigned long) dst % 128 == 0);
+
+    uint32_t nvec = n / 32;
+    uint32_t nloe = n % 32;
+
+    uint32_t i = 0;
+
+    #pragma unroll(4)
+    for (; i < nvec; i++) {
+        vdst[i] = velem;
+    }
+
+    if (nloe) {
+        hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(float), velem);
+    }
+}
+
+static __attribute__((always_inline)) int32_t is_in_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) {
+    uint32_t left_off  = (size_t) addr & (chunk_size - 1);
+    uint32_t right_off = left_off + n;
+    return right_off <= chunk_size;
+}
+
+static void hvx_vec_dump_fp16_n(char * pref, HVX_Vector v, uint32_t n) {
+    union {
+        HVX_Vector v;
+        __fp16 d[64];
+    } u = { .v = v };
+
+    const uint32_t n0 = n / 16;
+    const uint32_t n1 = n % 16;
+    int            i  = 0;
+    for (; i < n0; i++) {
+        htp_dump_fp16_line(pref, u.d + (16 * i), 16);
+    }
+    if (n1) {
+        htp_dump_fp16_line(pref, u.d + (16 * i), n1);
+    }
+}
+
+static void hvx_vec_dump_fp16(char * pref, HVX_Vector v) {
+    hvx_vec_dump_fp16_n(pref, v, 64);
+}
+
+static void hvx_vec_dump_fp32_n(char * pref, HVX_Vector v, uint32_t n) {
+    union {
+        HVX_Vector v;
+        float      d[32];
+    } u = { .v = v };
+
+    const uint32_t n0 = n / 16;
+    const uint32_t n1 = n % 16;
+    int            i  = 0;
+    for (; i < n0; i++) {
+        htp_dump_fp32_line(pref, u.d + (16 * i), 16);
+    }
+    if (n1) {
+        htp_dump_fp32_line(pref, u.d + (16 * i), n1);
+    }
+}
+
+static void hvx_vec_dump_fp32_hmt(char * pref, HVX_Vector v) {
+    union {
+        HVX_Vector v;
+        float      d[32];
+    } u = { .v = v };
+
+    FARF(HIGH, "%s: %.6f %.6f %.6f %.6f ...  %.6f %.6f %.6f %.6f ... %.6f %.6f %.6f %.6f\n", pref, u.d[0], u.d[1],
+         u.d[2], u.d[3], u.d[12], u.d[13], u.d[14], u.d[15], u.d[28], u.d[29], u.d[30], u.d[31]);
+}
+
+static void hvx_vec_dump_fp32(char * pref, HVX_Vector v) {
+    hvx_vec_dump_fp32_n(pref, v, 32);
+}
+
+static void hvx_vec_dump_int32(char * pref, HVX_Vector v) {
+    union {
+        HVX_Vector v;
+        int32_t    d[32];
+    } u = { .v = v };
+
+    for (int i = 0; i < 32 / 16; i++) {
+        htp_dump_int32_line(pref, u.d + (16 * i), 16);
+    }
+}
+
+static void hvx_vec_dump_int32_hmt(char * pref, HVX_Vector v) {
+    union {
+        HVX_Vector v;
+        int32_t    d[32];
+    } u = { .v = v };
+
+    FARF(HIGH, "%s: %d %d %d %d ... %d %d %d %d ... %d %d %d %d\n", pref, u.d[0], u.d[1], u.d[2], u.d[3], u.d[12],
+         u.d[13], u.d[14], u.d[15], u.d[28], u.d[29], u.d[30], u.d[31]);
+}
+
+static void hvx_vec_dump_int8_hmt(char * pref, HVX_Vector v) {
+    union {
+        HVX_Vector v;
+        int8_t     d[128];
+    } u = { .v = v };
+
+    FARF(HIGH, "%s: %d %d %d %d ... %d %d %d %d ... %d %d %d %d\n", pref, u.d[0], u.d[1], u.d[2], u.d[3], u.d[60],
+         u.d[61], u.d[62], u.d[63], u.d[124], u.d[125], u.d[126], u.d[127]);
+}
+
+static void hvx_vec_dump_int8(char * pref, HVX_Vector v) {
+    union {
+        HVX_Vector v;
+        int8_t     d[128];
+    } u = { .v = v };
+
+    for (int i = 0; i < 128 / 16; i++) {
+        htp_dump_int8_line(pref, u.d + (16 * i), 16);
+    }
+}
+
+static void hvx_vec_dump_uint8(char * pref, HVX_Vector v) {
+    union {
+        HVX_Vector v;
+        uint8_t    d[128];
+    } u = { .v = v };
+
+    for (int i = 0; i < 128 / 16; i++) {
+        htp_dump_uint8_line(pref, u.d + (16 * i), 16);
+    }
+}
+
+static bool hvx_vec_eq(HVX_Vector v0, HVX_Vector v1, size_t n) {
+    typedef union {
+        HVX_Vector v;
+        int8_t     d[128];
+    } U;
+
+    U u0 = { .v = v0 };
+    U u1 = { .v = v1 };
+
+    for (int i = 0; i < n; i++) {
+        if (u0.d[i] != u1.d[i]) {
+            return false;
+        }
+    }
+
+    return true;
+}
+
+static inline float hvx_vec_get_fp32(HVX_Vector v) {
+    float __attribute__((aligned(128))) x;
+    hvx_vec_store_a(&x, 4, v);
+    return x;
+}
+
+static inline HVX_Vector hvx_vec_int32_reduce_sum_n(HVX_Vector in, unsigned int n) {
+    unsigned int total = n * 4;  // total vec nbytes
+    unsigned int width = 4;      // int32
+
+    HVX_Vector sum = in, sum_t;
+    while (width < total) {
+        sum_t = Q6_V_vror_VR(sum, width);     // rotate right
+        sum   = Q6_Vw_vadd_VwVw(sum_t, sum);  // elementwise sum
+        width = width << 1;
+    }
+    return sum;
+}
+
+static inline HVX_Vector hvx_vec_int32_reduce_sum(HVX_Vector in) {
+    return hvx_vec_int32_reduce_sum_n(in, 32);
+}
+
+static inline HVX_Vector hvx_vec_qf32_reduce_sum_n(HVX_Vector in, unsigned int n) {
+    unsigned int total = n * 4;  // total vec nbytes
+    unsigned int width = 4;      // fp32 nbytes
+
+    HVX_Vector sum = in, sum_t;
+    while (width < total) {
+        sum_t = Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum), width);  // rotate right
+        sum   = Q6_Vqf32_vadd_Vqf32Vsf(sum, sum_t);             // elementwise sum
+        width = width << 1;
+    }
+    return sum;
+}
+
+static inline HVX_Vector hvx_vec_qf32_reduce_sum(HVX_Vector in) {
+    return hvx_vec_qf32_reduce_sum_n(in, 32);
+}
+
+static inline HVX_Vector hvx_vec_fp32_reduce_sum_n(HVX_Vector in, unsigned int n) {
+    unsigned int total = n * 4;  // total vec nbytes
+    unsigned int width = 4;      // fp32 nbytes
+
+    HVX_Vector sum = in, sum_t;
+    while (width < total) {
+        sum_t = Q6_V_vror_VR(sum, width);       // rotate right
+        sum   = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(sum, sum_t)); // elementwise sum
+        width = width << 1;
+    }
+    return sum;
+}
+
+static inline HVX_Vector hvx_vec_fp32_reduce_sum(HVX_Vector in) {
+    return hvx_vec_fp32_reduce_sum_n(in, 32);
+}
+
+static inline HVX_Vector hvx_vec_reduce_max_fp16(HVX_Vector in) {
+    unsigned total = 128;  // total vec nbytes
+    unsigned width = 2;    // fp16 nbytes
+
+    HVX_Vector _max = in, _max_t;
+    while (width < total) {
+        _max_t = Q6_V_vror_VR(_max, width);         // rotate right
+        _max   = Q6_Vhf_vmax_VhfVhf(_max_t, _max);  // elementwise max
+        width  = width << 1;
+    }
+
+    return _max;
+}
+
+static inline HVX_Vector hvx_vec_reduce_max2_fp16(HVX_Vector in, HVX_Vector _max) {
+    unsigned total = 128;  // total vec nbytes
+    unsigned width = 2;    // fp32 nbytes
+
+    HVX_Vector _max_t;
+
+    _max = Q6_Vhf_vmax_VhfVhf(in, _max);
+    while (width < total) {
+        _max_t = Q6_V_vror_VR(_max, width);         // rotate right
+        _max   = Q6_Vhf_vmax_VhfVhf(_max_t, _max);  // elementwise max
+        width  = width << 1;
+    }
+
+    return _max;
+}
+
+static inline HVX_Vector hvx_vec_reduce_max_fp32(HVX_Vector in) {
+    unsigned total = 128;  // total vec nbytes
+    unsigned width = 4;    // fp32 nbytes
+
+    HVX_Vector _max = in, _max_t;
+    while (width < total) {
+        _max_t = Q6_V_vror_VR(_max, width);         // rotate right
+        _max   = Q6_Vsf_vmax_VsfVsf(_max_t, _max);  // elementwise max
+        width  = width << 1;
+    }
+
+    return _max;
+}
+
+static inline HVX_Vector hvx_vec_reduce_max2_fp32(HVX_Vector in, HVX_Vector _max) {
+    unsigned total = 128;  // total vec nbytes
+    unsigned width = 4;    // fp32 nbytes
+
+    HVX_Vector _max_t;
+
+    _max = Q6_Vsf_vmax_VsfVsf(in, _max);
+    while (width < total) {
+        _max_t = Q6_V_vror_VR(_max, width);         // rotate right
+        _max   = Q6_Vsf_vmax_VsfVsf(_max_t, _max);  // elementwise max
+        width  = width << 1;
+    }
+
+    return _max;
+}
+
+static inline HVX_Vector hvx_vec_abs_fp16(HVX_Vector v) {
+    // abs by clearing the fp16 sign bit
+    HVX_Vector mask = Q6_Vh_vsplat_R(0x7fff);
+    return Q6_V_vand_VV(v, mask);
+}
+
+static inline HVX_Vector hvx_vec_neg_fp16(HVX_Vector v) {
+    // neg by setting the fp16 sign bit
+    HVX_Vector mask = Q6_Vh_vsplat_R(0x8000);
+    return Q6_V_vor_VV(v, mask);
+}
+
+static inline HVX_Vector hvx_vec_abs_fp32(HVX_Vector v) {
+    // abs by clearing the fp32 sign bit
+    HVX_Vector mask = Q6_V_vsplat_R(0x7fffffff);
+    return Q6_V_vand_VV(v, mask);
+}
+
+static inline HVX_Vector hvx_vec_neg_fp32(HVX_Vector v) {
+#if __HTP_ARCH__ > 75
+    return Q6_Vsf_vfneg_Vsf(v);
+#else
+    // neg by setting the fp32 sign bit
+    HVX_Vector mask = Q6_V_vsplat_R(0x80000000);
+    return Q6_V_vor_VV(v, mask);
+#endif  // __HTP_ARCH__ > 75
+}
+
+// ====================================================
+// FUNCTION: 1/(x+1)     y(0) = 1,  y(0.5) = 0.6667, y(1) = 0.5
+// Order:3; continuity: True; Ends forced: True
+// Mode: unsigned;   Result fractional bits: 14
+// Peak Error: 1.1295e-04  Rms Error: 2.8410e-05   Mean Error: 1.1370e-05
+//      32769  -32706   31252  -10589
+//      32590  -30635   22793   -4493
+//      32066  -27505   16481   -2348
+//      31205  -24054   11849   -1306
+
+static inline HVX_Vector hvx_vec_recip_xp1_O3_unsigned(HVX_Vector vx) {
+    // input is 0..0xffff representing 0.0  .. 1.0
+    HVX_Vector p;
+    p = Q6_Vh_vlut4_VuhPh(vx, 0xFAE6F6D4EE73D6A3ull);
+    p = Q6_Vh_vmpa_VhVhVuhPuh_sat(p, vx, 0x2E49406159097A14ull);
+    p = Q6_Vh_vmps_VhVhVuhPuh_sat(p, vx, 0x5DF66B7177AB7FC2ull);
+    p = Q6_Vh_vmpa_VhVhVuhPuh_sat(p, vx, 0x79E57D427F4E8001ull);
+    return p;  // signed result, 14 fractional bits
+}
+
+// Find reciprocal of fp16.
+// (1) first, convert to fp32, multiplying by 1.0; this is done to
+//    handle denormals. Ignoring sign and zero, result should be at
+//    least 5.9604645e-08 (32-bit code 0x33800000) and at most 131008 (0x47ffe000)
+//    (exponent in range [103,143])
+// (2) extract the mantissa into 16-bit unsigned; find reciprocal using a fitted poly
+// (3) put this, along with '253-exp' (exp from (1)) together to make an qf32
+// (4) convert that to fp16
+// (5) put sign back in. Also, if the original value (w/o sign) was <0x81, replace
+//     the result with the max value.
+static inline HVX_Vector hvx_vec_inverse_fp16(HVX_Vector vals) {
+    HVX_Vector     em_mask  = Q6_Vh_vsplat_R(0x7FFF);
+    HVX_Vector     avals    = Q6_V_vand_VV(vals, em_mask);
+    HVX_VectorPred is_neg   = Q6_Q_vcmp_gt_VhVh(avals, vals);
+    // is too small to 1/x ? for 'standard' fp16, this would be 0x101
+    HVX_VectorPred is_small = Q6_Q_vcmp_gt_VhVh(Q6_Vh_vsplat_R(0x101), avals);
+
+    HVX_VectorPair to_qf32  = Q6_Wqf32_vmpy_VhfVhf(avals, Q6_Vh_vsplat_R(0x3C00));  // *1.0
+    HVX_Vector     to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(to_qf32));
+    HVX_Vector     to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(to_qf32));
+
+    // bits 22..13 contain the mantissa now (w/o hidden bit); move to bit 14..5 of a 16-bit vector
+    HVX_Vector mant_u16 = Q6_Vh_vshuffo_VhVh(Q6_Vw_vasl_VwR(to_f32_1, 9), Q6_Vw_vasl_VwR(to_f32_0, 9));
+    // likewise extract the upper 16 from each, containing the exponents in range 103..142
+    HVX_Vector exp_u16  = Q6_Vh_vshuffo_VhVh(to_f32_1, to_f32_0);
+    //Get exponent in IEEE 32-bit representation
+    exp_u16             = Q6_Vuh_vlsr_VuhR(exp_u16, 7);
+
+    // so, mant_u16 contains an unbiased mantissa in upper 10 bits of each u16 lane
+    // We can consider it to be x-1.0, with 16 fractional bits, where 'x' is in range [1.0,2.0)
+    // Use poly to transform to 1/x, with 14 fractional bits
+    //
+    HVX_Vector rm = hvx_vec_recip_xp1_O3_unsigned(mant_u16);
+
+    HVX_Vector vcl0 = Q6_Vuh_vcl0_Vuh(rm);  //count leading zeros
+
+    // Get mantissa for 16-bit represenation
+    HVX_Vector mant_recip = Q6_V_vand_VV(Q6_Vh_vasr_VhR(Q6_Vh_vasl_VhVh(rm, vcl0), 5), Q6_Vh_vsplat_R(0x03FF));
+
+    //Compute Reciprocal Exponent
+    HVX_Vector exp_recip =
+        Q6_Vh_vsub_VhVh(Q6_Vh_vsub_VhVh(Q6_Vh_vsplat_R(254), exp_u16), Q6_Vh_vsub_VhVh(vcl0, Q6_Vh_vsplat_R(1)));
+    //Convert it for 16-bit representation
+    exp_recip = Q6_Vh_vadd_VhVh_sat(Q6_Vh_vsub_VhVh(exp_recip, Q6_Vh_vsplat_R(127)), Q6_Vh_vsplat_R(15));
+    exp_recip = Q6_Vh_vasl_VhR(exp_recip, 10);
+
+    //Merge exponent and mantissa for reciprocal
+    HVX_Vector recip = Q6_V_vor_VV(exp_recip, mant_recip);
+    // map 'small' inputs to standard largest value 0x7bff
+    recip            = Q6_V_vmux_QVV(is_small, Q6_Vh_vsplat_R(0x7bff), recip);
+    // add sign back
+    recip            = Q6_V_vandor_VQR(recip, is_neg, 0x80008000);
+    return recip;
+}
+
+#define IEEE_VSF_EXPLEN   (8)
+#define IEEE_VSF_EXPBIAS  (127)
+#define IEEE_VSF_EXPMASK  (0xFF)
+#define IEEE_VSF_MANTLEN  (23)
+#define IEEE_VSF_MANTMASK (0x7FFFFF)
+#define IEEE_VSF_MIMPMASK (0x800000)
+
+static inline HVX_Vector hvx_vec_truncate_fp32(HVX_Vector in_vec) {
+    HVX_Vector mask_mant_v  = Q6_V_vsplat_R(IEEE_VSF_MANTMASK);
+    HVX_Vector mask_impl_v  = Q6_V_vsplat_R(IEEE_VSF_MIMPMASK);
+    HVX_Vector const_zero_v = Q6_V_vzero();
+
+    HVX_VectorPred q_negative = Q6_Q_vcmp_gt_VwVw(const_zero_v, in_vec);
+
+    HVX_Vector expval_v = in_vec >> IEEE_VSF_MANTLEN;
+    expval_v &= IEEE_VSF_EXPMASK;
+    expval_v -= IEEE_VSF_EXPBIAS;
+
+    // negative exp == fractional value
+    HVX_VectorPred q_negexp = Q6_Q_vcmp_gt_VwVw(const_zero_v, expval_v);
+
+    HVX_Vector rshift_v = IEEE_VSF_MANTLEN - expval_v;         // fractional bits - exp shift
+
+    HVX_Vector mant_v = in_vec & mask_mant_v;                  // obtain mantissa
+    HVX_Vector vout   = Q6_Vw_vadd_VwVw(mant_v, mask_impl_v);  // add implicit 1.0
+
+    vout = Q6_Vw_vasr_VwVw(vout, rshift_v);                    // shift to obtain truncated integer
+    vout = Q6_V_vmux_QVV(q_negexp, const_zero_v, vout);        // expval<0 -> 0
+
+    HVX_Vector neg_vout = -vout;
+
+    vout = Q6_V_vmux_QVV(q_negative, neg_vout, vout);  // handle negatives
+
+    return (vout);
+}
+
+static inline HVX_Vector hvx_vec_floor_fp32(HVX_Vector in_vec) {
+    HVX_Vector mask_mant_v    = Q6_V_vsplat_R(IEEE_VSF_MANTMASK);
+    HVX_Vector mask_impl_v    = Q6_V_vsplat_R(IEEE_VSF_MIMPMASK);
+    HVX_Vector const_mnlen_v  = Q6_V_vsplat_R(IEEE_VSF_MANTLEN);
+    HVX_Vector const_zero_v   = Q6_V_vzero();
+    HVX_Vector const_negone_v = Q6_V_vsplat_R(0xbf800000);  // -1 IEEE vsf
+
+    HVX_VectorPred q_negative = Q6_Q_vcmp_gt_VwVw(const_zero_v, in_vec);
+
+    HVX_Vector expval_v = in_vec >> IEEE_VSF_MANTLEN;
+    expval_v &= IEEE_VSF_EXPMASK;
+    expval_v -= IEEE_VSF_EXPBIAS;
+
+    HVX_VectorPred q_negexp     = Q6_Q_vcmp_gt_VwVw(const_zero_v, expval_v);
+    HVX_VectorPred q_expltmn    = Q6_Q_vcmp_gt_VwVw(const_mnlen_v, expval_v);
+    HVX_VectorPred q_negexp_pos = Q6_Q_vcmp_gtand_QVwVw(q_negexp, in_vec, const_zero_v);
+    HVX_VectorPred q_negexp_neg = Q6_Q_vcmp_gtand_QVwVw(q_negexp, const_zero_v, in_vec);
+
+    // if expval < 0 (q_negexp)         // <0, floor is 0
+    //    if vin > 0
+    //       floor = 0
+    //    if vin < 0
+    //       floor = -1
+    // if expval < mant_len (q_expltmn) // >0, but fraction may exist
+    //    get sign (q_negative)
+    //    mask >> expval                // fraction bits to mask off
+    //    vout = ~(mask)                // apply mask to remove fraction
+    //    if (qneg)                     // negative floor is one less (more, sign bit for neg)
+    //      vout += ((impl_mask) >> expval)
+    //    if (mask && vin)
+    //      vout = vin
+    // else                             // already an integer
+    //    ;                             // no change
+
+    // compute floor
+    mask_mant_v >>= expval_v;
+    HVX_Vector neg_addin_v    = mask_impl_v >> expval_v;
+    HVX_Vector vout_neg_addin = Q6_Vw_vadd_VwVw(in_vec, neg_addin_v);
+    HVX_Vector vout           = Q6_V_vmux_QVV(q_negative, vout_neg_addin, in_vec);
+
+    HVX_Vector     mask_chk_v = Q6_V_vand_VV(in_vec, mask_mant_v);  // chk if bits set
+    HVX_VectorPred q_integral = Q6_Q_vcmp_eq_VwVw(const_zero_v, mask_chk_v);
+
+    HVX_Vector not_mask_v = Q6_V_vnot_V(mask_mant_v);        // frac bits to clear
+    HVX_Vector vfrfloor_v = Q6_V_vand_VV(vout, not_mask_v);  // clear frac bits
+
+    vout = in_vec;
+    vout = Q6_V_vmux_QVV(q_expltmn, vfrfloor_v, vout);         // expval<mant
+    vout = Q6_V_vmux_QVV(q_integral, in_vec, vout);            // integral values
+    vout = Q6_V_vmux_QVV(q_negexp_pos, const_zero_v, vout);    // expval<0 x>0 -> 0
+    vout = Q6_V_vmux_QVV(q_negexp_neg, const_negone_v, vout);  // expval<0 x<0 -> -1
+
+    return vout;
+}
+
+static inline HVX_Vector hvx_vec_i16_from_hf_rnd_sat(HVX_Vector vin) {
+    // This looks complicated.
+    // Ideally should just be Q6_Vh_equals_Vhf(vin)
+    // but that instruction does not do proper rounding.
+
+    // convert to qf32, multiplying by 1.0 in the process.
+    HVX_VectorPair v32 = Q6_Wqf32_vmpy_VhfVhf(vin, Q6_Vh_vsplat_R(0x3C00));
+
+    // 'in-range' values are +/32752.
+    // add 192K to it, convert to sf
+    HVX_Vector v192K = Q6_V_vsplat_R(0x48400000);
+    HVX_Vector vsf_0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(v32), v192K));
+    HVX_Vector vsf_1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(v32), v192K));
+
+    // for in-range cases, result is {163858... 229360} so the exponent is always 144.
+    // if we extract bits 21..0 as a signed quantity, and round 6 bits off, that will be the answer.
+    // Start by <<10 to get the final 'sign' bit in bit 15...
+    vsf_0 = Q6_Vw_vasl_VwR(vsf_0, 10);
+    vsf_1 = Q6_Vw_vasl_VwR(vsf_1, 10);
+
+    // now round down to 16
+    return Q6_Vh_vround_VwVw_sat(vsf_1, vsf_0);
+}
+
+static inline HVX_Vector hvx_vec_inverse_fp32(HVX_Vector v_sf) {
+    HVX_Vector inv_aprox_sf = Q6_V_vsplat_R(0x7EEEEBB3);
+    HVX_Vector two_sf       = hvx_vec_splat_fp32(2.0);
+
+    // First approximation
+    HVX_Vector i_sf = Q6_Vw_vsub_VwVw(inv_aprox_sf, v_sf);
+
+    HVX_Vector r_qf;
+
+    // Refine
+    r_qf = Q6_Vqf32_vmpy_VsfVsf(
+        i_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(i_sf, v_sf)))));
+    r_qf = Q6_Vqf32_vmpy_Vqf32Vqf32(
+        r_qf, Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(r_qf), v_sf))));
+    r_qf = Q6_Vqf32_vmpy_Vqf32Vqf32(
+        r_qf, Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(r_qf), v_sf))));
+
+    return Q6_Vsf_equals_Vqf32(r_qf);
+}
+
+#define FAST_SIGMOID_LOG2F (0x3fb8aa3b)  // 1.442695022
+#define FAST_SIGMOID_C1    (0x3d009076)  // 0.03138777
+#define FAST_SIGMOID_C2    (0x3e8d74bd)  // 0.276281267
+#define FAST_SIGMOID_C3    (0x3f000000)  // 0.5
+
+static inline HVX_Vector hvx_vec_fast_sigmoid_fp32(HVX_Vector v) {
+    v = Q6_Vqf32_vmpy_VsfVsf(v, Q6_V_vsplat_R(FAST_SIGMOID_LOG2F));
+    v = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v), Q6_V_vsplat_R(FAST_SIGMOID_C3));
+
+    HVX_Vector in_int = hvx_vec_truncate_fp32(Q6_Vsf_equals_Vqf32(v));
+    HVX_Vector x      = Q6_Vqf32_vsub_Vqf32Vsf(v, Q6_Vsf_equals_Vw(in_int));
+    HVX_Vector xx     = Q6_Vqf32_vmpy_Vqf32Vqf32(x, x);
+
+    HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(xx), Q6_V_vsplat_R(FAST_SIGMOID_C2));
+    v1            = Q6_Vqf32_vadd_Vqf32Vsf(v1, Q6_V_vsplat_R(FAST_SIGMOID_LOG2F));
+
+    HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(x), Q6_V_vsplat_R(FAST_SIGMOID_C1));
+    v2            = Q6_Vqf32_vmpy_Vqf32Vqf32(v2, xx);
+    v2            = Q6_Vqf32_vadd_Vqf32Vqf32(v2, x);
+
+    HVX_Vector v3          = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vqf32(v2, v1));
+    HVX_Vector v3_exponent = Q6_Vw_vasl_VwR(v3, 1);
+    v3_exponent            = Q6_Vuw_vlsr_VuwR(v3_exponent, 24);
+    v3_exponent            = Q6_Vw_vadd_VwVw(in_int, v3_exponent);
+    v3                     = Q6_Vw_vaslacc_VwVwR(v3, in_int, 24);
+
+    HVX_Vector v4 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_Vqf32Vqf32(v2, v1));
+    HVX_Vector v5 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(v3, v4));
+
+    HVX_Vector res = hvx_vec_inverse_fp32(v5);
+    res            = Q6_Vqf32_vmpy_VsfVsf(v3, res);
+
+    return Q6_Vsf_equals_Vqf32(res);
+}
+
+#define EXP_COEFF_5 (0x39506967)  // 0.000198757 = 1/(7!)
+#define EXP_COEFF_4 (0x3AB743CE)  // 0.0013982   = 1/(6!)
+#define EXP_COEFF_3 (0x3C088908)  // 0.00833345  = 1/(5!)
+#define EXP_COEFF_2 (0x3D2AA9C1)  // 0.416658    = 1/(4!)
+#define EXP_COEFF_1 (0x3E2AAAAA)  // 0.16666667  = 1/(3!)
+#define EXP_COEFF_0 (0x3F000000)  // 0.5         = 1/(2!)
+#define EXP_LOGN2   (0x3F317218)  // ln(2)   = 0.6931471805
+#define EXP_LOG2E   (0x3FB8AA3B)  // log2(e) = 1/ln(2) = 1.4426950408
+#define EXP_ONE     (0x3f800000)  // 1.0
+#define EXP_RANGE_R (0x41a00000)  // 20.0
+#define EXP_RANGE_L (0xc1a00000)  // -20.0
+
+static inline HVX_Vector hvx_vec_exp_fp32(HVX_Vector in_vec) {
+    HVX_Vector z_qf32_v;
+    HVX_Vector x_v;
+    HVX_Vector x_qf32_v;
+    HVX_Vector y_v;
+    HVX_Vector k_v;
+    HVX_Vector f_v;
+    HVX_Vector epsilon_v;
+    HVX_Vector log2e = Q6_V_vsplat_R(EXP_LOG2E);
+    HVX_Vector logn2 = Q6_V_vsplat_R(EXP_LOGN2);
+    HVX_Vector E_const;
+    HVX_Vector zero_v = Q6_V_vzero();
+
+    // exp(x) is approximated as follows:
+    //   f = floor(x/ln(2)) = floor(x*log2(e))
+    //   epsilon = x - f*ln(2)
+    //   exp(x) = exp(epsilon+f*ln(2))
+    //          = exp(epsilon)*exp(f*ln(2))
+    //          = exp(epsilon)*2^f
+    //
+    //   Since epsilon is close to zero, it can be approximated with its Taylor series:
+    //            exp(x) ~= 1+x+x^2/2!+x^3/3!+...+x^n/n!+...
+    //   Preserving the first eight elements, we get:
+    //            exp(x) ~= 1+x+e0*x^2+e1*x^3+e2*x^4+e3*x^5+e4*x^6+e5*x^7
+    //                   =  1+x+(E0+(E1+(E2+(E3+(E4+E5*x)*x)*x)*x)*x)*x^2
+
+    HVX_Vector temp_v = in_vec;
+
+    // Clamp inputs to (-20.0, 20.0)
+    HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, Q6_V_vsplat_R(EXP_RANGE_R));
+    HVX_VectorPred pred_cap_left  = Q6_Q_vcmp_gt_VsfVsf(Q6_V_vsplat_R(EXP_RANGE_L), in_vec);
+
+    in_vec = Q6_V_vmux_QVV(pred_cap_right, Q6_V_vsplat_R(EXP_RANGE_R), temp_v);
+    in_vec = Q6_V_vmux_QVV(pred_cap_left, Q6_V_vsplat_R(EXP_RANGE_L), temp_v);
+
+    epsilon_v = Q6_Vqf32_vmpy_VsfVsf(log2e, in_vec);
+    epsilon_v = Q6_Vsf_equals_Vqf32(epsilon_v);
+
+    //    f_v is the floating point result and k_v is the integer result
+    f_v = hvx_vec_floor_fp32(epsilon_v);
+    k_v = hvx_vec_truncate_fp32(f_v);
+
+    x_qf32_v = Q6_Vqf32_vadd_VsfVsf(in_vec, zero_v);
+
+    //  x = x - f_v * logn2;
+    epsilon_v = Q6_Vqf32_vmpy_VsfVsf(f_v, logn2);
+    x_qf32_v  = Q6_Vqf32_vsub_Vqf32Vqf32(x_qf32_v, epsilon_v);
+    // normalize before every QFloat's vmpy
+    x_qf32_v  = Q6_Vqf32_vadd_Vqf32Vsf(x_qf32_v, zero_v);
+
+    // z = x * x;
+    z_qf32_v = Q6_Vqf32_vmpy_Vqf32Vqf32(x_qf32_v, x_qf32_v);
+    z_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(z_qf32_v, zero_v);
+
+    x_v = Q6_Vsf_equals_Vqf32(x_qf32_v);
+
+    // y = E4 + E5 * x;
+    E_const = Q6_V_vsplat_R(EXP_COEFF_5);
+    y_v     = Q6_Vqf32_vmpy_VsfVsf(E_const, x_v);
+    E_const = Q6_V_vsplat_R(EXP_COEFF_4);
+    y_v     = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
+    y_v     = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
+
+    // y = E3 + y * x;
+    E_const = Q6_V_vsplat_R(EXP_COEFF_3);
+    y_v     = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v);
+    y_v     = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
+    y_v     = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
+
+    // y = E2 + y * x;
+    E_const = Q6_V_vsplat_R(EXP_COEFF_2);
+    y_v     = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v);
+    y_v     = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
+    y_v     = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
+
+    // y = E1 + y * x;
+    E_const = Q6_V_vsplat_R(EXP_COEFF_1);
+    y_v     = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v);
+    y_v     = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
+    y_v     = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
+
+    // y = E0 + y * x;
+    E_const = Q6_V_vsplat_R(EXP_COEFF_0);
+    y_v     = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v);
+    y_v     = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
+    y_v     = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
+
+    // y = x + y * z;
+    y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, z_qf32_v);
+    y_v = Q6_Vqf32_vadd_Vqf32Vqf32(y_v, x_qf32_v);
+    y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
+
+    // y = y + 1.0;
+    y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, Q6_V_vsplat_R(EXP_ONE));
+
+    // insert exponents
+    //        y = ldexpf(y, k);
+    //    y_v += k_v; // qf32
+    // modify exponent
+
+    y_v = Q6_Vsf_equals_Vqf32(y_v);
+
+    // add k_v to the exponent of y_v
+    HVX_Vector y_v_exponent = Q6_Vw_vasl_VwR(y_v, 1);
+
+    y_v_exponent = Q6_Vuw_vlsr_VuwR(y_v_exponent, IEEE_VSF_MANTLEN + 1);
+    y_v_exponent = Q6_Vw_vadd_VwVw(k_v, y_v_exponent);
+
+    // exponent cannot be negative; if overflow is detected, result is set to zero
+    HVX_VectorPred qy_v_negative_exponent = Q6_Q_vcmp_gt_VwVw(zero_v, y_v_exponent);
+
+    y_v = Q6_Vw_vaslacc_VwVwR(y_v, k_v, IEEE_VSF_MANTLEN);
+
+    y_v = Q6_V_vmux_QVV(qy_v_negative_exponent, zero_v, y_v);
+
+    return y_v;
+}
+
+#define RSQRT_CONST        0x5f3759df  // Constant for fast inverse square root calculation
+#define RSQRT_ONE_HALF     0x3f000000  // 0.5
+#define RSQRT_THREE_HALVES 0x3fc00000  // 1.5
+
+static inline HVX_Vector hvx_vec_rsqrt_fp32(HVX_Vector in_vec) {
+    //Algorithm :
+    //  x2 = input*0.5
+    //  y  = * (long *) &input
+    //  y  = 0x5f3759df - (y>>2)
+    //  y  = y*(threehalfs - x2*y*y)
+
+    HVX_Vector rsqrtconst = Q6_V_vsplat_R(RSQRT_CONST);
+    HVX_Vector onehalf    = Q6_V_vsplat_R(RSQRT_ONE_HALF);
+    HVX_Vector threehalfs = Q6_V_vsplat_R(RSQRT_THREE_HALVES);
+
+    HVX_Vector x2, y, ypower2, temp;
+
+    x2 = Q6_Vqf32_vmpy_VsfVsf(in_vec, onehalf);
+    x2 = Q6_Vqf32_vadd_Vqf32Vsf(x2, Q6_V_vzero());
+
+    y = Q6_Vw_vasr_VwR(in_vec, 1);
+    y = Q6_Vw_vsub_VwVw(rsqrtconst, y);
+
+    // 1st iteration
+    ypower2 = Q6_Vqf32_vmpy_VsfVsf(y, y);
+    ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero());
+    temp    = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2);
+    temp    = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp));
+    temp    = Q6_Vqf32_vmpy_VsfVsf(y, Q6_Vsf_equals_Vqf32(temp));
+
+    // 2nd iteration
+    y       = Q6_Vqf32_vadd_Vqf32Vsf(temp, Q6_V_vzero());
+    ypower2 = Q6_Vqf32_vmpy_Vqf32Vqf32(y, y);
+    ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero());
+    temp    = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2);
+    temp    = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp));
+    temp    = Q6_Vqf32_vmpy_Vqf32Vqf32(y, temp);
+
+    // 3rd iteration
+    y       = Q6_Vqf32_vadd_Vqf32Vsf(temp, Q6_V_vzero());
+    ypower2 = Q6_Vqf32_vmpy_Vqf32Vqf32(y, y);
+    ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero());
+    temp    = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2);
+    temp    = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp));
+    temp    = Q6_Vqf32_vmpy_Vqf32Vqf32(y, temp);
+
+    return Q6_Vsf_equals_Vqf32(temp);
+}
+
+static inline void hvx_fast_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems) {
+    int step_of_1 = num_elems >> 5;
+    int remaining = num_elems - step_of_1 * VLEN_FP32;
+
+    assert(remaining == 0);
+
+    const HVX_Vector * restrict v_src = (HVX_Vector *) src;
+    HVX_Vector * restrict v_dst       = (HVX_Vector *) dst;
+
+    #pragma unroll(4)
+    for (int i = 0; i < step_of_1; i++) {
+        v_dst[i] = hvx_vec_fast_sigmoid_fp32(v_src[i]);
+    }
+}
+
+float hvx_sum_of_squares_f32(const uint8_t * restrict src, const int num_elems);
+void  hvx_mul_f32(const uint8_t * restrict src0,
+                  const uint8_t * restrict src1,
+                  uint8_t * restrict dst,
+                  const int num_elems);
+void  hvx_mul_f32_opt(const uint8_t * restrict src0,
+                      const uint8_t * restrict src1,
+                      uint8_t * restrict dst,
+                      const int num_elems);
+void  hvx_mul_mul_f32_opt(const uint8_t * restrict src0,
+                          const uint8_t * restrict src1,
+                          const uint8_t * restrict src2,
+                          uint8_t * restrict dst,
+                          const int num_elems);
+void  hvx_mul_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems);
+void  hvx_add_f32(const uint8_t * restrict src0,
+                  const uint8_t * restrict src1,
+                  uint8_t * restrict dst,
+                  const int num_elems);
+void  hvx_add_f32_opt(const uint8_t * restrict src0,
+                      const uint8_t * restrict src1,
+                      uint8_t * restrict dst,
+                      const int num_elems);
+void  hvx_add_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems);
+void  hvx_sub_f32(const uint8_t * restrict src0,
+                  const uint8_t * restrict src1,
+                  uint8_t * restrict dst,
+                  const int num_elems);
+void  hvx_sub_f32_opt(const uint8_t * restrict src0,
+                      const uint8_t * restrict src1,
+                      uint8_t * restrict dst,
+                      const int num_elems);
+void  hvx_sub_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems);
+void  hvx_scale_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, const float scale);
+void  hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems);
+void  hvx_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems);
+void  hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate);
+float hvx_self_max_f32(const uint8_t * restrict src, const int num_elems);
+float hvx_self_sum_f32(const uint8_t * restrict src, const int num_elems);
+void  hvx_min_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems);
+void  hvx_clamp_scalar_f32(const uint8_t * restrict src,
+                           const float limit_left,
+                           const float limit_right,
+                           uint8_t * restrict dst,
+                           const int num_elems);
+
+#endif /* HVX_UTILS_H */
diff --git a/src/ggml-hexagon/htp/main.c b/src/ggml-hexagon/htp/main.c
new file mode 100644 (file)
index 0000000..e35ea3b
--- /dev/null
@@ -0,0 +1,945 @@
+#pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments"
+#pragma clang diagnostic ignored "-Wunused-function"
+
+#define FARF_ERROR  1
+#define FARF_HIGH   1
+#define FARF_MEDIUM 0
+#define FARF_LOW    0
+#include <AEEStdErr.h>
+#include <dspqueue.h>
+#include <HAP_compute_res.h>
+#include <HAP_etm_config.h>
+#include <HAP_farf.h>
+#include <HAP_mem.h>
+#include <HAP_perf.h>
+#include <HAP_power.h>
+#include <HAP_ps.h>
+#include <qurt.h>
+#include <qurt_thread.h>
+#include <remote.h>
+#include <string.h>
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+#include "htp-ctx.h"
+#include "htp-dma.h"
+#include "htp-msg.h"
+#include "htp-ops.h"
+#include "ops-utils.h"
+#include "worker-pool.h"
+
+AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) {
+    struct htp_context * ctx;
+    int                  err = 0;
+
+    ctx = calloc(1, sizeof(*ctx));
+    if (ctx == NULL) {
+        return AEE_ENOMEMORY;
+    }
+
+    // Use the context structure as a handle
+    *handle = (remote_handle64) ctx;
+
+    // Enable FARF logs
+    HAP_setFARFRuntimeLoggingParams(0xffff, NULL, 0);
+
+    // Set client class
+    {
+        HAP_power_request_t request;
+        memset(&request, 0, sizeof(HAP_power_request_t));
+        request.type    = HAP_power_set_apptype;
+        request.apptype = HAP_POWER_COMPUTE_CLIENT_CLASS;
+
+        if ((err = HAP_power_set((void *) ctx, &request)) != 0) {
+            return err;
+        }
+    }
+
+    {
+        HAP_power_request_t request;
+        memset(&request, 0, sizeof(request));
+
+        request.type                              = HAP_power_set_DCVS_v3;
+        request.dcvs_v3.set_dcvs_enable           = TRUE;
+        request.dcvs_v3.dcvs_enable               = TRUE;
+        request.dcvs_v3.dcvs_option               = HAP_DCVS_V2_PERFORMANCE_MODE;
+        request.dcvs_v3.set_bus_params            = TRUE;
+        request.dcvs_v3.bus_params.min_corner     = HAP_DCVS_VCORNER_MAX;
+        request.dcvs_v3.bus_params.max_corner     = HAP_DCVS_VCORNER_MAX;
+        request.dcvs_v3.bus_params.target_corner  = HAP_DCVS_VCORNER_MAX;
+        request.dcvs_v3.set_core_params           = TRUE;
+        request.dcvs_v3.core_params.min_corner    = HAP_DCVS_VCORNER_MAX;
+        request.dcvs_v3.core_params.max_corner    = HAP_DCVS_VCORNER_MAX;
+        request.dcvs_v3.core_params.target_corner = HAP_DCVS_VCORNER_MAX;
+        request.dcvs_v3.set_sleep_disable         = TRUE;
+        request.dcvs_v3.sleep_disable             = TRUE;
+        if ((err = HAP_power_set((void *) ctx, &request)) != 0) {
+            return err;
+        }
+
+        memset(&request, 0, sizeof(request));
+        request.type         = HAP_power_set_HVX;
+        request.hvx.power_up = TRUE;
+        if ((err = HAP_power_set((void *) ctx, &request)) != 0) {
+            return err;
+        }
+    }
+
+    {
+        // Power on HMX
+        HAP_power_request_t request;
+        memset(&request, 0, sizeof(HAP_power_request_t));
+        request.type         = HAP_power_set_HMX;
+        request.hmx.power_up = TRUE;
+        FARF(ALWAYS, "Powering HMX on\n");
+        err = HAP_power_set((void *) &ctx, &request);
+        if (err != AEE_SUCCESS) {
+            FARF(ERROR, "Error powering on HMX.");
+            return err;
+        }
+    }
+
+    return AEE_SUCCESS;
+}
+
+AEEResult htp_iface_close(remote_handle64 handle) {
+    struct htp_context * ctx = (struct htp_context *) handle;
+
+    if (!ctx) {
+        return AEE_EBADPARM;
+    }
+
+    if (ctx->queue) {
+        FARF(ERROR, "Closing handle with queue still open");
+        return AEE_EITEMBUSY;
+    }
+
+    free(ctx);
+    return AEE_SUCCESS;
+}
+
+AEEResult htp_iface_enable_etm(remote_handle64 handle) {
+    int err = HAP_user_etm_enable();
+    if (err) {
+        if (err == AEE_EVERSIONNOTSUPPORT) {
+            FARF(ERROR, "API HAP_user_etm_enable is not supported\n");
+        } else {
+            FARF(ERROR, "Error executing HAP_user_etm_enable with error code : 0x%x\n", err);
+        }
+    }
+    return err;
+}
+
+AEEResult htp_iface_disable_etm(remote_handle64 handle) {
+    int err = HAP_user_etm_disable();
+    if (err) {
+        if (err == AEE_EVERSIONNOTSUPPORT) {
+            FARF(ERROR, "API HAP_user_etm_disable is not supported\n");
+        } else {
+            FARF(ERROR, "Error executing HAP_user_etm_disable with error code : 0x%x\n", err);
+        }
+    }
+    return err;
+}
+
+static int vtcm_acquire(struct htp_context * ctx) {
+    if (!ctx->vtcm_valid) {
+        // Temporarily bump thread priority to make sure it's higher than other sessions.
+        // This way the resource manager will notify the other thread to release VTCM.
+        // Note that we need to reaquire VTCM at normal priority for this to work next time.
+        qurt_thread_set_priority(qurt_thread_get_id(), ctx->thread_prio - 10);
+        HAP_compute_res_acquire_cached(ctx->vtcm_rctx, 1000000);
+        HAP_compute_res_release_cached(ctx->vtcm_rctx);
+        qurt_thread_set_priority(qurt_thread_get_id(), ctx->thread_prio);
+
+        HAP_compute_res_acquire_cached(ctx->vtcm_rctx, 1000000);
+        ctx->vtcm_valid = true;
+    }
+
+    ctx->vtcm_inuse = true;
+    return 0;
+}
+
+static int vtcm_release(struct htp_context * ctx) {
+    ctx->vtcm_inuse = false;
+
+    if (ctx->vtcm_valid && ctx->vtcm_needs_release) {
+        ctx->vtcm_valid         = false;
+        ctx->vtcm_needs_release = false;
+        HAP_compute_res_release_cached(ctx->vtcm_rctx);
+    }
+
+    return 0;
+}
+
+static int vtcm_release_callback(unsigned int rctx, void * state) {
+    struct htp_context * ctx = (struct htp_context *) state;
+
+    if (!ctx || ctx->vtcm_rctx != rctx) {
+        return AEE_EBADPARM;
+    }
+
+    // If VTCM is not inuse (not processing Ops) release it right here
+    // otherwise we'll release it once we're done with the current Op.
+
+    if (ctx->vtcm_inuse) {
+        ctx->vtcm_needs_release = false;
+        return 0;
+    }
+
+    ctx->vtcm_valid = false;
+    HAP_compute_res_release_cached(ctx->vtcm_rctx);
+
+    return 0;
+}
+
+static int vtcm_alloc(struct htp_context * ctx) {
+    unsigned int vtcm_size = 8 * 1024 * 1024;  // 8MB default
+    HAP_compute_res_query_VTCM(0, &vtcm_size, NULL, NULL, NULL);
+
+    compute_res_attr_t attr;
+    HAP_compute_res_attr_init(&attr);
+    HAP_compute_res_attr_set_serialize(&attr, 0);
+    HAP_compute_res_attr_set_cache_mode(&attr, 1);
+    HAP_compute_res_attr_set_vtcm_param_v2(&attr, vtcm_size, vtcm_size, vtcm_size);
+    HAP_compute_res_attr_set_release_callback(&attr, vtcm_release_callback, (void *) ctx);
+    HAP_compute_res_attr_set_hmx_param(&attr, 1);
+
+    // Allocate VTCM for scratch pads
+    uint32_t rctx = HAP_compute_res_acquire(&attr, 1000000 /* timeout */);
+    if (!rctx) {
+        FARF(ERROR, "failed to allocate %zu bytes VTCM\n", ctx->vtcm_size);
+        return AEE_ENOMEMORY;
+    }
+
+    void * vtcm_ptr;
+    if (HAP_compute_res_attr_get_vtcm_ptr_v2(&attr, &vtcm_ptr, &vtcm_size) != 0) {
+        HAP_compute_res_release(rctx);
+        FARF(ERROR, "failed to allocate %zu bytes VTCM (new)\n", ctx->vtcm_size);
+        return AEE_ENOMEMORY;
+    }
+
+    ctx->vtcm_base          = (uint8_t *) vtcm_ptr;
+    ctx->vtcm_size          = vtcm_size;
+    ctx->vtcm_rctx          = rctx;
+    ctx->vtcm_valid         = false;
+    ctx->vtcm_inuse         = false;
+    ctx->vtcm_needs_release = false;
+
+    return 0;
+}
+
+static void vtcm_free(struct htp_context * ctx) {
+    if (ctx->vtcm_rctx) {
+        HAP_compute_res_release(ctx->vtcm_rctx);
+        ctx->vtcm_base = 0;
+        ctx->vtcm_rctx = 0;
+    }
+}
+
+static void htp_packet_callback(dspqueue_t queue, int error, void * context);
+static void htp_error_callback(dspqueue_t queue, int error, void * context);
+
+AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx) {
+    struct htp_context * ctx = (struct htp_context *) handle;
+
+    if (!ctx) {
+        return AEE_EBADPARM;
+    }
+
+    if (ctx->queue) {
+        FARF(ERROR, "Queue already open");
+        return AEE_EITEMBUSY;
+    }
+
+    // Import queue created on the CPU
+    int err = dspqueue_import(dsp_queue_id,         // Queue ID from dspqueue_export
+                              htp_packet_callback,  // Packet callback
+                              htp_error_callback,   // Error callback; no errors expected on the DSP
+                              (void *) ctx,         // Callback context
+                              &ctx->queue);
+
+    if (err) {
+        FARF(ERROR, "Queue import failed with 0x%08x", (unsigned) err);
+        return err;
+    }
+
+    ctx->thread_id   = qurt_thread_get_id();
+    ctx->thread_prio = qurt_thread_get_priority(ctx->thread_id);
+
+    // allocate VTCM
+    err = vtcm_alloc(ctx);
+    if (err != AEE_SUCCESS) {
+        FARF(ERROR, "Unable to allocate VTCM");
+        return AEE_ENOMEMORY;
+    }
+
+    qurt_sysenv_max_hthreads_t hw_threads;
+    qurt_sysenv_get_max_hw_threads(&hw_threads);
+    uint32_t hw_nhvx = (qurt_hvx_get_units() >> 8) & 0xFF;
+
+    if (n_hvx == 0) {
+        n_hvx = hw_nhvx;
+    }
+    if (n_hvx > hw_threads.max_hthreads) {
+        n_hvx = hw_threads.max_hthreads;
+    }
+    if (n_hvx > HTP_MAX_NTHREADS) {
+        n_hvx = HTP_MAX_NTHREADS;
+    }
+
+    ctx->n_threads = n_hvx;
+    for (int i = 0; i < ctx->n_threads; i++) {
+        ctx->dma[i] = dma_queue_create(HTP_SPAD_SRC0_NROWS * 2);
+    }
+
+    // init worker pool
+    err = worker_pool_init(&ctx->worker_pool, n_hvx);
+    if (err != AEE_SUCCESS) {
+        FARF(ERROR, "Unable to create worker pool");
+        return err;
+    }
+
+    FARF(HIGH, "session %u started: n-hvx %u vtcm-size %zu vtcm-rctx %u n-threads %u thread-id %d thread-prio %d \n",
+         sess_id, hw_nhvx, ctx->vtcm_size, ctx->vtcm_rctx, ctx->n_threads, ctx->thread_id, ctx->thread_prio);
+
+    return AEE_SUCCESS;
+}
+
+AEEResult htp_iface_stop(remote_handle64 handle) {
+    struct htp_context * ctx = (struct htp_context *) handle;
+    if (!ctx) {
+        return AEE_EBADPARM;
+    }
+
+    if (!ctx->queue) {
+        FARF(ERROR, "Queue not open");
+        return AEE_EBADSTATE;
+    }
+
+    // Close queue. dspqueue_close() will also wait for callbacks to finish.
+    int err    = dspqueue_close(ctx->queue);
+    ctx->queue = NULL;
+    if (err != 0) {
+        FARF(ERROR, "Queue close failed with 0x%08x", (unsigned) err);
+        return err;
+    }
+
+    if (ctx->worker_pool) {
+        // Release worker pool
+        worker_pool_release(&ctx->worker_pool);
+    }
+
+    for (int i = 0; i < ctx->n_threads; i++) {
+        dma_queue_delete(ctx->dma[i]);
+    }
+
+    vtcm_free(ctx);
+
+    return AEE_SUCCESS;
+}
+
+static void htp_error_callback(dspqueue_t queue, int error, void * context) {
+    // No errors expected on the DSP.
+    FARF(ERROR, "Error callback: 0x%08x", (unsigned) error);
+}
+
+struct profile_data {
+    uint64_t usecs;
+    uint64_t cycles;
+    uint64_t pkts;
+};
+
+static inline void profile_start(struct profile_data * d) {
+    d->usecs  = HAP_perf_get_qtimer_count();
+    d->cycles = htp_get_cycles();
+    d->pkts   = htp_get_pktcnt();
+}
+
+static inline void profile_stop(struct profile_data * d) {
+    d->usecs  = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - d->usecs);
+    d->cycles = htp_get_cycles() - d->cycles;
+    d->pkts   = htp_get_pktcnt() - d->pkts;
+}
+
+static int send_htp_rsp(struct htp_context *     c,
+                        uint32_t                 op,
+                        uint32_t                 status,
+                        struct dspqueue_buffer * bufs,
+                        size_t                   n_bufs,
+                        struct profile_data *    prof) {
+    // Prep response struct
+    struct htp_general_rsp rsp;
+    rsp.op          = op;
+    rsp.status      = status;
+    rsp.prof_usecs  = prof->usecs;
+    rsp.prof_cycles = prof->cycles;
+    rsp.prof_pkts   = prof->pkts;
+
+    int err = dspqueue_write(c->queue,
+                             0,                       // Flags
+                             n_bufs,
+                             bufs,                    // Buffer references
+                             sizeof(rsp),
+                             (const uint8_t *) &rsp,  // Message
+                             DSPQUEUE_TIMEOUT_NONE);
+
+    if (err != 0) {
+        FARF(ERROR, "dspqueue_write failed: 0x%08x", (unsigned) err);
+    }
+
+    return err;
+}
+
+static void proc_matmul_req(struct htp_context *     ctx,
+                            struct htp_general_req * req,
+                            struct dspqueue_buffer * bufs,
+                            size_t                   n_bufs) {
+    // Prep response buffer structs (needed for error responses, etc)
+    struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
+    memset(rsp_bufs, 0, sizeof(rsp_bufs));
+    rsp_bufs[0].fd     = bufs[0].fd;
+    rsp_bufs[0].ptr    = bufs[0].ptr;
+    rsp_bufs[0].size   = bufs[0].size;
+    rsp_bufs[0].offset = bufs[0].offset;
+    rsp_bufs[0].flags  = DSPQUEUE_BUFFER_FLAG_DEREF;  // Release reference
+
+    rsp_bufs[1].fd     = bufs[1].fd;
+    rsp_bufs[1].ptr    = bufs[1].ptr;
+    rsp_bufs[1].size   = bufs[1].size;
+    rsp_bufs[1].offset = bufs[1].offset;
+    rsp_bufs[1].flags  = DSPQUEUE_BUFFER_FLAG_DEREF;  // Release reference
+
+    // We had written to the output buffer, we'd also need to flush it
+    rsp_bufs[2].fd     = bufs[2].fd;
+    rsp_bufs[2].ptr    = bufs[2].ptr;
+    rsp_bufs[2].size   = bufs[2].size;
+    rsp_bufs[2].offset = bufs[2].offset;
+    rsp_bufs[2].flags  = (DSPQUEUE_BUFFER_FLAG_DEREF |                 // Release reference
+                         DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |          // Flush NSP
+                         DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);  // Invalidate CPU
+
+    // Setup Op context
+    struct htp_ops_context octx = { 0 };
+    octx.ctx                    = ctx;
+    octx.src0                   = req->src0;
+    octx.src1                   = req->src1;
+    octx.dst                    = req->dst;
+    octx.flags                  = req->flags;
+    octx.op                     = req->op;
+
+    // Update data pointers
+    octx.src0.data = (uint32_t) bufs[0].ptr;
+    octx.src1.data = (uint32_t) bufs[1].ptr;
+    octx.dst.data  = (uint32_t) bufs[2].ptr;
+    octx.n_threads = ctx->n_threads;
+
+    struct profile_data prof;
+    profile_start(&prof);
+
+    uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
+    if (vtcm_acquire(ctx) == AEE_SUCCESS) {
+        rsp_status = op_matmul(&octx);
+        vtcm_release(ctx);
+    }
+
+    profile_stop(&prof);
+    send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 3, &prof);
+}
+
+static void proc_matmul_id_req(struct htp_context *     ctx,
+                               struct htp_general_req * req,
+                               struct dspqueue_buffer * bufs,
+                               size_t                   n_bufs) {
+    // Prep response buffer structs (needed for error responses, etc)
+    struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
+    memset(rsp_bufs, 0, sizeof(rsp_bufs));
+    rsp_bufs[0].fd     = bufs[0].fd;
+    rsp_bufs[0].ptr    = bufs[0].ptr;
+    rsp_bufs[0].size   = bufs[0].size;
+    rsp_bufs[0].offset = bufs[0].offset;
+    rsp_bufs[0].flags  = DSPQUEUE_BUFFER_FLAG_DEREF;  // Release reference
+
+    rsp_bufs[1].fd     = bufs[1].fd;
+    rsp_bufs[1].ptr    = bufs[1].ptr;
+    rsp_bufs[1].size   = bufs[1].size;
+    rsp_bufs[1].offset = bufs[1].offset;
+    rsp_bufs[1].flags  = DSPQUEUE_BUFFER_FLAG_DEREF;  // Release reference
+
+    rsp_bufs[2].fd     = bufs[2].fd;
+    rsp_bufs[2].ptr    = bufs[2].ptr;
+    rsp_bufs[2].size   = bufs[2].size;
+    rsp_bufs[2].offset = bufs[2].offset;
+    rsp_bufs[2].flags  = DSPQUEUE_BUFFER_FLAG_DEREF;  // Release reference
+
+    // We had written to the output buffer, we'd also need to flush it
+    rsp_bufs[3].fd     = bufs[3].fd;
+    rsp_bufs[3].ptr    = bufs[3].ptr;
+    rsp_bufs[3].size   = bufs[3].size;
+    rsp_bufs[3].offset = bufs[3].offset;
+    rsp_bufs[3].flags  = (DSPQUEUE_BUFFER_FLAG_DEREF |                 // Release reference
+                         DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |          // Flush NSP
+                         DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);  // Invalidate CPU
+
+    // Setup Op context
+    struct htp_ops_context octx = { 0 };
+    octx.ctx                    = ctx;
+    octx.src0                   = req->src0;
+    octx.src1                   = req->src1;
+    octx.src2                   = req->src2;
+    octx.dst                    = req->dst;
+    octx.flags                  = req->flags;
+    octx.op                     = req->op;
+
+    // Update data pointers
+    octx.src0.data = (uint32_t) bufs[0].ptr;
+    octx.src1.data = (uint32_t) bufs[1].ptr;
+    octx.src2.data = (uint32_t) bufs[2].ptr;
+    octx.dst.data  = (uint32_t) bufs[3].ptr;
+    octx.n_threads = ctx->n_threads;
+
+    struct profile_data prof;
+    profile_start(&prof);
+
+    uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
+    if (vtcm_acquire(ctx) == AEE_SUCCESS) {
+        rsp_status = op_matmul_id(&octx);
+        vtcm_release(ctx);
+    }
+
+    profile_stop(&prof);
+    send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 4, &prof);
+}
+
+static void proc_binary_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
+    struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
+    memset(rsp_bufs, 0, sizeof(rsp_bufs));
+
+    rsp_bufs[0].fd     = bufs[0].fd;
+    rsp_bufs[0].ptr    = bufs[0].ptr;
+    rsp_bufs[0].offset = bufs[0].offset;
+    rsp_bufs[0].size   = bufs[0].size;
+    rsp_bufs[0].flags  = DSPQUEUE_BUFFER_FLAG_DEREF;  // Release reference
+
+    rsp_bufs[1].fd     = bufs[1].fd;
+    rsp_bufs[1].ptr    = bufs[1].ptr;
+    rsp_bufs[1].offset = bufs[1].offset;
+    rsp_bufs[1].size   = bufs[1].size;
+    rsp_bufs[1].flags  = DSPQUEUE_BUFFER_FLAG_DEREF;  // Release reference
+
+    // We had written to the output buffer, we'd also need to flush it
+    rsp_bufs[2].fd     = bufs[2].fd;
+    rsp_bufs[2].ptr    = bufs[2].ptr;
+    rsp_bufs[2].offset = bufs[2].offset;
+    rsp_bufs[2].size   = bufs[2].size;
+    rsp_bufs[2].flags  = (DSPQUEUE_BUFFER_FLAG_DEREF |                 // Release reference
+                         DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |          // Flush NSP
+                         DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);  // Invalidate CPU
+
+    // Setup Op context
+    struct htp_ops_context octx = { 0 };
+    octx.ctx                    = ctx;
+    octx.src0                   = req->src0;
+    octx.src1                   = req->src1;
+    octx.dst                    = req->dst;
+    octx.flags                  = req->flags;
+    octx.op                     = req->op;
+
+    // Update data pointers
+    octx.src0.data = (uint32_t) bufs[0].ptr;
+    octx.src1.data = (uint32_t) bufs[1].ptr;
+    octx.dst.data  = (uint32_t) bufs[2].ptr;
+    octx.n_threads = ctx->n_threads;
+
+    struct profile_data prof;
+    profile_start(&prof);
+
+    uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
+    if (vtcm_acquire(ctx) == AEE_SUCCESS) {
+        rsp_status = op_binary(&octx);
+        vtcm_release(ctx);
+    }
+
+    profile_stop(&prof);
+    send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 3, &prof);
+}
+
+static void proc_add_id_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
+    struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
+    memset(rsp_bufs, 0, sizeof(rsp_bufs));
+
+    rsp_bufs[0].fd     = bufs[0].fd;
+    rsp_bufs[0].ptr    = bufs[0].ptr;
+    rsp_bufs[0].offset = bufs[0].offset;
+    rsp_bufs[0].size   = bufs[0].size;
+    rsp_bufs[0].flags  = DSPQUEUE_BUFFER_FLAG_DEREF;  // Release reference
+
+    rsp_bufs[1].fd     = bufs[1].fd;
+    rsp_bufs[1].ptr    = bufs[1].ptr;
+    rsp_bufs[1].offset = bufs[1].offset;
+    rsp_bufs[1].size   = bufs[1].size;
+    rsp_bufs[1].flags  = DSPQUEUE_BUFFER_FLAG_DEREF;  // Release reference
+
+    rsp_bufs[2].fd     = bufs[2].fd;
+    rsp_bufs[2].ptr    = bufs[2].ptr;
+    rsp_bufs[2].offset = bufs[2].offset;
+    rsp_bufs[2].size   = bufs[2].size;
+    rsp_bufs[2].flags  = DSPQUEUE_BUFFER_FLAG_DEREF;  // Release reference
+
+    // We had written to the output buffer, we'd also need to flush it
+    rsp_bufs[3].fd     = bufs[3].fd;
+    rsp_bufs[3].ptr    = bufs[3].ptr;
+    rsp_bufs[3].offset = bufs[3].offset;
+    rsp_bufs[3].size   = bufs[3].size;
+    rsp_bufs[3].flags  = (DSPQUEUE_BUFFER_FLAG_DEREF |                 // Release reference
+                         DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |          // Flush NSP
+                         DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);  // Invalidate CPU
+
+    // Setup Op context
+    struct htp_ops_context octx = { 0 };
+    octx.ctx                    = ctx;
+    octx.src0                   = req->src0;
+    octx.src1                   = req->src1;
+    octx.src2                   = req->src2;
+    octx.dst                    = req->dst;
+    octx.flags                  = req->flags;
+    octx.op                     = req->op;
+
+    // Update data pointers
+    octx.src0.data = (uint32_t) bufs[0].ptr;
+    octx.src1.data = (uint32_t) bufs[1].ptr;
+    octx.src2.data = (uint32_t) bufs[2].ptr;
+    octx.dst.data  = (uint32_t) bufs[3].ptr;
+    octx.n_threads = ctx->n_threads;
+
+    struct profile_data prof;
+    profile_start(&prof);
+
+    uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
+    if (vtcm_acquire(ctx) == AEE_SUCCESS) {
+        rsp_status = op_binary(&octx);
+        vtcm_release(ctx);
+    }
+
+    profile_stop(&prof);
+    send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 4, &prof);
+}
+
+static void proc_unary_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
+    struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
+    memset(rsp_bufs, 0, sizeof(rsp_bufs));
+
+    rsp_bufs[0].fd     = bufs[0].fd;
+    rsp_bufs[0].ptr    = bufs[0].ptr;
+    rsp_bufs[0].offset = bufs[0].offset;
+    rsp_bufs[0].size   = bufs[0].size;
+    rsp_bufs[0].flags  = DSPQUEUE_BUFFER_FLAG_DEREF;  // Release reference
+
+    // We had written to the output buffer, we'd also need to flush it
+    rsp_bufs[1].fd     = bufs[1].fd;
+    rsp_bufs[1].ptr    = bufs[1].ptr;
+    rsp_bufs[1].offset = bufs[1].offset;
+    rsp_bufs[1].size   = bufs[1].size;
+    rsp_bufs[1].flags  = (DSPQUEUE_BUFFER_FLAG_DEREF |                 // Release reference
+                         DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |          // Flush NSP
+                         DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);  // Invalidate CPU
+
+    // Setup Op context
+    struct htp_ops_context octx = { 0 };
+    octx.ctx                    = ctx;
+    octx.src0                   = req->src0;
+    octx.dst                    = req->dst;
+    octx.flags                  = req->flags;
+    octx.op                     = req->op;
+
+    memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
+
+    // Update data pointers
+    octx.src0.data = (uint32_t) bufs[0].ptr;
+    octx.dst.data  = (uint32_t) bufs[1].ptr;
+    octx.n_threads = ctx->n_threads;
+
+    struct profile_data prof;
+    profile_start(&prof);
+
+    uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
+    if (vtcm_acquire(ctx) == AEE_SUCCESS) {
+        rsp_status = op_unary(&octx);
+        vtcm_release(ctx);
+    }
+
+    profile_stop(&prof);
+    send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 2, &prof);
+}
+
+static void proc_activations_req(struct htp_context *     ctx,
+                                 struct htp_general_req * req,
+                                 struct dspqueue_buffer * bufs,
+                                 uint32_t                 n_bufs) {
+    struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
+    memset(rsp_bufs, 0, sizeof(rsp_bufs));
+
+    rsp_bufs[0].fd     = bufs[0].fd;
+    rsp_bufs[0].ptr    = bufs[0].ptr;
+    rsp_bufs[0].offset = bufs[0].offset;
+    rsp_bufs[0].size   = bufs[0].size;
+    rsp_bufs[0].flags  = DSPQUEUE_BUFFER_FLAG_DEREF;  // Release reference
+
+    int write_idx = 1;
+    if (3 == n_bufs) {
+        rsp_bufs[1].fd     = bufs[1].fd;
+        rsp_bufs[1].ptr    = bufs[1].ptr;
+        rsp_bufs[1].offset = bufs[1].offset;
+        rsp_bufs[1].size   = bufs[1].size;
+        rsp_bufs[1].flags  = DSPQUEUE_BUFFER_FLAG_DEREF;  // Release reference
+
+        write_idx = 2;
+    }
+
+    // We had written to the output buffer, we'd also need to flush it
+    rsp_bufs[write_idx].fd     = bufs[write_idx].fd;
+    rsp_bufs[write_idx].ptr    = bufs[write_idx].ptr;
+    rsp_bufs[write_idx].offset = bufs[write_idx].offset;
+    rsp_bufs[write_idx].size   = bufs[write_idx].size;
+    rsp_bufs[write_idx].flags  = (DSPQUEUE_BUFFER_FLAG_DEREF |                 // Release reference
+                                 DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |          // Flush NSP
+                                 DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);  // Invalidate CPU
+
+    // Setup Op context
+    struct htp_ops_context octx = { 0 };
+    octx.ctx                    = ctx;
+    octx.src0                   = req->src0;
+    if (3 == n_bufs) {
+        octx.src1 = req->src1;
+    }
+    octx.dst   = req->dst;
+    octx.flags = req->flags;
+    octx.op    = req->op;
+
+    memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
+
+    // Update data pointers
+    octx.src0.data = (uint32_t) bufs[0].ptr;
+    if (3 == n_bufs) {
+        octx.src1.data = (uint32_t) bufs[1].ptr;
+        octx.dst.data  = (uint32_t) bufs[2].ptr;
+    } else {
+        octx.dst.data = (uint32_t) bufs[1].ptr;
+    }
+    octx.n_threads = ctx->n_threads;
+
+    struct profile_data prof;
+    profile_start(&prof);
+
+    uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
+    if (vtcm_acquire(ctx) == AEE_SUCCESS) {
+        if (octx.op == HTP_OP_SOFTMAX) {
+            rsp_status = op_softmax(&octx);
+        } else {
+            rsp_status = op_activations(&octx);
+        }
+        vtcm_release(ctx);
+    }
+
+    profile_stop(&prof);
+    send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, n_bufs, &prof);
+}
+
+static void proc_rope_req(struct htp_context *     ctx,
+                          struct htp_general_req * req,
+                          struct dspqueue_buffer * bufs,
+                          uint32_t                 n_bufs) {
+    struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
+    memset(rsp_bufs, 0, sizeof(rsp_bufs));
+
+    rsp_bufs[0].fd     = bufs[0].fd;
+    rsp_bufs[0].ptr    = bufs[0].ptr;
+    rsp_bufs[0].offset = bufs[0].offset;
+    rsp_bufs[0].size   = bufs[0].size;
+    rsp_bufs[0].flags  = DSPQUEUE_BUFFER_FLAG_DEREF;  // Release reference
+
+    rsp_bufs[1].fd     = bufs[1].fd;
+    rsp_bufs[1].ptr    = bufs[1].ptr;
+    rsp_bufs[1].offset = bufs[1].offset;
+    rsp_bufs[1].size   = bufs[1].size;
+    rsp_bufs[1].flags  = DSPQUEUE_BUFFER_FLAG_DEREF;  // Release reference
+
+    int write_idx = 2;
+    if (4 == n_bufs) {
+        rsp_bufs[write_idx].fd     = bufs[write_idx].fd;
+        rsp_bufs[write_idx].ptr    = bufs[write_idx].ptr;
+        rsp_bufs[write_idx].offset = bufs[write_idx].offset;
+        rsp_bufs[write_idx].size   = bufs[write_idx].size;
+        rsp_bufs[write_idx].flags  = DSPQUEUE_BUFFER_FLAG_DEREF;  // Release reference
+
+        write_idx++;
+    }
+
+    // We had written to the output buffer, we'd also need to flush it
+    rsp_bufs[write_idx].fd     = bufs[write_idx].fd;
+    rsp_bufs[write_idx].ptr    = bufs[write_idx].ptr;
+    rsp_bufs[write_idx].offset = bufs[write_idx].offset;
+    rsp_bufs[write_idx].size   = bufs[write_idx].size;
+    rsp_bufs[write_idx].flags  = (DSPQUEUE_BUFFER_FLAG_DEREF |                 // Release reference
+                                 DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |          // Flush NSP
+                                 DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);  // Invalidate CPU
+
+    // Setup Op context
+    struct htp_ops_context octx = { 0 };
+    octx.ctx                    = ctx;
+    octx.src0                   = req->src0;
+    octx.src1                   = req->src1;
+    if (4 == n_bufs) {
+        octx.src2 = req->src2;
+    }
+    octx.dst   = req->dst;
+    octx.flags = req->flags;
+    octx.op    = req->op;
+
+    memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
+
+    // Update data pointers
+    octx.src0.data = (uint32_t) bufs[0].ptr;
+    octx.src1.data = (uint32_t) bufs[1].ptr;
+    if (4 == n_bufs) {
+        octx.src2.data = (uint32_t) bufs[2].ptr;
+        octx.dst.data  = (uint32_t) bufs[3].ptr;
+    } else {
+        octx.dst.data = (uint32_t) bufs[2].ptr;
+    }
+    octx.n_threads = ctx->n_threads;
+
+    struct profile_data prof;
+    profile_start(&prof);
+
+    uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
+    if (vtcm_acquire(ctx) == AEE_SUCCESS) {
+        rsp_status = op_rope(&octx);
+        vtcm_release(ctx);
+    }
+
+    profile_stop(&prof);
+    send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, n_bufs, &prof);
+}
+
+static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
+    struct htp_context * ctx = (struct htp_context *) context;
+
+    // Repeatedly read packets from the queue until it's empty. We don't
+    // necessarily get a separate callback for each packet, and new packets
+    // may arrive while we're processing the previous one. This ensures we
+    // keep the DSP busy as much as possible and avoid waiting for the CPU.
+
+    while (1) {
+        struct htp_general_req req;
+        uint32_t               req_size;
+
+        struct dspqueue_buffer bufs[HTP_MAX_PACKET_BUFFERS];
+        uint32_t               n_bufs;
+        uint32_t               flags;
+
+        // Read packet from queue
+        int err = dspqueue_read_noblock(queue, &flags,
+                                        HTP_MAX_PACKET_BUFFERS,  // Maximum number of buffer references
+                                        &n_bufs,                 // Number of buffer references
+                                        bufs,                    // Buffer references
+                                        sizeof(req),             // Max message length
+                                        &req_size,               // Message length
+                                        (uint8_t *) &req);       // Message
+
+        if (err == AEE_EWOULDBLOCK) {
+            // Consumed all packets available for now
+            return;
+        }
+
+        if (err != 0) {
+            FARF(ERROR, "dspqueue_read_noblock failed: 0x%08x", (unsigned) err);
+            return;
+        }
+
+        if (req_size != sizeof(req)) {
+            FARF(ERROR, "Invalid request size");
+            continue;
+        }
+
+        if (req.flags & HTP_OPFLAGS_EARLY_WAKEUP) {
+            // Host wants early notification
+            dspqueue_write_early_wakeup_noblock(ctx->queue, 10, 0);
+        }
+
+        // Process packet based on its message type
+        switch (req.op) {
+            case HTP_OP_MUL_MAT:
+                if (n_bufs != 3) {
+                    FARF(ERROR, "Bad matmul-req buffer list");
+                    continue;
+                }
+                proc_matmul_req(ctx, &req, bufs, n_bufs);
+                break;
+
+            case HTP_OP_MUL_MAT_ID:
+                if (n_bufs != 4) {
+                    FARF(ERROR, "Bad matmul-id-req buffer list");
+                    continue;
+                }
+                proc_matmul_id_req(ctx, &req, bufs, n_bufs);
+                break;
+
+            case HTP_OP_MUL:
+            case HTP_OP_ADD:
+            case HTP_OP_SUB:
+                if (n_bufs != 3) {
+                    FARF(ERROR, "Bad binary-req buffer list");
+                    continue;
+                }
+                proc_binary_req(ctx, &req, bufs);
+                break;
+
+            case HTP_OP_RMS_NORM:
+                if (n_bufs != 2) {
+                    FARF(ERROR, "Bad unary-req buffer list");
+                    continue;
+                }
+
+                proc_unary_req(ctx, &req, bufs);
+                break;
+
+            case HTP_OP_UNARY_SILU:
+                if (n_bufs != 2) {
+                    FARF(ERROR, "Bad act-req buffer list");
+                    continue;
+                }
+                proc_activations_req(ctx, &req, bufs, n_bufs);
+                break;
+
+            case HTP_OP_GLU_SWIGLU:
+            case HTP_OP_SOFTMAX:
+                if ((n_bufs != 2) && (n_bufs != 3)) {
+                    FARF(ERROR, "Bad act-req buffer list");
+                    continue;
+                }
+                proc_activations_req(ctx, &req, bufs, n_bufs);
+                break;
+
+            case HTP_OP_ADD_ID:
+                if (n_bufs != 4) {
+                    FARF(ERROR, "Bad add-id-req buffer list");
+                    continue;
+                }
+                proc_add_id_req(ctx, &req, bufs);
+                break;
+
+            case HTP_OP_ROPE:
+                if ((n_bufs != 3) && (n_bufs != 4)) {
+                    FARF(ERROR, "Bad rope-req buffer list");
+                    continue;
+                }
+                proc_rope_req(ctx, &req, bufs, n_bufs);
+                break;
+
+            default:
+                FARF(ERROR, "Unknown Op %u", req.op);
+                break;
+        }
+    }
+}
diff --git a/src/ggml-hexagon/htp/matmul-ops.c b/src/ggml-hexagon/htp/matmul-ops.c
new file mode 100644 (file)
index 0000000..c99b6a0
--- /dev/null
@@ -0,0 +1,2223 @@
+#pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments"
+#pragma clang diagnostic ignored "-Wunused-function"
+#pragma clang diagnostic ignored "-Wunused-variable"
+#pragma clang diagnostic ignored "-Wunused-but-set-variable"
+
+#ifdef HTP_DEBUG
+#    define FARF_HIGH 1
+#endif
+
+#include <HAP_farf.h>
+#include <HAP_mem.h>
+#include <HAP_perf.h>
+#include <HAP_ps.h>
+#include <hexagon_protos.h>
+#include <hexagon_types.h>
+#include <math.h>
+#include <qurt_thread.h>
+#include <string.h>
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+#include "htp-ctx.h"
+#include "htp-dma.h"
+#include "htp-msg.h"
+#include "htp-ops.h"
+#include "hvx-utils.h"
+#include "ops-utils.h"
+
+struct htp_matmul_type {
+    const char * type;
+    void (*vec_dot)(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+    void (*vec_dot_rx2)(const int n,
+                        float * restrict s,
+                        const void * restrict vx,
+                        uint32_t vx_row_size,
+                        const void * restrict vy);
+};
+
+typedef struct {
+    HVX_Vector v[2];
+} HVX_Vector_x2;
+
+typedef struct {
+    HVX_Vector v[4];
+} HVX_Vector_x4;
+
+typedef struct {
+    HVX_Vector v[8];
+} HVX_Vector_x8;
+
+// vdelta control to replicate first 4x fp32 values across lanes
+static const uint8_t __attribute__((aligned(128))) repl_4x_fp32[128] = {
+    0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10,
+    0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20,
+    0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04,
+    0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40,
+    0x44, 0x44, 0x44, 0x44, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04,
+    0x04, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
+    0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10,
+};
+
+// vdelta control to replicate and interleave first 8x fp32 values across lanes
+static const uint8_t __attribute__((aligned(128))) repl_interleave_8x_fp32[128] = {
+    0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x00, 0x00, 0x00,
+    0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20,
+    0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04,
+    0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40,
+    0x44, 0x44, 0x44, 0x44, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, 0x44, 0x44, 0x44,
+    0x44, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
+    0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20,
+};
+
+// vdelta control to replicate first fp32 value across all elements
+static const uint8_t __attribute__((aligned(128))) repl_1x_fp32[128] = {
+    0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10,
+    0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
+    0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08,
+    0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08,
+    0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04,
+    0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10,
+    0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
+};
+
+// vdelta control to replicate first fp16 value across all elements
+static const uint8_t __attribute__((aligned(128))) repl_1x_fp16[128] = {
+    0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02,
+    0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04,
+    0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08,
+    0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x40, 0x40, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02,
+    0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02,
+    0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10,
+    0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
+};
+
+// vdelta control to expand first 32 e8m0 values into 32 uint32 elements
+static const uint8_t __attribute__((aligned(128))) expand_x32_e8m0[128] = {
+    0x00, 0x00, 0x00, 0x00, 0x01, 0x04, 0x00, 0x00, 0x02, 0x00, 0x08, 0x08, 0x01, 0x02, 0x00, 0x04, 0x04, 0x00, 0x00,
+    0x00, 0x11, 0x10, 0x10, 0x10, 0x02, 0x00, 0x04, 0x00, 0x01, 0x02, 0x08, 0x08, 0x08, 0x08, 0x00, 0x00, 0x01, 0x04,
+    0x00, 0x00, 0x22, 0x20, 0x20, 0x20, 0x21, 0x22, 0x20, 0x24, 0x04, 0x00, 0x00, 0x00, 0x09, 0x08, 0x00, 0x00, 0x02,
+    0x00, 0x04, 0x00, 0x11, 0x12, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x01, 0x04, 0x00, 0x00, 0x02, 0x00, 0x08, 0x08,
+    0x01, 0x02, 0x00, 0x04, 0x44, 0x40, 0x40, 0x40, 0x41, 0x40, 0x40, 0x40, 0x42, 0x40, 0x44, 0x40, 0x41, 0x42, 0x48,
+    0x48, 0x08, 0x08, 0x00, 0x00, 0x01, 0x04, 0x00, 0x00, 0x12, 0x10, 0x10, 0x10, 0x01, 0x02, 0x00, 0x04, 0x04, 0x00,
+    0x00, 0x00, 0x09, 0x08, 0x00, 0x00, 0x22, 0x20, 0x24, 0x20, 0x21, 0x22, 0x20, 0x20,
+};
+
+static const uint8_t __attribute__((aligned(VLEN))) kvalues_mxfp4_lut[] = {
+    0,    0, 1,    0, 2,    0, 3, 0, 4, 0, 6, 0, 8, 0, 12, 0, 0, 0, 0xff, 0, 0xfe, 0, 0xfd, 0, 0xfc, 0,
+    0xfa, 0, 0xf8, 0, 0xf4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  0, 0, 0, 0,    0, 0,    0, 0,    0, 0,    0,
+    0,    0, 0,    0, 0,    0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  0, 0, 0, 0,    0, 0,    0, 0,    0, 0,    0,
+    0,    0, 0,    0, 0,    0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  0, 0, 0, 0,    0, 0,    0, 0,    0, 0,    0,
+    0,    0, 0,    0, 0,    0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  0, 0, 0, 0,    0, 0,    0, 0,    0,
+};
+
+// q4x4x2 and q8x4x2 are the flat q4/8_0 formats where all quants are stored first followed by all scales
+
+static inline size_t q8x4x2_row_size(uint32_t ne) {
+    // ensures perfect alignment of quants and full row
+    const uint32_t qk = QK_Q8_0x4x2;
+    const uint32_t nb = (ne + qk - 1) / qk;
+    return htp_round_up(ne + nb * 8 * sizeof(__fp16), 128);
+}
+
+static inline HVX_Vector_x8 hvx_vec_load_q4x4x8(const uint8_t * restrict ptr) {
+    const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
+
+    HVX_Vector v0_1 = vptr[0];  // first 256 elements (128 bytes)
+    HVX_Vector v2_3 = vptr[1];  // ...
+    HVX_Vector v4_5 = vptr[2];  // ...
+    HVX_Vector v6_7 = vptr[3];  // ...
+
+    const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
+
+    HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4);  // & 0x0F
+    HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4);    // >> 4
+    HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4);  // & 0x0F
+    HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4);    // >> 4
+    HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4);  // & 0x0F
+    HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4);    // >> 4
+    HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4);  // & 0x0F
+    HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4);    // >> 4
+
+    // Convert uint4 to int4 (i.e. x - 8)
+    const HVX_Vector i8 = Q6_Vb_vsplat_R(8);
+    v0                  = Q6_Vb_vsub_VbVb(v0, i8);
+    v1                  = Q6_Vb_vsub_VbVb(v1, i8);
+    v2                  = Q6_Vb_vsub_VbVb(v2, i8);
+    v3                  = Q6_Vb_vsub_VbVb(v3, i8);
+    v4                  = Q6_Vb_vsub_VbVb(v4, i8);
+    v5                  = Q6_Vb_vsub_VbVb(v5, i8);
+    v6                  = Q6_Vb_vsub_VbVb(v6, i8);
+    v7                  = Q6_Vb_vsub_VbVb(v7, i8);
+
+    HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 };
+    return r;
+}
+
+static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8(const uint8_t * restrict ptr) {
+    const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
+
+    HVX_Vector v0_1 = vptr[0];  // first 256 elements (128 bytes)
+    HVX_Vector v2_3 = vptr[1];  // ...
+    HVX_Vector v4_5 = vptr[2];  // ...
+    HVX_Vector v6_7 = vptr[3];  // ...
+
+    const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
+
+    HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4);  // & 0x0F
+    HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4);    // >> 4
+    HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4);  // & 0x0F
+    HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4);    // >> 4
+    HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4);  // & 0x0F
+    HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4);    // >> 4
+    HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4);  // & 0x0F
+    HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4);    // >> 4
+
+    HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut;
+    v0             = Q6_Vb_vlut32_VbVbI(v0, lut, 0);
+    v1             = Q6_Vb_vlut32_VbVbI(v1, lut, 0);
+    v2             = Q6_Vb_vlut32_VbVbI(v2, lut, 0);
+    v3             = Q6_Vb_vlut32_VbVbI(v3, lut, 0);
+    v4             = Q6_Vb_vlut32_VbVbI(v4, lut, 0);
+    v5             = Q6_Vb_vlut32_VbVbI(v5, lut, 0);
+    v6             = Q6_Vb_vlut32_VbVbI(v6, lut, 0);
+    v7             = Q6_Vb_vlut32_VbVbI(v7, lut, 0);
+
+    HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 };
+    return r;
+}
+
+static inline HVX_Vector_x8 hvx_vec_load_q8x4x8(const uint8_t * restrict ptr) {
+    const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
+
+    HVX_Vector v0 = vptr[0];  // first  128 vals
+    HVX_Vector v1 = vptr[1];  // ...
+    HVX_Vector v2 = vptr[2];  // ...
+    HVX_Vector v3 = vptr[3];  // ...
+    HVX_Vector v4 = vptr[4];  // ...
+    HVX_Vector v5 = vptr[5];  // ...
+    HVX_Vector v6 = vptr[6];  // ...
+    HVX_Vector v7 = vptr[7];  // ...
+
+    HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 };
+    return r;
+}
+
+static inline HVX_Vector_x4 hvx_vec_load_x4_f16(const uint8_t * restrict ptr) {
+    const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
+
+    HVX_Vector v0 = vptr[0];  // first  64 vals
+    HVX_Vector v1 = vptr[1];  // second 64 vals
+    HVX_Vector v2 = vptr[2];  // third  64 vals
+    HVX_Vector v3 = vptr[3];  // forth  64 vals
+
+    HVX_Vector_x4 r = { v0, v1, v2, v3 };
+    return r;
+}
+
+static inline HVX_Vector_x4 hvx_vec_load_x4_f32_as_f16(const uint8_t * restrict ptr) {
+    const HVX_VectorPair * restrict vptr = (const HVX_VectorPair *) ptr;
+
+    HVX_VectorPair v0 = vptr[0];  // first  64 vals
+    HVX_VectorPair v1 = vptr[1];  // second 64 vals
+    HVX_VectorPair v2 = vptr[2];  // third  64 vals
+    HVX_VectorPair v3 = vptr[3];  // forth  64 vals
+
+    HVX_Vector vq0_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v0), Q6_V_vzero());
+    HVX_Vector vq0_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v0), Q6_V_vzero());
+    HVX_Vector vq1_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v1), Q6_V_vzero());
+    HVX_Vector vq1_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v1), Q6_V_vzero());
+    HVX_Vector vq2_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v2), Q6_V_vzero());
+    HVX_Vector vq2_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v2), Q6_V_vzero());
+    HVX_Vector vq3_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v3), Q6_V_vzero());
+    HVX_Vector vq3_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v3), Q6_V_vzero());
+
+    HVX_Vector vh0 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq0_hi, vq0_lo));
+    HVX_Vector vh1 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq1_hi, vq1_lo));
+    HVX_Vector vh2 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq2_hi, vq2_lo));
+    HVX_Vector vh3 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq3_hi, vq3_lo));
+
+    // vcombine does a shuffle, use vdeal to undo
+
+    HVX_Vector_x4 r = { Q6_Vh_vdeal_Vh(vh0), Q6_Vh_vdeal_Vh(vh1), Q6_Vh_vdeal_Vh(vh2), Q6_Vh_vdeal_Vh(vh3) };
+    return r;
+}
+
+// Reduce multiply 1024 x 1024 int8 elements (32x q4/8 blocks in 8x HVX vectors).
+// Accumulate each block into a single int32 value.
+// Return a single HVX vector with 32x int32 accumulators.
+// This version is parameterized to support less than 1024 elements.
+// if() checks are optimized out at compile time -- make sure to pass N as a constexpr.
+
+static inline HVX_Vector hvx_vec_rmpy_x8_n(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) {
+    HVX_Vector r0 = Q6_V_vsplat_R(0);
+    HVX_Vector r1 = Q6_V_vsplat_R(0);
+    HVX_Vector r2 = Q6_V_vsplat_R(0);
+    HVX_Vector r3 = Q6_V_vsplat_R(0);
+    HVX_Vector r4 = Q6_V_vsplat_R(0);
+    HVX_Vector r5 = Q6_V_vsplat_R(0);
+    HVX_Vector r6 = Q6_V_vsplat_R(0);
+    HVX_Vector r7 = Q6_V_vsplat_R(0);
+
+    HVX_VectorPair p3;
+    HVX_VectorPair p2;
+    HVX_VectorPair p1;
+    HVX_VectorPair p0;
+
+    if (n >=  128) { r0 = Q6_Vw_vrmpy_VbVb(x.v[0], y.v[0]); }
+    if (n >=  256) { r1 = Q6_Vw_vrmpy_VbVb(x.v[1], y.v[1]); }
+    if (n >=  384) { r2 = Q6_Vw_vrmpy_VbVb(x.v[2], y.v[2]); }
+    if (n >=  512) { r3 = Q6_Vw_vrmpy_VbVb(x.v[3], y.v[3]); }
+    if (n >=  640) { r4 = Q6_Vw_vrmpy_VbVb(x.v[4], y.v[4]); }
+    if (n >=  768) { r5 = Q6_Vw_vrmpy_VbVb(x.v[5], y.v[5]); }
+    if (n >=  896) { r6 = Q6_Vw_vrmpy_VbVb(x.v[6], y.v[6]); }
+    if (n >= 1024) { r7 = Q6_Vw_vrmpy_VbVb(x.v[7], y.v[7]); }
+
+    if (n >=  128) { p0 = Q6_W_vdeal_VVR(r1, r0, -4); }
+    if (n >=  384) { p1 = Q6_W_vdeal_VVR(r3, r2, -4); }
+    if (n >=  640) { p2 = Q6_W_vdeal_VVR(r5, r4, -4); }
+    if (n >=  896) { p3 = Q6_W_vdeal_VVR(r7, r6, -4); }
+
+    if (n >=  128) { r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); }
+    if (n >=  384) { r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1)); }
+    if (n >=  640) { r2 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p2), Q6_V_hi_W(p2)); }
+    if (n >=  896) { r3 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p3), Q6_V_hi_W(p3)); }
+
+    if (n >=  128) { p0 = Q6_W_vdeal_VVR(r1, r0, -4); }
+    if (n >=  640) { p1 = Q6_W_vdeal_VVR(r3, r2, -4); }
+
+    if (n >=  128) { r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); }
+    if (n >=  640) { r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1)); }
+
+    if (n >=  128) { p0 = Q6_W_vdeal_VVR(r1, r0, -4); }
+    if (n >=  128) { r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); }
+
+    return r0;
+}
+
+static inline HVX_Vector hvx_vec_rmpy_x8_full(HVX_Vector_x8 x, HVX_Vector_x8 y) {
+    return hvx_vec_rmpy_x8_n(x, y, 1024);
+}
+
+// Handle most common cases of tensors not multiple of 1024.
+static inline HVX_Vector hvx_vec_rmpy_x8_nloe(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) {
+    if (n <= 256) { return hvx_vec_rmpy_x8_n(x, y, 256); };
+    if (n <= 512) { return hvx_vec_rmpy_x8_n(x, y, 512); };
+    if (n <= 768) { return hvx_vec_rmpy_x8_n(x, y, 768); };
+    return hvx_vec_rmpy_x8_n(x, y, 1024);
+}
+
+static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    assert(n % 32 == 0);  // min sub-block size
+    assert((unsigned long) vx % 128 == 0);
+    assert((unsigned long) vy % 128 == 0);
+
+    const uint32_t qk = QK_Q4_0x4x2 * 4;
+
+    const uint32_t x_dblk_size = 8 * 4 * 2;                                  // 32x __fp16
+    const uint32_t x_qblk_size = qk / 2;                                     // int4
+    const uint32_t x_qrow_size = n / 2;                                      // int4 (not padded)
+
+    const uint32_t y_dblk_size = 8 * 4 * 2;                                  // 32x __fp16
+    const uint32_t y_qblk_size = qk;                                         // int8
+    const uint32_t y_qrow_size = n;                                          // int8 (not padded)
+
+    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx + 0);            // quants first
+    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx + x_qrow_size);  // then scales
+
+    const uint8_t * restrict y_q = ((const uint8_t *) vy + 0);               // quants first
+    const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size);     // then scales
+
+    // Row sum (qf32)
+    HVX_Vector r0_sum = Q6_V_vsplat_R(0);
+
+    // Multiply and accumulate into int32.
+    // Compute combined scale (fp32).
+    // Apply scale to acc and accumulate into the row sum (qf32).
+
+    const uint32_t nb   = n / qk;  // num full blocks
+    const uint32_t nloe = n % qk;  // num leftover elemements
+
+    uint32_t i = 0;
+    for (; i < nb; i++) {
+        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
+        HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
+
+        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
+
+        HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
+        HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
+
+        HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
+
+        HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
+
+        r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
+    }
+
+    // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
+    if (nloe) {
+        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
+        HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
+
+        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
+
+        HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
+        HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
+
+        HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
+
+        // Zero out unused scales
+        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
+        r0_dd                = Q6_V_vand_QV(bmask, r0_dd);
+
+        HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
+
+        r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
+    }
+
+    // Reduce and convert into fp32
+    r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
+
+    hvx_vec_store_u(&s[0], 4, r0_sum);
+}
+
+static void vec_dot_q4x4x2_q8x4x2_rx2(const int n,
+                                      float * restrict s,
+                                      const void * restrict vx,
+                                      uint32_t vx_row_size,
+                                      const void * restrict vy) {
+    assert(n % 32 == 0);  // min sub-block size
+    assert((unsigned long) vx % 128 == 0);
+    assert((unsigned long) vy % 128 == 0);
+
+    const uint32_t qk = QK_Q4_0x4x2 * 4;
+
+    const uint32_t x_dblk_size = 8 * 4 * 2;                                                        // 32x __fp16
+    const uint32_t x_qblk_size = qk / 2;                                                           // int4
+    const uint32_t x_qrow_size = n / 2;                                                            // int4 (not padded)
+
+    const uint32_t y_dblk_size = 8 * 4 * 2;                                                        // 32x __fp16
+    const uint32_t y_qblk_size = qk;                                                               // int8
+    const uint32_t y_qrow_size = n;                                                                // int8 (not padded)
+
+    const uint8_t * restrict r0_x_q = ((const uint8_t *) (vx + (0 * vx_row_size)) + 0);            // quants first
+    const uint8_t * restrict r0_x_d = ((const uint8_t *) (vx + (0 * vx_row_size)) + x_qrow_size);  // then scales
+
+    const uint8_t * restrict r1_x_q = ((const uint8_t *) (vx + (1 * vx_row_size)) + 0);            // quants first
+    const uint8_t * restrict r1_x_d = ((const uint8_t *) (vx + (1 * vx_row_size)) + x_qrow_size);  // then scales
+
+    const uint8_t * restrict y_q = ((const uint8_t *) vy + 0);                                     // quants first
+    const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size);                           // then scales
+
+    // Row sum (qf32)
+    HVX_Vector r0_sum = Q6_V_vsplat_R(0);
+    HVX_Vector r1_sum = Q6_V_vsplat_R(0);
+
+    // Multiply and accumulate into int32.
+    // Compute combined scale (fp32).
+    // Apply scale to acc and accumulate into the row sum (qf32).
+
+    const uint32_t nb   = n / qk;  // num full blocks
+    const uint32_t nloe = n % qk;  // num leftover elemements
+
+    uint32_t i = 0;
+    for (; i < nb; i++) {
+        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
+        HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
+        HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size);
+
+        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
+        HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
+
+        HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
+        HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
+        HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
+
+        HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
+        HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
+
+        HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
+        HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
+
+        r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
+        r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
+    }
+
+    // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
+    if (nloe) {
+        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
+        HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
+        HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size);
+
+        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
+        HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy_q, nloe));
+
+        HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
+        HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
+        HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
+
+        HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
+        HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
+
+        // Zero out unused scales
+        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
+        r0_dd                = Q6_V_vand_QV(bmask, r0_dd);
+        r1_dd                = Q6_V_vand_QV(bmask, r1_dd);
+
+        HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
+        HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
+
+        r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
+        r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
+    }
+
+    // Convert into fp32 and reduce
+    r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
+    r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum));
+    HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4);
+
+    hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
+}
+
+static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    assert(n % 32 == 0);  // min sub-block size
+    assert((unsigned long) vx % 128 == 0);
+    assert((unsigned long) vy % 128 == 0);
+
+    const uint32_t qk = QK_Q4_0x4x2 * 4;
+
+    const uint32_t x_dblk_size = 8 * 4 * 2;                                  // 32x __fp16
+    const uint32_t x_qblk_size = qk;                                         // int8
+    const uint32_t x_qrow_size = n;                                          // int8 (not padded)
+
+    const uint32_t y_dblk_size = 8 * 4 * 2;                                  // 32x __fp16
+    const uint32_t y_qblk_size = qk;                                         // int8
+    const uint32_t y_qrow_size = n;                                          // int8 (not padded)
+
+    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx + 0);            // quants first
+    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx + x_qrow_size);  // then scales
+
+    const uint8_t * restrict y_q = ((const uint8_t *) vy + 0);               // quants first
+    const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size);     // then scales
+
+    // Row sum (qf32)
+    HVX_Vector r0_sum = Q6_V_vsplat_R(0);
+
+    // Multiply and accumulate into int32.
+    // Compute combined scale (fp32).
+    // Apply scale to acc and accumulate into the row sum (qf32).
+
+    const uint32_t nb   = n / qk;  // num full blocks
+    int32_t        nloe = n % qk;  // num leftover elemements (must be signed)
+
+    uint32_t i = 0;
+    for (; i < nb; i++) {
+        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
+        HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
+
+        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
+
+        HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
+        HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
+
+        HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
+
+        HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
+
+        r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
+    }
+
+    // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
+    if (nloe) {
+        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
+        HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
+
+        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
+
+        HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
+        HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
+
+        HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
+
+        // Zero out unused scales
+        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
+        r0_dd                = Q6_V_vand_QV(bmask, r0_dd);
+
+        HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
+
+        r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
+    }
+
+    // Reduce and convert into fp32
+    r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
+
+    hvx_vec_store_u(&s[0], 4, r0_sum);
+}
+
+static void vec_dot_q8x4x2_q8x4x2_rx2(const int n,
+                                      float * restrict s,
+                                      const void * restrict vx,
+                                      uint32_t vx_row_size,
+                                      const void * restrict vy) {
+    assert(n % 32 == 0);  // min sub-block size
+    assert((unsigned long) vx % 128 == 0);
+    assert((unsigned long) vy % 128 == 0);
+
+    const uint32_t qk = QK_Q4_0x4x2 * 4;
+
+    const uint32_t x_dblk_size = 8 * 4 * 2;                                                        // 32x __fp16
+    const uint32_t x_qblk_size = qk;                                                               // int8
+    const uint32_t x_qrow_size = n;                                                                // int8 (not padded)
+
+    const uint32_t y_dblk_size = 8 * 4 * 2;                                                        // 32x __fp16
+    const uint32_t y_qblk_size = qk;                                                               // int8
+    const uint32_t y_qrow_size = n;                                                                // int8 (not padded)
+
+    const uint8_t * restrict r0_x_q = ((const uint8_t *) (vx + (0 * vx_row_size)) + 0);            // quants first
+    const uint8_t * restrict r0_x_d = ((const uint8_t *) (vx + (0 * vx_row_size)) + x_qrow_size);  // then scales
+
+    const uint8_t * restrict r1_x_q = ((const uint8_t *) (vx + (1 * vx_row_size)) + 0);            // quants first
+    const uint8_t * restrict r1_x_d = ((const uint8_t *) (vx + (1 * vx_row_size)) + x_qrow_size);  // then scales
+
+    const uint8_t * restrict y_q = ((const uint8_t *) vy + 0);                                     // quants first
+    const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size);                           // then scales
+
+    // Row sum (qf32)
+    HVX_Vector r0_sum = Q6_V_vsplat_R(0);
+    HVX_Vector r1_sum = Q6_V_vsplat_R(0);
+
+    // Multiply and accumulate into int32.
+    // Compute combined scale (fp32).
+    // Apply scale to acc and accumulate into the row sum (qf32).
+
+    const uint32_t nb   = n / qk;  // num full blocks
+    int32_t        nloe = n % qk;  // num leftover elemements (must be signed)
+
+    uint32_t i = 0;
+    for (; i < nb; i++) {
+        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
+        HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
+        HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size);
+
+        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
+        HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
+
+        HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
+        HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
+        HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
+
+        HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
+        HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
+
+        HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
+        HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
+
+        r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
+        r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
+    }
+
+    // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
+    if (nloe) {
+        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
+        HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
+        HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size);
+
+        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
+        HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy_q, nloe));
+
+        HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
+        HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
+        HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
+
+        HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
+        HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
+
+        // Zero out unused scales
+        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
+        r0_dd                = Q6_V_vand_QV(bmask, r0_dd);
+        r1_dd                = Q6_V_vand_QV(bmask, r1_dd);
+
+        HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
+        HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
+
+        r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
+        r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
+    }
+
+    // Convert into fp32 and reduce
+    r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
+    r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum));
+    HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4);
+
+    hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
+}
+
+static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
+                                     float * restrict s,
+                                     const void * restrict vx,
+                                     const void * restrict vy) {
+    assert(n % 32 == 0);  // min sub-block size
+    assert((unsigned long) vx % 128 == 0);
+    assert((unsigned long) vy % 128 == 0);
+
+    const uint32_t qk = QK_MXFP4x4x2 * 4;
+
+    const uint32_t x_dblk_size = 8 * 4 * 1;                                  // 32x e8m0
+    const uint32_t x_qblk_size = qk / 2;                                     // fp4
+    const uint32_t x_qrow_size = n / 2;                                      // fp4 (not padded)
+
+    const uint32_t y_dblk_size = 8 * 4 * 2;                                  // 32x __fp16
+    const uint32_t y_qblk_size = qk;                                         // int8
+    const uint32_t y_qrow_size = n;                                          // int8 (not padded)
+
+    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx + 0);            // quants first
+    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx + x_qrow_size);  // then scales
+
+    const uint8_t * restrict y_q = ((const uint8_t *) vy + 0);               // quants first
+    const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size);     // then scales
+
+    // Row sum (qf32)
+    HVX_Vector r0_sum = Q6_V_vsplat_R(0);
+
+    // Multiply and accumulate into int32.
+    // Compute combined scale (fp32).
+    // Apply scale to acc and accumulate into the row sum (qf32).
+
+    const uint32_t nb   = n / qk;  // num full blocks
+    int32_t        nloe = n % qk;  // num leftover elemements (must be signed)
+
+    uint32_t i = 0;
+    for (; i < nb; i++) {
+        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
+        HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
+
+        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
+
+        HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
+        HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
+
+        // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
+        HVX_Vector half = Q6_Vh_vsplat_R(0x3800);  // 0.5 in fp16
+        vy_d            = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half));
+        vy_d            = Q6_Vsf_equals_Vqf32(vy_d);
+
+        // Convert rX_d scales from e8m0 to fp32
+        // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
+        // Left shift with zero fill to create FP32
+        // FIXME: might need to handle zero as a special case (see ggml-cpu code)
+        HVX_Vector expand    = *(const HVX_Vector *) expand_x32_e8m0;
+        HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
+        r0_d                 = Q6_V_vdelta_VV(r0_d, expand);
+        r0_d                 = Q6_V_vand_VV(r0_d, e8m0_mask);
+        r0_d                 = Q6_Vw_vasl_VwR(r0_d, 23);
+
+        HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));
+
+        HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
+
+        r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
+    }
+
+    // Process leftovers
+    if (nloe) {
+        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
+        HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
+
+        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
+
+        HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
+        HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
+
+        // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
+        HVX_Vector half = Q6_Vh_vsplat_R(0x3800);  // 0.5 in fp16
+        vy_d            = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half));
+        vy_d            = Q6_Vsf_equals_Vqf32(vy_d);
+
+        // Convert rX_d scales from e8m0 to fp32
+        // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
+        // Left shift with zero fill to create FP32
+        // FIXME: might need to handle zero as a special case (see ggml-cpu code)
+        HVX_Vector expand    = *(const HVX_Vector *) expand_x32_e8m0;
+        HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
+        r0_d                 = Q6_V_vdelta_VV(r0_d, expand);
+        r0_d                 = Q6_V_vand_VV(r0_d, e8m0_mask);
+        r0_d                 = Q6_Vw_vasl_VwR(r0_d, 23);
+
+        HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));
+
+        // Zero-out unused scales
+        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
+        r0_dd                = Q6_V_vand_QV(bmask, r0_dd);
+
+        HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
+
+        r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
+    }
+
+    // Reduce and convert into fp32
+    r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
+
+    hvx_vec_store_u(&s[0], 4, r0_sum);
+}
+
+static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n,
+                                         float * restrict s,
+                                         const void * restrict vx,
+                                         uint32_t vx_row_size,
+                                         const void * restrict vy) {
+    assert(n % 32 == 0);  // min sub-block size
+    assert((unsigned long) vx % 128 == 0);
+    assert((unsigned long) vy % 128 == 0);
+
+    const uint32_t qk = QK_MXFP4x4x2 * 4;
+
+    const uint32_t x_dblk_size = 8 * 4 * 1;                                                        // 32x e8m0
+    const uint32_t x_qblk_size = qk / 2;                                                           // fp4
+    const uint32_t x_qrow_size = n / 2;                                                            // fp4 (not padded)
+
+    const uint32_t y_dblk_size = 8 * 4 * 2;                                                        // 32x __fp16
+    const uint32_t y_qblk_size = qk;                                                               // int8
+    const uint32_t y_qrow_size = n;                                                                // int8 (not padded)
+
+    const uint8_t * restrict r0_x_q = ((const uint8_t *) (vx + (0 * vx_row_size)) + 0);            // quants first
+    const uint8_t * restrict r0_x_d = ((const uint8_t *) (vx + (0 * vx_row_size)) + x_qrow_size);  // then scales
+
+    const uint8_t * restrict r1_x_q = ((const uint8_t *) (vx + (1 * vx_row_size)) + 0);            // quants first
+    const uint8_t * restrict r1_x_d = ((const uint8_t *) (vx + (1 * vx_row_size)) + x_qrow_size);  // then scales
+
+    const uint8_t * restrict y_q = ((const uint8_t *) vy + 0);                                     // quants first
+    const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size);                           // then scales
+
+    // Row sum (qf32)
+    HVX_Vector r0_sum = Q6_V_vsplat_R(0);
+    HVX_Vector r1_sum = Q6_V_vsplat_R(0);
+
+    // Multiply and accumulate into int32.
+    // Compute combined scale (fp32).
+    // Apply scale to acc and accumulate into the row sum (qf32).
+
+    const uint32_t nb   = n / qk;  // num full blocks
+    int32_t        nloe = n % qk;  // num leftover elemements (must be signed)
+
+    uint32_t i = 0;
+    for (; i < nb; i++) {
+        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
+        HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
+        HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size);
+
+        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
+        HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
+
+        HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
+        HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
+        HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
+
+        // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
+        HVX_Vector half = Q6_Vh_vsplat_R(0x3800);  // 0.5 in fp16
+        vy_d            = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half));
+        vy_d            = Q6_Vsf_equals_Vqf32(vy_d);
+
+        // Convert rX_d scales from e8m0 to fp32
+        // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
+        // Left shift with zero fill to create FP32
+        // FIXME: might need to handle zero as a special case (see ggml-cpu code)
+        HVX_Vector expand    = *(const HVX_Vector *) expand_x32_e8m0;
+        HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
+        r0_d                 = Q6_V_vdelta_VV(r0_d, expand);
+        r0_d                 = Q6_V_vand_VV(r0_d, e8m0_mask);
+        r0_d                 = Q6_Vw_vasl_VwR(r0_d, 23);
+        r1_d                 = Q6_V_vdelta_VV(r1_d, expand);
+        r1_d                 = Q6_V_vand_VV(r1_d, e8m0_mask);
+        r1_d                 = Q6_Vw_vasl_VwR(r1_d, 23);
+
+        HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));
+        HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d));
+
+        HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
+        HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
+
+        r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
+        r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
+    }
+
+    // Process leftovers
+    if (nloe) {
+        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
+        HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
+        HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size);
+
+        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
+        HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
+
+        HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
+        HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
+        HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
+
+        // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
+        HVX_Vector half = Q6_Vh_vsplat_R(0x3800);  // 0.5 in fp16
+        vy_d            = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half));
+        vy_d            = Q6_Vsf_equals_Vqf32(vy_d);
+
+        // Convert rX_d scales from e8m0 to fp32
+        // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
+        // Left shift with zero fill to create FP32
+        // FIXME: might need to handle zero as a special case (see ggml-cpu code)
+        HVX_Vector expand    = *(const HVX_Vector *) expand_x32_e8m0;
+        HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
+        r0_d                 = Q6_V_vdelta_VV(r0_d, expand);
+        r0_d                 = Q6_V_vand_VV(r0_d, e8m0_mask);
+        r0_d                 = Q6_Vw_vasl_VwR(r0_d, 23);
+        r1_d                 = Q6_V_vdelta_VV(r1_d, expand);
+        r1_d                 = Q6_V_vand_VV(r1_d, e8m0_mask);
+        r1_d                 = Q6_Vw_vasl_VwR(r1_d, 23);
+
+        HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));
+        HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d));
+
+        // Zero-out unused scales
+        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
+        r0_dd                = Q6_V_vand_QV(bmask, r0_dd);
+        r1_dd                = Q6_V_vand_QV(bmask, r1_dd);
+
+        HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
+        HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
+
+        r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
+        r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
+    }
+
+    // Convert into fp32 and reduce
+    r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
+    r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum));
+    HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4);
+
+    hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
+}
+
+#if 1
+static void vec_dot_f16_f32(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
+    if (0) {
+        float rsum                 = 0;
+        const __fp16 * restrict vx = (const __fp16 * restrict) x;
+        const float * restrict vy  = (const float * restrict) y;
+
+        for (uint32_t i = 0; i < n; i++) {
+            rsum += vx[i] * (__fp16) vy[i];
+        }
+        *s = rsum;
+        return;
+    }
+
+    const HVX_UVector * restrict vx     = (const HVX_UVector * restrict) x;
+    const HVX_UVectorPair * restrict vy = (const HVX_UVectorPair * restrict) y;
+
+    uint32_t nv0 = n / 64;  // num full fp16 hvx vectors
+    uint32_t nv1 = n % 64;  // leftover elements
+
+    // for some reason we need volatile here so that the compiler doesn't try anything funky
+    volatile HVX_Vector rsum = Q6_V_vsplat_R(0);
+
+    uint32_t i = 0;
+
+    for (i = 0; i < nv0; i++) {
+        HVX_VectorPair yp = vy[i];
+
+        HVX_Vector     x  = vx[i];
+        HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), Q6_Vh_vsplat_R(0x3C00));  // mul by 1.0
+
+        HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp));
+        HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp));
+
+        HVX_Vector sum = Q6_Vqf32_vadd_Vqf32Vqf32(hi, lo);
+        rsum           = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum);
+    }
+
+    if (nv1) {
+        HVX_VectorPair yp = vy[i];
+
+        HVX_Vector     x  = vx[i];
+        HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), Q6_Vh_vsplat_R(0x3C00));  // mul by 1.0
+
+        if (nv1 >= 32) {
+            HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp));
+            rsum          = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, hi);
+            nv1 -= 32;
+        }
+
+        rsum = hvx_vec_qf32_reduce_sum(rsum);
+
+        if (nv1) {
+            HVX_Vector lo  = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp));
+            HVX_Vector sum = hvx_vec_qf32_reduce_sum_n(lo, nv1);
+            rsum           = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum);
+        }
+
+        // hvx_vec_dump_fp16("X", x);
+        // hvx_vec_dump_fp16("Y", y);
+        // hvx_vec_dump_fp32("SUM",  Q6_Vsf_equals_Vqf32(sum));
+        // hvx_vec_dump_fp32("RSUM", Q6_Vsf_equals_Vqf32(rsum));
+    } else {
+        rsum = hvx_vec_qf32_reduce_sum(rsum);
+    }
+
+    *s = hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(rsum));
+
+#    ifdef HTP_DEBUG
+    {
+        float rsum                 = 0;
+        const __fp16 * restrict vx = (const __fp16 * restrict) x;
+        const float * restrict vy  = (const float * restrict) y;
+
+        for (uint32_t i = 0; i < n; i++) {
+            rsum += vx[i] * vy[i];
+        }
+
+        float diff = fabs(*s - rsum);
+        if (diff > 0.001) {
+            FARF(HIGH, "vec-dot-f16-missmatch: %u (%u:%u) expected %.6f got %.6f\n", n, nv0, nv1, rsum, *s);
+            // htp_dump_f16("x", vx, n);
+            // htp_dump_f32("y", vy, n);
+        }
+    }
+#    endif
+}
+#else
+static void vec_dot_f16_f32(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
+    const uint32_t fk = 64;
+    const uint32_t nb = n / fk;
+
+    assert(n % fk == 0);
+    assert(nb % 4 == 0);
+
+    const uint32_t x_blk_size = 2 * fk;  // fp16
+    const uint32_t y_blk_size = 4 * fk;  // fp32
+
+    // Row sum (qf32)
+    HVX_Vector rsum0 = Q6_V_vsplat_R(0);
+    HVX_Vector rsum1 = Q6_V_vsplat_R(0);
+    HVX_Vector rsum2 = Q6_V_vsplat_R(0);
+    HVX_Vector rsum3 = Q6_V_vsplat_R(0);
+
+    for (uint32_t i = 0; i < nb; i += 4) {
+        HVX_Vector_x4 vx = hvx_vec_load_x4_f16(x + (i * x_blk_size));
+        HVX_Vector_x4 vy = hvx_vec_load_x4_f32_as_f16(y + (i * y_blk_size));
+
+        HVX_VectorPair fa0 = Q6_Wqf32_vmpy_VhfVhf(vx.v[0], vy.v[0]);
+        HVX_VectorPair fa1 = Q6_Wqf32_vmpy_VhfVhf(vx.v[1], vy.v[1]);
+        HVX_VectorPair fa2 = Q6_Wqf32_vmpy_VhfVhf(vx.v[2], vy.v[2]);
+        HVX_VectorPair fa3 = Q6_Wqf32_vmpy_VhfVhf(vx.v[3], vy.v[3]);
+
+        rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(fa0), Q6_V_hi_W(fa0)));
+        rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(fa1), Q6_V_hi_W(fa1)));
+        rsum2 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum2, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(fa2), Q6_V_hi_W(fa2)));
+        rsum3 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum3, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(fa3), Q6_V_hi_W(fa3)));
+    }
+
+    // Reduce and convert into fp32
+    rsum0           = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, rsum1);
+    rsum2           = Q6_Vqf32_vadd_Vqf32Vqf32(rsum2, rsum3);
+    HVX_Vector rsum = hvx_vec_qf32_reduce_sum(Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, rsum2));
+    hvx_vec_store_u(s, 4, Q6_Vsf_equals_Vqf32(rsum));
+}
+#endif
+
+#define htp_matmul_preamble            \
+    const uint32_t ne00 = src0->ne[0]; \
+    const uint32_t ne01 = src0->ne[1]; \
+    const uint32_t ne02 = src0->ne[2]; \
+    const uint32_t ne03 = src0->ne[3]; \
+                                       \
+    const uint32_t ne10 = src1->ne[0]; \
+    const uint32_t ne11 = src1->ne[1]; \
+    const uint32_t ne12 = src1->ne[2]; \
+    const uint32_t ne13 = src1->ne[3]; \
+                                       \
+    const uint32_t ne0 = dst->ne[0];   \
+    const uint32_t ne1 = dst->ne[1];   \
+    const uint32_t ne2 = dst->ne[2];   \
+    const uint32_t ne3 = dst->ne[3];   \
+                                       \
+    const uint32_t nb00 = src0->nb[0]; \
+    const uint32_t nb01 = src0->nb[1]; \
+    const uint32_t nb02 = src0->nb[2]; \
+    const uint32_t nb03 = src0->nb[3]; \
+                                       \
+    const uint32_t nb10 = src1->nb[0]; \
+    const uint32_t nb11 = src1->nb[1]; \
+    const uint32_t nb12 = src1->nb[2]; \
+    const uint32_t nb13 = src1->nb[3]; \
+                                       \
+    const uint32_t nb0 = dst->nb[0];   \
+    const uint32_t nb1 = dst->nb[1];   \
+    const uint32_t nb2 = dst->nb[2];   \
+    const uint32_t nb3 = dst->nb[3];
+
+// q8x4 src1 tensor is already in VTCM spad
+static void matmul(struct htp_matmul_type * mt,
+                   struct htp_tensor * restrict src0,
+                   struct htp_tensor * restrict src1,
+                   struct htp_tensor * restrict dst,
+                   struct htp_spad * restrict src0_spad,
+                   struct htp_spad * restrict src1_spad,
+                   struct htp_spad * restrict dst_spad,
+                   uint32_t    nth,
+                   uint32_t    ith,
+                   uint32_t    src0_nrows_per_thread,
+                   dma_queue * dma_queue) {
+    htp_matmul_preamble;
+
+    const uint32_t src0_nrows = ne01 * ne02 * ne03;  // src0 rows
+    const uint32_t src1_nrows = ne11 * ne12 * ne13;  // src1 rows
+
+    const uint32_t src0_start_row  = src0_nrows_per_thread * ith;
+    const uint32_t src0_end_row    = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
+    const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U);
+
+    // no work for this thread
+    if (src0_start_row >= src0_end_row) {
+        return;
+    }
+
+    const size_t dst_row_size  = nb1;
+    const size_t src0_row_size = nb01;
+    const size_t src1_row_size = q8x4x2_row_size(ne10);
+
+    const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128);
+
+    // Per-thread VTCM scratchpads for all tensors
+    // Note that the entire src1 tensor is already in VTCM
+    // For other tensors we allocate N rows per thread, padded to HVX vector size
+    uint8_t * restrict spad_dst  = dst_spad->data + dst_spad->size_per_thread * ith;
+    uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith;
+    uint8_t * restrict src1_data = src1_spad->data;
+
+    volatile uint64_t t1, t2;
+    t1 = HAP_perf_get_qtimer_count();
+
+    const uint8_t * restrict src0_row = (const uint8_t *) src0->data;
+
+    // Prefill spad with src0 rows
+    #pragma unroll(4)
+    for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
+        const int is0 = (ir0 - src0_start_row);
+        if (is0 >= HTP_SPAD_SRC0_NROWS) {
+            break;
+        }
+        dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size,
+                       src0_row_size_padded, src0_row_size, 2);
+    }
+
+    // Process src0 rows
+    for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
+        const uint8_t * ss0 = dma_queue_pop(dma_queue);
+
+        #pragma unroll(2)
+        for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) {
+            const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_row_size);
+            float * restrict dst_row          = (float *) (dst->data + (ir1 * dst_row_size));
+            mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col);
+        }
+
+        // Prefetch next (n + spad_nrows) row
+        const int pr0 = (ir0 + HTP_SPAD_SRC0_NROWS);
+        const int is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS;
+        if (pr0 < src0_end_row_x2) {
+            dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size,
+                           src0_row_size_padded, src0_row_size, 2);
+        }
+    }
+
+    // Process the last row (if any)
+    if (src0_end_row != src0_end_row_x2) {
+        uint32_t  ir0 = src0_end_row_x2;
+        const int is0 = (ir0 - src0_start_row);
+        dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size,
+                       src0_row_size_padded, src0_row_size, 1);
+        const uint8_t * ss0 = dma_queue_pop(dma_queue);
+
+        #pragma unroll(2)
+        for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) {
+            const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_row_size);
+            float * restrict dst_row          = (float *) (dst->data + (ir1 * dst_row_size));
+            mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col);
+        }
+    }
+
+    t2 = HAP_perf_get_qtimer_count();
+
+    FARF(HIGH, "matmul-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mt->type, ith, nth,
+         src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1],
+         src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
+         (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+// q8x4x2 src1 tensor is already in VTCM spad
+static void matvec(struct htp_matmul_type * mt,
+                   struct htp_tensor * restrict src0,
+                   struct htp_tensor * restrict src1,
+                   struct htp_tensor * restrict dst,
+                   struct htp_spad * restrict src0_spad,
+                   struct htp_spad * restrict src1_spad,
+                   struct htp_spad * restrict dst_spad,
+                   uint32_t    nth,
+                   uint32_t    ith,
+                   uint32_t    src0_nrows_per_thread,
+                   dma_queue * dma_queue) {
+    htp_matmul_preamble;
+
+    const uint32_t src0_nrows = ne01;
+
+    const uint32_t src0_start_row  = src0_nrows_per_thread * ith;
+    const uint32_t src0_end_row    = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
+    const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U);
+
+    // no work for this thread
+    if (src0_start_row >= src0_end_row) {
+        return;
+    }
+
+    const size_t dst_row_size  = nb1;
+    const size_t src0_row_size = nb01;
+    const size_t src1_row_size = q8x4x2_row_size(ne10);
+
+    const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128);
+
+    // Per-thread VTCM scratchpads for all tensors
+    // Note that the entire src1 tensor is already in VTCM
+    // For other tensors we allocate N rows per thread, padded to HVX vector size
+    uint8_t * spad_dst  = dst_spad->data + dst_spad->size_per_thread * ith;
+    uint8_t * spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith;
+    uint8_t * src1_data = src1_spad->data;
+
+    uint64_t t1, t2;
+    t1 = HAP_perf_get_qtimer_count();
+
+    float * tmp = (float *) spad_dst;
+
+    const uint8_t * restrict src0_row = (const uint8_t *) src0->data;
+    const uint8_t * restrict src1_col = (const uint8_t *) src1_data;
+    float * restrict dst_col          = (float *) dst->data;
+
+    // Prefill spad with 2x src0 rows
+    #pragma unroll(2)
+    for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
+        const uint32_t is0 = (ir0 - src0_start_row);
+        if (is0 >= HTP_SPAD_SRC0_NROWS) {
+            break;
+        }
+        dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size,
+                       src0_row_size_padded, src0_row_size, 2);
+    }
+
+    // Process src0 rows
+    for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
+        const uint8_t * ss0 = dma_queue_pop(dma_queue);
+        mt->vec_dot_rx2(ne00, &tmp[ir0 - src0_start_row], ss0, src0_row_size_padded, src1_col);
+
+        // Prefetch next (n + spad_nrows) row
+        const uint32_t pr0 = (ir0 + HTP_SPAD_SRC0_NROWS);
+        const uint32_t is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS;
+        if (pr0 < src0_end_row_x2) {
+            dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size,
+                           src0_row_size_padded, src0_row_size, 2);
+        }
+    }
+
+    // Process the last row (if any)
+    if (src0_end_row != src0_end_row_x2) {
+        const uint32_t ir0 = src0_end_row_x2;
+        const uint32_t is0 = (ir0 - src0_start_row);
+        dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size,
+                       src0_row_size_padded, src0_row_size, 1);
+        const uint8_t * ss0 = dma_queue_pop(dma_queue);
+        mt->vec_dot(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col);
+    }
+
+    hvx_copy_fp32_ua((uint8_t *) &dst_col[src0_start_row], (uint8_t *) tmp, src0_end_row - src0_start_row);
+
+    t2 = HAP_perf_get_qtimer_count();
+
+    FARF(HIGH, "matvec-%s %u/%u: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mt->type, ith, nth,
+         src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1],
+         src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
+         (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) * ids->ne[0] * ids->ne[1] + (i1)]
+
+struct mmid_row_mapping {
+    uint32_t i1;
+    uint32_t i2;
+};
+
+// q8x4 src1 tensor is already in VTCM spad
+static void matmul_id(struct htp_matmul_type * mt,
+                      struct htp_tensor * restrict src0,
+                      struct htp_tensor * restrict src1,
+                      struct htp_tensor * restrict ids,
+                      struct htp_tensor * restrict dst,
+                      struct htp_spad * restrict src0_spad,
+                      struct htp_spad * restrict src1_spad,
+                      struct htp_spad * restrict src2_spad,
+                      struct htp_spad * restrict dst_spad,
+                      uint32_t    nth,
+                      uint32_t    ith,
+                      uint32_t    src0_nrows_per_thread,
+                      dma_queue * dma_queue) {
+    htp_matmul_preamble;
+
+    uint64_t t1, t2;
+    t1 = HAP_perf_get_qtimer_count();
+
+    const uint32_t src0_nrows = ne01;  // src0 rows per expert
+    const uint32_t src1_nrows = ne11;
+
+    const uint32_t src0_start_row  = src0_nrows_per_thread * ith;
+    const uint32_t src0_end_row    = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
+    const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U);
+
+    // no work for this thread
+    if (src0_start_row >= src0_end_row) {
+        return;
+    }
+
+    const uint32_t n_ids = ids->ne[0];  // n_expert_used
+    const uint32_t n_as  = ne02;        // n_expert
+
+    const size_t matrix_row_counts_size = n_as * sizeof(uint32_t);
+    const size_t matrix_row_map_size    = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping);
+
+    const uint32_t *                matrix_row_counts = (const uint32_t *) src2_spad->data + 0;
+    const struct mmid_row_mapping * matrix_rows       = (const void *) src2_spad->data + matrix_row_counts_size;
+
+    const size_t dst_row_size  = nb1;
+    const size_t src0_row_size = nb01;
+    const size_t src1_row_size = q8x4x2_row_size(ne10);
+
+    const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128);
+
+    // Per-thread VTCM scratchpads for all tensors
+    // Note that the entire src1 tensor is already in VTCM
+    // For other tensors we allocate N rows per thread, padded to HVX vector size
+    uint8_t * restrict spad_dst  = dst_spad->data + dst_spad->size_per_thread * ith;
+    uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith;
+    uint8_t * restrict src1_data = src1_spad->data;
+
+    for (uint32_t cur_a = 0; cur_a < n_as; ++cur_a) {
+        const int32_t cne1 = matrix_row_counts[cur_a];
+
+        if (cne1 == 0) {
+            continue;
+        }
+
+        const uint8_t * src0_row = (const uint8_t *) src0->data + (0 + cur_a * nb02 + 0);
+
+        // Prefill spad with src0 rows
+        #pragma unroll(4)
+        for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
+            const int is0 = (ir0 - src0_start_row);
+            if (is0 >= HTP_SPAD_SRC0_NROWS) {
+                break;
+            }
+            dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size,
+                           src0_row_size_padded, src0_row_size, 2);
+        }
+
+        // Process src0 rows
+        for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
+            const uint8_t * ss0 = dma_queue_pop(dma_queue);
+
+            for (uint32_t cid = 0; cid < cne1; ++cid) {
+                struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, cid);
+                const int               rm1         = row_mapping.i1;  // expert idx
+                const int               rm2         = row_mapping.i2;  // token idx
+
+                const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1;        // src1 row idx
+                const uint8_t * restrict src1_col =
+                    (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size);
+                float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0));
+
+                mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col);
+            }
+
+            // Prefetch next (n + spad_nrows) row
+            const int pr0 = (ir0 + HTP_SPAD_SRC0_NROWS);
+            const int is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS;
+            if (pr0 < src0_end_row_x2) {
+                dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size,
+                               src0_row_size_padded, src0_row_size, 2);
+            }
+        }
+
+        // Process the last row (if any)
+        if (src0_end_row != src0_end_row_x2) {
+            uint32_t       ir0 = src0_end_row_x2;
+            const uint32_t is0 = (ir0 - src0_start_row);
+            dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size,
+                           src0_row_size_padded, src0_row_size, 1);
+            const uint8_t * ss0 = dma_queue_pop(dma_queue);
+
+            for (uint32_t cid = 0; cid < cne1; ++cid) {
+                struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, cid);
+                const int               rm1         = row_mapping.i1;  // expert idx
+                const int               rm2         = row_mapping.i2;  // token idx
+
+                const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1;        // src1 row idx
+                const uint8_t * restrict src1_col =
+                    (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size);
+                float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0));
+
+                mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col);
+            }
+        }
+    }
+
+    t2 = HAP_perf_get_qtimer_count();
+
+    FARF(HIGH, "matmul-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mt->type,
+         ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0],
+         src1->ne[1], src1->ne[2], src1->ne[3], ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1],
+         dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+// q8x4 src1 tensor is already in VTCM spad
+static void matvec_id(struct htp_matmul_type * mt,
+                      struct htp_tensor * restrict src0,
+                      struct htp_tensor * restrict src1,
+                      struct htp_tensor * restrict src2,
+                      struct htp_tensor * restrict dst,
+                      struct htp_spad * restrict src0_spad,
+                      struct htp_spad * restrict src1_spad,
+                      struct htp_spad * restrict src2_spad,
+                      struct htp_spad * restrict dst_spad,
+                      uint32_t    nth,
+                      uint32_t    ith,
+                      uint32_t    src0_nrows_per_thread,
+                      dma_queue * dma_queue) {
+    htp_matmul_preamble;
+
+    uint64_t t1, t2;
+    t1 = HAP_perf_get_qtimer_count();
+
+    const uint32_t src0_nrows = ne01;  // src0 rows per expert
+
+    const uint32_t src0_start_row  = src0_nrows_per_thread * ith;
+    const uint32_t src0_end_row    = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
+    const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U);
+
+    // no work for this thread
+    if (src0_start_row >= src0_end_row) {
+        return;
+    }
+
+    assert(ne13 % ne03 == 0);
+
+    const size_t dst_row_size  = nb1;
+    const size_t src0_row_size = nb01;
+    const size_t src1_row_size = q8x4x2_row_size(ne10);
+
+    const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128);
+
+    const uint32_t n_aids = src2->ne[0];  // num activated experts
+    const uint32_t n_ids  = ne02;         // num experts
+
+    // Per-thread VTCM scratchpads for all tensors
+    // Note that the entire src1 tensor is already in VTCM
+    // For other tensors we allocate N rows per thread, padded to HVX vector size
+    uint8_t * restrict spad_dst  = dst_spad->data + dst_spad->size_per_thread * ith;
+    uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith;
+    uint8_t * restrict src1_data = src1_spad->data;
+
+    for (uint32_t ie1 = 0; ie1 < n_aids; ++ie1) {  // for each expert
+        const uint32_t eid = *(const int32_t *) ((const uint8_t *) src2->data + ie1 * src2->nb[0]);
+        assert(eid < n_ids);
+
+        const uint8_t * restrict src0_row = (const uint8_t *) src0->data + eid * nb02;
+        const uint8_t * restrict src1_col = (const uint8_t *) src1_data;
+        float * restrict dst_row          = (float *) (dst->data + ie1 * nb1);
+
+        // Prefill spad with src0 rows
+        #pragma unroll(4)
+        for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
+            const int is0 = (ir0 - src0_start_row);
+            if (is0 >= HTP_SPAD_SRC0_NROWS) {
+                break;
+            }
+            dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size,
+                           src0_row_size_padded, src0_row_size, 2);
+        }
+
+        // Process src0 rows
+        for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
+            const uint8_t * ss0 = dma_queue_pop(dma_queue);
+            mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col);
+
+            // Prefetch next (n + spad_nrows) row
+            const int pr0 = (ir0 + HTP_SPAD_SRC0_NROWS);
+            const int is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS;
+            if (pr0 < src0_end_row_x2) {
+                dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size,
+                               src0_row_size_padded, src0_row_size, 2);
+            }
+        }
+
+        // Process the last row (if any)
+        if (src0_end_row != src0_end_row_x2) {
+            uint32_t       ir0 = src0_end_row_x2;
+            const uint32_t is0 = (ir0 - src0_start_row);
+            dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size,
+                           src0_row_size_padded, src0_row_size, 1);
+            const uint8_t * ss0 = dma_queue_pop(dma_queue);
+            mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col);
+        }
+    }
+
+    t2 = HAP_perf_get_qtimer_count();
+
+    FARF(HIGH, "matvec-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mt->type,
+         ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0],
+         src1->ne[1], src1->ne[2], src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0],
+         dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+// *** matmul in fp16
+
+static void matmul_f16_f32(struct htp_tensor * restrict src0,
+                           struct htp_tensor * restrict src1,
+                           struct htp_tensor * restrict dst,
+                           struct htp_spad * restrict src0_spad,
+                           struct htp_spad * restrict src1_spad,
+                           struct htp_spad * restrict dst_spad,
+                           uint32_t    nth,
+                           uint32_t    ith,
+                           uint32_t    src0_nrows_per_thread,
+                           dma_queue * dma_queue) {
+    htp_matmul_preamble;
+
+    uint64_t t1, t2;
+    t1 = HAP_perf_get_qtimer_count();
+
+    const size_t src0_row_size = sizeof(__fp16) * ne00;
+    const size_t src1_row_size = sizeof(float) * ne10;
+
+    assert(ne12 % ne02 == 0);
+    assert(ne13 % ne03 == 0);
+
+    // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers)
+    const uint32_t nr0 = ne0;
+
+    // This is the size of the rest of the dimensions of the result
+    const uint32_t nr1 = ne1 * ne2 * ne3;
+
+    uint32_t chunk_size = 64;
+
+    // distribute the thread work across the inner or outer loop based on which one is larger
+    uint32_t nchunk0 = nr0 > nr1 ? nth : 1;  // parallelize by src0 rows
+    uint32_t nchunk1 = nr0 > nr1 ? 1 : nth;  // parallelize by src1 rows
+
+    // The number of elements in each chunk
+    const uint32_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
+    const uint32_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
+
+    uint32_t current_chunk = ith;
+
+    const uint32_t ith0 = current_chunk % nchunk0;
+    const uint32_t ith1 = current_chunk / nchunk0;
+
+    const uint32_t ir0_start = dr0 * ith0;
+    const uint32_t ir0_end   = MIN(ir0_start + dr0, nr0);
+
+    const uint32_t ir1_start = dr1 * ith1;
+    const uint32_t ir1_end   = MIN(ir1_start + dr1, nr1);
+
+    // broadcast factors
+    const uint32_t r2 = ne12 / ne02;
+    const uint32_t r3 = ne13 / ne03;
+
+    // no work for this thread
+    if (ir0_start >= ir0_end || ir1_start >= ir1_end) {
+        return;
+    }
+
+    // block-tiling attempt
+    const uint32_t blck_0 = 64;
+    const uint32_t blck_1 = 64;
+
+    float tmp[32];
+
+    for (uint32_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
+        for (uint32_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
+            for (uint32_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1++) {
+                const uint32_t i13 = (ir1 / (ne12 * ne1));
+                const uint32_t i12 = (ir1 - i13 * ne12 * ne1) / ne1;
+                const uint32_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1);
+
+                // broadcast src0 into src1
+                const uint32_t i03 = i13 / r3;
+                const uint32_t i02 = i12 / r2;
+
+                const uint32_t i1 = i11;
+                const uint32_t i2 = i12;
+                const uint32_t i3 = i13;
+
+                const uint8_t * restrict src0_row = (const uint8_t *) src0->data + (0 + i02 * nb02 + i03 * nb03);
+                const uint8_t * restrict src1_col =
+                    (const uint8_t *) src1->data + (i11 + i12 * ne11 + i13 * ne12 * ne11) * src1_row_size;
+                float * dst_col = (float *) ((uint8_t * restrict) dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
+
+                for (uint32_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0++) {
+                    vec_dot_f16_f32(ne00, &tmp[ir0 - iir0], src0_row + ir0 * src0_row_size, src1_col);
+                }
+
+                hvx_copy_fp32_ua((uint8_t *) &dst_col[iir0], (uint8_t *) tmp, MIN(iir0 + blck_0, ir0_end) - iir0);
+            }
+        }
+    }
+
+    t2 = HAP_perf_get_qtimer_count();
+
+    FARF(HIGH, "matmul-f16-f32 %d/%d: %ux%ux%ux%u (%u:%u %u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
+         src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0_start, ir0_end, ir1_start, ir1_end, src1->ne[0],
+         src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
+         (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+// *** dynamic quant
+
+static inline void quantize_block_fp32_q8x4(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
+    assert((unsigned long) x % 128 == 0);
+    assert((unsigned long) y_q % 128 == 0);
+
+    HVX_Vector * vx = (HVX_Vector *) x;
+
+    // Load and convert into QF32
+    HVX_Vector zero   = Q6_V_vsplat_R(0);
+    HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero);  // 32 elements
+    HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero);  // 32 elements
+    HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero);  // 32 elements
+    HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero);  // 32 elements
+
+    // Convert into fp16
+    HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf)));
+    HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
+
+    // Compute max and scale
+    HVX_Vector vmax_hf = hvx_vec_reduce_max_fp16(hvx_vec_abs_fp16(vx01_hf));
+    vmax_hf            = hvx_vec_reduce_max2_fp16(hvx_vec_abs_fp16(vx23_hf), vmax_hf);
+
+    // Replicate first fp16 scale across all lanes
+    HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_fp16;
+    vmax_hf         = Q6_V_vdelta_VV(vmax_hf, ctrl);
+
+    HVX_Vector vd_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax_hf, Q6_Vh_vsplat_R(0x2008));  // 1.0 / 127.0
+    HVX_Vector vd_hf   = Q6_Vhf_equals_Vqf16(vd_qf16);
+
+    *(HVX_UVector *) y_d = vd_hf;
+
+    // Divide input by the scale
+    HVX_Vector vd_inv_hf = hvx_vec_inverse_fp16(vd_hf);
+    vx01_hf              = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd_inv_hf));
+    vx23_hf              = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd_inv_hf));
+
+    // Convert to int8
+    HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf);
+    HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf);
+    HVX_Vector vx_i8    = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16);
+
+    *(HVX_Vector *) y_q = vx_i8;
+}
+
+// Overrides input x
+static void quantize_row_fp32_q8x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) {
+    assert(k % 32 == 0);
+    const uint32_t qk = QK_Q8_0x4x2;
+    const uint32_t nb = (k + qk - 1) / qk;
+
+    const uint32_t qrow_size = k;              // int8
+
+    const uint32_t dblk_size = 8 * 2;          // 8x __fp16
+    const uint32_t qblk_size = QK_Q8_0x4x2;    // int8
+
+    uint8_t * restrict y_q = (y + 0);          // quants first
+    uint8_t * restrict y_d = (y + qrow_size);  // then scales
+
+    // Temp scales override input since we're working off of the aligned temp buffer in VTCM
+    uint8_t * restrict t_d = (uint8_t *) x;
+
+    for (uint32_t i = 0; i < nb; i++) {
+        quantize_block_fp32_q8x4(x + (i * 2 + 0) * qk / 2, y_q + (i * 2 + 0) * qblk_size / 2,
+                                 t_d + (i * 2 + 0) * dblk_size / 2);
+        quantize_block_fp32_q8x4(x + (i * 2 + 1) * qk / 2, y_q + (i * 2 + 1) * qblk_size / 2,
+                                 t_d + (i * 2 + 1) * dblk_size / 2);
+    }
+
+    // now copy the scales into final location
+    hvx_copy_fp16_ua(y_d, t_d, nb * 8);
+}
+
+static void quantize_fp32_q8x4x2(const struct htp_tensor * src,
+                                 uint8_t * restrict dst,
+                                 struct htp_spad * spad,
+                                 uint32_t          nth,
+                                 uint32_t          ith,
+                                 uint32_t          nrows_per_thread) {
+    uint64_t t1 = HAP_perf_get_qtimer_count();
+
+    const uint32_t ne0 = src->ne[0];
+    const uint32_t ne1 = src->ne[1];
+    const uint32_t ne2 = src->ne[2];
+    const uint32_t ne3 = src->ne[3];
+
+    const uint32_t nrows = ne1 * ne2 * ne3;                             // total n_rows
+
+    const uint32_t ir_first = nrows_per_thread * ith;                   // first row
+    const uint32_t ir_last  = MIN(ir_first + nrows_per_thread, nrows);  // last row
+
+    const size_t src_row_size = src->nb[1];
+    const size_t dst_row_size = q8x4x2_row_size(ne0);
+
+    uint8_t * restrict src_data = (uint8_t *) src->data + (src_row_size * ir_first);
+    uint8_t * restrict dst_data = (uint8_t *) dst + (dst_row_size * ir_first);
+    uint8_t * restrict tmp_data = (uint8_t *) spad->data + (spad->size_per_thread * ith);
+
+    const size_t src_row_size_padded = htp_round_up(src_row_size, QK_Q8_0x4x2 * sizeof(float));
+    memset(tmp_data, 0, src_row_size_padded);  // zero-out temp row data for padding
+
+    for (uint32_t i = ir_first; i < ir_last; ++i) {
+        htp_l2fetch(src_data, 2, src_row_size, src_row_size);
+        hvx_copy_fp32_aa(tmp_data, src_data, ne0);
+
+        // FARF(HIGH, "quantize-q8x4-row: %u\n", i);
+        quantize_row_fp32_q8x4x2((float *) tmp_data, dst_data, ne0);
+        dst_data += dst_row_size;
+        src_data += src_row_size;
+    }
+
+    uint64_t t2 = HAP_perf_get_qtimer_count();
+
+    FARF(HIGH, "quantize-fp32-q8x4: %u/%u : n-rows %u (%u:%u) row-size %u -> %u usec %u\n", ith, nth, nrows, ir_first,
+         ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+static void htp_quantize_fp32_q8x4x2(unsigned int n, unsigned int i, void * data) {
+    struct htp_ops_context * octx = data;
+    quantize_fp32_q8x4x2(&octx->src1, octx->src1_spad.data, &octx->src0_spad, n, i, octx->src1_nrows_per_thread);
+}
+
+// ** matmul callbacks for worker_pool
+
+static void htp_matvec_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
+    struct htp_ops_context * octx = data;
+
+    struct htp_matmul_type mt;
+    mt.type        = "q4x4x2-q8x4x2";
+    mt.vec_dot     = vec_dot_q4x4x2_q8x4x2;
+    mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
+
+    matvec(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i,
+           octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+}
+
+static void htp_matmul_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
+    struct htp_ops_context * octx = data;
+
+    struct htp_matmul_type mt;
+    mt.type        = "q4x4x2-q8x4x2";
+    mt.vec_dot     = vec_dot_q4x4x2_q8x4x2;
+    mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
+
+    matmul(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i,
+           octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+}
+
+static void htp_matvec_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
+    struct htp_ops_context * octx = data;
+
+    struct htp_matmul_type mt;
+    mt.type        = "q8x4x2-q8x4x2";
+    mt.vec_dot     = vec_dot_q8x4x2_q8x4x2;
+    mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
+
+    matvec(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i,
+           octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+}
+
+static void htp_matmul_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
+    struct htp_ops_context * octx = data;
+
+    struct htp_matmul_type mt;
+    mt.type        = "q8x4x2-q8x4x2";
+    mt.vec_dot     = vec_dot_q8x4x2_q8x4x2;
+    mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
+
+    matmul(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i,
+           octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+}
+
+static void htp_matvec_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
+    struct htp_ops_context * octx = data;
+
+    struct htp_matmul_type mt;
+    mt.type        = "mxfp4x4x2-q8x4x2";
+    mt.vec_dot     = vec_dot_mxfp4x4x2_q8x4x2;
+    mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
+
+    matvec(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i,
+           octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+}
+
+static void htp_matmul_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
+    struct htp_ops_context * octx = data;
+
+    struct htp_matmul_type mt;
+    mt.type        = "mxfp4x4x2-q8x4x2";
+    mt.vec_dot     = vec_dot_mxfp4x4x2_q8x4x2;
+    mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
+
+    matmul(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i,
+           octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+}
+
+static void htp_matmul_f16_f32(unsigned int n, unsigned int i, void * data) {
+    struct htp_ops_context * octx = data;
+    matmul_f16_f32(&octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i,
+                   octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+}
+
+// ** matmul-id callbacks for worker_pool
+
+static void htp_matvec_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
+    struct htp_ops_context * octx = data;
+
+    struct htp_matmul_type mt;
+    mt.type        = "q4x4x2-q8x4x2";
+    mt.vec_dot     = vec_dot_q4x4x2_q8x4x2;
+    mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
+
+    matvec_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad,
+              &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+}
+
+static void htp_matmul_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
+    struct htp_ops_context * octx = data;
+
+    struct htp_matmul_type mt;
+    mt.type        = "q4x4x2-q8x4x2";
+    mt.vec_dot     = vec_dot_q4x4x2_q8x4x2;
+    mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
+
+    matmul_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad,
+              &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+}
+
+static void htp_matvec_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
+    struct htp_ops_context * octx = data;
+
+    struct htp_matmul_type mt;
+    mt.type        = "q8x4x2-q8x4x2";
+    mt.vec_dot     = vec_dot_q8x4x2_q8x4x2;
+    mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
+
+    matvec_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad,
+              &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+}
+
+static void htp_matmul_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
+    struct htp_ops_context * octx = data;
+
+    struct htp_matmul_type mt;
+    mt.type        = "q8x4x2-q8x4x2";
+    mt.vec_dot     = vec_dot_q8x4x2_q8x4x2;
+    mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
+
+    matmul_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad,
+              &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+}
+
+static void htp_matvec_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
+    struct htp_ops_context * octx = data;
+
+    struct htp_matmul_type mt;
+    mt.type        = "mxfp4x4x2-q8x4x2";
+    mt.vec_dot     = vec_dot_mxfp4x4x2_q8x4x2;
+    mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
+
+    matvec_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad,
+              &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+}
+
+static void htp_matmul_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
+    struct htp_ops_context * octx = data;
+
+    struct htp_matmul_type mt;
+    mt.type        = "mxfp4x4x2-q8x4x2";
+    mt.vec_dot     = vec_dot_mxfp4x4x2_q8x4x2;
+    mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
+
+    matmul_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad,
+              &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+}
+
+// ** main matmul entry point
+
+int op_matmul(struct htp_ops_context * octx) {
+    const struct htp_tensor * src0 = &octx->src0;
+    const struct htp_tensor * src1 = &octx->src1;
+    struct htp_tensor *       dst  = &octx->dst;
+
+    htp_matmul_preamble;
+
+    const char * op_type;
+
+    const uint32_t src0_nrows = ne01 * ne02 * ne03;
+    const uint32_t src1_nrows = ne11 * ne12 * ne13;
+
+    const size_t src0_row_size = nb01;
+    const size_t dst_row_size  = nb1;
+    size_t       src1_row_size = nb11;
+
+    const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128);
+    size_t       src1_row_size_padded;
+
+    worker_callback_t quant_job_func;
+    worker_callback_t matmul_job_func;
+
+    bool need_quant = !(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE);
+
+    switch (src0->type) {
+        case HTP_TYPE_Q4_0:
+            op_type        = "q4x4x2-fp32";
+            quant_job_func = htp_quantize_fp32_q8x4x2;
+            if (src1_nrows > 1) {
+                matmul_job_func = htp_matmul_q4x4x2_q8x4x2;
+            } else {
+                matmul_job_func = htp_matvec_q4x4x2_q8x4x2;
+            }
+
+            src1_row_size = q8x4x2_row_size(ne10);  // row size post quantization
+
+            // Entire src1 tensor is placed into the VTCM
+            // For other tensors we allocate N rows per thread, padded to HVX vector size
+
+            octx->dst_spad.size_per_thread  = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256);
+            octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
+            octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
+
+            // src0 spad is also used in dynamic quantizer to store padded src1 rows
+            src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
+            if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
+                octx->src0_spad.size_per_thread = src1_row_size_padded;
+            }
+
+            octx->src1_spad.size = octx->src1_spad.size_per_thread;
+            octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
+            octx->dst_spad.size  = octx->dst_spad.size_per_thread * octx->n_threads;
+            break;
+
+        case HTP_TYPE_Q8_0:
+            op_type        = "q8x4x2-fp32";
+            quant_job_func = htp_quantize_fp32_q8x4x2;
+            if (src1_nrows > 1) {
+                matmul_job_func = htp_matmul_q8x4x2_q8x4x2;
+            } else {
+                matmul_job_func = htp_matvec_q8x4x2_q8x4x2;
+            }
+
+            src1_row_size = q8x4x2_row_size(ne10);  // row size post quantization
+
+            // Entire src1 tensor is placed into the VTCM
+            // For other tensors we allocate N rows per thread, padded to HVX vector size
+
+            octx->dst_spad.size_per_thread  = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256);
+            octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
+            octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
+
+            // src0 spad is also used in dynamic quantizer to store padded src1 rows
+            src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
+            if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
+                octx->src0_spad.size_per_thread = src1_row_size_padded;
+            }
+
+            octx->src1_spad.size = octx->src1_spad.size_per_thread;
+            octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
+            octx->dst_spad.size  = octx->dst_spad.size_per_thread * octx->n_threads;
+            break;
+
+        case HTP_TYPE_MXFP4:
+            op_type        = "mxfp4x4x2-f32";
+            quant_job_func = htp_quantize_fp32_q8x4x2;
+            if (src1_nrows > 1) {
+                matmul_job_func = htp_matmul_mxfp4x4x2_q8x4x2;
+            } else {
+                matmul_job_func = htp_matvec_mxfp4x4x2_q8x4x2;
+            }
+
+            src1_row_size = q8x4x2_row_size(ne10);  // row size post quantization
+
+            // Entire src1 tensor is placed into the VTCM
+            // For other tensors we allocate N rows per thread, padded to HVX vector size
+
+            octx->dst_spad.size_per_thread  = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256);
+            octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
+            octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
+
+            // src0 spad is also used in dynamic quantizer to store padded src1 rows
+            src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
+            if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
+                octx->src0_spad.size_per_thread = src1_row_size_padded;
+            }
+
+            octx->src1_spad.size = octx->src1_spad.size_per_thread;
+            octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
+            octx->dst_spad.size  = octx->dst_spad.size_per_thread * octx->n_threads;
+            break;
+
+        case HTP_TYPE_F16:
+            op_type         = "f16-f32";
+            quant_job_func  = NULL;  // htp_quantize_f32_f16;
+            matmul_job_func = htp_matmul_f16_f32;
+
+            // For all tensors we allocate N rows per thread, padded to HVX vector size
+            octx->dst_spad.size_per_thread  = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256);
+            octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size, 256);
+            octx->src1_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC1_NROWS * src1_row_size, 256);
+
+            octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
+            octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads;
+            octx->dst_spad.size  = octx->dst_spad.size_per_thread * octx->n_threads;
+
+            need_quant = false;
+            break;
+
+        default:
+            return HTP_STATUS_NO_SUPPORT;
+    }
+
+    // VTCM scratchpads for all tensors
+    size_t spad_size = octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size;
+
+    FARF(HIGH, "matmul-%s : src0-spad-size %u src1-spad-size %u dst-spad-size %u (%zu)\n", op_type,
+         octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size, spad_size);
+
+    FARF(HIGH, "matmul-%s : %ux%ux%ux%u * %ux%ux%ux%u-> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", op_type, src0->ne[0],
+         src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0],
+         dst->ne[1], dst->ne[2], dst->ne[3], src0->data, src1->data, dst->data);
+
+    // Make sure the reserved vtcm size is sufficient
+    if (octx->ctx->vtcm_size < spad_size) {
+        FARF(ERROR, "matmul-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type,
+             octx->ctx->vtcm_size, spad_size);
+        return HTP_STATUS_VTCM_TOO_SMALL;
+    }
+
+    octx->src0_spad.data = octx->ctx->vtcm_base;
+    octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
+    octx->dst_spad.data  = octx->src1_spad.data + octx->src1_spad.size;
+
+    octx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads;
+    octx->src0_nrows_per_thread += (octx->src0_nrows_per_thread & 1);  // round up to even
+
+    if (need_quant) {
+        // Run quant jobs
+        const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);
+        octx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
+        worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, octx, n_quant_jobs);
+    }
+
+    if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
+        // Run matmul jobs
+        const uint32_t n_matmul_jobs = octx->n_threads;
+        worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, octx, n_matmul_jobs);
+    }
+
+    return HTP_STATUS_OK;
+}
+
+// ** main matmul-id entry point
+
+int op_matmul_id(struct htp_ops_context * octx) {
+    const struct htp_tensor * src0 = &octx->src0;
+    const struct htp_tensor * src1 = &octx->src1;
+    const struct htp_tensor * ids  = &octx->src2;
+    struct htp_tensor *       dst  = &octx->dst;
+
+    htp_matmul_preamble;
+
+    const char * op_type;
+
+    worker_callback_t quant_job_func;
+    worker_callback_t matmul_id_job_func;
+
+    const size_t src0_row_size = nb01;
+    const size_t dst_row_size  = nb1;
+
+    const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128);
+
+    const uint32_t src0_nrows = ne01;  // per expert
+    const uint32_t src1_nrows = ne11 * ne12 * ne13;
+
+    size_t src1_row_size;
+    size_t src1_row_size_padded;
+
+    // row groups
+    const int n_ids = ids->ne[0];  // n_expert_used
+    const int n_as  = ne02;        // n_expert
+
+    size_t matrix_row_counts_size = n_as * sizeof(uint32_t);
+    size_t matrix_row_map_size    = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping);
+
+    switch (src0->type) {
+        case HTP_TYPE_Q4_0:
+            op_type        = "q4x2x2-f32";
+            quant_job_func = htp_quantize_fp32_q8x4x2;
+            src1_row_size  = q8x4x2_row_size(ne10);  // row size post quantization
+            if (src1_nrows > 1) {
+                matmul_id_job_func = htp_matmul_id_q4x4x2_q8x4x2;
+            } else {
+                matmul_id_job_func = htp_matvec_id_q4x4x2_q8x4x2;
+            }
+
+            // Entire src1 tensor is placed into the VTCM
+            // For other tensors we allocate N rows per thread, padded to HVX vector size
+            octx->dst_spad.size_per_thread  = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256);
+            octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
+            octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
+            octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
+
+            // src0 spad is also used in dynamic quantizer to store padded src1 rows
+            src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
+            if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
+                octx->src0_spad.size_per_thread = src1_row_size_padded;
+            }
+
+            octx->src2_spad.size = octx->src2_spad.size_per_thread;
+            octx->src1_spad.size = octx->src1_spad.size_per_thread;
+            octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
+            octx->dst_spad.size  = octx->dst_spad.size_per_thread * octx->n_threads;
+            break;
+
+        case HTP_TYPE_Q8_0:
+            op_type        = "q8x2x2-f32";
+            quant_job_func = htp_quantize_fp32_q8x4x2;
+            src1_row_size  = q8x4x2_row_size(ne10);  // row size post quantization
+            if (src1_nrows > 1) {
+                matmul_id_job_func = htp_matmul_id_q8x4x2_q8x4x2;
+            } else {
+                matmul_id_job_func = htp_matvec_id_q8x4x2_q8x4x2;
+            }
+
+            // Entire src1 tensor is placed into the VTCM
+            // For other tensors we allocate N rows per thread, padded to HVX vector size
+            octx->dst_spad.size_per_thread  = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256);
+            octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
+            octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
+            octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
+
+            // src0 spad is also used in dynamic quantizer to store padded src1 rows
+            src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
+            if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
+                octx->src0_spad.size_per_thread = src1_row_size_padded;
+            }
+
+            octx->src2_spad.size = octx->src2_spad.size_per_thread;
+            octx->src1_spad.size = octx->src1_spad.size_per_thread;
+            octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
+            octx->dst_spad.size  = octx->dst_spad.size_per_thread * octx->n_threads;
+            break;
+
+        case HTP_TYPE_MXFP4:
+            op_type        = "mxfp4x2x2-f32";
+            quant_job_func = htp_quantize_fp32_q8x4x2;
+            src1_row_size  = q8x4x2_row_size(ne10);  // row size post quantization
+            if (src1_nrows > 1) {
+                matmul_id_job_func = htp_matmul_id_mxfp4x4x2_q8x4x2;
+            } else {
+                matmul_id_job_func = htp_matvec_id_mxfp4x4x2_q8x4x2;
+            }
+
+            // Entire src1 tensor is placed into the VTCM
+            // For other tensors we allocate N rows per thread, padded to HVX vector size
+            octx->dst_spad.size_per_thread  = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256);
+            octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
+            octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
+            octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
+
+            // src0 spad is also used in dynamic quantizer to store padded src1 rows
+            src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
+            if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
+                octx->src0_spad.size_per_thread = src1_row_size_padded;
+            }
+
+            octx->src2_spad.size = octx->src2_spad.size_per_thread;
+            octx->src1_spad.size = octx->src1_spad.size_per_thread;
+            octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
+            octx->dst_spad.size  = octx->dst_spad.size_per_thread * octx->n_threads;
+            break;
+
+        default:
+            return HTP_STATUS_NO_SUPPORT;
+    }
+
+    size_t spad_size = octx->src2_spad.size + octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size;
+
+    FARF(HIGH, "matmul-id-%s : src0-spad-size %u src1-spad-size %u src2-spad-size %u dst-spad-size %u (%zu)\n", op_type,
+         octx->src0_spad.size, octx->src1_spad.size, octx->src2_spad.size, octx->dst_spad.size, spad_size);
+
+    FARF(HIGH, "matmul-id-%s : %ux%ux%ux%u * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", op_type,
+         src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
+         ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->data,
+         src1->data, dst->data);
+
+    // Make sure the reserved vtcm size is sufficient
+    if (octx->ctx->vtcm_size < spad_size) {
+        FARF(ERROR, "matmul-id-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type,
+             octx->ctx->vtcm_size, spad_size);
+        return HTP_STATUS_VTCM_TOO_SMALL;
+    }
+
+    octx->src0_spad.data = octx->ctx->vtcm_base;
+    octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
+    octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size;
+    octx->dst_spad.data  = octx->src2_spad.data + octx->src2_spad.size;
+
+    octx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads;
+    octx->src0_nrows_per_thread += (octx->src0_nrows_per_thread & 1);  // round up to even
+
+    if (src1_nrows > 1) {
+        // initialize matrix_row_counts and map
+        uint32_t *                matrix_row_counts = (uint32_t *) octx->src2_spad.data + 0;
+        struct mmid_row_mapping * matrix_rows       = (void *) octx->src2_spad.data + matrix_row_counts_size;
+
+        memset(matrix_row_counts, 0, n_as * sizeof(uint32_t));
+
+        // group rows by src0 matrix
+        for (uint32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {  // token idx
+            for (uint32_t id = 0; id < n_ids; ++id) {         // expert idx
+                const uint32_t i02 =
+                    *(const uint32_t *) ((const uint8_t *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]);
+
+                assert(i02 >= 0 && i02 < n_as);
+
+                MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) { id, iid1 };
+                matrix_row_counts[i02] += 1;
+            }
+        }
+    }
+
+    // Setup worker pool callbacks
+    if (!(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE)) {
+        // Run quant jobs
+        const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);
+        octx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
+        worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, octx, n_quant_jobs);
+    }
+
+    if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
+        // Run matmul-id jobs
+        const uint32_t n_matmul_jobs = octx->n_threads;
+        worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, octx, n_matmul_jobs);
+    }
+
+    return HTP_STATUS_OK;
+}
diff --git a/src/ggml-hexagon/htp/ops-utils.h b/src/ggml-hexagon/htp/ops-utils.h
new file mode 100644 (file)
index 0000000..f03ff34
--- /dev/null
@@ -0,0 +1,116 @@
+#ifndef OPS_UTILS_H
+#define OPS_UTILS_H
+
+#include "htp-msg.h"
+
+#ifndef MAX
+#    define MAX(a, b) ((a) > (b) ? (a) : (b))
+#endif
+
+#ifndef MIN
+#    define MIN(a, b) ((a) < (b) ? (a) : (b))
+#endif
+
+static inline uint64_t htp_get_cycles() {
+    uint64_t cycles = 0;
+    asm volatile(" %0 = c15:14\n" : "=r"(cycles));
+    return cycles;
+}
+
+static inline uint64_t htp_get_pktcnt() {
+    uint64_t pktcnt;
+    asm volatile(" %0 = c19:18\n" : "=r"(pktcnt));
+    return pktcnt;
+}
+
+static inline int32_t htp_is_aligned(void * addr, uint32_t align) {
+    return ((size_t) addr & (align - 1)) == 0;
+}
+
+static inline uint32_t htp_round_up(uint32_t n, uint32_t m) {
+    return m * ((n + m - 1) / m);
+}
+
+static inline void htp_l2fetch(const void * p, uint32_t height, uint32_t width, uint32_t stride) {
+    const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height));
+    asm volatile(" l2fetch(%0,%1) " : : "r"(p), "r"(control));
+}
+
+static inline int32_t htp_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) {
+    uint32_t left_off  = (size_t) addr & (chunk_size - 1);
+    uint32_t right_off = left_off + n;
+    return right_off <= chunk_size;
+}
+
+static inline void htp_dump_int8_line(char * pref, const int8_t * x, int n) {
+    char str[1024], *p = str;
+    p += sprintf(p, "%s: ", pref);
+    for (int i = 0; i < 16; i++) {
+        p += sprintf(p, "%d, ", x[i]);
+    }
+    FARF(HIGH, "%s\n", str);
+}
+
+static inline void htp_dump_uint8_line(char * pref, const uint8_t * x, uint32_t n) {
+    char str[1024], *p = str;
+    p += sprintf(p, "%s: ", pref);
+    for (int i = 0; i < n; i++) {
+        p += sprintf(p, "%d, ", x[i]);
+    }
+    FARF(HIGH, "%s\n", str);
+}
+
+static inline void htp_dump_int32_line(char * pref, const int32_t * x, uint32_t n) {
+    char str[1024], *p = str;
+    p += sprintf(p, "%s: ", pref);
+    for (int i = 0; i < n; i++) {
+        p += sprintf(p, "%d, ", (int) x[i]);
+    }
+    FARF(HIGH, "%s\n", str);
+}
+
+static inline void htp_dump_fp16_line(char * pref, const __fp16 * x, uint32_t n) {
+    char str[1024], *p = str;
+    p += sprintf(p, "%s: ", pref);
+    for (int i = 0; i < n; i++) {
+        p += sprintf(p, "%.6f, ", (float) x[i]);
+    }
+    FARF(HIGH, "%s\n", str);
+}
+
+static inline void htp_dump_fp32_line(char * pref, const float * x, uint32_t n) {
+    char str[1024], *p = str;
+    p += sprintf(p, "%s: ", pref);
+    for (int i = 0; i < n; i++) {
+        p += sprintf(p, "%.6f, ", x[i]);
+    }
+    FARF(HIGH, "%s\n", str);
+}
+
+static inline void htp_dump_f32(char * pref, const float * x, uint32_t n) {
+    uint32_t n0 = n / 16;
+    uint32_t n1 = n % 16;
+
+    uint32_t i = 0;
+    for (; i < n0; i++) {
+        htp_dump_fp32_line(pref, x + (16 * i), 16);
+    }
+    if (n1) {
+        htp_dump_fp32_line(pref, x + (16 * i), n1);
+    }
+}
+
+static inline void htp_dump_f16(char * pref, const __fp16 * x, uint32_t n) {
+    uint32_t n0 = n / 16;
+    uint32_t n1 = n % 16;
+
+    uint32_t i = 0;
+    for (; i < n0; i++) {
+        htp_dump_fp16_line(pref, x + (16 * i), 16);
+    }
+    if (n1) {
+        htp_dump_fp16_line(pref, x + (16 * i), n1);
+    }
+}
+
+#endif /* OPS_UTILS_H */
diff --git a/src/ggml-hexagon/htp/rope-ops.c b/src/ggml-hexagon/htp/rope-ops.c
new file mode 100644 (file)
index 0000000..16afa50
--- /dev/null
@@ -0,0 +1,418 @@
+#pragma clang diagnostic ignored "-Wunused-variable"
+#pragma clang diagnostic ignored "-Wunused-function"
+#pragma clang diagnostic ignored "-Wunused-but-set-variable"
+
+#ifdef HTP_DEBUG
+#    define FARF_HIGH 1
+#endif
+#include <HAP_farf.h>
+#include <HAP_mem.h>
+#include <HAP_perf.h>
+#include <HAP_ps.h>
+#include <hexagon_protos.h>
+#include <hexagon_types.h>
+#include <math.h>
+#include <qurt_thread.h>
+#include <string.h>
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+#include "htp-ctx.h"
+#include "htp-dma.h"
+#include "htp-msg.h"
+#include "htp-ops.h"
+#include "hvx-utils.h"
+#include "ops-utils.h"
+
+#define htp_rope_preamble              \
+    const uint32_t ne00 = src0->ne[0]; \
+    const uint32_t ne01 = src0->ne[1]; \
+    const uint32_t ne02 = src0->ne[2]; \
+    const uint32_t ne03 = src0->ne[3]; \
+                                       \
+    const uint32_t ne0 = dst->ne[0];   \
+    const uint32_t ne1 = dst->ne[1];   \
+    const uint32_t ne2 = dst->ne[2];   \
+    const uint32_t ne3 = dst->ne[3];   \
+                                       \
+    const uint32_t nb00 = src0->nb[0]; \
+    const uint32_t nb01 = src0->nb[1]; \
+    const uint32_t nb02 = src0->nb[2]; \
+    const uint32_t nb03 = src0->nb[3]; \
+                                       \
+    const uint32_t nb0 = dst->nb[0];   \
+    const uint32_t nb1 = dst->nb[1];   \
+    const uint32_t nb2 = dst->nb[2];   \
+    const uint32_t nb3 = dst->nb[3];
+
+struct rope_th_ctx {
+    int32_t n_dims;
+    int32_t mode;
+    int32_t n_ctx_orig;
+    int32_t sections[4];
+
+    float freq_base;
+    float freq_scale;
+    float ext_factor;
+    float attn_factor;
+    float beta_fast;
+    float beta_slow;
+    float theta_scale;
+    float corr_dims[2];
+
+    struct htp_ops_context * octx;
+};
+
+static float rope_yarn_ramp(const float low, const float high, const int i0) {
+    const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
+
+    return (1 - MIN(1, MAX(0, y)));
+}
+
+static void rope_cache_init(const float   theta_base,
+                            float         freq_scale,
+                            const float * freq_factors,
+                            float *       corr_dims,
+                            uint32_t      ne0,
+                            float         ext_factor,
+                            float         mscale,
+                            float *       cache,
+                            float         theta_scale) {
+    // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
+    float theta = theta_base;
+
+    for (uint32_t i0 = 0; i0 < ne0; i0 += 2) {
+        const float ff = freq_factors ? freq_factors[i0 / 2] : 1.0f;
+
+        float theta_extrap = theta / ff;
+
+        // Get n-d rotational scaling corrected for extrapolation
+        float theta_interp = freq_scale * theta_extrap;
+        float theta2       = theta_interp;
+
+        if (ext_factor != 0.0f) {
+            float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
+            theta2         = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
+
+            // Get n-d magnitude scaling corrected for interpolation
+            mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
+        }
+
+        cache[i0 + 0] = cosf(theta2) * mscale;
+        cache[i0 + 1] = sinf(theta2) * mscale;
+
+        theta *= theta_scale;
+    }
+}
+
+#define M_PI 3.1415926535897932384626433
+
+static void rope_corr_dims(int     n_dims,
+                           int     n_ctx_orig,
+                           float   freq_base,
+                           float   beta_fast,
+                           float   beta_slow,
+                           float * dims) {
+    float start = floorf(n_dims * logf(n_ctx_orig / (beta_fast * 2 * (float) M_PI)) / (2 * logf(freq_base)));
+    float end   = ceilf(n_dims * logf(n_ctx_orig / (beta_slow * 2 * (float) M_PI)) / (2 * logf(freq_base)));
+    dims[0]     = MAX(0, start);
+    dims[1]     = MIN(n_dims - 1, end);
+}
+
+static void init_rope_ctx(struct rope_th_ctx * rope_ctx, struct htp_ops_context * octx) {
+    memset(rope_ctx, 0, sizeof(struct rope_th_ctx));
+
+    const int32_t * op_params = &octx->op_params[0];
+
+    rope_ctx->n_dims     = ((const int32_t *) op_params)[1];
+    rope_ctx->mode       = ((const int32_t *) op_params)[2];
+    rope_ctx->n_ctx_orig = ((const int32_t *) op_params)[4];
+
+    memcpy(&rope_ctx->freq_base, (int32_t *) op_params + 5, sizeof(float));
+    memcpy(&rope_ctx->freq_scale, (int32_t *) op_params + 6, sizeof(float));
+    memcpy(&rope_ctx->ext_factor, (int32_t *) op_params + 7, sizeof(float));
+    memcpy(&rope_ctx->attn_factor, (int32_t *) op_params + 8, sizeof(float));
+    memcpy(&rope_ctx->beta_fast, (int32_t *) op_params + 9, sizeof(float));
+    memcpy(&rope_ctx->beta_slow, (int32_t *) op_params + 10, sizeof(float));
+    memcpy(&rope_ctx->sections, (int32_t *) op_params + 11, sizeof(int) * 4);
+
+    rope_ctx->theta_scale = powf(rope_ctx->freq_base, -2.0f / rope_ctx->n_dims);
+
+    rope_corr_dims(rope_ctx->n_dims, rope_ctx->n_ctx_orig, rope_ctx->freq_base, rope_ctx->beta_fast,
+                   rope_ctx->beta_slow, rope_ctx->corr_dims);
+
+    rope_ctx->octx = octx;
+    FARF(HIGH, "rope-f32 n_dims:%d, ext_factor:%.6f, theta_scale:%.6f, attn_factor:%.6f\n", rope_ctx->n_dims,
+         rope_ctx->ext_factor, rope_ctx->theta_scale, rope_ctx->attn_factor);
+}
+
+static void hvx_calc_rope_f32(const float * restrict src0,
+                              float * restrict dst,
+                              const int num_elems,
+                              const float * restrict theta_cache) {
+    // for (int i = 0; i < num_elems; i += 2) {
+    //const float cos_theta = theta_cache[i + 0];
+    //const float sin_theta = theta_cache[i + 1];
+
+    //const float x0 = src[0];
+    //const float x1 = src[1];
+
+    //dst[0] = x0*cos_theta - x1*sin_theta;
+    //dst[1] = x0*sin_theta + x1*cos_theta;
+
+    //src += 2;
+    //dst += 2;
+    // }
+
+    const uint8_t * restrict src0_curr  = (const uint8_t *) src0;
+    const uint8_t * restrict theta_curr = (const uint8_t *) theta_cache;
+    uint8_t * restrict dst_curr         = (uint8_t *) dst;
+
+    int step_of_1 = num_elems >> 6;  // 6 because we process two vectors at once
+
+    for (int i = 0; i < step_of_1; i++) {
+        HVX_Vector v0 = *(HVX_Vector *) src0_curr;
+        HVX_Vector v1 = *(HVX_Vector *) (src0_curr + VLEN);
+
+        HVX_Vector v2 = *(HVX_Vector *) theta_curr;
+        HVX_Vector v3 = *(HVX_Vector *) (theta_curr + VLEN);
+
+        HVX_VectorPair vx0_x1   = Q6_W_vdeal_VVR(v1, v0, -4);  // vx0_x1[0] = x0, vx0_x1[1] = x1
+        HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4);  // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta
+
+        HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_lo_W(vcos_sin));
+        HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_hi_W(vcos_sin));
+        HVX_Vector vx1_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_hi_W(vx0_x1), Q6_V_lo_W(vcos_sin));
+        HVX_Vector vx1_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_hi_W(vx0_x1), Q6_V_hi_W(vcos_sin));
+
+        HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s);
+        HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c);
+
+        HVX_VectorPair vstore = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(v5), Q6_Vsf_equals_Vqf32(v4), -4);
+
+        *(HVX_Vector *) dst_curr          = Q6_V_lo_W(vstore);
+        *(HVX_Vector *) (dst_curr + VLEN) = Q6_V_hi_W(vstore);
+
+        src0_curr += 2 * VLEN;
+        theta_curr += 2 * VLEN;
+        dst_curr += 2 * VLEN;
+    }
+}
+
+static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
+                         const uint32_t       ir0,
+                         const uint32_t       ir1,
+                         int                  nth,
+                         int                  ith,
+                         int                  opt_path) {
+    struct htp_ops_context * octx = rope_ctx->octx;
+
+    const struct htp_tensor * src0 = &octx->src0;
+    const struct htp_tensor * src1 = &octx->src1;
+    const struct htp_tensor * src2 = &octx->src2;
+    struct htp_tensor *       dst  = &octx->dst;
+
+    htp_rope_preamble;
+
+    const int32_t * pos = (const int32_t *) src1->data;
+
+    float * wp0 = (float *) (octx->src0_spad.data + (ith * nb01));
+
+    const float * freq_factors = NULL;
+    if (src2 != NULL) {
+        freq_factors = (const float *) src2->data;
+    }
+
+    int ir = 0;
+
+    for (uint32_t i3 = 0; i3 < ne3; i3++) {      // batch
+        for (uint32_t i2 = 0; i2 < ne2; i2++) {  // seq-len
+            const int32_t p = pos[i2];
+
+            rope_cache_init(p, rope_ctx->freq_scale, freq_factors, rope_ctx->corr_dims, ne0, rope_ctx->ext_factor,
+                            rope_ctx->attn_factor, wp0, rope_ctx->theta_scale);
+
+            for (uint32_t i1 = 0; i1 < ne1; i1++) {  // attn-heads
+                if (ir++ < ir0) {
+                    continue;
+                }
+                if (ir > ir1) {
+                    break;
+                }
+
+                const float * src      = (float *) ((char *) src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01);
+                float *       dst_data = (float *) ((char *) dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1);
+
+                const float * src_loc      = src;
+                float *       dst_data_loc = dst_data;
+
+                if (1 == opt_path) {
+                    hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
+                } else {
+                    for (uint32_t i0 = 0; i0 < rope_ctx->n_dims; i0 += 2) {
+                        const float cos_theta = wp0[i0 + 0];
+                        const float sin_theta = wp0[i0 + 1];
+
+                        const float x0 = src_loc[0];
+                        const float x1 = src_loc[1];
+
+                        dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
+                        dst_data_loc[1] = x0 * sin_theta + x1 * cos_theta;
+
+                        src_loc += 2;
+                        dst_data_loc += 2;
+                    }
+                }
+
+                for (uint32_t i0 = rope_ctx->n_dims; i0 < ne0; i0 += 2) {
+                    dst_data_loc[0] = src_loc[0];
+                    dst_data_loc[1] = src_loc[1];
+
+                    src_loc += 2;
+                    dst_data_loc += 2;
+                }
+            }
+        }
+    }
+}
+
+static void rope_job_f32_per_thread(struct rope_th_ctx * rope_ctx, int nth, int ith) {
+    struct htp_ops_context * octx = rope_ctx->octx;
+
+    const struct htp_tensor * src0 = &octx->src0;
+    const struct htp_tensor * src1 = &octx->src1;
+    struct htp_tensor *       dst  = &octx->dst;
+
+    htp_rope_preamble;
+
+    const uint32_t src0_nrows            = ne01 * ne02 * ne03;  // src0 rows
+    const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
+
+    const uint32_t src0_start_row = src0_nrows_per_thread * ith;
+    const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
+
+    // no work for this thread
+    if (src0_start_row >= src0_end_row) {
+        return;
+    }
+
+    uint64_t t1, t2;
+    t1 = HAP_perf_get_qtimer_count();
+
+    int is_aligned = 1;
+    int opt_path   = 0;
+    if ((0 == htp_is_aligned((void *) src0->data, VLEN)) || (0 == htp_is_aligned((void *) src1->data, VLEN)) ||
+        (0 == htp_is_aligned((void *) dst->data, VLEN))) {
+        FARF(HIGH, "rope-f32: unaligned addresses in rope op, possibly slower execution\n");
+        is_aligned = 0;
+    }
+    if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
+        opt_path = 1;
+    }
+
+    rope_hex_f32(rope_ctx, src0_start_row, src0_end_row, nth, ith, opt_path);
+
+    t2 = HAP_perf_get_qtimer_count();
+
+    FARF(HIGH, "rope-f32: %d/%d/%d: (%u:%u) usec %u\n", ith, nth, opt_path, src0_start_row, src0_end_row,
+         (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+static void rope_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) {
+    struct rope_th_ctx * rope_ctx = (struct rope_th_ctx *) data;
+
+    rope_job_f32_per_thread(rope_ctx, n, i);
+}
+
+static int execute_op_rope_f32(struct htp_ops_context * octx) {
+    int err = HTP_STATUS_OK;
+
+    const struct htp_tensor * src0 = &octx->src0;
+    const struct htp_tensor * src1 = &octx->src1;
+    const struct htp_tensor * src2 = &octx->src2;
+    struct htp_tensor *       dst  = &octx->dst;
+
+    worker_callback_t op_func;
+    const char *      op_type = NULL;
+
+    struct rope_th_ctx rope_ctx;
+
+    switch (octx->op) {
+        case HTP_OP_ROPE:
+            op_func = rope_job_dispatcher_f32;
+            op_type = "rope-f32";
+
+            init_rope_ctx(&rope_ctx, octx);
+            break;
+
+        default:
+            FARF(ERROR, "Unsupported Op %u\n", octx->op);
+            return HTP_STATUS_NO_SUPPORT;
+    }
+
+    const uint32_t n_threads = octx->n_threads;
+
+    const size_t src0_row_size = src0->nb[1];
+    const size_t src1_row_size = src0_row_size;
+    const size_t dst_row_size  = dst->nb[1];
+
+    // VTCM scratchpads for all tensors
+    // N rows per thread, padded to HVX vector size
+    octx->dst_spad.size  = htp_round_up(dst_row_size, 128) * n_threads;
+    octx->src0_spad.size = htp_round_up(src0_row_size, 128) * n_threads;
+    octx->src1_spad.size = htp_round_up(src1_row_size, 128) * n_threads;
+
+    size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
+
+    if (src2->ne[0]) {
+        FARF(HIGH,
+             "%s: %ux%ux%ux%u (x %ux%ux%ux%u x %ux%ux%ux%u) -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u "
+             "dst-spad-size %u\n",
+             op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
+             src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0], dst->ne[1], dst->ne[2],
+             dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
+    } else {
+        FARF(HIGH,
+             "%s: %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n",
+             op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
+             src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size,
+             octx->dst_spad.size);
+    }
+
+    // Make sure the reserved vtcm size is sufficient
+    if (octx->ctx->vtcm_size < spad_size) {
+        FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
+             spad_size);
+        return HTP_STATUS_VTCM_TOO_SMALL;
+    }
+
+    octx->src0_spad.data = octx->ctx->vtcm_base;
+    octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
+    octx->dst_spad.data  = octx->src1_spad.data + octx->src1_spad.size;
+
+    uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
+
+    if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
+        uint32_t n_jobs             = MIN(n_threads, src0_nrows);
+        octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
+        worker_pool_run_func(octx->ctx->worker_pool, op_func, &rope_ctx, n_jobs);
+    }
+
+    return err;
+}
+
+int op_rope(struct htp_ops_context * octx) {
+    int err = HTP_STATUS_OK;
+
+    switch (octx->src0.type) {
+        case HTP_TYPE_F32:
+            err = execute_op_rope_f32(octx);
+            break;
+
+        default:
+            err = HTP_STATUS_NO_SUPPORT;
+            break;
+    }
+
+    return err;
+}
diff --git a/src/ggml-hexagon/htp/softmax-ops.c b/src/ggml-hexagon/htp/softmax-ops.c
new file mode 100644 (file)
index 0000000..5bf0cbf
--- /dev/null
@@ -0,0 +1,402 @@
+#pragma clang diagnostic ignored "-Wunused-variable"
+#pragma clang diagnostic ignored "-Wunused-function"
+#pragma clang diagnostic ignored "-Wunused-but-set-variable"
+
+#ifdef HTP_DEBUG
+#    define FARF_HIGH 1
+#endif
+#include <HAP_farf.h>
+#include <HAP_mem.h>
+#include <HAP_perf.h>
+#include <HAP_ps.h>
+#include <hexagon_protos.h>
+#include <hexagon_types.h>
+#include <math.h>
+#include <qurt_thread.h>
+#include <string.h>
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+#include "htp-ctx.h"
+#include "htp-dma.h"
+#include "htp-msg.h"
+#include "htp-ops.h"
+#include "hvx-utils.h"
+#include "ops-utils.h"
+
+#define htp_softmax_preamble3                              \
+    const uint32_t ne00 = src0->ne[0];                     \
+    const uint32_t ne01 = src0->ne[1];                     \
+    const uint32_t ne02 = src0->ne[2];                     \
+    const uint32_t ne03 = src0->ne[3];                     \
+                                                           \
+    const uint32_t nb00 = src0->nb[0];                     \
+    const uint32_t nb01 = src0->nb[1];                     \
+    const uint32_t nb02 = src0->nb[2];                     \
+    const uint32_t nb03 = src0->nb[3];                     \
+                                                           \
+    const uint32_t ne10 = (src1->ne[0]) ? src1->ne[0] : 1; \
+    const uint32_t ne11 = (src1->ne[0]) ? src1->ne[1] : 1; \
+    const uint32_t ne12 = (src1->ne[0]) ? src1->ne[2] : 1; \
+    const uint32_t ne13 = (src1->ne[0]) ? src1->ne[3] : 1; \
+                                                           \
+    const uint32_t nb10 = (src1->ne[0]) ? src1->nb[0] : 1; \
+    const uint32_t nb11 = (src1->ne[0]) ? src1->nb[1] : 1; \
+    const uint32_t nb12 = (src1->ne[0]) ? src1->nb[2] : 1; \
+    const uint32_t nb13 = (src1->ne[0]) ? src1->nb[3] : 1; \
+                                                           \
+    const uint32_t ne0 = dst->ne[0];                       \
+    const uint32_t ne1 = dst->ne[1];                       \
+    const uint32_t ne2 = dst->ne[2];                       \
+    const uint32_t ne3 = dst->ne[3];                       \
+                                                           \
+    const uint32_t nb0 = dst->nb[0];                       \
+    const uint32_t nb1 = dst->nb[1];                       \
+    const uint32_t nb2 = dst->nb[2];                       \
+    const uint32_t nb3 = dst->nb[3];
+
+struct softmax_th_ctx {
+    bool     use_f16;
+    bool     use_src1;
+    uint32_t n_head;
+    uint32_t n_head_log2;
+
+    float scale;
+    float max_bias;
+    float m0;
+    float m1;
+
+    struct htp_ops_context * octx;
+};
+
+static void init_softmax_ctx(struct softmax_th_ctx * softmax_ctx, struct htp_ops_context * octx) {
+    const struct htp_tensor * src0 = &octx->src0;
+    const struct htp_tensor * src1 = &octx->src1;
+
+    memset(softmax_ctx, 0, sizeof(struct softmax_th_ctx));
+
+    memcpy(&softmax_ctx->scale, (float *) octx->op_params, sizeof(float));
+    memcpy(&softmax_ctx->max_bias, (float *) octx->op_params + 1, sizeof(float));
+
+    softmax_ctx->n_head      = src0->ne[2];
+    softmax_ctx->n_head_log2 = 1u << (uint32_t) floor(log2(softmax_ctx->n_head));
+
+    softmax_ctx->m0 = powf(2.0f, -(softmax_ctx->max_bias) / softmax_ctx->n_head_log2);
+    softmax_ctx->m1 = powf(2.0f, -(softmax_ctx->max_bias / 2.0f) / softmax_ctx->n_head_log2);
+
+    softmax_ctx->use_src1 = (src1->ne[0] != 0);
+    softmax_ctx->use_f16  = (src1->ne[0] != 0) && (src1->type == HTP_TYPE_F16);
+
+    softmax_ctx->octx = octx;
+}
+
+static void hvx_fast_softmax_prep_f32(const uint8_t * restrict src,
+                                      uint8_t * restrict dst,
+                                      const int num_elems,
+                                      float     scale,
+                                      const uint8_t * restrict mask,
+                                      float slope) {
+    const uint8_t * restrict src_curr  = src;
+    uint8_t * restrict dst_curr        = dst;
+    const uint8_t * restrict mask_curr = mask;
+
+    HVX_Vector scale_vec = hvx_vec_splat_fp32(scale);
+    HVX_Vector slope_vec = hvx_vec_splat_fp32(slope);
+
+    int step_of_1 = num_elems >> 5;
+
+    #pragma unroll(4)
+    for (int i = 0; i < step_of_1; i++) {
+        HVX_Vector v1 = *(HVX_Vector *) src_curr;
+
+        HVX_Vector v3 = *(HVX_Vector *) mask_curr;
+
+        HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_vec);
+
+        HVX_Vector v4 = Q6_Vqf32_vmpy_VsfVsf(v3, slope_vec);
+
+        HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(v2, v4);
+
+        *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v5);
+
+        src_curr += VLEN;
+        dst_curr += VLEN;
+        mask_curr += VLEN;
+    }
+}
+
+static void hvx_fast_softmax_f32(const uint8_t * restrict src,
+                                 uint8_t * restrict dst,
+                                 uint8_t * restrict pad,
+                                 const int num_elems) {
+    const HVX_Vector * restrict v_src = (HVX_Vector *) src;
+    HVX_Vector * restrict v_pad       = (HVX_Vector *) pad;
+    HVX_Vector * restrict v_dst       = (HVX_Vector *) dst;
+
+    HVX_Vector sum_vec = Q6_V_vsplat_R(0x00000000);
+    HVX_Vector max_vec = hvx_vec_splat_fp32(((const float *) src)[0]);
+    HVX_Vector zero_v  = Q6_V_vzero();
+    HVX_Vector one_v   = hvx_vec_splat_fp32(1.0);
+
+    int step_of_1 = num_elems >> 5;
+
+    #pragma unroll(4)
+    for (int i = 0; i < step_of_1; i++) {
+        HVX_Vector v1 = v_src[i];
+        max_vec       = Q6_Vsf_vmax_VsfVsf(max_vec, v1);
+    }
+
+    HVX_Vector v = hvx_vec_reduce_max_fp32(max_vec);
+    max_vec      = hvx_vec_repl4(v);
+
+    #pragma unroll(4)
+    for (int i = 0; i < step_of_1; i++) {
+        HVX_Vector v1 = v_src[i];
+        HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v1, max_vec);
+
+        HVX_Vector v3 = hvx_vec_exp_fp32(Q6_Vsf_equals_Vqf32(v2));
+
+        sum_vec = Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(sum_vec), v3);
+
+        v_pad[i] = v3;
+    }
+
+    v       = hvx_vec_qf32_reduce_sum(sum_vec);
+    sum_vec = hvx_vec_repl4(Q6_Vsf_equals_Vqf32(v));
+
+    HVX_VectorPred pos_sum   = Q6_Q_vcmp_gt_VwVw(sum_vec, zero_v);
+    HVX_Vector     v4        = hvx_vec_inverse_fp32(sum_vec);
+    HVX_Vector     scale_vec = Q6_V_vmux_QVV(pos_sum, v4, one_v);
+
+    #pragma unroll(4)
+    for (int i = 0; i < step_of_1; i++) {
+        HVX_Vector v1 = v_pad[i];
+        HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_vec);
+        v_dst[i]      = Q6_Vsf_equals_Vqf32(v2);
+    }
+}
+
+static float hvx_softmax_f32(const uint8_t * restrict src,
+                             uint8_t * restrict dst,
+                             uint8_t * restrict spad,
+                             const int   num_elems,
+                             const float max) {
+    hvx_sub_scalar_f32(src, max, spad, num_elems);
+
+    hvx_exp_f32(spad, dst, num_elems, false);
+
+    float sum = hvx_self_sum_f32(dst, num_elems);
+
+    return sum;
+}
+
+static void softmax_htp_f32(int nth, int ith, struct softmax_th_ctx * softmax_ctx, int opt_path) {
+    struct htp_ops_context * octx = softmax_ctx->octx;
+
+    const struct htp_tensor * src0 = &octx->src0;
+    const struct htp_tensor * src1 = &octx->src1;
+    const struct htp_tensor * dst  = &octx->dst;
+
+    htp_softmax_preamble3;
+
+    uint8_t * src0_spad_data = octx->src0_spad.data + (ith * nb01);
+    uint8_t * src1_spad_data = octx->src1_spad.data + (ith * nb01);
+    uint8_t * dst_spad_data  = octx->dst_spad.data + (ith * nb1);
+
+    float * wp0 = (float *) src0_spad_data;
+    float * wp1 = (float *) src1_spad_data;
+    float * wp2 = (float *) dst_spad_data;
+
+    for (uint32_t i03 = 0; i03 < ne03; i03++) {
+        for (uint32_t i02 = 0; i02 < ne02; i02++) {
+            for (uint32_t i01 = ith; i01 < ne01; i01 += nth) {
+                const uint32_t i11 = i01;
+                const uint32_t i12 = i02 % ne12;
+                const uint32_t i13 = i03 % ne13;
+
+                // ALiBi
+                const uint32_t h = i02;  // head
+
+                const float slope = (softmax_ctx->max_bias > 0.0f) ?
+                                        h < softmax_ctx->n_head_log2 ?
+                                        powf(softmax_ctx->m0, h + 1) :
+                                        powf(softmax_ctx->m1, 2 * (h - softmax_ctx->n_head_log2) + 1) :
+                                        1.0f;
+
+                float * sp = (float *) ((char *) octx->src0.data + i01 * nb01 + i02 * nb02 + i03 * nb03);
+                float * dp = (float *) ((char *) octx->dst.data + i01 * nb1 + i02 * nb2 + i03 * nb3);
+
+                // broadcast the mask across rows
+                __fp16 * mp_f16 = (softmax_ctx->use_src1) ?
+                                      (__fp16 *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
+                                      NULL;
+                float *  mp_f32 = (softmax_ctx->use_src1) ?
+                                      (float *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
+                                      NULL;
+
+                if ((1 == opt_path) && (mp_f32) && !(softmax_ctx->use_f16)) {
+                    hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale,
+                                              (const uint8_t *) mp_f32, slope);
+                } else {
+                    hvx_scale_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale);
+                    if (mp_f32) {
+                        if (softmax_ctx->use_f16) {
+                            for (int i = 0; i < ne00; ++i) {
+                                wp0[i] += slope * (float) mp_f16[i];
+                            }
+                        } else {
+                            for (int i = 0; i < ne00; ++i) {
+                                wp0[i] += slope * mp_f32[i];
+                            }
+                        }
+                    }
+                }
+
+                if (1 == opt_path) {
+                    hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00);
+                } else {
+                    float max = hvx_self_max_f32((const uint8_t *) wp0, ne00);
+                    float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max);
+                    sum       = sum > 0.0 ? (1.0 / sum) : 1;
+                    hvx_scale_f32((const uint8_t *) wp2, (uint8_t *) dp, ne00, sum);
+                }
+            }
+        }
+    }
+}
+
+static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int nth, int ith) {
+    struct htp_ops_context * octx = softmax_ctx->octx;
+
+    const struct htp_tensor * src0 = &octx->src0;
+    const struct htp_tensor * src1 = &octx->src1;
+    struct htp_tensor *       dst  = &octx->dst;
+
+    htp_softmax_preamble3;
+
+    const uint32_t src0_nrows            = ne01 * ne02 * ne03;  // src0 rows
+    const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
+
+    const uint32_t src0_start_row = src0_nrows_per_thread * ith;
+    const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
+
+    // no work for this thread
+    if (src0_start_row >= src0_end_row) {
+        return;
+    }
+
+    uint64_t t1, t2;
+    t1 = HAP_perf_get_qtimer_count();
+
+    int is_aligned = 1;
+    int opt_path   = 0;
+    if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) {
+        is_aligned = 0;
+        FARF(HIGH, "softmax-f32: unaligned addresses in elementwise op, possibly slower execution\n");
+    }
+    if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
+        opt_path = 1;
+    }
+
+    softmax_htp_f32(nth, ith, softmax_ctx, opt_path);
+
+    t2 = HAP_perf_get_qtimer_count();
+
+    FARF(HIGH, "softmax-f32 %d/%d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
+         softmax_ctx->use_f16, opt_path, ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13,
+         ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+static void softmax_job_dispatcher_f32(unsigned int n, unsigned int i, void * p_data) {
+    struct softmax_th_ctx * p_softmax_ctx = (struct softmax_th_ctx *) p_data;
+    softmax_job_f32_per_thread(p_softmax_ctx, n, i);
+}
+
+static int execute_op_softmax_f32(struct htp_ops_context * octx) {
+    int err = HTP_STATUS_OK;
+
+    const struct htp_tensor * src0 = &octx->src0;
+    const struct htp_tensor * src1 = &octx->src1;
+    struct htp_tensor *       dst  = &octx->dst;
+
+    worker_callback_t op_func;
+    const char *      op_type = NULL;
+
+    struct softmax_th_ctx softmax_ctx;
+
+    switch (octx->op) {
+        case HTP_OP_SOFTMAX:
+            op_func = softmax_job_dispatcher_f32;
+            op_type = "softmax-f32";
+
+            init_softmax_ctx(&softmax_ctx, octx);
+            break;
+
+        default:
+            FARF(ERROR, "Unsupported Op %u\n", octx->op);
+            return HTP_STATUS_NO_SUPPORT;
+    }
+
+    const uint32_t n_threads = octx->n_threads;
+
+    const size_t src0_row_size = src0->nb[1];
+    const size_t src1_row_size = src0_row_size;
+    const size_t dst_row_size  = dst->nb[1];
+
+    // VTCM scratchpads for all tensors
+    // N rows per thread, padded to HVX vector size
+    octx->dst_spad.size  = htp_round_up(dst_row_size, 128) * n_threads;
+    octx->src0_spad.size = htp_round_up(src0_row_size, 128) * n_threads;
+    octx->src1_spad.size = htp_round_up(src1_row_size, 128) * n_threads;
+
+    size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
+
+    if (src1->ne[0]) {
+        FARF(HIGH,
+             "%s: %ux%ux%ux%u x %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n",
+             op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
+             src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size,
+             octx->dst_spad.size);
+    } else {
+        FARF(HIGH, "%s: %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type,
+             src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
+             octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
+    }
+
+    // Make sure the reserved vtcm size is sufficient
+    if (octx->ctx->vtcm_size < spad_size) {
+        FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
+             spad_size);
+        return HTP_STATUS_VTCM_TOO_SMALL;
+    }
+
+    octx->src0_spad.data = octx->ctx->vtcm_base;
+    octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
+    octx->dst_spad.data  = octx->src1_spad.data + octx->src1_spad.size;
+
+    uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
+
+    if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
+        uint32_t n_jobs             = MIN(n_threads, src0_nrows);
+        octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
+        worker_pool_run_func(octx->ctx->worker_pool, op_func, &softmax_ctx, n_jobs);
+    }
+
+    return err;
+}
+
+int op_softmax(struct htp_ops_context * octx) {
+    int err = HTP_STATUS_OK;
+
+    switch (octx->src0.type) {
+        case HTP_TYPE_F32:
+            err = execute_op_softmax_f32(octx);
+            break;
+
+        default:
+            err = HTP_STATUS_NO_SUPPORT;
+            break;
+    }
+
+    return err;
+}
diff --git a/src/ggml-hexagon/htp/unary-ops.c b/src/ggml-hexagon/htp/unary-ops.c
new file mode 100644 (file)
index 0000000..bb7557b
--- /dev/null
@@ -0,0 +1,255 @@
+#pragma clang diagnostic ignored "-Wunused-variable"
+#pragma clang diagnostic ignored "-Wunused-function"
+#pragma clang diagnostic ignored "-Wunused-but-set-variable"
+
+#ifdef HTP_DEBUG
+#    define FARF_HIGH 1
+#endif
+
+#include <HAP_farf.h>
+#include <HAP_mem.h>
+#include <HAP_perf.h>
+#include <HAP_ps.h>
+#include <hexagon_protos.h>
+#include <hexagon_types.h>
+#include <math.h>
+#include <qurt_thread.h>
+#include <string.h>
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+#include "htp-ctx.h"
+#include "htp-dma.h"
+#include "htp-msg.h"
+#include "htp-ops.h"
+#include "hvx-utils.h"
+#include "ops-utils.h"
+
+#define htp_unary_preamble            \
+    const uint32_t ne00 = src->ne[0]; \
+    const uint32_t ne01 = src->ne[1]; \
+    const uint32_t ne02 = src->ne[2]; \
+    const uint32_t ne03 = src->ne[3]; \
+                                      \
+    const uint32_t ne0 = dst->ne[0];  \
+    const uint32_t ne1 = dst->ne[1];  \
+    const uint32_t ne2 = dst->ne[2];  \
+    const uint32_t ne3 = dst->ne[3];  \
+                                      \
+    const uint32_t nb00 = src->nb[0]; \
+    const uint32_t nb01 = src->nb[1]; \
+    const uint32_t nb02 = src->nb[2]; \
+    const uint32_t nb03 = src->nb[3]; \
+                                      \
+    const uint32_t nb0 = dst->nb[0];  \
+    const uint32_t nb1 = dst->nb[1];  \
+    const uint32_t nb2 = dst->nb[2];  \
+    const uint32_t nb3 = dst->nb[3];
+
+static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
+                                  uint8_t * restrict dst,
+                                  uint8_t * restrict pad,
+                                  const int num_elems,
+                                  float     epsilon) {
+    const HVX_Vector * restrict v_src = (HVX_Vector *) src;
+    HVX_Vector * restrict v_dst       = (HVX_Vector *) dst;
+
+    HVX_Vector sum_v     = Q6_V_vsplat_R(0x00000000);
+    HVX_Vector epsilon_v = hvx_vec_splat_fp32(epsilon);
+
+    int step_of_1 = num_elems >> 5;
+    #pragma unroll(4)
+    for (int i = 0; i < step_of_1; i++) {
+        HVX_Vector v1 = v_src[i];
+        HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
+        sum_v         = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
+    }
+
+    HVX_Vector reduced_sum = hvx_vec_qf32_reduce_sum(sum_v);
+    sum_v                  = hvx_vec_repl4(Q6_Vsf_equals_Vqf32(reduced_sum));
+
+    HVX_Vector t_v            = hvx_vec_splat_fp32((float) num_elems);
+    HVX_Vector denom_v        = hvx_vec_inverse_fp32(t_v);
+    HVX_Vector mean_v         = Q6_Vqf32_vmpy_VsfVsf(sum_v, denom_v);
+    HVX_Vector mean_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(mean_v, epsilon_v);
+
+    HVX_Vector scale_v = hvx_vec_rsqrt_fp32(Q6_Vsf_equals_Vqf32(mean_epsilon_v));
+
+    #pragma unroll(4)
+    for (int i = 0; i < step_of_1; i++) {
+        HVX_Vector v1 = v_src[i];
+        HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v);
+        v_dst[i]      = Q6_Vsf_equals_Vqf32(v2);
+    }
+}
+
+static void rms_norm_htp_f32(const float * restrict src,
+                             float * restrict dst,
+                             uint8_t * restrict spad,
+                             const uint32_t num_rows,
+                             const uint32_t row_elems,
+                             const size_t   row_size,
+                             int32_t *      op_params,
+                             int            opt_path) {
+    float epsilon = 0.f;
+    memcpy(&epsilon, op_params, sizeof(float));
+
+    for (uint32_t ir = 0; ir < num_rows; ir++) {
+        const float * restrict src_local = src + (ir * row_elems);
+        float * restrict dst_local       = dst + (ir * row_elems);
+
+        if (ir + 1 < num_rows) {
+            htp_l2fetch(src_local + row_elems, 1, row_size, row_size);
+        }
+
+        if (1 == opt_path) {
+            hvx_fast_rms_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon);
+        } else {
+            float sum = hvx_sum_of_squares_f32((const uint8_t *) src_local, row_elems);
+
+            const float mean  = sum / row_elems;
+            const float scale = 1.0f / sqrtf(mean + epsilon);
+
+            hvx_scale_f32((const uint8_t *) src_local, (uint8_t *) dst_local, row_elems, scale);
+        }
+    }
+}
+
+static void unary_job_f32_per_thread(const struct htp_tensor * src,
+                                     struct htp_tensor *       dst,
+                                     uint8_t *                 spad,
+                                     int                       htp_op,
+                                     int32_t *                 op_params,
+                                     uint32_t                  nth,
+                                     uint32_t                  ith,
+                                     uint32_t                  src0_nrows_per_thread) {
+    htp_unary_preamble;
+
+    const size_t src0_row_size = nb01;
+    const size_t dst_row_size  = nb1;
+
+    const uint32_t src0_nrows = ne01 * ne02 * ne03;  // src0 rows
+
+    const uint32_t src0_start_row = src0_nrows_per_thread * ith;
+    const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
+
+    // no work for this thread
+    if (src0_start_row >= src0_end_row) {
+        return;
+    }
+
+    uint64_t t1, t2;
+    t1 = HAP_perf_get_qtimer_count();
+
+    int is_aligned = 1;
+    int opt_path   = 0;
+    if ((0 == htp_is_aligned((void *) src->data, VLEN)) || (0 == htp_is_aligned((void *) dst->data, VLEN))) {
+        is_aligned = 0;
+        FARF(HIGH, "unary-f32: unaligned addresses in unary op, possibly slower execution\n");
+    }
+    if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
+        opt_path = 1;
+    }
+
+    const uint8_t * restrict data_src = (const uint8_t *) src->data;
+    uint8_t * restrict data_dst       = (uint8_t *) dst->data;
+
+    const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size));
+    float * restrict dst_th       = (float *) (data_dst + (src0_start_row * dst_row_size));
+    uint8_t * restrict spad_th    = (uint8_t *) spad + (ith * nb01);
+
+    switch (htp_op) {
+        case HTP_OP_RMS_NORM:
+            rms_norm_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
+            break;
+
+        default:
+            break;
+    }
+
+    t2 = HAP_perf_get_qtimer_count();
+
+    FARF(HIGH, "unary-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, src->ne[0],
+         src->ne[1], src->ne[2], src->ne[3], src0_start_row, src0_end_row, dst->ne[0], dst->ne[1], dst->ne[2],
+         dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+static void unary_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) {
+    struct htp_ops_context * octx = (struct htp_ops_context *) data;
+
+    unary_job_f32_per_thread(&octx->src0, &octx->dst, octx->src0_spad.data, octx->op, octx->op_params, n, i,
+                             octx->src0_nrows_per_thread);
+}
+
+static int execute_op_unary_f32(struct htp_ops_context * octx) {
+    int err = HTP_STATUS_OK;
+
+    const struct htp_tensor * src0 = &octx->src0;
+    struct htp_tensor *       dst  = &octx->dst;
+
+    worker_callback_t unary_op_func;
+    const char *      op_type = NULL;
+
+    switch (octx->op) {
+        case HTP_OP_RMS_NORM:
+            unary_op_func = unary_job_dispatcher_f32;
+            op_type       = "rmsnorm-f32";
+            break;
+
+        default:
+            FARF(ERROR, "Unsupported unary Op %u\n", octx->op);
+            return HTP_STATUS_NO_SUPPORT;
+    }
+
+    const int      n_threads  = octx->n_threads;
+    const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
+
+    const size_t src0_row_size = src0->nb[1];
+    const size_t dst_row_size  = dst->nb[1];
+
+    // VTCM scratchpads for all tensors
+    octx->dst_spad.size  = htp_round_up(dst_row_size, 128) * n_threads;
+    octx->src0_spad.size = htp_round_up(src0_row_size, 128) * n_threads;
+
+    size_t spad_size = octx->src0_spad.size + octx->dst_spad.size;
+
+    FARF(HIGH, "%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type,
+         src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
+         octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
+
+    // Make sure the reserved vtcm size is sufficient
+    if (octx->ctx->vtcm_size < spad_size) {
+        FARF(ERROR, "unary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
+             spad_size);
+        return HTP_STATUS_VTCM_TOO_SMALL;
+    }
+
+    octx->src0_spad.data = octx->ctx->vtcm_base;
+    octx->dst_spad.data  = octx->src0_spad.data + octx->src0_spad.size;
+
+    if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
+        uint32_t n_jobs = MIN(n_threads, src0_nrows);
+
+        octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
+
+        worker_pool_run_func(octx->ctx->worker_pool, unary_op_func, octx, n_jobs);
+    }
+
+    return err;
+}
+
+int op_unary(struct htp_ops_context * octx) {
+    int err = HTP_STATUS_OK;
+
+    switch (octx->src0.type) {
+        case HTP_TYPE_F32:
+            err = execute_op_unary_f32(octx);
+            break;
+
+        default:
+            err = HTP_STATUS_NO_SUPPORT;
+            break;
+    }
+
+    return err;
+}
diff --git a/src/ggml-hexagon/htp/worker-pool.c b/src/ggml-hexagon/htp/worker-pool.c
new file mode 100644 (file)
index 0000000..cd38c21
--- /dev/null
@@ -0,0 +1,297 @@
+#include "worker-pool.h"
+
+#include <qurt.h>
+#include <stdatomic.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+
+#ifdef HTP_DEBUG
+#    define FARF_HIGH 1
+#endif
+
+#include "HAP_farf.h"
+
+#define WORKER_THREAD_STACK_SZ  (2 * 16384)
+#define LOWEST_USABLE_QURT_PRIO (254)
+
+struct worker_pool_s;
+
+// internal structure kept in thread-local storage per instance of worker pool
+typedef struct {
+    struct worker_pool_s * pool;
+    unsigned int           id;
+} worker_context_t;
+
+// internal structure kept in thread-local storage per instance of worker pool
+typedef struct worker_pool_s {
+    worker_pool_job_t job[MAX_NUM_WORKERS];      // list of job descriptors
+    qurt_thread_t     thread[MAX_NUM_WORKERS];   // thread ID's of the workers
+    worker_context_t  context[MAX_NUM_WORKERS];  // worker contexts
+    void *            stack[MAX_NUM_WORKERS];    // thread stack pointers
+    unsigned int      n_threads;                 // number of workers in this pool
+
+    atomic_uint seqn;                            // seqno used to detect new jobs
+    atomic_uint next_job;                        // next job index
+    atomic_uint n_pending;                       // number of pending jobs
+    atomic_uint n_jobs;                          // number of current jobs
+    atomic_bool killed;                          // threads need to exit
+} worker_pool_t;
+
+static void worker_pool_main(void * context) {
+    worker_context_t * me   = (worker_context_t *) context;
+    worker_pool_t *    pool = me->pool;
+
+    FARF(HIGH, "worker-pool: thread %u started", me->id);
+
+    unsigned int prev_seqn = 0;
+    while (!atomic_load(&pool->killed)) {
+        unsigned int seqn = atomic_load(&pool->seqn);
+        if (seqn == prev_seqn) {
+            // Nothing to do
+            qurt_futex_wait(&pool->seqn, prev_seqn);
+            continue;
+        }
+
+        // New job
+        prev_seqn = seqn;
+
+        unsigned int n = atomic_load(&pool->n_jobs);
+        unsigned int i = atomic_fetch_add(&pool->next_job, 1);
+        if (i >= n) {
+            // Spurios wakeup
+            continue;
+        }
+
+        pool->job[i].func(n, i, pool->job[i].data);
+
+        atomic_fetch_sub(&pool->n_pending, 1);
+    }
+
+    FARF(HIGH, "worker-pool: thread %u stopped", me->id);
+}
+
+AEEResult worker_pool_init_with_stack_size(worker_pool_context_t * context, uint32_t n_threads, uint32_t stack_size) {
+    int err = 0;
+
+    if (NULL == context) {
+        FARF(ERROR, "NULL context passed to worker_pool_init().");
+        return AEE_EBADPARM;
+    }
+
+    // Allocations
+    int size = (stack_size * n_threads) + (sizeof(worker_pool_t));
+
+    unsigned char * mem_blob = (unsigned char *) malloc(size);
+    if (!mem_blob) {
+        FARF(ERROR, "Could not allocate memory for worker pool!!");
+        return AEE_ENOMEMORY;
+    }
+
+    worker_pool_t * me = (worker_pool_t *) (mem_blob + stack_size * n_threads);
+
+    // name for the first worker, useful in debugging threads
+    char name[19];
+    snprintf(name, 12, "0x%8x:", (int) me);
+    strcat(name, "worker0");
+    me->n_threads = n_threads;
+
+    // initializations
+    for (unsigned int i = 0; i < me->n_threads; i++) {
+        me->stack[i]  = NULL;
+        me->thread[i] = 0;
+
+        me->context[i].id   = i;
+        me->context[i].pool = me;
+    }
+
+    // initialize job queue
+    me->n_pending = 0;
+    me->n_jobs    = 0;
+    me->next_job  = 0;
+    me->seqn      = 0;
+    me->killed    = 0;
+
+    // launch the workers
+    qurt_thread_attr_t attr;
+    qurt_thread_attr_init(&attr);
+
+    for (unsigned int i = 0; i < me->n_threads; i++) {
+        // set up stack
+        me->stack[i] = mem_blob;
+        mem_blob += stack_size;
+        qurt_thread_attr_set_stack_addr(&attr, me->stack[i]);
+        qurt_thread_attr_set_stack_size(&attr, stack_size);
+
+        // set up name
+        qurt_thread_attr_set_name(&attr, name);
+        name[17] = (name[17] + 1);
+        // name threads context:worker0, context:worker1, .. (recycle at 9, but num threads should be less than that anyway)
+        if (name[17] > '9') {
+            name[17] = '0';
+        }
+
+        // set up priority - by default, match the creating thread's prio
+        int prio = qurt_thread_get_priority(qurt_thread_get_id());
+
+        if (prio < 1) {
+            prio = 1;
+        }
+        if (prio > LOWEST_USABLE_QURT_PRIO) {
+            prio = LOWEST_USABLE_QURT_PRIO;
+        }
+
+        qurt_thread_attr_set_priority(&attr, prio);
+
+        // launch
+        err = qurt_thread_create(&me->thread[i], &attr, worker_pool_main, (void *) &me->context[i]);
+        if (err) {
+            FARF(ERROR, "Could not launch worker threads!");
+            worker_pool_release((worker_pool_context_t *) &me);
+            return AEE_EQURTTHREADCREATE;
+        }
+    }
+    *context = (worker_pool_context_t *) me;
+    return AEE_SUCCESS;
+}
+
+AEEResult worker_pool_init(worker_pool_context_t * context, uint32_t n_threads) {
+    return worker_pool_init_with_stack_size(context, n_threads, WORKER_THREAD_STACK_SZ);
+}
+
+// clean up worker pool
+void worker_pool_release(worker_pool_context_t * context) {
+    worker_pool_t * me = (worker_pool_t *) *context;
+
+    // if no worker pool exists, return error.
+    if (NULL == me) {
+        return;
+    }
+
+    atomic_store(&me->killed, 1);
+    atomic_fetch_add(&me->seqn, 1);
+    qurt_futex_wake(&me->seqn, me->n_threads);
+
+    // de-initializations
+    for (unsigned int i = 0; i < me->n_threads; i++) {
+        if (me->thread[i]) {
+            int status;
+            (void) qurt_thread_join(me->thread[i], &status);
+        }
+    }
+
+    // free allocated memory (were allocated as a single buffer starting at stack[0])
+    if (me->stack[0]) {
+        free(me->stack[0]);
+    }
+
+    *context = NULL;
+}
+
+// run jobs
+AEEResult worker_pool_run_jobs(worker_pool_context_t context, worker_pool_job_t * job, unsigned int n) {
+    worker_pool_t * me = (worker_pool_t *) context;
+    if (NULL == me) {
+        FARF(ERROR, "worker-pool: invalid context");
+        return AEE_EBADPARM;
+    }
+
+    if (n > me->n_threads) {
+        FARF(ERROR, "worker-pool: invalid number of jobs %u for n-threads %u", n, me->n_threads);
+        return AEE_EBADPARM;
+    }
+
+    memcpy(me->job, job, sizeof(worker_pool_job_t) * n);
+
+    if (n > 1) {
+        atomic_store(&me->next_job, 1);
+        atomic_store(&me->n_jobs, n);
+        atomic_store(&me->n_pending, n - 1);
+
+        // wake up workers
+        atomic_fetch_add(&me->seqn, 1);
+        qurt_futex_wake(&me->seqn, n - 1);
+    }
+
+    // main thread runs job #0
+    me->job[0].func(n, 0, me->job[0].data);
+
+    if (n > 1) {
+        while (atomic_load(&me->n_pending))
+            ;
+    }
+
+    return 0;
+}
+
+// run func
+AEEResult worker_pool_run_func(worker_pool_context_t context, worker_callback_t func, void * data, unsigned int n) {
+    worker_pool_job_t job[n];
+
+    for (unsigned int i = 0; i < n; i++) {
+        job[i].func = func;
+        job[i].data = data;
+    }
+
+    return worker_pool_run_jobs(context, job, n);
+}
+
+AEEResult worker_pool_set_thread_priority(worker_pool_context_t context, unsigned int prio) {
+    worker_pool_t * me = (worker_pool_t *) context;
+
+    // if no worker pool exists, return error.
+    if (!me) {
+        return AEE_ENOMORE;
+    }
+
+    int result = AEE_SUCCESS;
+    if (prio < 1) {
+        prio = 1;
+    }
+    if (prio > LOWEST_USABLE_QURT_PRIO) {
+        prio = LOWEST_USABLE_QURT_PRIO;
+    }
+
+    for (unsigned int i = 0; i < me->n_threads; i++) {
+        int res = qurt_thread_set_priority(me->thread[i], (unsigned short) prio);
+        if (0 != res) {
+            result = AEE_EBADPARM;
+            FARF(ERROR, "QURT failed to set priority of thread %d, ERROR = %d", me->thread[i], res);
+        }
+    }
+
+    return result;
+}
+
+AEEResult worker_pool_retrieve_thread_id(worker_pool_context_t context, unsigned int * tids) {
+    worker_pool_t * me = (worker_pool_t *) context;
+    if (!me) {
+        FARF(ERROR, "worker-pool: invalid context");
+        return AEE_EBADPARM;
+        ;
+    }
+
+    for (int i = 0; i < me->n_threads; i++) {
+        tids[i] = me->thread[i];
+    }
+
+    return AEE_SUCCESS;
+}
+
+AEEResult worker_pool_get_thread_priority(worker_pool_context_t context, unsigned int * prio) {
+    worker_pool_t * me = (worker_pool_t *) context;
+    if (!me) {
+        FARF(ERROR, "worker-pool: invalid context");
+        return AEE_EBADPARM;
+    }
+
+    int priority = qurt_thread_get_priority(me->thread[0]);
+    if (priority > 0) {
+        *prio = priority;
+        return 0;
+    } else {
+        *prio = 0;
+        return AEE_EBADSTATE;
+    }
+}
diff --git a/src/ggml-hexagon/htp/worker-pool.h b/src/ggml-hexagon/htp/worker-pool.h
new file mode 100644 (file)
index 0000000..6f8c905
--- /dev/null
@@ -0,0 +1,57 @@
+#ifndef HTP_WORKER_POOL_H
+#define HTP_WORKER_POOL_H
+
+// MACRO enables function to be visible in shared-library case.
+#define WORKERPOOL_API __attribute__((visibility("default")))
+
+#include <AEEStdDef.h>
+#include <AEEStdErr.h>
+#include <stdint.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/// signature of callbacks to be invoked by worker threads
+typedef void (*worker_callback_t)(unsigned int n, unsigned int i, void *);
+
+/// Typedef of worker_pool context
+typedef void * worker_pool_context_t;
+
+/// descriptor for requested callback
+typedef struct {
+    worker_callback_t func;
+    void *            data;
+} worker_pool_job_t;
+
+/// Maximum supported number of worker threads.
+#define MAX_NUM_WORKERS 10
+
+// Initialize worker pool.
+WORKERPOOL_API AEEResult worker_pool_init(worker_pool_context_t * context, uint32_t n_threads);
+
+// Initialize worker pool with custom stack size
+WORKERPOOL_API AEEResult worker_pool_init_with_stack_size(worker_pool_context_t * context,
+                                                          uint32_t                n_threads,
+                                                          uint32_t                stack_size);
+
+// Kill worker threads and release worker pool resources
+WORKERPOOL_API void worker_pool_release(worker_pool_context_t * context);
+
+// Run jobs with the worker pool.
+WORKERPOOL_API AEEResult worker_pool_run_jobs(worker_pool_context_t context, worker_pool_job_t * job, unsigned int n);
+
+WORKERPOOL_API AEEResult worker_pool_run_func(worker_pool_context_t context,
+                                              worker_callback_t     func,
+                                              void *                data,
+                                              unsigned int          n);
+
+WORKERPOOL_API AEEResult worker_pool_set_thread_priority(worker_pool_context_t context, unsigned int prio);
+WORKERPOOL_API AEEResult worker_pool_get_thread_priority(worker_pool_context_t context, unsigned int * prio);
+WORKERPOOL_API AEEResult worker_pool_retrieve_thread_id(worker_pool_context_t context, unsigned int * tids);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif  // #ifndef HTP_WORKER_POOL_H