]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
hexagon: add f32 ssm_conv op (llama/20122)
authorTodor Boinovski <redacted>
Fri, 6 Mar 2026 17:59:26 +0000 (09:59 -0800)
committerGeorgi Gerganov <redacted>
Mon, 16 Mar 2026 11:10:15 +0000 (13:10 +0200)
* hexagon: add ssm_conv op

* hexagon: hvx kernel is functional

* hexagon: improvements to ssm-conv hvx kernel

* hexagon: added dma to ssm-conv hvx kernel

* hexagon: ssm-conv dynamically compute gather scratchpad

* hex-ssm-conv: add local context and fix various issues (spad indexing, etc)

---------

Co-authored-by: Max Krasnyansky <redacted>
ggml/src/ggml-hexagon/ggml-hexagon.cpp
ggml/src/ggml-hexagon/htp/CMakeLists.txt
ggml/src/ggml-hexagon/htp/htp-msg.h
ggml/src/ggml-hexagon/htp/htp-ops.h
ggml/src/ggml-hexagon/htp/hvx-utils.h
ggml/src/ggml-hexagon/htp/main.c
ggml/src/ggml-hexagon/htp/ssm-conv.c [new file with mode: 0644]

index b70da8f3b28a05cc986a3f19578f45e798afbe88..d6e9776b87831620065347f0998f6e09af2643b6 100644 (file)
@@ -2152,6 +2152,44 @@ static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess
     return true;
 }
 
+static bool ggml_hexagon_supported_ssm_conv(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
+    const struct ggml_tensor * src0 = op->src[0];
+    const struct ggml_tensor * src1 = op->src[1];
+    const struct ggml_tensor * dst  = op;
+
+    // Only support FP32 for now
+    if (src0->type != GGML_TYPE_F32 || src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) {
+        return false;
+    }
+
+    // Check IO tensor shapes and dims
+    if (src0->ne[3] != 1 || src1->ne[2] != 1 || src1->ne[3] != 1 || dst->ne[3] != 1) {
+        return false; // src0 should be effectively 3D
+    }
+
+    const int d_conv = src1->ne[0];
+    const int d_inner = src0->ne[1];
+    const int n_t = dst->ne[1];
+    const int n_s = dst->ne[2];
+
+    if (src0->ne[0] != d_conv - 1 + n_t || src0->ne[1] != d_inner || src0->ne[2] != n_s) {
+        return false;
+    }
+    if (src1->ne[0] != d_conv || src1->ne[1] != d_inner) {
+        return false;
+    }
+    if (dst->ne[0] != d_inner || dst->ne[1] != n_t || dst->ne[2] != n_s) {
+        return false;
+    }
+
+    // TODO: add support for non-contiguous tensors
+    if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
+        return false;
+    }
+
+    return true;
+}
+
 enum dspqbuf_type {
     DSPQBUF_TYPE_DSP_WRITE_CPU_READ = 0,
     DSPQBUF_TYPE_CPU_WRITE_DSP_READ,
@@ -2468,6 +2506,17 @@ static inline size_t init_flash_attn_ext_req(htp_general_req * req, dspqueue_buf
     return n_bufs;
 }
 
+static inline size_t init_ssm_conv_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
+    req->op = HTP_OP_SSM_CONV;
+
+    size_t n_bufs = 0;
+    n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+    n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CONSTANT);
+    n_bufs += htp_req_buff_init(&req->dst,  &bufs[n_bufs], t,         DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
+
+    return n_bufs;
+}
+
 static const char * ggml_backend_hexagon_name(ggml_backend_t backend) {
     auto sess = static_cast<ggml_hexagon_session *>(backend->context);
     return sess->name.c_str();
@@ -2606,6 +2655,10 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
                 ggml_hexagon_dispatch_op<init_argsort_req>(sess, node, flags);
                 break;
 
+            case GGML_OP_SSM_CONV:
+                ggml_hexagon_dispatch_op<init_ssm_conv_req>(sess, node, flags);
+                break;
+
             default:
                 GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node));
         }
@@ -3024,6 +3077,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
             supp = ggml_hexagon_supported_argsort(sess, op);
             break;
 
+        case GGML_OP_SSM_CONV:
+            supp = ggml_hexagon_supported_ssm_conv(sess, op);
+            break;
+
         default:
             break;
     }
index 2c23b60da3d1328875d3a516c411715394edb3fa..02d07a503d50fb405f90b25c5cf91ad84ac090a4 100644 (file)
@@ -31,6 +31,7 @@ add_library(${HTP_LIB} SHARED
     get-rows-ops.c
     cpy-ops.c
     argsort-ops.c
+    ssm-conv.c
 )
 
 target_compile_definitions(${HTP_LIB} PRIVATE
index 25403bb1126538a6c0103a9de3f671077aa45cdd..52dcc36d8f7fb0402e3100e877cbe7c8cffbe563 100644 (file)
@@ -68,6 +68,7 @@ enum htp_op {
     HTP_OP_SQR,
     HTP_OP_SQRT,
     HTP_OP_SUM_ROWS,
+    HTP_OP_SSM_CONV,
     INVALID
 };
 
index 127ab1d66598592aac1765f1ba4ae526be4d84fa..2ef20936f1bb28847c3f53e121e9319f519ea985 100644 (file)
@@ -41,9 +41,6 @@ struct htp_ops_context {
     worker_pool_context_t * wpool;      // worker pool
     uint32_t                n_threads;  // num threads
 
-    uint32_t src0_nrows_per_thread;
-    uint32_t src1_nrows_per_thread;
-
     uint32_t flags;
 };
 
@@ -61,5 +58,6 @@ int op_set_rows(struct htp_ops_context * octx);
 int op_get_rows(struct htp_ops_context * octx);
 int op_cpy(struct htp_ops_context * octx);
 int op_argsort(struct htp_ops_context * octx);
+int op_ssm_conv(struct htp_ops_context * octx);
 
 #endif /* HTP_OPS_H */
index a518ad37331dee7040ea60497bbd9f1b7caf40e4..083437987946dedcded18b07540db43bda6dd7a1 100644 (file)
 #include "hvx-div.h"
 #include "hvx-base.h"
 
+#ifndef GATHER_TYPE
+#    if defined(__hexagon__)
+#        define GATHER_TYPE(_a) (intptr_t) _a
+#    else
+#        define GATHER_TYPE(_a) (HVX_Vector *) _a
+#    endif
+#endif
+
 #endif /* HVX_UTILS_H */
index 92a1422896cb88d61a913aad7491811df4c5e3da..3f99dbb32c478e7235be71a3c9bbb90cbaba7195 100644 (file)
@@ -757,6 +757,47 @@ static void proc_sum_rows_req(struct htp_context * ctx, struct htp_general_req *
     send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
 }
 
+static void proc_ssm_conv_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
+    struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
+
+    // We've written to the output buffer, we'd also need to flush it
+    rsp_bufs[0].fd     = bufs[2].fd;
+    rsp_bufs[0].ptr    = bufs[2].ptr;
+    rsp_bufs[0].offset = bufs[2].offset;
+    rsp_bufs[0].size   = bufs[2].size;
+    rsp_bufs[0].flags  = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |         // Flush HTP
+                         DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);  // Invalidate CPU
+
+    // Setup OP context
+    struct htp_ops_context octx = { 0 };
+    octx.ctx                    = ctx;
+    octx.src0                   = req->src0;
+    octx.src1                   = req->src1;
+    octx.dst                    = req->dst;
+    octx.flags                  = req->flags;
+    octx.op                     = req->op;
+
+    memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
+
+    // Update data pointers
+    octx.src0.data = (uint32_t) bufs[0].ptr;
+    octx.src1.data = (uint32_t) bufs[1].ptr;
+    octx.dst.data  = (uint32_t) bufs[2].ptr;
+    octx.n_threads = ctx->n_threads;
+
+    struct profile_data prof;
+    profile_start(&prof);
+
+    uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
+    if (vtcm_acquire(ctx) == AEE_SUCCESS) {
+        rsp_status = op_ssm_conv(&octx);
+        vtcm_release(ctx);
+    }
+
+    profile_stop(&prof);
+    send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
+}
+
 static void proc_activations_req(struct htp_context *     ctx,
                                  struct htp_general_req * req,
                                  struct dspqueue_buffer * bufs,
@@ -1142,6 +1183,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
                 proc_argsort_req(ctx, &req, bufs);
                 break;
 
+            case HTP_OP_SSM_CONV:
+                if (n_bufs != 3) {
+                    FARF(ERROR, "Bad ssm-conv-req buffer list");
+                    continue;
+                }
+                proc_ssm_conv_req(ctx, &req, bufs);
+                break;
+
             default:
                 FARF(ERROR, "Unknown Op %u", req.op);
                 break;
diff --git a/ggml/src/ggml-hexagon/htp/ssm-conv.c b/ggml/src/ggml-hexagon/htp/ssm-conv.c
new file mode 100644 (file)
index 0000000..b3c1ef9
--- /dev/null
@@ -0,0 +1,339 @@
+#pragma clang diagnostic ignored "-Wunused-variable"
+#pragma clang diagnostic ignored "-Wunused-function"
+#pragma clang diagnostic ignored "-Wunused-but-set-variable"
+
+#include <HAP_farf.h>
+#include <HAP_mem.h>
+#include <HAP_perf.h>
+#include <HAP_ps.h>
+#include <hexagon_protos.h>
+#include <hexagon_types.h>
+#include <math.h>
+#include <qurt_thread.h>
+#include <string.h>
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+#include "htp-ctx.h"
+#include "hex-dma.h"
+#include "htp-msg.h"
+#include "htp-ops.h"
+#include "hvx-utils.h"
+
+#define htp_ssm_conv_tensors_preamble                        \
+    struct htp_tensor * restrict src0    = &octx->src0;      \
+    struct htp_tensor * restrict src1    = &octx->src1;      \
+    struct htp_tensor * restrict dst     = &octx->dst;       \
+    struct htp_spad * restrict src0_spad = &octx->src0_spad; \
+    struct htp_spad * restrict src1_spad = &octx->src1_spad; \
+    struct htp_spad * restrict dst_spad  = &octx->dst_spad;  \
+                                                             \
+    const uint32_t ne00 = src0->ne[0];                       \
+    const uint32_t ne01 = src0->ne[1];                       \
+    const uint32_t ne02 = src0->ne[2];                       \
+    const uint32_t ne03 = src0->ne[3];                       \
+                                                             \
+    const uint32_t ne10 = src1->ne[0];                       \
+    const uint32_t ne11 = src1->ne[1];                       \
+    const uint32_t ne12 = src1->ne[2];                       \
+    const uint32_t ne13 = src1->ne[3];                       \
+                                                             \
+    const uint32_t ne0 = dst->ne[0];                         \
+    const uint32_t ne1 = dst->ne[1];                         \
+    const uint32_t ne2 = dst->ne[2];                         \
+    const uint32_t ne3 = dst->ne[3];                         \
+                                                             \
+    const uint32_t nb00 = src0->nb[0];                       \
+    const uint32_t nb01 = src0->nb[1];                       \
+    const uint32_t nb02 = src0->nb[2];                       \
+    const uint32_t nb03 = src0->nb[3];                       \
+                                                             \
+    const uint32_t nb10 = src1->nb[0];                       \
+    const uint32_t nb11 = src1->nb[1];                       \
+    const uint32_t nb12 = src1->nb[2];                       \
+    const uint32_t nb13 = src1->nb[3];                       \
+                                                             \
+    const uint32_t nb0 = dst->nb[0];                         \
+    const uint32_t nb1 = dst->nb[1];                         \
+    const uint32_t nb2 = dst->nb[2];                         \
+    const uint32_t nb3 = dst->nb[3];
+
+struct htp_ssm_conv_context {
+    struct htp_ops_context * octx;
+    uint32_t nrows_per_thread;
+    uint64_t t_start;
+};
+
+#define htp_ssm_conv_preamble                            \
+    struct htp_ssm_conv_context * scctx = (struct htp_ssm_conv_context *) data; \
+    struct htp_ops_context * octx = scctx->octx;         \
+    htp_ssm_conv_tensors_preamble;                       \
+    dma_queue * dma_queue         = octx->ctx->dma[ith];
+
+// Scalar FP32 SSM_CONV implementation
+static void ssm_conv_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) {
+    htp_ssm_conv_preamble;
+
+    uint64_t t1, t2;
+    t1 = HAP_perf_get_qtimer_count();
+
+    const uint32_t d_conv  = src1->ne[0];
+    const uint32_t d_inner = src0->ne[1];
+    const uint32_t n_t     = dst->ne[1];
+    const uint32_t n_s     = dst->ne[2];
+
+    const uint32_t src0_stride_inner = src0->nb[1] / sizeof(float); // stride for inner dimension
+    const uint32_t src0_stride_seq   = src0->nb[2] / sizeof(float); // stride for sequence dimension
+    const uint32_t src1_stride_inner = src1->nb[1] / sizeof(float); // stride for inner dimension
+    const uint32_t dst_stride_token  = dst->nb[1]  / sizeof(float); // stride for token dimension
+    const uint32_t dst_stride_seq    = dst->nb[2]  / sizeof(float); // stride for sequence dimension
+
+    const float * src0_data = (const float *) src0->data;
+    const float * src1_data = (const float *) src1->data;
+    float *       dst_data  = (float *) dst->data;
+
+    // Calculate row range for this thread
+    const uint32_t d_inner_per_thread = scctx->nrows_per_thread;
+    const uint32_t d_inner_start = d_inner_per_thread * ith;
+    const uint32_t d_inner_end   = MIN(d_inner_start + d_inner_per_thread, d_inner);
+
+    // No work for this thread
+    if (d_inner_start >= d_inner_end) {
+        return;
+    }
+
+    for (uint32_t i3 = 0; i3 < n_s; ++i3) {
+        for (uint32_t i2 = 0; i2 < n_t; ++i2) {
+            for (uint32_t i1 = d_inner_start; i1 < d_inner_end; ++i1) {
+                float sumf = 0.0f;
+
+                for (uint32_t i0 = 0; i0 < d_conv; ++i0) {
+                    const uint32_t src0_idx = (i2 + i0) + i1 * src0_stride_inner + i3 * src0_stride_seq;
+                    const uint32_t src1_idx = i0 + i1 * src1_stride_inner;
+
+                    sumf += src0_data[src0_idx] * src1_data[src1_idx];
+                }
+
+                const uint32_t dst_idx = i1 + i2 * dst_stride_token + i3 * dst_stride_seq;
+                dst_data[dst_idx] = sumf;
+            }
+        }
+    }
+
+    t2 = HAP_perf_get_qtimer_count();
+
+    FARF(HIGH, "ssm-conv-f32 %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n",
+         ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], d_inner_start, d_inner_end,
+         src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1],
+         dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+// HVX FP32 SSM_CONV implementation - vectorizes across d_inner dimension
+static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void *data) {
+    htp_ssm_conv_preamble;
+
+    uint64_t t1, t2;
+    t1 = HAP_perf_get_qtimer_count();
+
+    const int nc  = src1->ne[0]; // d_conv
+    const int ncs = src0->ne[0]; // d_conv - 1 + n_t
+
+    const uint32_t d_conv  = src1->ne[0];
+    const uint32_t d_inner = src0->ne[1];
+    const uint32_t n_t     = dst->ne[1];
+    const uint32_t n_s     = dst->ne[2];
+
+    const float * src0_data = (const float *) src0->data;
+    const float * src1_data = (const float *) src1->data;
+    float *       dst_data  = (float *) dst->data;
+
+    // Calculate row range for this thread
+    const int dr = scctx->nrows_per_thread;
+    const uint32_t ir0 = dr * ith;
+    const uint32_t ir1 = MIN(ir0 + dr, d_inner);
+    const int      ir  = ir1 - ir0;
+
+    if (ir0 >= ir1) {
+        return;  // No work for this thread
+    }
+
+    // src0 and src1 gather offsets
+    uint32_t __attribute__((aligned(VLEN))) src0_offsets[VLEN_FP32] = { 0 };
+    uint32_t __attribute__((aligned(VLEN))) src1_offsets[VLEN_FP32] = { 0 };
+
+    for (uint32_t i = 0; i < VLEN_FP32; ++i) {
+        src0_offsets[i] = i * (ncs)    * sizeof(float);
+        src1_offsets[i] = i * (d_conv) * sizeof(float);
+    }
+
+    const uint32_t src0_gather_len = VLEN * ncs;
+    const uint32_t src1_gather_len = VLEN * d_conv;
+
+    // gather scratchpads
+    HVX_Vector * src0_vec = (HVX_Vector *) (octx->ctx->vtcm_base + ith * VLEN*2 + 0);
+    HVX_Vector * src1_vec = (HVX_Vector *) (octx->ctx->vtcm_base + ith * VLEN*2 + VLEN);
+
+    float * data_src0 = (float *) ((char *) src0->data + ir0 * src0->nb[1]);
+    float * data_src1 = (float *) ((char *) src1->data + ir0 * src1->nb[1]);
+
+    uint8_t * spad_src0 = octx->src0_spad.data + ith * octx->src0_spad.size_per_thread;
+    uint8_t * spad_src1 = octx->src1_spad.data + ith * octx->src1_spad.size_per_thread;
+
+    // copy src1 workload to VTCM
+    dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src1, data_src1), nb11, nb11, ir);
+
+    // FARF(HIGH, "ssm-conv-src1-fetch %d: ir0 %u size %u\n", ith, ir0, nb11 * ir);
+
+    for (uint32_t i3 = 0; i3 < n_s; ++i3) {
+        float * src0_data_ptr = (float *) ((char *) data_src0 + i3 * (src0->nb[2]));
+
+        // copy src0 workload to VTCM
+        dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0, src0_data_ptr), nb01, nb01, ir);
+
+        // FARF(HIGH, "ssm-conv-src0-fetch %d: ir0 %u i3 %u size %u\n", ith, ir0, i3, nb01 * ir);
+
+        dma_queue_flush(dma_queue);
+
+        for (uint32_t i2 = 0; i2 < n_t; ++i2) {
+            float * dst_ptr = (float *) ((char *) dst->data + ir0 * (dst->nb[0]) + i2 * (dst->nb[1]) + i3 * (dst->nb[2]));
+
+            const uint32_t nvec = ir / VLEN_FP32;
+            const uint32_t nloe = ir % VLEN_FP32;
+            uint32_t i1 = 0;
+
+            for (uint32_t vi1 = 0; vi1 < nvec; vi1++) {
+                HVX_Vector acc_vec = Q6_V_vsplat_R(0);
+
+                for (uint32_t i0 = 0; i0 < d_conv; ++i0) {
+                    Q6_vgather_ARMVw(src0_vec, GATHER_TYPE(spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0])),
+                                     src0_gather_len, (*(const HVX_Vector *) src0_offsets));
+                    Q6_vgather_ARMVw(src1_vec, GATHER_TYPE(spad_src1 + (i0 + i1 * nc) * sizeof(float)),
+                                     src1_gather_len, (*(const HVX_Vector *) src1_offsets));
+
+                    HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec);
+                    acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod);
+                }
+
+                *(HVX_UVector *) (dst_ptr + i1) = Q6_Vsf_equals_Vqf32(acc_vec);
+                i1 += VLEN_FP32;
+            }
+
+            if (nloe) {
+                HVX_Vector acc_vec = Q6_V_vsplat_R(0);
+
+                for (uint32_t i0 = 0; i0 < d_conv; ++i0) {
+                    Q6_vgather_ARMVw(src0_vec, GATHER_TYPE(spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0])),
+                                     src0_gather_len, (*(const HVX_Vector *) src0_offsets));
+                    Q6_vgather_ARMVw(src1_vec, GATHER_TYPE(spad_src1 + (i0 + i1 * nc) * sizeof(float)),
+                                     src1_gather_len, (*(const HVX_Vector *) src1_offsets));
+
+                    HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec);
+                    acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod);
+                }
+
+                hvx_vec_store_u(dst_ptr + i1, (ir - i1) * 4, Q6_Vsf_equals_Vqf32(acc_vec));
+            }
+        }
+    }
+
+    t2 = HAP_perf_get_qtimer_count();
+
+    FARF(HIGH, "ssm-conv-f32-hvx %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n",
+         ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0, ir1,
+         src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1],
+         dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+int op_ssm_conv_f32(struct htp_ops_context * octx) {
+    htp_ssm_conv_tensors_preamble;
+
+    if (src0->type != HTP_TYPE_F32 || src1->type != HTP_TYPE_F32 || dst->type != HTP_TYPE_F32) {
+        FARF(ERROR, "ssm_conv: only (F32 x F32 -> F32) OPs supported");
+        return HTP_STATUS_NO_SUPPORT;
+    }
+
+    struct htp_ssm_conv_context scctx = { 0 };
+    scctx.octx = octx;
+
+    const uint32_t d_conv  = src1->ne[0];
+    const uint32_t d_inner = src0->ne[1];
+    const uint32_t n_t     = dst->ne[1];  // tokens per sequence
+    const uint32_t n_s     = dst->ne[2];  // number of sequences in the batch
+
+    const uint32_t n_threads = MIN(octx->n_threads, d_inner);
+
+    if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
+        uint32_t use_hvx = 0;
+        if (d_inner >= VLEN_FP32 && d_inner % VLEN_FP32 == 0) {
+            int is_aligned = hex_is_aligned((void *) src0->data, VLEN) &&
+                             hex_is_aligned((void *) src1->data, VLEN) &&
+                             hex_is_aligned((void *) dst->data, VLEN);
+
+            if (is_aligned) {
+                use_hvx = 1;
+            }
+        }
+
+        if (use_hvx) {
+            scctx.nrows_per_thread  = (d_inner + n_threads - 1) / n_threads; // d_inner chunks per thread
+            scctx.nrows_per_thread += (scctx.nrows_per_thread & 1); // round up to even
+
+            octx->src0_spad.size_per_thread = hex_round_up(scctx.nrows_per_thread * nb01, 256);
+            octx->src1_spad.size_per_thread = hex_round_up(scctx.nrows_per_thread * nb11, 256);
+            octx->dst_spad.size_per_thread  = hex_round_up(scctx.nrows_per_thread * sizeof(float), 256);
+
+            octx->src0_spad.size = octx->src0_spad.size_per_thread * n_threads;
+            octx->src1_spad.size = octx->src1_spad.size_per_thread * n_threads;
+            octx->dst_spad.size  = octx->dst_spad.size_per_thread  * n_threads;
+
+            // Compute gather scratchpad size for src0 and src1
+            const size_t gather_spad_size = n_threads * VLEN * 2;
+
+            octx->src0_spad.data = octx->ctx->vtcm_base + gather_spad_size;
+            octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
+            octx->dst_spad.data  = octx->src1_spad.data + octx->src1_spad.size;
+
+            FARF(HIGH, "ssm_conv-f32: gather-spad:%zu spad-per-thread:(%u:%u:%u) spad-sizes:(%u:%u:%u) spad-data:(%p:%p:%p)\n",
+                gather_spad_size, octx->src0_spad.size_per_thread, octx->src1_spad.size_per_thread,
+                octx->dst_spad.size_per_thread, octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size,
+                octx->src0_spad.data, octx->src1_spad.data, octx->dst_spad.data);
+
+            const size_t total_spad_size =
+                gather_spad_size + octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
+
+            if (total_spad_size > octx->ctx->vtcm_size) {
+                FARF(HIGH, "ssm_conv-f32: HVX scratchpad size %zu exceeds VTCM size %zu", total_spad_size,
+                     octx->ctx->vtcm_size);
+                use_hvx = 0;
+            }
+        }
+
+        FARF(HIGH, "ssm-conv-f32: (%ux%ux%ux%u) x (%ux%ux%ux%u) -> (%ux%ux%ux%u) : use_hvx %d\n", src0->ne[0],
+             src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0],
+             dst->ne[1], dst->ne[2], dst->ne[3], use_hvx);
+
+        if (use_hvx) {
+            worker_pool_run_func(octx->ctx->worker_pool, ssm_conv_thread_f32_f32_hvx, &scctx, n_threads);
+        } else {
+            worker_pool_run_func(octx->ctx->worker_pool, ssm_conv_thread_f32_f32, &scctx, n_threads);
+        }
+    }
+
+    return HTP_STATUS_OK;
+}
+
+int op_ssm_conv(struct htp_ops_context * octx) {
+    int                 err = HTP_STATUS_OK;
+    struct htp_tensor * dst = &octx->dst;
+
+    switch (dst->type) {
+        case HTP_TYPE_F32:
+            err = op_ssm_conv_f32(octx);
+            break;
+        default:
+            err = HTP_STATUS_NO_SUPPORT;
+            break;
+    }
+
+    return err;
+}