]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
hexagon refactor all Ops to use local context struct (#19819)
authorMax Krasnyansky <redacted>
Tue, 24 Feb 2026 00:32:14 +0000 (16:32 -0800)
committerGitHub <redacted>
Tue, 24 Feb 2026 00:32:14 +0000 (16:32 -0800)
* hexagon: refactor set/get/sum-rows ops to use local context

* hexagon: refactor ROPE and Softmax Ops to use local context

Improves performance a bit by precomputing things and saving in the context.

* hexagon: refactor activation ops to use local context struct

* hexagon: refactor unary ops to use local context struct and DMA/VTCM

* hexagon: use aligned hvx_scale function

* hexagon: remove unused fields from op_context

* hexagon: rewrite ROPE to use DMA and VTCM scratchpad

* hex-rope: keep N rows in scratchpad (instead of just two)

* hex-rope: introduce rowidx cache

* hex-rope: remove unused fields

* hex-rope: rewrite dma prefetch logic to allow for multi-row fetch/compute

also removes the need for fastdiv.

* hex-rope: minor formatting

* hex-rope: use indices and unroll the loops

* hex-rope: more updates to cleanup rope-block handling

* hexagon: cleanup supported type/dims checks

* hexagon: all reduce funcs replicated across lanes

There is no need to explicitly replicate the first value.

* snapdragon: update adb and windows scripts to use ubatch-size 256

Updated Ops support handles larger ubatches.

15 files changed:
ggml/src/ggml-hexagon/ggml-hexagon.cpp
ggml/src/ggml-hexagon/htp/act-ops.c
ggml/src/ggml-hexagon/htp/get-rows-ops.c
ggml/src/ggml-hexagon/htp/hex-dma.h
ggml/src/ggml-hexagon/htp/htp-ops.h
ggml/src/ggml-hexagon/htp/matmul-ops.c
ggml/src/ggml-hexagon/htp/rope-ops.c
ggml/src/ggml-hexagon/htp/set-rows-ops.c
ggml/src/ggml-hexagon/htp/softmax-ops.c
ggml/src/ggml-hexagon/htp/sum-rows-ops.c
ggml/src/ggml-hexagon/htp/unary-ops.c
scripts/snapdragon/adb/run-cli.sh
scripts/snapdragon/adb/run-completion.sh
scripts/snapdragon/adb/run-mtmd.sh
scripts/snapdragon/windows/run-cli.ps1

index 54f9986498f49d45af1cd00e7b6fb3be8958b8d2..7a44443a8a3e0cae98827f07ca27f295bc45e141 100644 (file)
@@ -1749,23 +1749,6 @@ static inline bool ggml_backend_buffer_is_hexagon_repack(const struct ggml_backe
     return b->buft->iface.alloc_buffer == ggml_backend_hexagon_repack_buffer_type_alloc_buffer;
 }
 
-static bool hex_supported_dims2(const struct ggml_tensor * x, const struct ggml_tensor * y) {
-    if (x->ne[0] != y->ne[0]) {
-        return false;
-    }
-    if (x->ne[1] != y->ne[1]) {
-        return false;
-    }
-    if (x->ne[2] != y->ne[2]) {
-        return false;
-    }
-    if (x->ne[3] != y->ne[3]) {
-        return false;
-    }
-
-    return true;
-}
-
 static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
     const struct ggml_tensor * src0 = op->src[0];
     const struct ggml_tensor * src1 = op->src[1];
@@ -1797,43 +1780,6 @@ static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_sess
     return opt_experimental;
 }
 
-static bool hex_supported_src0_type(ggml_type t) {
-    return t == GGML_TYPE_F32;
-}
-
-static bool hex_supported_src1_type(ggml_type t) {
-    return t == GGML_TYPE_F32;
-}
-
-static bool hex_supported_src2_type(ggml_type t) {
-    return t == GGML_TYPE_F32;
-}
-
-static bool hex_supported_src1_type2(ggml_type t) {
-    return t == GGML_TYPE_F16;
-}
-
-static bool hex_supported_src1_type3(ggml_type t) {
-    return t == GGML_TYPE_I32;
-}
-
-static bool hex_supported_dst_type(ggml_type t) {
-    return t == GGML_TYPE_F32;
-}
-
-static bool hex_supported_dims(const struct ggml_tensor * x, const struct ggml_tensor * y) {
-    // TODO: support broadcast for ne[2 and 3]
-    if (x->ne[0] != y->ne[0]) {
-        return false;
-    }
-    if (x->ne[2] != y->ne[2]) {
-        return false;
-    }
-    if (x->ne[3] != y->ne[3]) {
-        return false;
-    }
-    return true;
-}
 
 static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * dst) {
     const struct ggml_tensor * src0 = dst->src[0];
@@ -1919,19 +1865,19 @@ static bool ggml_hexagon_supported_binary(const struct ggml_hexagon_session * se
     const struct ggml_tensor * src1 = op->src[1];
     const struct ggml_tensor * dst  = op;
 
-    if (!hex_supported_src0_type(src0->type)) {
+    if (src0->type != GGML_TYPE_F32) {
         return false;
     }
-    if (!hex_supported_src1_type(src1->type)) {
+    if (src1->type != GGML_TYPE_F32) {
         return false;
     }
-    if (!hex_supported_dst_type(dst->type)) {
+    if (dst->type != GGML_TYPE_F32) {
         return false;
     }
-    if (!hex_supported_dims2(src0, dst)) {
+    if (!ggml_are_same_shape(src0, dst)) {
         return false;
     }
-    if (!ggml_can_repeat(src1, src0)) {
+    if (!ggml_can_repeat(src1, src0) || ggml_is_permuted(src1)) {
         return false;
     }
 
@@ -1943,16 +1889,16 @@ static bool ggml_hexagon_supported_add_id(const struct ggml_hexagon_session * se
     const struct ggml_tensor * src1 = op->src[1];
     const struct ggml_tensor * dst  = op;
 
-    if (!hex_supported_src0_type(src0->type)) {
+    if (src0->type != GGML_TYPE_F32) {
         return false;
     }
-    if (!hex_supported_src1_type(src1->type)) {
+    if (src1->type != GGML_TYPE_F32) {
         return false;
     }
-    if (!hex_supported_dst_type(dst->type)) {
+    if (dst->type != GGML_TYPE_F32) {
         return false;
     }
-    if (!hex_supported_dims2(src0, dst)) {
+    if (!ggml_are_same_shape(src0, dst)) {
         return false;
     }
 
@@ -1968,13 +1914,13 @@ static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * ses
     const struct ggml_tensor * src0 = op->src[0];
     const struct ggml_tensor * dst  = op;
 
-    if (!hex_supported_src0_type(src0->type)) {
+    if (src0->type != GGML_TYPE_F32) {
         return false;
     }
-    if (!hex_supported_dst_type(dst->type)) {
+    if (dst->type != GGML_TYPE_F32) {
         return false;
     }
-    if (!hex_supported_dims2(src0, dst)) {
+    if (!ggml_are_same_shape(src0, dst)) {
         return false;
     }
 
@@ -1990,10 +1936,10 @@ static bool ggml_hexagon_supported_sum_rows(const struct ggml_hexagon_session *
     const struct ggml_tensor * src0 = op->src[0];
     const struct ggml_tensor * dst  = op;
 
-    if (!hex_supported_src0_type(src0->type)) {
+    if (src0->type != GGML_TYPE_F32) {
         return false;
     }
-    if (!hex_supported_dst_type(dst->type)) {
+    if (dst->type != GGML_TYPE_F32) {
         return false;
     }
 
@@ -2011,10 +1957,10 @@ static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session
     const struct ggml_tensor * src1 = op->src[1];
     const struct ggml_tensor * dst  = op;
 
-    if (!hex_supported_src0_type(src0->type)) {
+    if (src0->type != GGML_TYPE_F32) {
         return false;
     }
-    if (!hex_supported_dst_type(dst->type)) {
+    if (dst->type != GGML_TYPE_F32) {
         return false;
     }
 
@@ -2023,10 +1969,10 @@ static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session
     }
 
     if (src1) {
-        if (!hex_supported_src1_type(src1->type)) {
+        if (src1->type != GGML_TYPE_F32) {
             return false;
         }
-        if (!hex_supported_dims2(src0, src1)) {
+        if (!ggml_are_same_shape(src0, src1)) {
             return false;
         }
         if (!ggml_is_contiguous(src1)) {
@@ -2047,15 +1993,15 @@ static bool ggml_hexagon_supported_softmax(const struct ggml_hexagon_session * s
         return false;  // FIXME: add support for sinks
     }
 
-    if (!hex_supported_src0_type(src0->type)) {
+    if (src0->type != GGML_TYPE_F32) {
         return false;
     }
-    if (!hex_supported_dst_type(dst->type)) {
+    if (dst->type != GGML_TYPE_F32) {
         return false;
     }
 
     if (src1) {
-        if (!hex_supported_src1_type(src1->type) && !hex_supported_src1_type2(src1->type)) {
+        if (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) {
             return false;
         }
         if (src0->ne[0] != src1->ne[0]) {
@@ -2162,17 +2108,17 @@ static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess
     const struct ggml_tensor * src2 = op->src[2];
     const struct ggml_tensor * dst  = op;
 
-    if (!hex_supported_src0_type(src0->type)) {
+    if (src0->type != GGML_TYPE_F32) {
         return false;  // FIXME: add support for GGML_TYPE_F16 for src0
     }
-    if (!hex_supported_dst_type(dst->type)) {
+    if (dst->type != GGML_TYPE_F32) {
         return false;
     }
-    if (!hex_supported_src1_type3(src1->type)) {
+    if (src1->type != GGML_TYPE_I32) {
         return false;
     }
     if (src2) {
-        if (!hex_supported_src2_type(src2->type)) {
+        if (src2->type != GGML_TYPE_F32) {
             return false;
         }
         int n_dims = op_params[1];
index 950d836ad349d52f148a9bea5d9cf2da303f1200..21bd4050a1d2706966085554db42aa5d50411c22 100644 (file)
     const uint32_t nb2 = dst->nb[2];   \
     const uint32_t nb3 = dst->nb[3];
 
-static void glu_swiglu_f32_per_thread(const struct htp_tensor * src0,
-                                       const struct htp_tensor * src1,
-                                       struct htp_tensor *       dst,
-                                       const int32_t *           op_params,
-                                       struct htp_spad *         src0_spad,
-                                       struct htp_spad *         src1_spad,
-                                       struct htp_spad *         dst_spad,
-                                       uint32_t                  nth,
-                                       uint32_t                  ith,
-                                       uint32_t                  src0_nrows_per_thread,
-                                       dma_queue *               dma_queue) {
+struct htp_act_context {
+    struct htp_ops_context *  octx;
+
+    // Precomputed values
+    const uint8_t *           data_src0;
+    const uint8_t *           data_src1;
+    uint8_t *                 data_dst;
+
+    size_t                    src0_row_size;
+    size_t                    src1_row_size;
+    size_t                    dst_row_size;
+
+    size_t                    src0_row_size_aligned;
+    size_t                    src1_row_size_aligned;
+    size_t                    dst_row_size_aligned;
+
+    size_t                    src0_spad_half_size;
+    size_t                    src1_spad_half_size;
+    size_t                    dst_spad_half_size;
+
+    uint32_t                  block;
+    uint32_t                  src0_nrows;
+    uint32_t                  src0_nrows_per_thread;
+    int                       nc;
+};
+
+static void glu_swiglu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
+    struct htp_act_context * actx = (struct htp_act_context *) data;
+    const struct htp_tensor * src0 = &actx->octx->src0;
+    const struct htp_tensor * src1 = &actx->octx->src1;
+    const struct htp_tensor * dst  = &actx->octx->dst;
     htp_act_preamble3;
 
-    size_t src0_row_size = nb01;
-    size_t src1_row_size = nb11;
-    size_t dst_row_size  = nb1;
-
-
-
-    const uint32_t src0_nrows = ne01 * ne02 * ne03;  // src0 rows
+    size_t src0_row_size = actx->src0_row_size;
+    size_t src1_row_size = actx->src1_row_size;
+    size_t dst_row_size  = actx->dst_row_size;
 
+    const uint32_t src0_nrows = actx->src0_nrows;
+    const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
     const uint32_t src0_start_row = src0_nrows_per_thread * ith;
     const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
 
@@ -101,43 +119,34 @@ static void glu_swiglu_f32_per_thread(const struct htp_tensor * src0,
     uint64_t t1, t2;
     t1 = HAP_perf_get_qtimer_count();
 
-    const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
-    const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
-    uint8_t * restrict data_dst        = (uint8_t *) dst->data;
-
-    const bool src1_valid = src1->ne[0];
-    const int  nc         = (src1_valid) ? ne00 : ne00 / 2;
-    if (!src1_valid) {
-        const int32_t swapped = op_params[1];
-        data_src1             = data_src0;
-        src1_row_size         = src0_row_size;
+    const uint8_t * restrict data_src0 = actx->data_src0;
+    const uint8_t * restrict data_src1 = actx->data_src1;
+    uint8_t * restrict data_dst        = actx->data_dst;
 
-        const size_t nc_in_bytes = nc * SIZEOF_FP32;
-        data_src0 += swapped ? nc_in_bytes : 0;
-        data_src1 += swapped ? 0 : nc_in_bytes;
-    }
+    const int  nc = actx->nc;
 
-    const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
-    const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
-    const size_t dst_row_size_aligned  = hex_round_up(dst_row_size, VLEN);
+    const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
+    const size_t src1_row_size_aligned = actx->src1_row_size_aligned;
+    const size_t dst_row_size_aligned  = actx->dst_row_size_aligned;
 
-    uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
-    uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
-    uint8_t * restrict dst_spad_data  = dst_spad->data + (ith * dst_spad->size_per_thread);
+    uint8_t * restrict src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
+    uint8_t * restrict src1_spad_data = actx->octx->src1_spad.data + (ith * actx->octx->src1_spad.size_per_thread);
+    uint8_t * restrict dst_spad_data  = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);
 
-    // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
-    size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
-    size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
-    size_t dst_spad_half_size  = dst_spad->size_per_thread / 2;
+    size_t src0_spad_half_size = actx->src0_spad_half_size;
+    size_t src1_spad_half_size = actx->src1_spad_half_size;
+    size_t dst_spad_half_size  = actx->dst_spad_half_size;
 
-    const int BLOCK = src0_spad_half_size / src0_row_size_aligned;  // How many rows can we process in one block
+    const int BLOCK = actx->block;
     if (BLOCK == 0) {
         FARF(ERROR,
              "swiglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
-             src0_spad->size_per_thread, src0_row_size_aligned);
+             actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
         return;
     }
 
+    dma_queue * dma_queue = actx->octx->ctx->dma[ith];
+
     // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
     for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
         const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
@@ -196,27 +205,22 @@ static void glu_swiglu_f32_per_thread(const struct htp_tensor * src0,
          (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
 }
 
-static void glu_swiglu_oai_f32_per_thread(const struct htp_tensor * src0,
-                                           const struct htp_tensor * src1,
-                                           struct htp_tensor *       dst,
-                                           const int32_t *           op_params,
-                                           struct htp_spad *         src0_spad,
-                                           struct htp_spad *         src1_spad,
-                                           struct htp_spad *         dst_spad,
-                                           uint32_t                  nth,
-                                           uint32_t                  ith,
-                                           uint32_t                  src0_nrows_per_thread,
-                                           dma_queue *               dma_queue) {
+static void glu_swiglu_oai_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
+    struct htp_act_context * actx = (struct htp_act_context *) data;
+    const struct htp_tensor * src0 = &actx->octx->src0;
+    const struct htp_tensor * src1 = &actx->octx->src1;
+    const struct htp_tensor * dst  = &actx->octx->dst;
     htp_act_preamble3;
 
     uint64_t t1, t2;
     t1 = HAP_perf_get_qtimer_count();
 
-    size_t src0_row_size = nb01;
-    size_t src1_row_size = nb11;
-    size_t dst_row_size  = nb1;
+    size_t src0_row_size = actx->src0_row_size;
+    size_t src1_row_size = actx->src1_row_size;
+    size_t dst_row_size  = actx->dst_row_size;
 
-    const uint32_t src0_nrows = ne01 * ne02 * ne03;  // src0 rows
+    const uint32_t src0_nrows = actx->src0_nrows;
+    const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
 
     const uint32_t src0_start_row = src0_nrows_per_thread * ith;
     const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
@@ -226,45 +230,36 @@ static void glu_swiglu_oai_f32_per_thread(const struct htp_tensor * src0,
         return;
     }
 
-    const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
-    const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
-    uint8_t * restrict data_dst        = (uint8_t *) dst->data;
-
-    const bool src1_valid = src1->ne[0];
-    const int  nc         = (src1_valid) ? ne00 : ne00 / 2;
-    if (!src1_valid) {
-        const int32_t swapped = op_params[1];
-        data_src1             = data_src0;
-        src1_row_size         = src0_row_size;
+    const uint8_t * restrict data_src0 = actx->data_src0;
+    const uint8_t * restrict data_src1 = actx->data_src1;
+    uint8_t * restrict data_dst        = actx->data_dst;
 
-        const size_t nc_in_bytes = nc * SIZEOF_FP32;
-        data_src0 += swapped ? nc_in_bytes : 0;
-        data_src1 += swapped ? 0 : nc_in_bytes;
-    }
+    const int nc = actx->nc;
 
-    const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
-    const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
-    const size_t dst_row_size_aligned  = hex_round_up(dst_row_size, VLEN);
+    const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
+    const size_t src1_row_size_aligned = actx->src1_row_size_aligned;
+    const size_t dst_row_size_aligned  = actx->dst_row_size_aligned;
 
-    uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
-    uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
-    uint8_t * restrict dst_spad_data  = dst_spad->data + (ith * dst_spad->size_per_thread);
+    uint8_t * restrict src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
+    uint8_t * restrict src1_spad_data = actx->octx->src1_spad.data + (ith * actx->octx->src1_spad.size_per_thread);
+    uint8_t * restrict dst_spad_data  = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);
 
-    // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
-    size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
-    size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
-    size_t dst_spad_half_size  = dst_spad->size_per_thread / 2;
+    size_t src0_spad_half_size = actx->src0_spad_half_size;
+    size_t src1_spad_half_size = actx->src1_spad_half_size;
+    size_t dst_spad_half_size  = actx->dst_spad_half_size;
 
-    const int BLOCK = src0_spad_half_size / src0_row_size_aligned;  // How many rows can we process in one block
+    const int BLOCK = actx->block;
     if (BLOCK == 0) {
         FARF(ERROR,
              "swiglu-oai-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least "
              "%zu\n",
-             src0_spad->size_per_thread, src0_row_size_aligned);
+             actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
         return;
     }
-    const float alpha = ((const float *) (op_params))[2];
-    const float limit = ((const float *) (op_params))[3];
+    const float alpha = ((const float *) (actx->octx->op_params))[2];
+    const float limit = ((const float *) (actx->octx->op_params))[3];
+
+    dma_queue * dma_queue = actx->octx->ctx->dma[ith];
 
     // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
     for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
@@ -335,26 +330,22 @@ static void glu_swiglu_oai_f32_per_thread(const struct htp_tensor * src0,
 }
 
 
-static void unary_gelu_f32_per_thread(const struct htp_tensor * src0,
-                                       struct htp_tensor *       dst,
-                                       const int32_t *           op_params,
-                                       struct htp_spad *         src0_spad,
-                                       struct htp_spad *         dst_spad,
-                                       uint32_t                  nth,
-                                       uint32_t                  ith,
-                                       uint32_t                  src0_nrows_per_thread,
-                                       dma_queue *               dma_queue) {
+static void unary_gelu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
+    struct htp_act_context * actx = (struct htp_act_context *) data;
+    const struct htp_tensor * src0 = &actx->octx->src0;
+    const struct htp_tensor * dst  = &actx->octx->dst;
     htp_act_preamble2;
 
     uint64_t t1, t2;
     t1 = HAP_perf_get_qtimer_count();
 
-    const size_t src0_row_size = nb01;
-    const size_t dst_row_size  = nb1;
-    const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
-    const size_t dst_row_size_aligned  = hex_round_up(dst_row_size, VLEN);
+    const size_t src0_row_size = actx->src0_row_size;
+    const size_t dst_row_size  = actx->dst_row_size;
+    const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
+    const size_t dst_row_size_aligned  = actx->dst_row_size_aligned;
 
-    const uint32_t src0_nrows = ne01 * ne02 * ne03;
+    const uint32_t src0_nrows = actx->src0_nrows;
+    const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
 
     const uint32_t src0_start_row = src0_nrows_per_thread * ith;
     const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
@@ -364,25 +355,29 @@ static void unary_gelu_f32_per_thread(const struct htp_tensor * src0,
         return;
     }
 
-    const uint8_t * data_src0 = (const uint8_t *) src0->data;
-    uint8_t * data_dst        = (uint8_t *) dst->data;
+    const uint8_t * data_src0 = actx->data_src0;
+    uint8_t * data_dst        = actx->data_dst;
 
-    uint8_t * src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
-    uint8_t * dst_spad_data  = dst_spad->data  + (ith * dst_spad->size_per_thread);
+    // nc/ne0 matches.
+    const int ne0_val = actx->nc; // == dst->ne[0]
 
-    // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
-    size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
-    size_t dst_spad_half_size  = dst_spad->size_per_thread  / 2;
+    uint8_t * src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
+    uint8_t * dst_spad_data  = actx->octx->dst_spad.data  + (ith * actx->octx->dst_spad.size_per_thread);
+
+    size_t src0_spad_half_size = actx->src0_spad_half_size;
+    size_t dst_spad_half_size  = actx->dst_spad_half_size;
 
     // In gelu = x*sigmoid(x*1.702)
-    const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
+    const int BLOCK = actx->block;
 
     if (BLOCK == 0) {
         FARF(ERROR, "gelu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
-                src0_spad->size_per_thread, src0_row_size_aligned);
+                actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
         return;
     }
 
+    dma_queue * dma_queue = actx->octx->ctx->dma[ith];
+
     // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
     for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
         const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
@@ -408,9 +403,9 @@ static void unary_gelu_f32_per_thread(const struct htp_tensor * src0,
             float* dst_spad_ptr        = dst_spad  + ib * (dst_row_size_aligned  / sizeof(float));
 
             // gelu = x * sigmoid(1.702 * x) // current implementation
-            hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (float) 1.702, ne0);
-            hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
-            hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
+            hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (float) 1.702, ne0_val);
+            hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0_val);
+            hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0_val);
         }
 
         dma_queue_push_vtcm_to_ddr(dma_queue,
@@ -435,34 +430,23 @@ static void unary_gelu_f32_per_thread(const struct htp_tensor * src0,
          ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
 }
 
-static void unary_gelu_f32(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = (struct htp_ops_context *) data;
-    unary_gelu_f32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
-                               octx->src0_nrows_per_thread, octx->ctx->dma[i]);
-}
-
 
-
-static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
-                                       struct htp_tensor *       dst,
-                                       const int32_t *           op_params,
-                                       struct htp_spad *         src0_spad,
-                                       struct htp_spad *         dst_spad,
-                                       uint32_t                  nth,
-                                       uint32_t                  ith,
-                                       uint32_t                  src0_nrows_per_thread,
-                                       dma_queue *               dma_queue) {
+static void unary_silu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
+    struct htp_act_context * actx = (struct htp_act_context *) data;
+    const struct htp_tensor * src0 = &actx->octx->src0;
+    const struct htp_tensor * dst  = &actx->octx->dst;
     htp_act_preamble2;
 
     uint64_t t1, t2;
     t1 = HAP_perf_get_qtimer_count();
 
-    const size_t src0_row_size = nb01;
-    const size_t dst_row_size  = nb1;
-    const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
-    const size_t dst_row_size_aligned  = hex_round_up(dst_row_size, VLEN);
+    const size_t src0_row_size = actx->src0_row_size;
+    const size_t dst_row_size  = actx->dst_row_size;
+    const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
+    const size_t dst_row_size_aligned  = actx->dst_row_size_aligned;
 
-    const uint32_t src0_nrows = ne01 * ne02 * ne03;
+    const uint32_t src0_nrows = actx->src0_nrows;
+    const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
 
     const uint32_t src0_start_row = src0_nrows_per_thread * ith;
     const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
@@ -472,24 +456,27 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
         return;
     }
 
-    const uint8_t * data_src0 = (const uint8_t *) src0->data;
-    uint8_t * data_dst        = (uint8_t *) dst->data;
+    const uint8_t * data_src0 = actx->data_src0;
+    uint8_t * data_dst        = actx->data_dst;
 
-    uint8_t * src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
-    uint8_t * dst_spad_data  = dst_spad->data  + (ith * dst_spad->size_per_thread);
+    const int ne0_val = actx->nc; // == dst->ne[0]
 
-    // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
-    size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
-    size_t dst_spad_half_size  = dst_spad->size_per_thread  / 2;
+    uint8_t * src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
+    uint8_t * dst_spad_data  = actx->octx->dst_spad.data  + (ith * actx->octx->dst_spad.size_per_thread);
 
-    const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
+    size_t src0_spad_half_size = actx->src0_spad_half_size;
+    size_t dst_spad_half_size  = actx->dst_spad_half_size;
+
+    const int BLOCK = actx->block;
 
     if (BLOCK == 0) {
         FARF(ERROR, "silu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
-                src0_spad->size_per_thread, src0_row_size_aligned);
+                actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
         return;
     }
 
+    dma_queue * dma_queue = actx->octx->ctx->dma[ith];
+
     // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
     for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
         const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
@@ -515,8 +502,8 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
             float* dst_spad_ptr        = dst_spad  + ib * (dst_row_size_aligned  / sizeof(float));
 
             // silu = x * sigmoid(x)
-            hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, ne0);
-            hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
+            hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, ne0_val);
+            hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0_val);
         }
 
         dma_queue_push_vtcm_to_ddr(dma_queue,
@@ -544,27 +531,22 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
 static const float GELU_COEF_A     = 0.044715f;
 static const float SQRT_2_OVER_PI  = 0.79788456080286535587989211986876f;
 
-static void glu_geglu_f32_per_thread(const struct htp_tensor * src0,
-                                       const struct htp_tensor * src1,
-                                       struct htp_tensor *       dst,
-                                       const int32_t *           op_params,
-                                       struct htp_spad *         src0_spad,
-                                       struct htp_spad *         src1_spad,
-                                       struct htp_spad *         dst_spad,
-                                       uint32_t                  nth,
-                                       uint32_t                  ith,
-                                       uint32_t                  src0_nrows_per_thread,
-                                       dma_queue *               dma_queue) {
+static void glu_geglu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
+    struct htp_act_context * actx = (struct htp_act_context *) data;
+    const struct htp_tensor * src0 = &actx->octx->src0;
+    const struct htp_tensor * src1 = &actx->octx->src1;
+    const struct htp_tensor * dst  = &actx->octx->dst;
     htp_act_preamble3;
 
-    size_t src0_row_size = nb01;
-    size_t src1_row_size = nb11;
-    size_t dst_row_size  = nb1;
+    size_t src0_row_size = actx->src0_row_size;
+    size_t src1_row_size = actx->src1_row_size;
+    size_t dst_row_size  = actx->dst_row_size;
 
     uint64_t t1, t2;
     t1 = HAP_perf_get_qtimer_count();
 
-    const uint32_t src0_nrows = ne01 * ne02 * ne03;  // src0 rows
+    const uint32_t src0_nrows = actx->src0_nrows;
+    const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
 
     const uint32_t src0_start_row = src0_nrows_per_thread * ith;
     const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
@@ -574,43 +556,34 @@ static void glu_geglu_f32_per_thread(const struct htp_tensor * src0,
         return;
     }
 
-    const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
-    const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
-    uint8_t * restrict data_dst        = (uint8_t *) dst->data;
+    const uint8_t * restrict data_src0 = actx->data_src0;
+    const uint8_t * restrict data_src1 = actx->data_src1;
+    uint8_t * restrict data_dst        = actx->data_dst;
 
-    const bool src1_valid = src1->ne[0];
-    const int  nc         = (src1_valid) ? ne00 : ne00 / 2;
-    if (!src1_valid) {
-        const int32_t swapped = op_params[1];
-        data_src1             = data_src0;
-        src1_row_size         = src0_row_size;
-
-        const size_t nc_in_bytes = nc * SIZEOF_FP32;
-        data_src0 += swapped ? nc_in_bytes : 0;
-        data_src1 += swapped ? 0 : nc_in_bytes;
-    }
+    const int nc = actx->nc;
 
-    const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
-    const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
-    const size_t dst_row_size_aligned  = hex_round_up(dst_row_size, VLEN);
+    const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
+    const size_t src1_row_size_aligned = actx->src1_row_size_aligned;
+    const size_t dst_row_size_aligned  = actx->dst_row_size_aligned;
 
-    uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
-    uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
-    uint8_t * restrict dst_spad_data  = dst_spad->data + (ith * dst_spad->size_per_thread);
+    uint8_t * restrict src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
+    uint8_t * restrict src1_spad_data = actx->octx->src1_spad.data + (ith * actx->octx->src1_spad.size_per_thread);
+    uint8_t * restrict dst_spad_data  = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);
 
-    // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
-    size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
-    size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
-    size_t dst_spad_half_size  = dst_spad->size_per_thread / 2;
+    size_t src0_spad_half_size = actx->src0_spad_half_size;
+    size_t src1_spad_half_size = actx->src1_spad_half_size;
+    size_t dst_spad_half_size  = actx->dst_spad_half_size;
 
-    const int BLOCK = src0_spad_half_size / src0_row_size_aligned;  // How many rows can we process in one block
+    const int BLOCK = actx->block;
     if (BLOCK == 0) {
         FARF(ERROR,
              "geglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
-             src0_spad->size_per_thread, src0_row_size_aligned);
+             actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
         return;
     }
 
+    dma_queue * dma_queue = actx->octx->ctx->dma[ith];
+
     // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
     for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
         const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
@@ -678,33 +651,7 @@ static void glu_geglu_f32_per_thread(const struct htp_tensor * src0,
          (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
 }
 
-static void unary_silu_f32(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = (struct htp_ops_context *) data;
-    unary_silu_f32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
-                               octx->src0_nrows_per_thread, octx->ctx->dma[i]);
-}
-
-static void glu_swiglu_f32(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = (struct htp_ops_context *) data;
-    glu_swiglu_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
-                               &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
-}
-
-static void glu_swiglu_oai_f32(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = (struct htp_ops_context *) data;
-    glu_swiglu_oai_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
-                                   &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
-}
-
-static void glu_geglu_f32(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = (struct htp_ops_context *) data;
-    glu_geglu_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
-                               &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
-}
-
 static int execute_op_activations_f32(struct htp_ops_context * octx) {
-    int err = HTP_STATUS_OK;
-
     const struct htp_tensor * src0 = &octx->src0;
     const struct htp_tensor * src1 = &octx->src1;
     struct htp_tensor *       dst  = &octx->dst;
@@ -719,26 +666,26 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) {
 
     switch (octx->op) {
         case HTP_OP_UNARY_SILU:
-            act_op_func = unary_silu_f32;
+            act_op_func = (worker_callback_t)unary_silu_f32_per_thread;
             op_type     = "silu-f32";
             break;
 
         case HTP_OP_GLU_SWIGLU:
-            act_op_func = glu_swiglu_f32;
+            act_op_func = (worker_callback_t)glu_swiglu_f32_per_thread;
             op_type     = "swiglu-f32";
             break;
 
         case HTP_OP_GLU_SWIGLU_OAI:
-            act_op_func = glu_swiglu_oai_f32;
+            act_op_func = (worker_callback_t)glu_swiglu_oai_f32_per_thread;
             op_type     = "swiglu-oai-f32";
             break;
         case HTP_OP_UNARY_GELU:
-            act_op_func = unary_gelu_f32;
+            act_op_func = (worker_callback_t)unary_gelu_f32_per_thread;
             op_type     = "gelu-f32";
             break;
 
         case HTP_OP_GLU_GEGLU:
-            act_op_func = glu_geglu_f32;
+            act_op_func = (worker_callback_t)glu_geglu_f32_per_thread;
             op_type     = "geglu-f32";
             break;
         default:
@@ -797,13 +744,58 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) {
              octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
     }
 
-    if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
-        uint32_t n_jobs = MIN(n_threads, src0_nrows);
-        octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
-        worker_pool_run_func(octx->ctx->worker_pool, act_op_func, octx, n_jobs);
+    if ((octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
+        return HTP_STATUS_OK;
     }
 
-    return err;
+    uint32_t n_jobs = MIN(n_threads, src0_nrows);
+
+    // Prepare context
+    struct htp_act_context actx;
+    actx.octx = octx;
+
+    actx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
+
+    actx.src0_row_size = src0_row_size;
+    actx.src1_row_size = src1_row_size;
+    actx.dst_row_size  = dst_row_size;
+
+    actx.src0_row_size_aligned = src0_row_size_aligned;
+    actx.src1_row_size_aligned = src1_row_size_aligned;
+    actx.dst_row_size_aligned  = dst_row_size_aligned;
+
+    actx.src0_spad_half_size = octx->src0_spad.size_per_thread / 2;
+    actx.src1_spad_half_size = octx->src1_spad.size_per_thread / 2;
+    actx.dst_spad_half_size  = octx->dst_spad.size_per_thread / 2;
+
+    actx.block = actx.src0_spad_half_size / actx.src0_row_size_aligned;
+    actx.src0_nrows = src0_nrows;
+
+    actx.nc = dst->ne[0];
+
+    // Pointers and GLU logic
+    const uint8_t * data_src0 = (const uint8_t *) src0->data;
+    const uint8_t * data_src1 = (const uint8_t *) src1->data;
+
+    if (!src1_valid && (octx->op == HTP_OP_GLU_SWIGLU || octx->op == HTP_OP_GLU_SWIGLU_OAI || octx->op == HTP_OP_GLU_GEGLU)) {
+         const int32_t swapped = octx->op_params[1];
+         data_src1 = data_src0;
+         actx.src1_row_size = actx.src0_row_size;
+
+         size_t nc_in_bytes = actx.nc * SIZEOF_FP32;
+         if (swapped) {
+             data_src0 += nc_in_bytes;
+         } else {
+             data_src1 += nc_in_bytes;
+         }
+    }
+
+    actx.data_src0 = data_src0;
+    actx.data_src1 = data_src1;
+    actx.data_dst  = (uint8_t *) dst->data;
+
+    worker_pool_run_func(octx->ctx->worker_pool, act_op_func, &actx, n_jobs);
+    return HTP_STATUS_OK;
 }
 
 int op_activations(struct htp_ops_context * octx) {
index a657cd2dcf278dc923a67201b3ea6a4583c7594d..bf24bbda70ae2dfb14b8203b84fa1427d5909ba2 100644 (file)
 #include "htp-ops.h"
 #include "hvx-utils.h"
 
+struct get_rows_context {
+    struct htp_ops_context * octx;
+    uint32_t src1_nrows_per_thread;
+    struct fastdiv_values get_rows_div_ne10;
+    struct fastdiv_values get_rows_div_ne10_ne11;
+};
+
 #define get_rows_preamble \
     const uint32_t ne00 = octx->src0.ne[0]; \
     const uint32_t ne01 = octx->src0.ne[1]; \
                                             \
     const uint32_t nr = ne10 * ne11 * ne12;
 
-static int get_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) {
+static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) {
+    struct get_rows_context * grctx = (struct get_rows_context *)data;
+    struct htp_ops_context * octx = grctx->octx;
     get_rows_preamble;
 
     // parallelize by src1 elements (which correspond to dst rows)
-    const uint32_t dr  = octx->src1_nrows_per_thread;
+    const uint32_t dr  = grctx->src1_nrows_per_thread;
     const uint32_t ir0 = dr * ith;
     const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
 
     const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
 
     for (uint32_t i = ir0; i < ir1; ++i) {
-        const uint32_t i12 = fastdiv(i, &octx->get_rows_div_ne10_ne11);
+        const uint32_t i12 = fastdiv(i, &grctx->get_rows_div_ne10_ne11);
         const uint32_t rem = i - i12 * ne11 * ne10;
-        const uint32_t i11 = fastdiv(rem, &octx->get_rows_div_ne10);
+        const uint32_t i11 = fastdiv(rem, &grctx->get_rows_div_ne10);
         const uint32_t i10 = rem - i11 * ne10;
 
         const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
@@ -68,12 +77,6 @@ static int get_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth,
         const uintptr_t dst_ptr  = octx->dst.data  + i10*nb1  + i11*nb2  + i12*nb3;
         hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);
     }
-
-    return HTP_STATUS_OK;
-}
-
-static void get_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) {
-    get_rows_thread_f32_f32((struct htp_ops_context *) data, n, i);
 }
 
 int op_get_rows(struct htp_ops_context * octx) {
@@ -95,12 +98,14 @@ int op_get_rows(struct htp_ops_context * octx) {
         return HTP_STATUS_OK;
     }
 
-    octx->get_rows_div_ne10      = init_fastdiv_values(octx->src1.ne[0]);
-    octx->get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]);
+    struct get_rows_context grctx;
+    grctx.octx = octx;
+    grctx.get_rows_div_ne10      = init_fastdiv_values(octx->src1.ne[0]);
+    grctx.get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]);
 
     const uint32_t n_jobs = MIN(nr, octx->n_threads);
-    octx->src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
+    grctx.src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
 
-    worker_pool_run_func(octx->ctx->worker_pool, get_rows_work_f32_f32, octx, n_jobs);
+    worker_pool_run_func(octx->ctx->worker_pool, get_rows_thread_f32_f32, &grctx, n_jobs);
     return HTP_STATUS_OK;
 }
index d1ddb0ecbf04ce5d05f27e98d79286ce61381d5e..350ab9d966f9a382818bf26102876d38c9729044 100644 (file)
@@ -102,7 +102,7 @@ static inline bool dma_queue_push(dma_queue * q,
     dmlink(q->tail, desc);
     q->tail = desc;
 
-    // FARF(ERROR, "dma-push: i %u len %u dst %p src %p\n", q->push_idx, len, dst, src);
+    // FARF(ERROR, "dma-push: i %u width %u nrows %d dst %p src %p\n", q->push_idx, width, nrows, dptr.dst, dptr.src);
     q->push_idx = (q->push_idx + 1) & q->idx_mask;
     return true;
 }
@@ -144,11 +144,37 @@ static inline dma_ptr dma_queue_pop(dma_queue * q) {
 
     dptr = q->dptr[q->pop_idx];
 
-    // FARF(ERROR, "dma-pop: i %u dst %p\n", q->pop_idx, dst);
+    // FARF(ERROR, "dma-pop: i %u dst %p src %p\n", q->pop_idx, dptr.dst, dptr.src);
     q->pop_idx = (q->pop_idx + 1) & q->idx_mask;
     return dptr;
 }
 
+static inline dma_ptr dma_queue_pop_nowait(dma_queue * q) {
+    dma_ptr dptr  = { NULL };
+
+    if (q->push_idx == q->pop_idx) {
+        return dptr;
+    }
+
+    dptr = q->dptr[q->pop_idx];
+
+    // FARF(ERROR, "dma-pop-nowait: i %u dst %p src %p\n", q->pop_idx, dptr.dst, dptr.src);
+    q->pop_idx = (q->pop_idx + 1) & q->idx_mask;
+    return dptr;
+}
+
+static inline bool dma_queue_empty(dma_queue * q) {
+    return q->push_idx == q->pop_idx;
+}
+
+static inline uint32_t dma_queue_depth(dma_queue * q) {
+    return (q->push_idx - q->pop_idx) & q->idx_mask;
+}
+
+static inline uint32_t dma_queue_capacity(dma_queue * q) {
+    return q->capacity;
+}
+
 #ifdef __cplusplus
 }  // extern "C"
 #endif
index f1ad24dbfaa677f4846775f4a1ba2514283ab835..127ab1d66598592aac1765f1ba4ae526be4d84fa 100644 (file)
@@ -44,32 +44,6 @@ struct htp_ops_context {
     uint32_t src0_nrows_per_thread;
     uint32_t src1_nrows_per_thread;
 
-    struct fastdiv_values src0_div1;  // fastdiv values for ne1
-    struct fastdiv_values src0_div2;  // fastdiv values for ne2
-    struct fastdiv_values src0_div3;  // fastdiv values for ne3
-    struct fastdiv_values src0_div21; // fastdiv values for ne2 * ne1
-
-    struct fastdiv_values src1_div1;  // fastdiv values for ne1
-    struct fastdiv_values src1_div2;  // fastdiv values for ne2
-    struct fastdiv_values src1_div3;  // fastdiv values for ne3
-    struct fastdiv_values src1_div21; // fastdiv values for ne2 * ne1
-
-    struct fastdiv_values src3_div1;  // fastdiv values for ne1
-    struct fastdiv_values src3_div2;  // fastdiv values for ne2
-    struct fastdiv_values src3_div3;  // fastdiv values for ne3
-    struct fastdiv_values src3_div21; // fastdiv values for ne2 * ne1
-
-    struct fastdiv_values broadcast_rk2;
-    struct fastdiv_values broadcast_rk3;
-    struct fastdiv_values broadcast_rv2;
-    struct fastdiv_values broadcast_rv3;
-
-    struct fastdiv_values set_rows_div_ne12; // fastdiv values for ne12
-    struct fastdiv_values set_rows_div_ne11; // fastdiv values for ne11
-
-    struct fastdiv_values get_rows_div_ne10;      // fastdiv values for ne10
-    struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11
-
     uint32_t flags;
 };
 
index c360abe8dae2d308e4c4d85d0692e9574ecb6dba..6f6f51f01f58d6ce5bc9b5bdb5360d08966cfead 100644 (file)
@@ -49,62 +49,6 @@ struct htp_matmul_context {
     struct fastdiv_values mm_div_r3;
 };
 
-// vdelta control to replicate first 4x fp32 values across lanes
-static const uint8_t __attribute__((aligned(128))) repl_4x_f32[128] = {
-    0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10,
-    0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20,
-    0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04,
-    0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40,
-    0x44, 0x44, 0x44, 0x44, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04,
-    0x04, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
-    0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10,
-};
-
-// vdelta control to replicate and interleave first 8x fp32 values across lanes
-static const uint8_t __attribute__((aligned(128))) repl_interleave_8x_f32[128] = {
-    0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x00, 0x00, 0x00,
-    0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20,
-    0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04,
-    0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40,
-    0x44, 0x44, 0x44, 0x44, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, 0x44, 0x44, 0x44,
-    0x44, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
-    0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20,
-};
-
-// vdelta control to replicate first fp32 value across all elements
-static const uint8_t __attribute__((aligned(128))) repl_1x_f32[128] = {
-    0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10,
-    0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
-    0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08,
-    0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08,
-    0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04,
-    0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10,
-    0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
-};
-
-// vdelta control to replicate first fp16 value across all elements
-static const uint8_t __attribute__((aligned(128))) repl_1x_f16[128] = {
-    0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02,
-    0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04,
-    0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08,
-    0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x40, 0x40, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02,
-    0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02,
-    0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10,
-    0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
-};
-
-// vdelta control to replicate first fp16 value across all elements
-static const uint8_t __attribute__((aligned(128))) repl_2x_f16[128] = {
-    0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
-    0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
-    0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
-    0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
-    0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
-    0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
-    0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
-    0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
-};
-
 // vdelta control to expand first 32 e8m0 values into 32 uint32 elements
 static const uint8_t __attribute__((aligned(128))) expand_x32_e8m0[128] = {
     0x00, 0x00, 0x00, 0x00, 0x01, 0x04, 0x00, 0x00, 0x02, 0x00, 0x08, 0x08, 0x01, 0x02, 0x00, 0x04, 0x04, 0x00, 0x00,
@@ -2067,10 +2011,10 @@ static inline void quantize_block_f32_q8x1(float * restrict x, uint8_t * restric
     HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero);  // 32 elements
 
     // Convert to QF32
-    HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero);
-    HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero);
-    HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero);
-    HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero);
+    HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero); // replicated over all lanes
+    HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero); // replicated over all lanes
+    HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero); // replicated over all lanes
+    HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero); // replicated over all lanes
 
     // Combine and convert to fp16
     HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf)));
@@ -2080,11 +2024,6 @@ static inline void quantize_block_f32_q8x1(float * restrict x, uint8_t * restric
     HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf)));
     HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
 
-    // Replicate first fp16 scale across all lanes
-    HVX_Vector ctrl = *(const HVX_Vector *) repl_2x_f16;
-    vmax01_hf         = Q6_V_vdelta_VV(vmax01_hf, ctrl);
-    vmax23_hf         = Q6_V_vdelta_VV(vmax23_hf, ctrl);
-
     HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008));  // 1.0 / 127.0
     HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008));  // 1.0 / 127.0
     HVX_Vector vd01_hf   = Q6_Vhf_equals_Vqf16(vd01_qf16);
@@ -2130,13 +2069,8 @@ static inline void quantize_block_f32_q8x2(float * restrict x, uint8_t * restric
     HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
 
     // Compute max and scale
-    HVX_Vector vmax01_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf));
-    HVX_Vector vmax23_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx23_hf));
-
-    // Replicate first fp16 scale across all lanes
-    HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_f16;
-    vmax01_hf         = Q6_V_vdelta_VV(vmax01_hf, ctrl);
-    vmax23_hf         = Q6_V_vdelta_VV(vmax23_hf, ctrl);
+    HVX_Vector vmax01_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf)); // replicated over all lanes
+    HVX_Vector vmax23_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx23_hf)); // replicated over all lanes
 
     HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008));  // 1.0 / 127.0
     HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008));  // 1.0 / 127.0
@@ -2179,11 +2113,7 @@ static inline void quantize_block_f32_q8x4(float * restrict x, uint8_t * restric
 
     // Compute max and scale
     HVX_Vector vmax_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf));
-    vmax_hf            = hvx_vec_reduce_max2_f16(hvx_vec_abs_f16(vx23_hf), vmax_hf);
-
-    // Replicate first fp16 scale across all lanes
-    HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_f16;
-    vmax_hf         = Q6_V_vdelta_VV(vmax_hf, ctrl);
+    vmax_hf            = hvx_vec_reduce_max2_f16(hvx_vec_abs_f16(vx23_hf), vmax_hf); // replicated over all lanes
 
     HVX_Vector vd_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax_hf, Q6_Vh_vsplat_R(0x2008));  // 1.0 / 127.0
     HVX_Vector vd_hf   = Q6_Vhf_equals_Vqf16(vd_qf16);
index 943ca5c952e38bca2a19b37323e9d28d20b25425..aa6a6c9008d1ef45d3d3069150ff3fe5f74caefc 100644 (file)
@@ -10,6 +10,7 @@
 
 #include "hex-dma.h"
 #include "hvx-utils.h"
+#include "hex-fastdiv.h"
 
 #define GGML_COMMON_DECL_C
 #include "ggml-common.h"
@@ -21,6 +22,9 @@
 #define HTP_ROPE_TYPE_NORMAL 0
 #define HTP_ROPE_TYPE_NEOX   2
 
+#define HTP_ROPE_SPAD_NROWS  16
+#define HTP_ROPE_SPAD_BLOCK  (HTP_ROPE_SPAD_NROWS/2)
+
 #define htp_rope_preamble              \
     const uint32_t ne00 = src0->ne[0]; \
     const uint32_t ne01 = src0->ne[1]; \
@@ -42,7 +46,7 @@
     const uint32_t nb2 = dst->nb[2];   \
     const uint32_t nb3 = dst->nb[3];
 
-struct rope_th_ctx {
+struct htp_rope_context {
     int32_t n_dims;
     int32_t mode;
     int32_t n_ctx_orig;
@@ -57,7 +61,19 @@ struct rope_th_ctx {
     float theta_scale;
     float corr_dims[2];
 
+    uint32_t src0_nrows_per_thread;
+    size_t spad_stride;
+
     struct htp_ops_context * octx;
+
+    size_t src0_row_size;
+    size_t dst_row_size;
+    size_t src0_row_size_aligned;
+    size_t dst_row_size_aligned;
+    size_t theta_cache_offset;
+    uint32_t src0_nrows;
+
+    uint64_t t_start;
 };
 
 static float rope_yarn_ramp(const float low, const float high, const int i0) {
@@ -117,64 +133,23 @@ static void rope_corr_dims(int     n_dims,
     dims[1]     = MIN(n_dims - 1, end);
 }
 
-static void init_rope_ctx(struct rope_th_ctx * rope_ctx, struct htp_ops_context * octx) {
-    memset(rope_ctx, 0, sizeof(struct rope_th_ctx));
-
-    const int32_t * op_params = &octx->op_params[0];
-
-    rope_ctx->n_dims     = ((const int32_t *) op_params)[1];
-    rope_ctx->mode       = ((const int32_t *) op_params)[2];
-    rope_ctx->n_ctx_orig = ((const int32_t *) op_params)[4];
-
-    memcpy(&rope_ctx->freq_base, (int32_t *) op_params + 5, sizeof(float));
-    memcpy(&rope_ctx->freq_scale, (int32_t *) op_params + 6, sizeof(float));
-    memcpy(&rope_ctx->ext_factor, (int32_t *) op_params + 7, sizeof(float));
-    memcpy(&rope_ctx->attn_factor, (int32_t *) op_params + 8, sizeof(float));
-    memcpy(&rope_ctx->beta_fast, (int32_t *) op_params + 9, sizeof(float));
-    memcpy(&rope_ctx->beta_slow, (int32_t *) op_params + 10, sizeof(float));
-    memcpy(&rope_ctx->sections, (int32_t *) op_params + 11, sizeof(int) * 4);
-
-    rope_ctx->theta_scale = powf(rope_ctx->freq_base, -2.0f / rope_ctx->n_dims);
-
-    rope_corr_dims(rope_ctx->n_dims, rope_ctx->n_ctx_orig, rope_ctx->freq_base, rope_ctx->beta_fast,
-                   rope_ctx->beta_slow, rope_ctx->corr_dims);
-
-    rope_ctx->octx = octx;
-    FARF(HIGH, "rope-f32 n_dims:%d, ext_factor:%.6f, theta_scale:%.6f, attn_factor:%.6f\n", rope_ctx->n_dims,
-         rope_ctx->ext_factor, rope_ctx->theta_scale, rope_ctx->attn_factor);
-}
+static inline void hvx_rope_neox_f32_aa(float * restrict dst, const float * restrict src0, uint32_t ne, const float * restrict theta_cache) {
+    const HVX_Vector * restrict vsrc   = (const HVX_Vector *) src0;
+    const HVX_Vector * restrict vtheta = (const HVX_Vector *) theta_cache;
+    HVX_Vector       * restrict vdst   = (HVX_Vector *) dst;
 
-static void hvx_calc_rope_neox_f32(const float * restrict src0,
-                                   float * restrict dst,
-                                   const int num_elems,
-                                   const float * restrict theta_cache) {
-    // for (int i = 0; i < num_elems; i += 2) {
-    //const float cos_theta = theta_cache[i + 0];
-    //const float sin_theta = theta_cache[i + 1];
+    uint32_t nvec = (ne / (VLEN_FP32 * 2) * 2); // 2 vecs per loop, step of 2
 
-    //const float x0 = src[0];
-    //const float x1 = src[num_elems/2];
+    uint32_t he = ne / 2;         // half_dims offset in elements
+    uint32_t hv = he / VLEN_FP32; // half_dims offset in vectors
 
-    //dst[0] = x0*cos_theta - x1*sin_theta;
-    //dst[num_elems/2] = x0*sin_theta + x1*cos_theta;
+    #pragma unroll(2)
+    for (uint32_t i = 0; i < nvec; i += 2) {
+        HVX_Vector v0 = vsrc[i/2+0];
+        HVX_Vector v1 = vsrc[i/2+hv];
 
-    //src += 1;
-    //dst += 1;
-    // }
-
-    const uint8_t * restrict src0_curr  = (const uint8_t *) src0;
-    const uint8_t * restrict theta_curr = (const uint8_t *) theta_cache;
-    uint8_t * restrict dst_curr         = (uint8_t *) dst;
-
-    int step_of_1 = num_elems >> 6;  // 6 because we process two vectors at once
-    int half_size = (sizeof(float) * (num_elems / 2));
-
-    for (int i = 0; i < step_of_1; i++) {
-        HVX_Vector v0 = *(HVX_Vector *) src0_curr;
-        HVX_Vector v1 = *(HVX_Vector *) (src0_curr + half_size);
-
-        HVX_Vector v2 = *(HVX_Vector *) theta_curr;
-        HVX_Vector v3 = *(HVX_Vector *) (theta_curr + VLEN);
+        HVX_Vector v2 = vtheta[i+0];
+        HVX_Vector v3 = vtheta[i+1];
 
         HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4);  // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta
 
@@ -186,45 +161,34 @@ static void hvx_calc_rope_neox_f32(const float * restrict src0,
         HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s);
         HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c);
 
-        *(HVX_Vector *) dst_curr               = Q6_Vsf_equals_Vqf32(v4);
-        *(HVX_Vector *) (dst_curr + half_size) = Q6_Vsf_equals_Vqf32(v5);
+        vdst[i/2+0]  = Q6_Vsf_equals_Vqf32(v4);
+        vdst[i/2+hv] = Q6_Vsf_equals_Vqf32(v5);
+    }
 
-        src0_curr += VLEN;
-        theta_curr += 2 * VLEN;
-        dst_curr += VLEN;
+    for (uint32_t i = nvec * VLEN_FP32; i < ne; i += 2) {
+        const float cos_theta = theta_cache[i+0];
+        const float sin_theta = theta_cache[i+1];
+        float x0 = src0[i/2];
+        float x1 = src0[i/2 + he];
+        dst[i/2]      = x0 * cos_theta - x1 * sin_theta;
+        dst[i/2 + he] = x0 * sin_theta + x1 * cos_theta;
     }
 }
 
-static void hvx_calc_rope_f32(const float * restrict src0,
-                              float * restrict dst,
-                              const int num_elems,
-                              const float * restrict theta_cache) {
-    // for (int i = 0; i < num_elems; i += 2) {
-    //const float cos_theta = theta_cache[i + 0];
-    //const float sin_theta = theta_cache[i + 1];
-
-    //const float x0 = src[0];
-    //const float x1 = src[1];
+static inline void hvx_rope_f32_aa(float * restrict dst, const float * restrict src0, uint32_t ne, const float * restrict theta_cache) {
+    const HVX_Vector * restrict vsrc   = (const HVX_Vector *) src0;
+    const HVX_Vector * restrict vtheta = (const HVX_Vector *) theta_cache;
+    HVX_Vector       * restrict vdst   = (HVX_Vector *) dst;
 
-    //dst[0] = x0*cos_theta - x1*sin_theta;
-    //dst[1] = x0*sin_theta + x1*cos_theta;
+    uint32_t nvec = (ne / (VLEN_FP32 * 2)) * 2; // 2 vecs per loop, step of two
 
-    //src += 2;
-    //dst += 2;
-    // }
+    #pragma unroll(2)
+    for (uint32_t i = 0; i < nvec; i+=2) {
+        HVX_Vector v0 = vsrc[i+0];
+        HVX_Vector v1 = vsrc[i+1];
 
-    const uint8_t * restrict src0_curr  = (const uint8_t *) src0;
-    const uint8_t * restrict theta_curr = (const uint8_t *) theta_cache;
-    uint8_t * restrict dst_curr         = (uint8_t *) dst;
-
-    int step_of_1 = num_elems >> 6;  // 6 because we process two vectors at once
-
-    for (int i = 0; i < step_of_1; i++) {
-        HVX_Vector v0 = *(HVX_Vector *) src0_curr;
-        HVX_Vector v1 = *(HVX_Vector *) (src0_curr + VLEN);
-
-        HVX_Vector v2 = *(HVX_Vector *) theta_curr;
-        HVX_Vector v3 = *(HVX_Vector *) (theta_curr + VLEN);
+        HVX_Vector v2 = vtheta[i+0];
+        HVX_Vector v3 = vtheta[i+1];
 
         HVX_VectorPair vx0_x1   = Q6_W_vdeal_VVR(v1, v0, -4);  // vx0_x1[0] = x0, vx0_x1[1] = x1
         HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4);  // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta
@@ -239,151 +203,182 @@ static void hvx_calc_rope_f32(const float * restrict src0,
 
         HVX_VectorPair vstore = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(v5), Q6_Vsf_equals_Vqf32(v4), -4);
 
-        *(HVX_Vector *) dst_curr          = Q6_V_lo_W(vstore);
-        *(HVX_Vector *) (dst_curr + VLEN) = Q6_V_hi_W(vstore);
+        vdst[i+0] = Q6_V_lo_W(vstore);
+        vdst[i+1] = Q6_V_hi_W(vstore);
+    }
+
+    for (uint32_t i = nvec * VLEN_FP32; i < ne; i += 2) {
+        const float cos_theta = theta_cache[i+0];
+        const float sin_theta = theta_cache[i+1];
+        float x0 = src0[i+0];
+        float x1 = src0[i+1];
+        dst[i+0] = x0 * cos_theta - x1 * sin_theta;
+        dst[i+1] = x0 * sin_theta + x1 * cos_theta;
+    }
+}
+
+static void inline rope_basic_f32(struct htp_rope_context * rctx, uint8_t * restrict dst, uint8_t * restrict src,
+                   uint32_t nr, uint32_t ne0, const float * restrict theta_cache) {
+    #pragma unroll(4)
+    for (uint32_t i = 0; i < nr; i++) {
+        float * d = (float *) (dst + i * rctx->dst_row_size_aligned);
+        float * s = (float *) (src + i * rctx->src0_row_size_aligned);
+
+        hvx_rope_f32_aa(d, s, rctx->n_dims, theta_cache);
+
+        // fill the remain channels with data from src tensor
+        if (rctx->n_dims < ne0) {
+            hvx_copy_f32_uu((uint8_t *)(d + rctx->n_dims), (uint8_t *)(s + rctx->n_dims), ne0 - rctx->n_dims);
+        }
+    }
+}
+
+static void inline rope_neox_f32(struct htp_rope_context * rctx, uint8_t * restrict dst, uint8_t * restrict src,
+                   uint32_t nr, uint32_t ne0, const float * restrict theta_cache) {
+    #pragma unroll(4)
+    for (uint32_t i = 0; i < nr; i++) {
+        float * d = (float *) (dst + i * rctx->dst_row_size_aligned);
+        float * s = (float *) (src + i * rctx->src0_row_size_aligned);
 
-        src0_curr += 2 * VLEN;
-        theta_curr += 2 * VLEN;
-        dst_curr += 2 * VLEN;
+        hvx_rope_neox_f32_aa(d, s, rctx->n_dims, theta_cache);
+
+        // fill the remain channels with data from src tensor
+        if (rctx->n_dims < ne0) {
+            hvx_copy_f32_uu((uint8_t *)(d + rctx->n_dims), (uint8_t *)(s + rctx->n_dims), ne0 - rctx->n_dims);
+        }
     }
 }
 
-static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
-                         const uint32_t       ir0,
-                         const uint32_t       ir1,
-                         int                  nth,
-                         int                  ith,
-                         const int            opt_path) {
-    struct htp_ops_context * octx = rope_ctx->octx;
+static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) {
+    struct htp_rope_context * rctx = (struct htp_rope_context *) data;
+    struct htp_ops_context * octx = rctx->octx;
 
     const struct htp_tensor * src0 = &octx->src0;
     const struct htp_tensor * src1 = &octx->src1;
     const struct htp_tensor * src2 = &octx->src2;
     struct htp_tensor *       dst  = &octx->dst;
 
-    const int32_t mode    = rope_ctx->mode;
-    const bool    is_neox = mode & HTP_ROPE_TYPE_NEOX;
-
     htp_rope_preamble;
 
-    const int32_t * pos = (const int32_t *) src1->data;
+    const uint32_t src0_nrows = rctx->src0_nrows;
+    const uint32_t src0_nrows_per_thread = rctx->src0_nrows_per_thread;
 
-    float * wp0 = (float *) (octx->src0_spad.data + (ith * nb01));
+    const uint32_t src0_start_row = src0_nrows_per_thread * ith;
+    const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
 
-    const float * freq_factors = NULL;
-    if (src2 != NULL) {
-        freq_factors = (const float *) src2->data;
+    // no work for this thread
+    if (src0_start_row >= src0_end_row) {
+        return;
     }
 
-    const uint32_t i1_end       = MIN(ir1, ne1);
-    const int32_t  half_dims    = rope_ctx->n_dims / 2;
-    const size_t   remain_bytes = (ne0 - rope_ctx->n_dims) * sizeof(float);
-    for (uint32_t i3 = 0; i3 < ne3; i3++) {      // batch
-        for (uint32_t i2 = 0; i2 < ne2; i2++) {  // seq-len
-            const int32_t p = pos[i2];
+    uint64_t tt = HAP_perf_get_qtimer_count();
 
-            rope_cache_init(p, rope_ctx->freq_scale, freq_factors, rope_ctx->corr_dims, ne0, rope_ctx->ext_factor,
-                            rope_ctx->attn_factor, wp0, rope_ctx->theta_scale);
+    const int32_t mode    = rctx->mode;
+    const bool    is_neox = mode & HTP_ROPE_TYPE_NEOX;
 
-            for (uint32_t i1 = ir0; i1 < i1_end; i1++) {  // attn-heads
-                const float * src      = (float *) ((char *) src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01);
-                float *       dst_data = (float *) ((char *) dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1);
+    // VTCM setup
+    uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
+    float *   theta_cache    = (float *) (src0_spad_base);
+              src0_spad_base = src0_spad_base + rctx->theta_cache_offset;
+    uint8_t * dst_spad_base  = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
 
-                const float * src_loc      = src;
-                float *       dst_data_loc = dst_data;
+    dma_queue * dma_queue = octx->ctx->dma[ith];
+    const int32_t * pos = (const int32_t *) src1->data;
+    const float * freq_factors = src2->data ? (const float *) src2->data : NULL;
 
-                if (1 == opt_path) {
-                    if (is_neox) {
-                        hvx_calc_rope_neox_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
-                    } else {
-                        hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
-                    }
+    uint32_t ir = 0;
+    uint32_t prev_i2 = (uint32_t) -1;
 
-                    src_loc += rope_ctx->n_dims;
-                    dst_data_loc += rope_ctx->n_dims;
-                } else {
-                    for (uint32_t i0 = 0; i0 < rope_ctx->n_dims; i0 += 2) {
-                        const float cos_theta = wp0[i0 + 0];
-                        const float sin_theta = wp0[i0 + 1];
+    for (uint32_t i3 = 0; i3 < ne3; i3++) { // batch
+        for (uint32_t i2 = 0; i2 < ne2; i2++) { // seq-len
+            for (uint32_t i1 = 0; i1 < ne1; ) { // attn-heads
+                if (ir < src0_start_row) { ir++; i1++; continue; }
+                if (ir >= src0_end_row) goto done;
 
-                        if (is_neox) {
-                            const float x0 = src_loc[0];
-                            const float x1 = src_loc[half_dims];
+                // Rows in this block
+                const uint32_t nrows = MIN(src0_end_row - ir, ne1 - i1);
 
-                            dst_data_loc[0]         = x0 * cos_theta - x1 * sin_theta;
-                            dst_data_loc[half_dims] = x0 * sin_theta + x1 * cos_theta;
+                // Depth before prefetch
+                uint32_t dma_depth = dma_queue_depth(dma_queue);
 
-                            src_loc += 1;
-                            dst_data_loc += 1;
-                        } else {
-                            const float x0 = src_loc[0];
-                            const float x1 = src_loc[1];
+                // FARF(HIGH, "rope-block %u: ir %u n-rows %u dma-depth %u : usec %u", ith, ir, nrows, dma_depth,
+                //             (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start));
 
-                            dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
-                            dst_data_loc[1] = x0 * sin_theta + x1 * cos_theta;
+                // Prefetch loop
+                for (uint32_t pnr = 0, pr = 0; pr < nrows && pr < HTP_ROPE_SPAD_NROWS; pr += pnr) {
+                    pnr = MIN(nrows - pr, HTP_ROPE_SPAD_BLOCK);
 
-                            src_loc += 2;
-                            dst_data_loc += 2;
-                        }
-                    }
+                    uint32_t pi1 = i1 + pr;
+                    uint32_t pir = ir + pr;
+
+                    // Dummy DMA transaction for sequencing (interleaving dst,src,dst,...)
+                    dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr((void *) dst->data, dst_spad_base + pr * rctx->dst_row_size_aligned), 0, 0, 0);
 
-                    src_loc += (is_neox ? half_dims : 0);
-                    dst_data_loc += (is_neox ? half_dims : 0);
+                    const uint8_t * src_addr = (const uint8_t *) src0->data + i3 * nb03 + i2 * nb02 + pi1 * nb01;
+                          uint8_t * src_spad = src0_spad_base + pr * rctx->src0_row_size_aligned;
+                    dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src_spad, src_addr),
+                        rctx->src0_row_size_aligned, rctx->src0_row_size, pnr);
+
+                    // FARF(HIGH, "rope-prefetch %u: pr %u i1 %u i2 %u i3 %u src-spad %p src-addr %p pnr %u", ith, pir, pi1, i2, i3, src_spad, src_addr, pnr);
                 }
 
-                // TODO: use simd to speed up the remaining elements copy
-                memcpy(dst_data_loc, src_loc, remain_bytes);
-            }
-        }
-    }
-}
+                // Update theta cache
+                if (i2 != prev_i2) {
+                    prev_i2 = i2;
 
-static void rope_job_f32_per_thread(struct rope_th_ctx * rope_ctx, int nth, int ith) {
-    struct htp_ops_context * octx = rope_ctx->octx;
+                    const int32_t p = pos[i2];
+                    rope_cache_init(p, rctx->freq_scale, freq_factors, rctx->corr_dims, ne0, rctx->ext_factor, rctx->attn_factor, theta_cache, rctx->theta_scale);
 
-    const struct htp_tensor * src0 = &octx->src0;
-    const struct htp_tensor * src1 = &octx->src1;
-    struct htp_tensor *       dst  = &octx->dst;
+                    // FARF(HIGH, "rope-theta %u: ir %u i1 %u i2 %u i3 %u cache %p : usec %u", ith, ir, i1, i2, i3, theta_cache,
+                    //         (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start));
+                }
 
-    htp_rope_preamble;
+                // Skip DMA transactions from prev block (if any)
+                // No need to wait for these since the DMA is setup for in-order processing
+                for (uint32_t d=0; d < dma_depth; d++) { dma_queue_pop_nowait(dma_queue); }
 
-    const uint32_t src0_nrows            = ne01 * ne02 * ne03;  // src0 rows
-    const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
+                // Compute loop
+                for (uint32_t cnr = 0, cr = 0; cr < nrows; cr += cnr, ir += cnr, i1 += cnr) {
+                    // Number of rows to compute
+                    cnr = MIN(nrows - cr, HTP_ROPE_SPAD_BLOCK);
 
-    const uint32_t src0_start_row = src0_nrows_per_thread * ith;
-    const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
+                    uint8_t * dst_spad = (uint8_t *) dma_queue_pop(dma_queue).src;
+                    uint8_t * src_spad = (uint8_t *) dma_queue_pop(dma_queue).dst;
 
-    // no work for this thread
-    if (src0_start_row >= src0_end_row) {
-        return;
-    }
+                    // FARF(HIGH, "rope-compute %u: ir %u i1 %u i2 %u i3 %u src-spad %p cnr %u : usec %u", ith, ir, i1, i2, i3, src_spad, cnr,
+                    //         (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start));
 
-    uint64_t t1, t2;
-    t1 = HAP_perf_get_qtimer_count();
+                    if (is_neox) {
+                        rope_neox_f32(rctx, dst_spad, src_spad, cnr, ne0, theta_cache);
+                    } else {
+                        rope_basic_f32(rctx, dst_spad, src_spad, cnr, ne0, theta_cache);
+                    }
 
-    int is_aligned = 1;
-    int opt_path   = 0;
-    if ((0 == hex_is_aligned((void *) src0->data, VLEN)) || (0 == hex_is_aligned((void *) src1->data, VLEN)) ||
-        (0 == hex_is_aligned((void *) dst->data, VLEN))) {
-        FARF(HIGH, "rope-f32: unaligned addresses in rope op, possibly slower execution\n");
-        is_aligned = 0;
-    }
-    if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
-        opt_path = 1;
-    }
+                    uint8_t * dst_addr = (uint8_t *) dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1;
+                    dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(dst_addr, dst_spad), rctx->dst_row_size, rctx->dst_row_size_aligned, cnr);
 
-    rope_hex_f32(rope_ctx, src0_start_row, src0_end_row, nth, ith, opt_path);
+                    // Prefetch more rows (if any)
+                    if ((cr + HTP_ROPE_SPAD_NROWS) < nrows) {
+                        uint32_t pnr = MIN(nrows - (cr + HTP_ROPE_SPAD_NROWS), HTP_ROPE_SPAD_BLOCK);
+                        uint32_t pi1 = i1 + HTP_ROPE_SPAD_NROWS;
+                        uint32_t pir = ir + HTP_ROPE_SPAD_NROWS;
 
-    t2 = HAP_perf_get_qtimer_count();
+                        const uint8_t * src_addr = (const uint8_t *) src0->data + i3 * nb03 + i2 * nb02 + pi1 * nb01;
+                        dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src_spad, src_addr),
+                            rctx->src0_row_size_aligned, rctx->src0_row_size, pnr);
 
-    FARF(HIGH, "rope-f32: %d/%d/%d: (%u:%u) usec %u\n", ith, nth, opt_path, src0_start_row, src0_end_row,
-         (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
-}
+                        // FARF(HIGH, "rope-prefetch %u: pr %u i1 %u i2 %u i3 %u src-spad %p src-addr %p pnr %u", ith, pir, pi1, i2, i3, src_spad, src_addr, pnr);
+                    }
+                }
+            }
+        }
+    }
 
-static void rope_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) {
-    struct rope_th_ctx * rope_ctx = (struct rope_th_ctx *) data;
+done:
+    dma_queue_flush(dma_queue);
+    tt = HAP_perf_get_qtimer_count() - tt;
 
-    rope_job_f32_per_thread(rope_ctx, n, i);
+    FARF(HIGH, "rope-f32: %d/%d: (%u:%u) usec %u\n", ith, nth, src0_start_row, src0_end_row, (unsigned) HAP_perf_qtimer_count_to_us(tt));
 }
 
 static int execute_op_rope_f32(struct htp_ops_context * octx) {
@@ -394,17 +389,10 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) {
     const struct htp_tensor * src2 = &octx->src2;
     struct htp_tensor *       dst  = &octx->dst;
 
-    worker_callback_t op_func;
-    const char *      op_type = NULL;
-
-    struct rope_th_ctx rope_ctx;
+    const char * op_type = "rope-f32";
 
     switch (octx->op) {
         case HTP_OP_ROPE:
-            op_func = rope_job_dispatcher_f32;
-            op_type = "rope-f32";
-
-            init_rope_ctx(&rope_ctx, octx);
             break;
 
         default:
@@ -415,49 +403,79 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) {
     const uint32_t n_threads = octx->n_threads;
 
     const size_t src0_row_size = src0->nb[1];
-    const size_t src1_row_size = src0_row_size;
     const size_t dst_row_size  = dst->nb[1];
 
-    // VTCM scratchpads for all tensors
-    // N rows per thread, padded to HVX vector size
-    octx->dst_spad.size  = hex_round_up(dst_row_size, 128) * n_threads;
-    octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads;
-    octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads;
-
-    size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
-
-    if (src2->ne[0]) {
-        FARF(HIGH,
-             "%s: %ux%ux%ux%u (x %ux%ux%ux%u x %ux%ux%ux%u) -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u "
-             "dst-spad-size %u\n",
-             op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
-             src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0], dst->ne[1], dst->ne[2],
-             dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
-    } else {
-        FARF(HIGH,
-             "%s: %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n",
-             op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
-             src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size,
-             octx->dst_spad.size);
-    }
+    // Aligned row sizes for VTCM
+    const size_t src0_row_size_aligned    = hex_round_up(src0_row_size, VLEN);
+    const size_t dst_row_size_aligned     = hex_round_up(dst_row_size, VLEN);
+    const size_t theta_cache_size_aligned = hex_round_up(src0->ne[0] * sizeof(float), 128);
+
+    // Calculate spad sizes per thread
+    size_t src0_spad_per_thread = theta_cache_size_aligned + HTP_ROPE_SPAD_NROWS * src0_row_size_aligned;
+    size_t dst_spad_per_thread  = HTP_ROPE_SPAD_NROWS * dst_row_size_aligned;
+    size_t spad_per_thread = src0_spad_per_thread + dst_spad_per_thread;
 
-    // Make sure the reserved vtcm size is sufficient
-    if (octx->ctx->vtcm_size < spad_size) {
-        FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
-             spad_size);
+    // Check if we fit in VTCM
+    size_t total_vtcm_needed = spad_per_thread * n_threads;
+    if (octx->ctx->vtcm_size < total_vtcm_needed) {
+        FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size, total_vtcm_needed);
         return HTP_STATUS_VTCM_TOO_SMALL;
     }
 
+    // Assign sizes
+    octx->src0_spad.size_per_thread = src0_spad_per_thread;
+    octx->dst_spad.size_per_thread  = dst_spad_per_thread;
+    octx->src0_spad.size = n_threads * src0_spad_per_thread;
+    octx->dst_spad.size  = n_threads * dst_spad_per_thread;
+    octx->src1_spad.size = 0;
+
+    // Assign pointers
     octx->src0_spad.data = octx->ctx->vtcm_base;
-    octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
-    octx->dst_spad.data  = octx->src1_spad.data + octx->src1_spad.size;
+    octx->src1_spad.data = NULL;
+    octx->dst_spad.data  = octx->src0_spad.data + octx->src0_spad.size;
+
+    // Fill context
+    struct htp_rope_context rctx;
+    memset(&rctx, 0, sizeof(struct htp_rope_context));
+
+    rctx.t_start = HAP_perf_get_qtimer_count();
+
+    rctx.octx = octx;
+
+    const int32_t * op_params = &octx->op_params[0];
+    rctx.n_dims     = ((const int32_t *) op_params)[1];
+    rctx.mode       = ((const int32_t *) op_params)[2];
+    rctx.n_ctx_orig = ((const int32_t *) op_params)[4];
 
+    memcpy(&rctx.freq_base,   (int32_t *) op_params + 5,  sizeof(float));
+    memcpy(&rctx.freq_scale,  (int32_t *) op_params + 6,  sizeof(float));
+    memcpy(&rctx.ext_factor,  (int32_t *) op_params + 7,  sizeof(float));
+    memcpy(&rctx.attn_factor, (int32_t *) op_params + 8,  sizeof(float));
+    memcpy(&rctx.beta_fast,   (int32_t *) op_params + 9,  sizeof(float));
+    memcpy(&rctx.beta_slow,   (int32_t *) op_params + 10, sizeof(float));
+    memcpy(&rctx.sections,    (int32_t *) op_params + 11, sizeof(int) * 4);
+
+    rctx.theta_scale = powf(rctx.freq_base, -2.0f / rctx.n_dims);
+
+    rope_corr_dims(rctx.n_dims, rctx.n_ctx_orig, rctx.freq_base, rctx.beta_fast, rctx.beta_slow, rctx.corr_dims);
+
+    rctx.src0_row_size = src0_row_size;
+    rctx.dst_row_size  = dst_row_size;
+    rctx.src0_row_size_aligned = src0_row_size_aligned;
+    rctx.dst_row_size_aligned  = dst_row_size_aligned;
+    rctx.theta_cache_offset    = theta_cache_size_aligned;
+
+    uint32_t ne0 = dst->ne[0];
     uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
+    rctx.src0_nrows = src0_nrows;
+
+    FARF(HIGH, "rope-f32 n-rows %u n-dims %d ne0 %u ext-factor %.6f theta-scale %.6f attn-factor %.6f\n", rctx.src0_nrows, rctx.n_dims, ne0,
+         rctx.ext_factor, rctx.theta_scale, rctx.attn_factor);
 
     if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
-        uint32_t n_jobs             = MIN(n_threads, src0_nrows);
-        octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
-        worker_pool_run_func(octx->ctx->worker_pool, op_func, &rope_ctx, n_jobs);
+        uint32_t n_jobs = MIN(n_threads, src0_nrows);
+        rctx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
+        worker_pool_run_func(octx->ctx->worker_pool, rope_job_f32, &rctx, n_jobs);
     }
 
     return err;
index 904484da9de88e86022a6e9e3742b484c06b6f56..2fd6c907724407f726363599c21de58019034412 100644 (file)
                                             \
     const uint32_t nr  = ne01;
 
-static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) {
+struct htp_set_rows_context {
+    struct htp_ops_context * octx;
+    struct fastdiv_values div_ne12;
+    struct fastdiv_values div_ne11;
+    uint32_t src0_nrows_per_thread;
+};
+
+static void set_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) {
+    struct htp_set_rows_context * srctx = (struct htp_set_rows_context *)data;
+    struct htp_ops_context * octx = srctx->octx;
+
     set_rows_preamble;
 
     // parallelize by rows of src0
-    const uint32_t dr  = octx->src0_nrows_per_thread;
+    const uint32_t dr  = srctx->src0_nrows_per_thread;
     const uint32_t ir0 = dr * ith;
     const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
 
@@ -56,8 +66,8 @@ static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth,
     for (uint32_t i03 = 0; i03 < ne03; ++i03) {
         for (uint32_t i02 = 0; i02 < ne02; ++i02) {
             for (uint32_t i = ir0; i < ir1; ++i) {
-                const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12);
-                const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11);
+                const uint32_t i12 = fastmodulo(i03, ne12, &srctx->div_ne12);
+                const uint32_t i11 = fastmodulo(i02, ne11, &srctx->div_ne11);
                 const uint32_t i10 = i;
 
                 const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
@@ -76,15 +86,16 @@ static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth,
             }
         }
     }
-
-    return HTP_STATUS_OK;
 }
 
-static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth, const int ith) {
+static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *data) {
+    struct htp_set_rows_context * srctx = (struct htp_set_rows_context *)data;
+    struct htp_ops_context * octx = srctx->octx;
+
     set_rows_preamble;
 
     // parallelize by rows of src0
-    const uint32_t dr  = octx->src0_nrows_per_thread;
+    const uint32_t dr  = srctx->src0_nrows_per_thread;
     const uint32_t ir0 = dr * ith;
     const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
 
@@ -93,8 +104,8 @@ static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth,
     for (uint32_t i03 = 0; i03 < ne03; ++i03) {
         for (uint32_t i02 = 0; i02 < ne02; ++i02) {
             for (uint32_t i = ir0; i < ir1; ++i) {
-                const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12);
-                const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11);
+                const uint32_t i12 = fastmodulo(i03, ne12, &srctx->div_ne12);
+                const uint32_t i11 = fastmodulo(i02, ne11, &srctx->div_ne11);
                 const uint32_t i10 = i;
 
                 const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
@@ -112,16 +123,6 @@ static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth,
             }
         }
     }
-
-    return HTP_STATUS_OK;
-}
-
-static void set_rows_work_f16_f32(unsigned int n, unsigned int i, void *data) {
-    set_rows_thread_f16_f32((struct htp_ops_context *) data, n, i);
-}
-
-static void set_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) {
-    set_rows_thread_f32_f32((struct htp_ops_context *) data, n, i);
 }
 
 int op_set_rows(struct htp_ops_context * octx) {
@@ -143,18 +144,20 @@ int op_set_rows(struct htp_ops_context * octx) {
         return HTP_STATUS_OK;
     }
 
-    octx->set_rows_div_ne12 = init_fastdiv_values(ne12);
-    octx->set_rows_div_ne11 = init_fastdiv_values(ne11);
+    struct htp_set_rows_context srctx;
+    srctx.octx = octx;
+    srctx.div_ne12 = init_fastdiv_values(ne12);
+    srctx.div_ne11 = init_fastdiv_values(ne11);
 
     const uint32_t n_jobs = MIN(nr, octx->n_threads);
-    octx->src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
+    srctx.src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
 
     switch(octx->dst.type) {
     case HTP_TYPE_F32:
-        worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f32_f32, octx, n_jobs);
+        worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f32_f32, &srctx, n_jobs);
         break;
     case HTP_TYPE_F16:
-        worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f16_f32, octx, n_jobs);
+        worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f16_f32, &srctx, n_jobs);
         break;
     default:
         return HTP_STATUS_NO_SUPPORT;
index e91a16d947f56f0a4db7653388a4346970fc75b1..6e22eb6a63991727a4f3bdb18955cbb06d098273 100644 (file)
@@ -10,6 +10,7 @@
 
 #include "hex-dma.h"
 #include "hvx-utils.h"
+#include "hex-fastdiv.h"
 
 #define GGML_COMMON_DECL_C
 #include "ggml-common.h"
@@ -48,7 +49,7 @@
     const uint32_t nb2 = dst->nb[2];                       \
     const uint32_t nb3 = dst->nb[3];
 
-struct softmax_th_ctx {
+struct htp_softmax_context {
     bool     use_f16;
     bool     use_src1;
     uint32_t n_head;
@@ -59,28 +60,48 @@ struct softmax_th_ctx {
     float m0;
     float m1;
 
+    uint32_t src0_nrows_per_thread;
+    struct fastdiv_values fastdiv_ne01;
+    struct fastdiv_values fastdiv_ne02;
+    struct fastdiv_values fastdiv_ne12; // For mask broadcasting
+    struct fastdiv_values fastdiv_ne13; // For mask broadcasting
+    size_t spad_stride;
+
     struct htp_ops_context * octx;
 };
 
-static void init_softmax_ctx(struct softmax_th_ctx * softmax_ctx, struct htp_ops_context * octx) {
+static void init_softmax_ctx(struct htp_softmax_context * smctx, struct htp_ops_context * octx) {
     const struct htp_tensor * src0 = &octx->src0;
     const struct htp_tensor * src1 = &octx->src1;
 
-    memset(softmax_ctx, 0, sizeof(struct softmax_th_ctx));
+    memset(smctx, 0, sizeof(struct htp_softmax_context));
+
+    memcpy(&smctx->scale, (float *) octx->op_params, sizeof(float));
+    memcpy(&smctx->max_bias, (float *) octx->op_params + 1, sizeof(float));
+
+    smctx->n_head      = src0->ne[2];
+    smctx->n_head_log2 = 1u << (uint32_t) floor(log2(smctx->n_head));
+
+    smctx->m0 = powf(2.0f, -(smctx->max_bias) / smctx->n_head_log2);
+    smctx->m1 = powf(2.0f, -(smctx->max_bias / 2.0f) / smctx->n_head_log2);
 
-    memcpy(&softmax_ctx->scale, (float *) octx->op_params, sizeof(float));
-    memcpy(&softmax_ctx->max_bias, (float *) octx->op_params + 1, sizeof(float));
+    smctx->use_src1 = (src1->ne[0] != 0);
+    smctx->use_f16  = (src1->ne[0] != 0) && (src1->type == HTP_TYPE_F16);
 
-    softmax_ctx->n_head      = src0->ne[2];
-    softmax_ctx->n_head_log2 = 1u << (uint32_t) floor(log2(softmax_ctx->n_head));
+    smctx->octx = octx;
 
-    softmax_ctx->m0 = powf(2.0f, -(softmax_ctx->max_bias) / softmax_ctx->n_head_log2);
-    softmax_ctx->m1 = powf(2.0f, -(softmax_ctx->max_bias / 2.0f) / softmax_ctx->n_head_log2);
+    // Initialize fastdiv values
+    const uint32_t ne01 = src0->ne[1];
+    const uint32_t ne02 = src0->ne[2];
 
-    softmax_ctx->use_src1 = (src1->ne[0] != 0);
-    softmax_ctx->use_f16  = (src1->ne[0] != 0) && (src1->type == HTP_TYPE_F16);
+    if (ne01 > 0) smctx->fastdiv_ne01 = init_fastdiv_values(ne01);
+    if (ne02 > 0) smctx->fastdiv_ne02 = init_fastdiv_values(ne02);
 
-    softmax_ctx->octx = octx;
+    const uint32_t ne12 = (src1->ne[0]) ? src1->ne[2] : 1;
+    const uint32_t ne13 = (src1->ne[0]) ? src1->ne[3] : 1;
+
+    if (ne12 > 0) smctx->fastdiv_ne12 = init_fastdiv_values(ne12);
+    if (ne13 > 0) smctx->fastdiv_ne13 = init_fastdiv_values(ne13);
 }
 
 static void hvx_fast_softmax_prep_f32(const uint8_t * restrict src,
@@ -139,8 +160,7 @@ static void hvx_fast_softmax_f32(const uint8_t * restrict src,
         max_vec       = Q6_Vsf_vmax_VsfVsf(max_vec, v1);
     }
 
-    HVX_Vector v = hvx_vec_reduce_max_f32(max_vec);
-    max_vec      = hvx_vec_repl4(v);
+    max_vec = hvx_vec_reduce_max_f32(max_vec); // replicated over all lanes
 
     #pragma unroll(4)
     for (int i = 0; i < step_of_1; i++) {
@@ -154,8 +174,7 @@ static void hvx_fast_softmax_f32(const uint8_t * restrict src,
         v_pad[i] = v3;
     }
 
-    v       = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_vec));
-    sum_vec = hvx_vec_repl4(v);
+    sum_vec = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_vec)); // replicated over all lanes
 
     HVX_VectorPred pos_sum   = Q6_Q_vcmp_gt_VwVw(sum_vec, zero_v);
     HVX_Vector     v4        = hvx_vec_inverse_f32(sum_vec);
@@ -183,83 +202,9 @@ static float hvx_softmax_f32(const uint8_t * restrict src,
     return sum;
 }
 
-static void softmax_htp_f32(int nth, int ith, struct softmax_th_ctx * softmax_ctx, int opt_path) {
-    struct htp_ops_context * octx = softmax_ctx->octx;
-
-    const struct htp_tensor * src0 = &octx->src0;
-    const struct htp_tensor * src1 = &octx->src1;
-    const struct htp_tensor * dst  = &octx->dst;
-
-    htp_softmax_preamble3;
-
-    uint8_t * src0_spad_data = octx->src0_spad.data + (ith * nb01);
-    uint8_t * src1_spad_data = octx->src1_spad.data + (ith * nb01);
-    uint8_t * dst_spad_data  = octx->dst_spad.data + (ith * nb1);
-
-    float * wp0 = (float *) src0_spad_data;
-    float * wp1 = (float *) src1_spad_data;
-    float * wp2 = (float *) dst_spad_data;
-
-    for (uint32_t i03 = 0; i03 < ne03; i03++) {
-        for (uint32_t i02 = 0; i02 < ne02; i02++) {
-            for (uint32_t i01 = ith; i01 < ne01; i01 += nth) {
-                const uint32_t i11 = i01;
-                const uint32_t i12 = i02 % ne12;
-                const uint32_t i13 = i03 % ne13;
-
-                // ALiBi
-                const uint32_t h = i02;  // head
-
-                const float slope = (softmax_ctx->max_bias > 0.0f) ?
-                                        h < softmax_ctx->n_head_log2 ?
-                                        powf(softmax_ctx->m0, h + 1) :
-                                        powf(softmax_ctx->m1, 2 * (h - softmax_ctx->n_head_log2) + 1) :
-                                        1.0f;
-
-                float * sp = (float *) ((char *) octx->src0.data + i01 * nb01 + i02 * nb02 + i03 * nb03);
-                float * dp = (float *) ((char *) octx->dst.data + i01 * nb1 + i02 * nb2 + i03 * nb3);
-
-                // broadcast the mask across rows
-                __fp16 * mp_f16 = (softmax_ctx->use_src1) ?
-                                      (__fp16 *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
-                                      NULL;
-                float *  mp_f32 = (softmax_ctx->use_src1) ?
-                                      (float *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
-                                      NULL;
-
-                if ((1 == opt_path) && (mp_f32) && !(softmax_ctx->use_f16)) {
-                    hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale,
-                                              (const uint8_t *) mp_f32, slope);
-                } else {
-                    hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, softmax_ctx->scale);
-                    if (mp_f32) {
-                        if (softmax_ctx->use_f16) {
-                            for (int i = 0; i < ne00; ++i) {
-                                wp0[i] += slope * (float) mp_f16[i];
-                            }
-                        } else {
-                            for (int i = 0; i < ne00; ++i) {
-                                wp0[i] += slope * mp_f32[i];
-                            }
-                        }
-                    }
-                }
-
-                if (1 == opt_path) {
-                    hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00);
-                } else {
-                    float max = hvx_reduce_max_f32((const uint8_t *) wp0, ne00);
-                    float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max);
-                    sum       = sum > 0.0 ? (1.0 / sum) : 1;
-                    hvx_scale_f32((uint8_t *) dp, (const uint8_t *) wp2, ne00, sum);
-                }
-            }
-        }
-    }
-}
-
-static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int nth, int ith) {
-    struct htp_ops_context * octx = softmax_ctx->octx;
+static void softmax_job_f32(unsigned int nth, unsigned int ith, void * data) {
+    struct htp_softmax_context * smctx = (struct htp_softmax_context *) data;
+    struct htp_ops_context * octx = smctx->octx;
 
     const struct htp_tensor * src0 = &octx->src0;
     const struct htp_tensor * src1 = &octx->src1;
@@ -268,7 +213,7 @@ static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int
     htp_softmax_preamble3;
 
     const uint32_t src0_nrows            = ne01 * ne02 * ne03;  // src0 rows
-    const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
+    const uint32_t src0_nrows_per_thread = smctx->src0_nrows_per_thread;
 
     const uint32_t src0_start_row = src0_nrows_per_thread * ith;
     const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
@@ -291,20 +236,103 @@ static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int
         opt_path = 1;
     }
 
-    softmax_htp_f32(nth, ith, softmax_ctx, opt_path);
+    uint8_t * src0_spad_data = octx->src0_spad.data + (ith * smctx->spad_stride);
+    uint8_t * src1_spad_data = octx->src1_spad.data + (ith * smctx->spad_stride);
+    uint8_t * dst_spad_data  = octx->dst_spad.data + (ith * smctx->spad_stride);
+
+    float * wp0 = (float *) src0_spad_data;
+    float * wp1 = (float *) src1_spad_data;
+    float * wp2 = (float *) dst_spad_data;
+
+    uint32_t prev_i2 = (uint32_t)-1;
+    float slope = 1.0f;
+
+    for (uint32_t r = src0_start_row; r < src0_end_row; ++r) {
+        uint32_t i1 = fastmodulo(r, ne01, &smctx->fastdiv_ne01);
+        uint32_t r_div_ne01 = fastdiv(r, &smctx->fastdiv_ne01);
+        uint32_t i2 = fastmodulo(r_div_ne01, ne02, &smctx->fastdiv_ne02);
+        uint32_t i3 = fastdiv(r_div_ne01, &smctx->fastdiv_ne02);
+
+        // Map to original logic indices
+        // i01 = i1
+        // i02 = i2
+        // i03 = i3
+
+        const uint32_t i11 = i1;
+        // const uint32_t i12 = i2 % ne12;
+        // const uint32_t i13 = i3 % ne13;
+
+        uint32_t i12, i13;
+        if (ne12 == ne02) {
+             i12 = i2;
+        } else {
+             i12 = fastmodulo(i2, ne12, &smctx->fastdiv_ne12);
+        }
+
+        if (ne13 == ne03) {
+             i13 = i3;
+        } else {
+             i13 = fastmodulo(i3, ne13, &smctx->fastdiv_ne13);
+        }
+
+        // ALiBi
+        if (i2 != prev_i2) {
+            const uint32_t h = i2;  // head
+
+            slope = (smctx->max_bias > 0.0f) ?
+                        h < smctx->n_head_log2 ?
+                        powf(smctx->m0, h + 1) :
+                        powf(smctx->m1, 2 * (h - smctx->n_head_log2) + 1) :
+                        1.0f;
+            prev_i2 = i2;
+        }
+
+        float * sp = (float *) ((char *) octx->src0.data + i1 * nb01 + i2 * nb02 + i3 * nb03);
+        float * dp = (float *) ((char *) octx->dst.data + i1 * nb1 + i2 * nb2 + i3 * nb3);
+
+        // broadcast the mask across rows
+        __fp16 * mp_f16 = (smctx->use_src1) ?
+                              (__fp16 *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
+                              NULL;
+        float *  mp_f32 = (smctx->use_src1) ?
+                              (float *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
+                              NULL;
+
+        if ((1 == opt_path) && (mp_f32) && !(smctx->use_f16)) {
+            hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, smctx->scale,
+                                      (const uint8_t *) mp_f32, slope);
+        } else {
+            hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, smctx->scale);
+            if (mp_f32) {
+                if (smctx->use_f16) {
+                    for (int i = 0; i < ne00; ++i) {
+                        wp0[i] += slope * (float) mp_f16[i];
+                    }
+                } else {
+                    for (int i = 0; i < ne00; ++i) {
+                        wp0[i] += slope * mp_f32[i];
+                    }
+                }
+            }
+        }
+
+        if (1 == opt_path) {
+            hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00);
+        } else {
+            float max = hvx_reduce_max_f32((const uint8_t *) wp0, ne00);
+            float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max);
+            sum       = sum > 0.0 ? (1.0 / sum) : 1;
+            hvx_scale_f32((uint8_t *) dp, (const uint8_t *) wp2, ne00, sum);
+        }
+    }
 
     t2 = HAP_perf_get_qtimer_count();
 
     FARF(HIGH, "softmax-f32 %d/%d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
-         softmax_ctx->use_f16, opt_path, ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13,
+         smctx->use_f16, opt_path, ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13,
          ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
 }
 
-static void softmax_job_dispatcher_f32(unsigned int n, unsigned int i, void * p_data) {
-    struct softmax_th_ctx * p_softmax_ctx = (struct softmax_th_ctx *) p_data;
-    softmax_job_f32_per_thread(p_softmax_ctx, n, i);
-}
-
 static int execute_op_softmax_f32(struct htp_ops_context * octx) {
     int err = HTP_STATUS_OK;
 
@@ -312,17 +340,12 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) {
     const struct htp_tensor * src1 = &octx->src1;
     struct htp_tensor *       dst  = &octx->dst;
 
-    worker_callback_t op_func;
-    const char *      op_type = NULL;
-
-    struct softmax_th_ctx softmax_ctx;
+    struct htp_softmax_context smctx;
+    const char * op_type = "softmax-f32";
 
     switch (octx->op) {
         case HTP_OP_SOFTMAX:
-            op_func = softmax_job_dispatcher_f32;
-            op_type = "softmax-f32";
-
-            init_softmax_ctx(&softmax_ctx, octx);
+            init_softmax_ctx(&smctx, octx);
             break;
 
         default:
@@ -342,6 +365,9 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) {
     octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads;
     octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads;
 
+    // Use stride for calculating offset
+    smctx.spad_stride = hex_round_up(src0_row_size, 128);
+
     size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
 
     if (src1->ne[0]) {
@@ -371,8 +397,8 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) {
 
     if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
         uint32_t n_jobs             = MIN(n_threads, src0_nrows);
-        octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
-        worker_pool_run_func(octx->ctx->worker_pool, op_func, &softmax_ctx, n_jobs);
+        smctx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
+        worker_pool_run_func(octx->ctx->worker_pool, softmax_job_f32, &smctx, n_jobs);
     }
 
     return err;
index 62e45da2b3520f4a480334ab18160bed2c95ec49..04fa72182a38e8636a6a3a5d2bb8f113f4e13d05 100644 (file)
@@ -17,7 +17,6 @@
 #include "htp-msg.h"
 #include "htp-ops.h"
 
-
 #define sum_rows_preamble                       \
     struct htp_tensor *src0 =  &octx->src0;\
     struct htp_tensor *dst  = &octx->dst;  \
     const uint32_t  nb2 = dst->nb[2];      \
     const uint32_t  nb3 = dst->nb[3];      \
 
-static int sum_rows_thread_f32(struct htp_ops_context * octx, const int nth, const int ith) {
-    sum_rows_preamble;
+struct sum_rows_context {
+    const uint8_t * src_data;
+    uint8_t       * dst_data;
+    uint32_t        ne00;
+    size_t          src_stride;
+    size_t          dst_stride;
+    uint32_t        rows_per_thread;
+    uint32_t        total_rows;
+    bool            opt_path;
+};
 
-    const uint32_t src0_nrows_per_thread  = octx->src0_nrows_per_thread;
-    const size_t src0_row_size = nb01;
-    const size_t dst_row_size  = nb1;
+static void sum_rows_thread_f32(unsigned int nth, unsigned int ith, void *data) {
+    const struct sum_rows_context * smctx = (const struct sum_rows_context *) data;
 
-    const uint32_t src0_nrows = ne01 * ne02 * ne03;  // src0 rows
+    const uint32_t rows_per_thread = smctx->rows_per_thread;
+    const uint32_t total_rows      = smctx->total_rows;
 
-    const uint32_t src0_start_row = src0_nrows_per_thread * ith;
-    const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
+    const uint32_t start_row = rows_per_thread * ith;
+    const uint32_t end_row   = MIN(start_row + rows_per_thread, total_rows);
 
-    // no work for this thread
-    if (src0_start_row >= src0_end_row) {
-        return HTP_STATUS_OK;
+    if (start_row >= end_row) {
+        return;
     }
 
-    int opt_path   = 0;
-    if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) {
-        opt_path = 1;
-    }
+    const size_t   src_stride = smctx->src_stride;
+    const size_t   dst_stride = smctx->dst_stride;
+    const uint32_t ne00       = smctx->ne00;
+    const bool     opt_path   = smctx->opt_path;
 
-    const uint8_t * restrict data_src = (const uint8_t *) src0->data;
-    uint8_t * restrict data_dst       = (uint8_t *) dst->data;
+    const float * restrict src_th = (const float *) (smctx->src_data + (start_row * src_stride));
+    float       * restrict dst_th = (float *)       (smctx->dst_data + (start_row * dst_stride));
 
-    const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size));
-    float * restrict dst_th       = (float *) (data_dst + (src0_start_row * dst_row_size));
+    // Calculate actual number of rows for this thread
+    const uint32_t n_rows = end_row - start_row;
 
-    for (uint32_t ir = 0; ir < src0_nrows_per_thread; ir++) {
-        const float * restrict src_local = src_th + (ir * ne00);
+    for (uint32_t ir = 0; ir < n_rows; ir++) {
+        const float * restrict src_local = src_th + (ir * (src_stride / sizeof(float)));
 
-        if (ir + 1 < src0_nrows_per_thread) {
-            hex_l2fetch(src_local + ne00, src0_row_size, src0_row_size, 1);
+        if (ir + 1 < n_rows) {
+            hex_l2fetch(src_local + (src_stride / sizeof(float)), src_stride, src_stride, 1);
         }
 
-        if (1 == opt_path) {
+        if (opt_path) {
             dst_th[ir] = hvx_reduce_sum_f32_a((const uint8_t *) src_local, ne00);
         } else {
             dst_th[ir] = hvx_reduce_sum_f32((const uint8_t *) src_local, ne00);
         }
     }
-
-    return HTP_STATUS_OK;
-}
-
-static void sum_rows_work_f32(unsigned int n, unsigned int i, void *data) {
-    sum_rows_thread_f32((struct htp_ops_context *) data, n, i);
 }
 
 int op_sum_rows(struct htp_ops_context * octx) {
@@ -106,10 +106,25 @@ int op_sum_rows(struct htp_ops_context * octx) {
     const uint32_t src0_nrows = ne01 * ne02 * ne03;
 
     uint32_t n_jobs = MIN(n_threads, src0_nrows);
-    octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
+    uint32_t rows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
 
-    worker_pool_run_func(octx->ctx->worker_pool, sum_rows_work_f32, octx, n_jobs);
+    bool opt_path = false;
+    if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) {
+        opt_path = true;
+    }
+
+    struct sum_rows_context smctx = {
+        .src_data        = (const uint8_t *) src0->data,
+        .dst_data        = (uint8_t *) dst->data,
+        .ne00            = ne00,
+        .src_stride      = nb01,
+        .dst_stride      = nb1,
+        .rows_per_thread = rows_per_thread,
+        .total_rows      = src0_nrows,
+        .opt_path        = opt_path,
+    };
+
+    worker_pool_run_func(octx->ctx->worker_pool, sum_rows_thread_f32, &smctx, n_jobs);
 
     return HTP_STATUS_OK;
 }
-
index ce879bf03701d7897521994bfe4b5947a1e1de0d..98135c50ab8da31aacdc7c875d334a43cfdeee2d 100644 (file)
 #include "htp-msg.h"
 #include "htp-ops.h"
 
+struct htp_unary_context {
+    struct htp_ops_context * octx;
+
+    // Precomputed values
+    const uint8_t *           data_src0;
+    uint8_t *                 data_dst;
+
+    size_t                    src0_row_size;
+    size_t                    dst_row_size;
+
+    size_t                    src0_row_size_aligned;
+    size_t                    dst_row_size_aligned;
+
+    size_t                    src0_spad_half_size;
+    size_t                    dst_spad_half_size;
+
+    uint32_t                  block;
+    uint32_t                  src0_nrows;
+    uint32_t                  src0_nrows_per_thread;
+    uint32_t                  nc;
+};
+
 #define htp_unary_preamble            \
     const uint32_t ne00 = src->ne[0]; \
     const uint32_t ne01 = src->ne[1]; \
@@ -57,8 +79,7 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
         sum_v         = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
     }
 
-    HVX_Vector reduced_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v));
-    sum_v                  = hvx_vec_repl4(reduced_sum);
+    sum_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v)); // replicated over all lanes
 
     HVX_Vector t_v            = hvx_vec_splat_f32((float) num_elems);
     HVX_Vector denom_v        = hvx_vec_inverse_f32(t_v);
@@ -75,128 +96,95 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
     }
 }
 
-static void scale_htp_f32(const float * restrict src,
-                          float * restrict dst,
-                          uint8_t * restrict spad,
-                          const uint32_t num_rows,
-                          const uint32_t row_elems,
-                          const size_t   row_size,
-                          int32_t *      op_params,
-                          int            opt_path) {
+static void scale_f32(const float * restrict src,
+                      float * restrict dst,
+                      uint8_t * restrict spad,
+                      const uint32_t num_rows,
+                      const uint32_t row_elems,
+                      const size_t   row_size,
+                      int32_t *      op_params) {
     float scale = 0.f;
     float bias  = 0.f;
     memcpy(&scale, &op_params[0], sizeof(float));
     memcpy(&bias,  &op_params[1], sizeof(float));
 
     for (uint32_t ir = 0; ir < num_rows; ir++) {
-        const float * restrict src_local = src + (ir * row_elems);
-        float * restrict dst_local       = dst + (ir * row_elems);
-
-        if (ir + 1 < num_rows) {
-            hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
-        }
+        const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
+        uint8_t * restrict dst_local       = (uint8_t *)dst + (ir * row_size);
 
-        hvx_scale_offset_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias);
+        hvx_scale_offset_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias);
     }
 }
 
-static void rms_norm_htp_f32(const float * restrict src,
-                             float * restrict dst,
-                             uint8_t * restrict spad,
-                             const uint32_t num_rows,
-                             const uint32_t row_elems,
-                             const size_t   row_size,
-                             int32_t *      op_params,
-                             int            opt_path) {
+static void rms_norm_f32(const float * restrict src,
+                         float * restrict dst,
+                         uint8_t * restrict spad,
+                         const uint32_t num_rows,
+                         const uint32_t row_elems,
+                         const size_t   row_size,
+                         int32_t *      op_params) {
     float epsilon = 0.f;
     memcpy(&epsilon, op_params, sizeof(float));
 
     for (uint32_t ir = 0; ir < num_rows; ir++) {
-        const float * restrict src_local = src + (ir * row_elems);
-        float * restrict dst_local       = dst + (ir * row_elems);
+        const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
+        uint8_t * restrict dst_local       = (uint8_t *)dst + (ir * row_size);
 
-        if (ir + 1 < num_rows) {
-            hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
-        }
-
-        if (1 == opt_path) {
-            hvx_fast_rms_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon);
-        } else {
-            float sum = hvx_sum_of_squares_f32((const uint8_t *) src_local, row_elems);
-
-            const float mean  = sum / row_elems;
-            const float scale = 1.0f / sqrtf(mean + epsilon);
-
-            hvx_scale_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale);
-        }
+        hvx_fast_rms_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon);
     }
 }
 
-static void sqr_htp_f32(const float * restrict src,
-                          float * restrict dst,
-                          uint8_t * restrict spad,
-                          const uint32_t num_rows,
-                          const uint32_t row_elems,
-                          const size_t   row_size,
-                          int32_t *      op_params,
-                          int            opt_path) {
+static void sqr_f32(const float * restrict src,
+                    float * restrict dst,
+                    uint8_t * restrict spad,
+                    const uint32_t num_rows,
+                    const uint32_t row_elems,
+                    const size_t   row_size,
+                    int32_t *      op_params) {
 
     for (uint32_t ir = 0; ir < num_rows; ir++) {
-        const float * restrict src_local = src + (ir * row_elems);
-        float * restrict dst_local       = dst + (ir * row_elems);
-
-        if (ir + 1 < num_rows) {
-            hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
-        }
+        const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
+        uint8_t * restrict dst_local       = (uint8_t *)dst + (ir * row_size);
 
-        if (1 == opt_path) {
-            hvx_sqr_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
-        } else {
-            hvx_sqr_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
-        }
+        hvx_sqr_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
     }
 }
 
-static void sqrt_htp_f32(const float * restrict src,
-                          float * restrict dst,
-                          uint8_t * restrict spad,
-                          const uint32_t num_rows,
-                          const uint32_t row_elems,
-                          const size_t   row_size,
-                          int32_t *      op_params,
-                          int            opt_path) {
+static void sqrt_f32(const float * restrict src,
+                     float * restrict dst,
+                     uint8_t * restrict spad,
+                     const uint32_t num_rows,
+                     const uint32_t row_elems,
+                     const size_t   row_size,
+                     int32_t *      op_params) {
 
     for (uint32_t ir = 0; ir < num_rows; ir++) {
-        const float * restrict src_local = src + (ir * row_elems);
-        float * restrict dst_local       = dst + (ir * row_elems);
+        const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
+        uint8_t * restrict dst_local       = (uint8_t *)dst + (ir * row_size);
 
-        if (ir + 1 < num_rows) {
-            hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
-        }
-
-        if (1 == opt_path) {
-            hvx_sqrt_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
-        } else {
-            hvx_sqrt_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
-        }
+        hvx_sqrt_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
     }
 }
 
-static void unary_job_f32_per_thread(const struct htp_tensor * src,
-                                     struct htp_tensor *       dst,
-                                     uint8_t *                 spad,
-                                     int                       htp_op,
-                                     int32_t *                 op_params,
-                                     uint32_t                  nth,
-                                     uint32_t                  ith,
-                                     uint32_t                  src0_nrows_per_thread) {
+static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
+    const struct htp_unary_context * uctx = (const struct htp_unary_context *) data;
+    struct htp_ops_context * octx = uctx->octx;
+    const struct htp_tensor * src = &octx->src0;
+    const struct htp_tensor * dst = &octx->dst;
+
     htp_unary_preamble;
 
-    const size_t src0_row_size = nb01;
-    const size_t dst_row_size  = nb1;
+    int                       htp_op = octx->op;
+    int32_t *                 op_params = octx->op_params;
+    uint32_t                  src0_nrows_per_thread = uctx->src0_nrows_per_thread;
 
-    const uint32_t src0_nrows = ne01 * ne02 * ne03;  // src0 rows
+    const size_t src0_row_size = uctx->src0_row_size;
+    const size_t dst_row_size  = uctx->dst_row_size;
 
+    const size_t src0_row_size_aligned = uctx->src0_row_size_aligned;
+    const size_t dst_row_size_aligned  = uctx->dst_row_size_aligned;
+
+    const uint32_t src0_nrows = uctx->src0_nrows;
     const uint32_t src0_start_row = src0_nrows_per_thread * ith;
     const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
 
@@ -208,79 +196,104 @@ static void unary_job_f32_per_thread(const struct htp_tensor * src,
     uint64_t t1, t2;
     t1 = HAP_perf_get_qtimer_count();
 
-    int is_aligned = 1;
-    int opt_path   = 0;
-    if ((0 == hex_is_aligned((void *) src->data, VLEN)) || (0 == hex_is_aligned((void *) dst->data, VLEN))) {
-        is_aligned = 0;
-    }
-    if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
-        opt_path = 1;
+    const uint8_t * restrict data_src = uctx->data_src0;
+    uint8_t * restrict       data_dst = uctx->data_dst;
+
+    uint8_t * src0_spad_data = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
+    uint8_t * dst_spad_data  = octx->dst_spad.data  + (ith * octx->dst_spad.size_per_thread);
+
+    size_t src0_spad_half_size = uctx->src0_spad_half_size;
+    size_t dst_spad_half_size  = uctx->dst_spad_half_size;
+
+    const int BLOCK = uctx->block;
+    if (BLOCK == 0) {
+        FARF(ERROR, "unary-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
+             octx->src0_spad.size_per_thread, src0_row_size_aligned);
+        return;
     }
 
-    const uint8_t * restrict data_src = (const uint8_t *) src->data;
-    uint8_t * restrict data_dst       = (uint8_t *) dst->data;
+    dma_queue * dma_queue = octx->ctx->dma[ith];
 
-    const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size));
-    float * restrict dst_th       = (float *) (data_dst + (src0_start_row * dst_row_size));
-    uint8_t * restrict spad_th    = (uint8_t *) spad + (ith * nb01);
+    for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
+        const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
 
-    switch (htp_op) {
-        case HTP_OP_RMS_NORM:
-            rms_norm_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
-            break;
-        case HTP_OP_SCALE:
-            scale_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
-            break;
-        case HTP_OP_SQR:
-            sqr_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
-            break;
-        case HTP_OP_SQRT:
-            sqrt_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
-            break;
+        // Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
+        dma_queue_push_vtcm_to_ddr(dma_queue,
+            dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
+            dst_row_size, dst_row_size_aligned, 0);
 
-        default:
-            break;
+        dma_queue_push_ddr_to_vtcm(dma_queue,
+            dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src + (ir * src0_row_size)),
+            src0_row_size_aligned, src0_row_size, block_size);
     }
 
+    for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
+        const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
+
+        float * dst_spad  = (float *) dma_queue_pop(dma_queue).src;
+        float * src0_spad = (float *) dma_queue_pop(dma_queue).dst;
+
+        // Process block in VTCM
+        switch (htp_op) {
+            case HTP_OP_RMS_NORM:
+                rms_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
+                break;
+            case HTP_OP_SCALE:
+                scale_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
+                break;
+            case HTP_OP_SQR:
+                sqr_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
+                break;
+            case HTP_OP_SQRT:
+                sqrt_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
+                break;
+            default:
+                break;
+        }
+
+        dma_queue_push_vtcm_to_ddr(dma_queue,
+            dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad),
+            dst_row_size, dst_row_size_aligned, block_size);
+
+        // prefetch N+2 loop iteration if any
+        const uint32_t pref_block = (ir + BLOCK * 2);
+        if (pref_block < src0_end_row) {
+            const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);
+            dma_queue_push_ddr_to_vtcm(dma_queue,
+                dma_make_ptr(src0_spad, data_src + (pref_block * src0_row_size)),
+                src0_row_size_aligned, src0_row_size, pref_block_size);
+        }
+    }
+
+    dma_queue_flush(dma_queue);
+
     t2 = HAP_perf_get_qtimer_count();
 
-    FARF(HIGH, "unary-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, src->ne[0],
+    FARF(HIGH, "unary-f32 %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, src->ne[0],
          src->ne[1], src->ne[2], src->ne[3], src0_start_row, src0_end_row, dst->ne[0], dst->ne[1], dst->ne[2],
          dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
 }
 
-static void unary_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = (struct htp_ops_context *) data;
-
-    unary_job_f32_per_thread(&octx->src0, &octx->dst, octx->src0_spad.data, octx->op, octx->op_params, n, i,
-                             octx->src0_nrows_per_thread);
-}
-
 static int execute_op_unary_f32(struct htp_ops_context * octx) {
     int err = HTP_STATUS_OK;
 
     const struct htp_tensor * src0 = &octx->src0;
     struct htp_tensor *       dst  = &octx->dst;
 
-    worker_callback_t unary_op_func;
-    const char *      op_type = NULL;
+    const char * op_type = NULL;
 
     switch (octx->op) {
         case HTP_OP_RMS_NORM:
-            unary_op_func = unary_job_dispatcher_f32;
-            op_type       = "rmsnorm-f32";
+            op_type = "rmsnorm-f32";
             break;
         case HTP_OP_SCALE:
-            unary_op_func = unary_job_dispatcher_f32;
-            op_type       = "scale-f32";
+            op_type = "scale-f32";
             break;
         case HTP_OP_SQR:
-            unary_op_func = unary_job_dispatcher_f32;
-            op_type       = "sqr-f32";
+            op_type = "sqr-f32";
             break;
         case HTP_OP_SQRT:
-            unary_op_func = unary_job_dispatcher_f32;
-            op_type       = "sqrt-f32";
+            op_type = "sqrt-f32";
             break;
 
         default:
@@ -294,32 +307,61 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
     const size_t src0_row_size = src0->nb[1];
     const size_t dst_row_size  = dst->nb[1];
 
-    // VTCM scratchpads for all tensors
-    octx->dst_spad.size  = hex_round_up(dst_row_size, 128) * n_threads;
-    octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads;
+    const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
+    const size_t dst_row_size_aligned  = hex_round_up(dst_row_size, VLEN);
 
-    size_t spad_size = octx->src0_spad.size + octx->dst_spad.size;
+    // VTCM scratchpads for all tensors
+    // N rows per thread, padded to HVX vector size
+    // Double buffering requires 2x size per buffer
 
-    FARF(HIGH, "%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type,
-         src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
-         octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
+    size_t spad_size_per_row   = 2 * (src0_row_size_aligned + dst_row_size_aligned);
+    size_t vtcm_row_per_thread = (octx->ctx->vtcm_size)/ (n_threads * spad_size_per_row);
 
     // Make sure the reserved vtcm size is sufficient
-    if (octx->ctx->vtcm_size < spad_size) {
+    if (vtcm_row_per_thread == 0) {
         FARF(ERROR, "unary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
-             spad_size);
+             spad_size_per_row * n_threads);
         return HTP_STATUS_VTCM_TOO_SMALL;
     }
 
+    octx->src0_spad.size_per_thread = src0_row_size_aligned * vtcm_row_per_thread * 2;
+    octx->dst_spad.size_per_thread  = dst_row_size_aligned * vtcm_row_per_thread * 2;
+
+    octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread;
+    octx->dst_spad.size  = n_threads * octx->dst_spad.size_per_thread;
+
     octx->src0_spad.data = octx->ctx->vtcm_base;
     octx->dst_spad.data  = octx->src0_spad.data + octx->src0_spad.size;
 
+    FARF(HIGH, "%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type,
+         src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
+         octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
+
     if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
         uint32_t n_jobs = MIN(n_threads, src0_nrows);
 
-        octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
+        struct htp_unary_context uctx = {
+            .octx                  = octx,
+            .src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs,
+            .src0_nrows            = src0_nrows,
+
+            .data_src0             = (const uint8_t *)src0->data,
+            .data_dst              = (uint8_t *)dst->data,
+
+            .src0_row_size         = src0_row_size,
+            .dst_row_size          = dst_row_size,
+
+            .src0_row_size_aligned = src0_row_size_aligned,
+            .dst_row_size_aligned  = dst_row_size_aligned,
+
+            .src0_spad_half_size   = octx->src0_spad.size_per_thread / 2,
+            .dst_spad_half_size    = octx->dst_spad.size_per_thread / 2,
+
+            .block                 = (octx->src0_spad.size_per_thread / 2) / src0_row_size_aligned,
+            .nc                    = src0->ne[0],
+        };
 
-        worker_pool_run_func(octx->ctx->worker_pool, unary_op_func, octx, n_jobs);
+        worker_pool_run_func(octx->ctx->worker_pool, unary_job_f32_per_thread, &uctx, n_jobs);
     }
 
     return err;
index d19d4e920e229d8af8209f1548bf255b7701c501..dfc051b28b5631975e5db33b69ba96c333ef5a1d 100755 (executable)
@@ -54,6 +54,6 @@ adb $adbserial $adbhost shell " \
     $verbose $experimental $sched $opmask $profile $nhvx $ndev $hb \
       ./$branch/bin/llama-cli --no-mmap -m $basedir/../gguf/$model \
          --poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1           \
-         --ctx-size 8192 --batch-size 128 -fa on \
-         -ngl 99 --device $device $cli_opts $@   \
+         --ctx-size 8192 --ubatch-size 256 -fa on                  \
+         -ngl 99 --device $device $cli_opts $@                     \
 "
index da9df110a0ec03ad3b3d06ac58ab4cd9d14480bb..d53b5887399d5e88b4f0b6b99b36a3b014a9fd69 100755 (executable)
@@ -54,6 +54,6 @@ adb $adbserial $adbhost shell " \
     $verbose $experimental $sched $opmask $profile $nhvx $ndev $hb        \
       ./$branch/bin/llama-completion --no-mmap -m $basedir/../gguf/$model \
          --poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1                  \
-         --ctx-size 8192 --batch-size 128 -fa on \
-         -ngl 99 -no-cnv --device $device $cli_opts $@   \
+         --ctx-size 8192 --ubatch-size 256 -fa on                         \
+         -ngl 99 -no-cnv --device $device $cli_opts $@                    \
 "
index fc018e726927d797e4f1e433db4d5e586f8f3dd4..41d7cd44f8d2359997bf03f85f0b271f55102657 100755 (executable)
@@ -58,11 +58,11 @@ adb $adbserial $adbhost shell " \
   cd $basedir; ulimit -c unlimited;        \
     LD_LIBRARY_PATH=$basedir/$branch/lib   \
     ADSP_LIBRARY_PATH=$basedir/$branch/lib \
-    $verbose $experimental $sched $opmask $profile $nhvx $ndev $mtmd_backend       \
-      ./$branch/bin/llama-mtmd-cli --no-mmap -m $basedir/../gguf/$model   \
-         --mmproj $basedir/../gguf/$mmproj \
-         --image $basedir/../gguf/$image \
-         --poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1             \
-         --ctx-size 8192 --batch-size 128 -ctk q8_0 -ctv q8_0 -fa on \
-         -ngl 99 --device $device -v $cli_opts $@ \
+    $verbose $experimental $sched $opmask $profile $nhvx $ndev $mtmd_backend \
+      ./$branch/bin/llama-mtmd-cli --no-mmap -m $basedir/../gguf/$model      \
+         --mmproj $basedir/../gguf/$mmproj                                   \
+         --image $basedir/../gguf/$image                                     \
+         --poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1                     \
+         --ctx-size 8192 --ubatch-size 256 -fa on                            \
+         -ngl 99 --device $device -v $cli_opts $@                            \
 "
index b13161aa6319a098f34a6563f3b74bd8046a2648..40c7acc430fc20e66c244c92ad2d27e31e40688b 100644 (file)
@@ -49,5 +49,5 @@ $env:ADSP_LIBRARY_PATH="$basedir\lib"
 & "$basedir\bin\llama-completion.exe" `
     --no-mmap -no-cnv -m $basedir\..\..\gguf\$model `
     --poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 `
-    --ctx-size 8192 --batch-size 128 -ctk q8_0 -ctv q8_0 -fa on `
+    --ctx-size 8192 --ubatch-size 128 -fa on `
     -ngl 99 --device $device $cli_opts