]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml: riscv: add riscv spacemit backend (llama/15288)
authoralex-spacemit <redacted>
Mon, 29 Sep 2025 14:50:44 +0000 (22:50 +0800)
committerGeorgi Gerganov <redacted>
Mon, 29 Sep 2025 18:21:06 +0000 (21:21 +0300)
* ggml: add spacemit backend

Change-Id: I249bdc043485d815a9c351867137bc1e27cc2e23

* add new line at end of file

Change-Id: I889ed1c85fb45e62350ecde0c06f70450cadfbe2

* add riscv zba extension limit

Change-Id: I321eb200f859751727afe5cae13074dfce2bb0ce

* fixed for review comments, file renamed and format

Change-Id: Ia20b6ec24a36638e62e0fe07cf100916a7cce3ce

* fixed for code format, after clang-format

Change-Id: I5dc33a0412da3d3f2d77075d8939185d3009eca2

* use _Float16 instead of __fp16

Change-Id: I039fb02bb95270e641bc4442204e658735859d43

* add ci for riscv64-spacemit-ime-native

Change-Id: I711c1033061df1a289ea77891b2997599dfe8279

* update debian-13-riscv64-spacemit-ime-native ci label

Change-Id: Ifb2b891e2fca57b5da604fce2ac255f27731179a

* remove license comment for spacemit ime

Change-Id: If0dc3ca30a958631ccca0a28b62e0b825f9fb0c3

* upgrade binutils for gcc ime

Change-Id: Ibf2fa74c1064408974cb5b45f044d40987e5fb45

* add spacemit ime cross jobs

Change-Id: I80d74909941d41cb9cd09e51d8baf01c985cbfc6

* remove native compile for riscv64-spacemit-ime

Change-Id: I01920afafdc73fa7424014fd648d243f8ec9e25e

* ci : add caching for spacemit ime cross toolchain

Change-Id: Ic54a192019a2fd982bbd58225ce3bbc38f4053de

* ci: bug fixed for cache path and env

Change-Id: I28c42e10b6fff053bb6580926ca2353448cb042a

* Update .github/workflows/build-linux-cross.yml for cache path

Co-authored-by: Sigbjørn Skjæret <redacted>
* bugfixed for  build-linux-cross.yml,  syntax error

Co-authored-by: Sigbjørn Skjæret <redacted>
---------

Co-authored-by: cailinxi <redacted>
Co-authored-by: Sigbjørn Skjæret <redacted>
src/ggml-cpu/CMakeLists.txt
src/ggml-cpu/ggml-cpu.cpp
src/ggml-cpu/spacemit/ime.cpp [new file with mode: 0644]
src/ggml-cpu/spacemit/ime.h [new file with mode: 0644]
src/ggml-cpu/spacemit/ime1_kernels.cpp [new file with mode: 0644]
src/ggml-cpu/spacemit/ime_kernels.h [new file with mode: 0644]

index 369905750754fb48df1fbac3dc1d5296e1766670..50bb9cac92bca8990e8d8c05657887b694e16a54 100644 (file)
@@ -439,6 +439,15 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
             ggml-cpu/arch/riscv/quants.c
             ggml-cpu/arch/riscv/repack.cpp
             )
+        if (GGML_CPU_RISCV64_SPACEMIT)
+            target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_CPU_RISCV64_SPACEMIT ${RISCV64_SPACEMIT_IME_SPEC})
+            list(APPEND GGML_CPU_SOURCES
+                ggml-cpu/spacemit/ime.cpp
+                ggml-cpu/spacemit/ime.h
+                ggml-cpu/spacemit/ime1_kernels.cpp
+                ggml-cpu/spacemit/ime_kernels.h
+            )
+        endif()
         set(MARCH_STR "rv64gc")
         if (GGML_RV_ZFH)
             string(APPEND MARCH_STR "_zfh")
index 81a314e4d68d750512107cdd0e61b26094f2e419..3191faaa4cd92c3bd4b46f2a84dd879db1d348ab 100644 (file)
 #    include "kleidiai/kleidiai.h"
 #endif
 
+#ifdef GGML_USE_CPU_RISCV64_SPACEMIT
+#    include "spacemit/ime.h"
+#endif
+
 #if defined(_WIN32)
 #    define WIN32_LEAN_AND_MEAN
 #    ifndef NOMINMAX
@@ -45,6 +49,12 @@ std::vector<ggml_backend_buffer_type_t> & ggml_backend_cpu_get_extra_buffer_type
         }
 #endif
 
+#ifdef GGML_USE_CPU_RISCV64_SPACEMIT
+        if (ggml_backend_cpu_riscv64_spacemit_buffer_type()) {
+            bufts.push_back(ggml_backend_cpu_riscv64_spacemit_buffer_type());
+        }
+#endif
+
 #ifdef GGML_USE_CPU_KLEIDIAI
         if (ggml_backend_cpu_kleidiai_buffer_type()) {
             bufts.push_back(ggml_backend_cpu_kleidiai_buffer_type());
diff --git a/src/ggml-cpu/spacemit/ime.cpp b/src/ggml-cpu/spacemit/ime.cpp
new file mode 100644 (file)
index 0000000..54d3dec
--- /dev/null
@@ -0,0 +1,1024 @@
+#define GGML_COMMON_IMPL_CPP
+#define GGML_COMMON_DECL_CPP
+
+#include "ime.h"
+
+#include "ggml-backend-impl.h"
+#include "ggml-common.h"
+#include "ggml-cpu.h"
+#include "ime_kernels.h"
+#include "traits.h"
+
+#include <algorithm>
+#include <cassert>
+#include <cmath>
+#include <cstdio>  // for GGML_ASSERT
+#include <stdexcept>
+#include <thread>
+
+// clang-format off
+#if defined(__riscv)
+
+#if !defined(__riscv_v) || !defined(__riscv_v_intrinsic)
+#error "riscv v extension or v_intrinsic not enabled"
+#else
+#include <riscv_vector.h>
+#endif
+
+#if !defined(__riscv_zfh)
+#error "riscv zfh extension not enabled"
+#endif
+
+#if defined(RISCV64_SPACEMIT_IME1)
+#else
+#error "RISCV64_SPACEMIT_IME1 not defined"
+#endif
+
+#else
+
+#error "riscv not enabled in this build"
+
+#endif
+
+#if defined(__GNUC__)
+#pragma GCC diagnostic ignored "-Woverlength-strings"
+#pragma GCC diagnostic ignored "-Wcast-qual"
+#pragma GCC diagnostic ignored "-Wunused-parameter"
+#endif
+
+#if defined(RISCV64_SPACEMIT_IME1)
+#define QGEMM_STRIDEN_THREAD_ALIGN 16
+#else
+#define QGEMM_STRIDEN_THREAD_ALIGN 32
+#endif
+
+// clang-format on
+
+struct qnbitgemm_spacemit_ime_args {
+    const float *     a_ptr               = nullptr;
+    size_t            lda                 = 0;
+    const std::byte * packed_quant_b_data = nullptr;
+    const float *     quant_b_scale       = nullptr;
+    const void *      quant_b_zp          = nullptr;
+    const float *     quant_b_blksum      = nullptr;
+    const float *     bias                = nullptr;
+    float *           c_ptr               = nullptr;
+    size_t            ldc                 = 0;
+};
+
+constexpr size_t div_round_up(size_t up, size_t down) {
+    return (up + down - 1) / down;
+}
+
+constexpr size_t q8_blk_size(size_t blk_len) {
+    const size_t blk_size = sizeof(float) + blk_len * sizeof(int8_t);
+    // Currently, the strictest alignment requirement of a block is for a float.
+    // Ensure contiguous blocks are suitably aligned.
+    assert(blk_size % alignof(float) == 0);
+    return blk_size;
+}
+
+namespace ggml::cpu::riscv64_spacemit {
+
+const int num_ai_cores = std::thread::hardware_concurrency() / 2;
+
+}  // namespace ggml::cpu::riscv64_spacemit
+
+static void sqnbitgemm_spacemit_ime_i8i4(const size_t                        blk_len,
+                                         const size_t                        gemm_k,
+                                         const qnbitgemm_spacemit_ime_args * gemm_args,
+                                         void * const                        per_gemm_ws,
+                                         const size_t                        m_start,
+                                         const size_t                        m_count,
+                                         const size_t                        n_start,
+                                         const size_t                        n_count) {
+    constexpr size_t scale_stride = sizeof(uint16_t);
+    constexpr size_t blk_bitwidth = 4;
+
+    const size_t k_blks = div_round_up(gemm_k, blk_len);
+
+    const size_t      lda         = k_blks * q8_blk_size(blk_len);
+    const size_t      ldc         = gemm_args->ldc;
+    const size_t      ldb         = k_blks * (blk_len * blk_bitwidth / 8);
+    const std::byte * quant_a_ptr = static_cast<const std::byte *>(per_gemm_ws) + m_start * lda;
+
+    const size_t      zero_point_stride   = gemm_args->quant_b_zp != nullptr ? sizeof(uint8_t) : 0;
+    const size_t      packed_b_stride     = ldb + k_blks * (scale_stride + zero_point_stride);
+    const std::byte * packed_quant_b_data = gemm_args->packed_quant_b_data + n_start * packed_b_stride;
+
+    float * c_ptr = gemm_args->c_ptr + m_start * ldc + n_start;
+
+    size_t       count_n               = 0;
+    const size_t compute_block_count_n = m_count == 1 ? n_count : 16;
+    for (size_t n = 0; n < n_count; n += count_n) {
+        count_n = std::min(n_count - n, compute_block_count_n);
+
+        const std::byte * a_row    = quant_a_ptr;
+        const std::byte * b_col    = packed_quant_b_data + n * packed_b_stride;
+        const std::byte * b_col_zp = (zero_point_stride != 0) ? b_col : nullptr;
+        float *           c_blk    = c_ptr + n;
+
+        int32_t rows_remaining = m_count;
+
+        while (rows_remaining > 0) {
+            const auto rows_handled = sqnbitgemm_spacemit_ime::ime1::gemm_kernel_i8i4(
+                blk_len, a_row, b_col, nullptr, b_col_zp, c_blk, rows_remaining, count_n, gemm_k, k_blks, ldc, nullptr,
+                scale_stride);
+
+            c_blk += rows_handled * ldc;
+            a_row += rows_handled * lda;
+
+            rows_remaining -= rows_handled;
+        }
+    }
+}
+
+template <int K> constexpr int QK_0() {
+    if constexpr (K == 4) {
+        return QK4_0;
+    }
+    if constexpr (K == 8) {
+        return QK8_0;
+    }
+    return -1;
+}
+
+template <int K, int N> struct block {
+    ggml_half d[N];                         // deltas for N qK_0 blocks
+    uint8_t   qs[(QK_0<K>() * N * K) / 8];  // quants for N qK_0 blocks
+};
+
+template <int K, int N> struct block_with_zp {
+    ggml_half d[N];                         // deltas for N qK_1 blocks
+    uint8_t   zp[N];                        // zero points for N qK_1 blocks
+    uint8_t   qs[(QK_0<K>() * N * K) / 8];  // quants for N qK_1 blocks
+};
+
+// control size
+static_assert(sizeof(block<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8, "wrong block<4,16> size/padding");
+static_assert(sizeof(block_with_zp<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8 + 16 * sizeof(uint8_t),
+              "wrong block_with_zp<4,16> size/padding");
+static_assert(sizeof(block<8, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 16, "wrong block<8,16> size/padding");
+
+using block_q4_0x16 = block<4, 16>;
+using block_q4_1x16 = block_with_zp<4, 16>;
+using block_q8_0x16 = block<8, 16>;
+
+static block_q4_0x16 make_block_q4_0x16(block_q4_0 * in, unsigned int blck_size_interleave) {
+    block_q4_0x16 out;
+    GGML_ASSERT(QK4_0 / blck_size_interleave == 2);
+
+    for (int i = 0; i < 16; i++) {
+        out.d[i] = in[i].d;
+    }
+
+    for (int i = 0; i < 16; i++) {
+        // [0, 15], in.d & 0x0F
+        for (int j = 0; j < QK4_0 / 4; j++) {
+            //src [b0 b16] ......... [b8 b24] ......... [b15 b31]
+            //dst [b0 b8] ......... [b7 b15]
+            out.qs[i * QK4_0 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_0 / 4] & 0x0F) << 4);
+        }
+    }
+
+    for (int i = 0; i < 16; i++) {
+        // [16, 31], in.d & 0xF0
+        for (int j = 0; j < QK4_0 / 4; j++) {
+            //src [b0 b16] ......... [b8 b24] ......... [b15 b31]
+            //dst [b16 b24] ......... [b23 b31]
+            out.qs[4 * QK4_0 + i * QK4_0 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_0 / 4] & 0xF0);
+        }
+    }
+
+    return out;
+}
+
+static block_q4_1x16 make_block_q4_1x16(block_q4_1 * in, unsigned int blck_size_interleave) {
+    block_q4_1x16 out;
+    GGML_ASSERT(QK4_1 / blck_size_interleave == 2);
+
+    for (int i = 0; i < 16; i++) {
+        float d   = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d);
+        float m   = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m);
+        float mid = -std::nearbyintf(m / d);
+        mid       = std::min(15.0f, std::max(0.0f, mid));
+        out.d[i]  = GGML_FP32_TO_FP16(d);
+        out.zp[i] = static_cast<uint8_t>(mid);
+    }
+
+    for (int i = 0; i < 16; i++) {
+        // [0, 15], in.d & 0x0F
+        for (int j = 0; j < QK4_1 / 4; j++) {
+            //src [b0 b16] ......... [b8 b24] ......... [b15 b31]
+            //dst [b0 b8] ......... [b7 b15]
+            out.qs[i * QK4_1 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_1 / 4] & 0x0F) << 4);
+        }
+    }
+
+    for (int i = 0; i < 16; i++) {
+        // [16, 31], in.d & 0xF0
+        for (int j = 0; j < QK4_1 / 4; j++) {
+            //src [b0 b16] ......... [b8 b24] ......... [b15 b31]
+            //dst [b16 b24] ......... [b23 b31]
+            out.qs[4 * QK4_1 + i * QK4_1 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_1 / 4] & 0xF0);
+        }
+    }
+
+    return out;
+}
+
+static int repack_q4_0_to_q4_0_16_bl(struct ggml_tensor *       t,
+                                     int                        interleave_block,
+                                     const void * GGML_RESTRICT data,
+                                     size_t                     data_size) {
+    GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
+    GGML_ASSERT(interleave_block == 16);
+
+    constexpr int nrows_interleaved = 16;
+
+    block_q4_0x16 *    dst = (block_q4_0x16 *) t->data;
+    const block_q4_0 * src = (const block_q4_0 *) data;
+    block_q4_0         dst_tmp[16];
+    int                nrow    = ggml_nrows(t);
+    int                nblocks = t->ne[0] / QK4_0;
+
+    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
+
+    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) {
+        return -1;
+    }
+
+    for (int b = 0; b < nrow; b += nrows_interleaved) {
+        for (int64_t x = 0; x < nblocks; x++) {
+            for (int i = 0; i < nrows_interleaved; i++) {
+                dst_tmp[i] = src[x + i * nblocks];
+            }
+            *dst++ = make_block_q4_0x16(dst_tmp, interleave_block);
+        }
+        src += nrows_interleaved * nblocks;
+    }
+    return 0;
+
+    GGML_UNUSED(data_size);
+}
+
+static int repack_q4_1_to_q4_1_16_bl(struct ggml_tensor *       t,
+                                     int                        interleave_block,
+                                     const void * GGML_RESTRICT data,
+                                     size_t                     data_size) {
+    GGML_ASSERT(t->type == GGML_TYPE_Q4_1);
+    GGML_ASSERT(interleave_block == 16);
+
+    constexpr int nrows_interleaved = 16;
+
+    block_q4_1x16 *    dst = (block_q4_1x16 *) t->data;
+    const block_q4_1 * src = (const block_q4_1 *) data;
+    block_q4_1         dst_tmp[16];
+    int                nrow    = ggml_nrows(t);
+    int                nblocks = t->ne[0] / QK4_1;
+
+    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_1));
+
+    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_1 != 0) {
+        return -1;
+    }
+
+    for (int b = 0; b < nrow; b += nrows_interleaved) {
+        for (int64_t x = 0; x < nblocks; x++) {
+            for (int i = 0; i < nrows_interleaved; i++) {
+                dst_tmp[i] = src[x + i * nblocks];
+            }
+            *dst++ = make_block_q4_1x16(dst_tmp, interleave_block);
+        }
+        src += nrows_interleaved * nblocks;
+    }
+    return 0;
+
+    GGML_UNUSED(data_size);
+}
+
+static inline void get_scale_min_k4(int                           j,
+                                    const uint8_t * GGML_RESTRICT q,
+                                    uint8_t * GGML_RESTRICT       d,
+                                    uint8_t * GGML_RESTRICT       m) {
+    if (j < 4) {
+        *d = q[j] & 63;
+        *m = q[j + 4] & 63;
+    } else {
+        *d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
+        *m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4);
+    }
+}
+
+static int repack_q4_k_to_q4_1_16_bl(struct ggml_tensor *       t,
+                                     int                        interleave_block,
+                                     const void * GGML_RESTRICT data,
+                                     size_t                     data_size) {
+    GGML_ASSERT(t->type == GGML_TYPE_Q4_K);
+    GGML_ASSERT(interleave_block == 16);
+    GGML_ASSERT(QK_K / QK4_1 == 8);
+
+    constexpr int nrows_interleaved = 16;
+
+    block_q4_1x16 *    dst = (block_q4_1x16 *) t->data;
+    const block_q4_K * src = (const block_q4_K *) data;
+    block_q4_1         dst_tmp[16];
+    int                nrow    = ggml_nrows(t);
+    int                nblocks = t->ne[0] / QK_K;
+
+    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) {
+        return -1;
+    }
+
+    for (int b = 0; b < nrow; b += nrows_interleaved) {
+        for (int64_t x = 0; x < nblocks; x++) {
+            for (int j = 0; j < 8; j++) {
+                for (int i = 0; i < nrows_interleaved; i++) {
+                    uint8_t     sc, m;
+                    const float d = GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d);
+                    const float min =
+                        GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin);
+                    get_scale_min_k4(j, src[x + i * nblocks].scales, &sc, &m);
+                    const float d1 = d * sc;
+                    const float m1 = min * m;
+
+                    dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d = GGML_FP32_TO_FP16(d1);
+                    dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m = GGML_FP32_TO_FP16(-m1);
+                    // src -> [b0, b32] [b1, b33] ... [b31, b63]
+                    // dst -> [b0, b16] [b1, b17] ... [b15, b31] [b32, b48] [b33, b49] ... [b47, b63]
+                    const uint8_t * q                                  = src[x + i * nblocks].qs + (j / 2) * QK4_1;
+                    if (j % 2 == 0) {
+                        for (int ii = 0; ii < 16; ii++) {
+                            dst_tmp[i].qs[ii] = (q[ii] & 0x0F) | ((q[ii + 16] & 0x0F) << 4);
+                        }
+                    } else {
+                        for (int ii = 0; ii < 16; ii++) {
+                            dst_tmp[i].qs[ii] = ((q[ii] & 0xF0) >> 4) | (q[ii + 16] & 0xF0);
+                        }
+                    }
+                }
+                *dst++ = make_block_q4_1x16(dst_tmp, interleave_block);
+            }
+        }
+        src += nrows_interleaved * nblocks;
+    }
+    return 0;
+
+    GGML_UNUSED(data_size);
+}
+
+namespace ggml::cpu::riscv64_spacemit {
+
+template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
+int repack(struct ggml_tensor *, const void *, size_t);
+
+template <> int repack<block_q4_0, 8, 16>(struct ggml_tensor * t, const void * data, size_t data_size) {
+    return repack_q4_0_to_q4_0_16_bl(t, 16, data, data_size);
+}
+
+template <> int repack<block_q4_1, 8, 16>(struct ggml_tensor * t, const void * data, size_t data_size) {
+    return repack_q4_1_to_q4_1_16_bl(t, 16, data, data_size);
+}
+
+template <> int repack<block_q4_K, 8, 16>(struct ggml_tensor * t, const void * data, size_t data_size) {
+    return repack_q4_k_to_q4_1_16_bl(t, 16, data, data_size);
+}
+
+class tensor_traits_base : public ggml::cpu::tensor_traits {
+  public:
+    virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0;
+};
+
+template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_traits : public tensor_traits_base {
+    bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
+        switch (op->op) {
+            case GGML_OP_MUL_MAT:
+                size = ggml_row_size(GGML_TYPE_Q8_0, ggml_nelements(op->src[1])) * 4;
+                size = ((size + QK4_0 - 1) / QK4_0) * (QK4_0 * sizeof(float) + sizeof(float));
+                return true;
+            default:
+                // GGML_ABORT("fatal error");
+                break;
+        }
+        return false;
+    }
+
+    bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override {
+        switch (op->op) {
+            case GGML_OP_MUL_MAT:
+                if (op->src[0]->type == GGML_TYPE_Q4_0 ||  //
+                    op->src[0]->type == GGML_TYPE_Q4_1 ||  //
+                    op->src[0]->type == GGML_TYPE_Q4_K) {
+                    forward_mul_mat_q4(params, op);
+                    return true;
+                }
+            default:
+                // GGML_ABORT("fatal error");
+                break;
+        }
+        return false;
+    }
+
+    void forward_mul_mat_q4(ggml_compute_params * params, ggml_tensor * op) {
+        const ggml_tensor * src0 = op->src[0];
+        const ggml_tensor * src1 = op->src[1];
+        ggml_tensor *       dst  = op;
+
+        GGML_TENSOR_BINARY_OP_LOCALS
+
+        int ith = params->ith;
+        int nth = params->nth;
+
+        [[maybe_unused]] const enum ggml_type type = src0->type;
+
+        void *        w_data  = (void *) src0->data;
+        const float * feature = (const float *) src1->data;
+        float *       output  = (float *) dst->data;
+
+        const size_t                  batch_feature = ne12 * ne13;
+        [[maybe_unused]] const size_t batch_weight  = ne02 * ne03;
+        const size_t                  gemm_m        = ne11;
+        const size_t                  gemm_k        = ne10;
+        const size_t                  gemm_n        = ne01;
+
+        GGML_ASSERT(batch_weight == 1);
+
+        const size_t block_count_k           = div_round_up(gemm_k, QK4_0);
+        const size_t per_gemm_workspace_size = gemm_m * block_count_k * q8_blk_size(QK4_0);
+        const size_t per_gemm_workspace_stride =
+            div_round_up(per_gemm_workspace_size, alignof(uint64_t)) * alignof(uint64_t);
+        const size_t gemm_workspace_size = batch_feature * per_gemm_workspace_stride;
+        const size_t desired_wsize       = gemm_workspace_size + alignof(uint64_t) - 1;
+
+        if (ith == 0 && params->wsize < desired_wsize) {
+            throw std::runtime_error("wsize less than desired_wsize");
+        }
+
+        std::vector<qnbitgemm_spacemit_ime_args> qnbitgemm_args(batch_feature);
+
+        for (size_t i = 0; i < batch_feature; i++) {
+            qnbitgemm_args[i].a_ptr               = feature + gemm_m * gemm_k * i;
+            qnbitgemm_args[i].lda                 = gemm_k;
+            qnbitgemm_args[i].packed_quant_b_data = (const std::byte *) w_data;
+            qnbitgemm_args[i].quant_b_scale       = nullptr;
+
+            if constexpr (std::is_same_v<BLOC_TYPE, block_q4_0>) {
+                qnbitgemm_args[i].quant_b_zp = nullptr;
+            } else {
+                qnbitgemm_args[i].quant_b_zp = w_data;
+            }
+
+            qnbitgemm_args[i].bias  = nullptr;
+            qnbitgemm_args[i].c_ptr = output + gemm_m * gemm_n * i;
+            qnbitgemm_args[i].ldc   = gemm_n;
+        }
+
+        const uintptr_t ws_ptr = reinterpret_cast<uintptr_t>(params->wdata);
+        void *          ws = reinterpret_cast<void *>((ws_ptr + alignof(uint64_t) - 1) & (~(alignof(uint64_t) - 1)));
+        const size_t    quant_a_stride = block_count_k * q8_blk_size(QK4_0);
+
+        {
+            constexpr size_t block_size_m           = 4;
+            size_t           per_gemm_block_count_m = div_round_up(gemm_m, block_size_m);
+            int32_t          task_count             = batch_feature * per_gemm_block_count_m;
+            int32_t          task_per_thread        = (task_count + nth - 1) / nth;
+            int32_t          start                  = ith * task_per_thread;
+            int32_t          end                    = std::min((ith + 1) * task_per_thread, task_count);
+            for (int32_t compute_idx = start; compute_idx < end; compute_idx++) {
+                int32_t                             gemm_idx = compute_idx / block_size_m;
+                int32_t                             m_idx    = compute_idx % block_size_m * block_size_m;
+                const qnbitgemm_spacemit_ime_args & data     = qnbitgemm_args[gemm_idx];
+                int32_t rows_tobe_handled = (gemm_m - m_idx) > block_size_m ? block_size_m : (gemm_m - m_idx);
+
+                if (rows_tobe_handled == block_size_m) {
+                    const float * a_row_ptr = data.a_ptr + m_idx * data.lda;
+                    std::byte *   quant_a_row_ptr =
+                        static_cast<std::byte *>(ws) + gemm_idx * per_gemm_workspace_stride + m_idx * quant_a_stride;
+                    sqnbitgemm_spacemit_ime::ime1::quantize_a_4row_i8(QK4_0, a_row_ptr, gemm_k, quant_a_row_ptr);
+                } else {
+                    while (rows_tobe_handled) {
+                        const float * a_row_ptr       = data.a_ptr + m_idx * data.lda;
+                        std::byte *   quant_a_row_ptr = static_cast<std::byte *>(ws) +
+                                                      gemm_idx * per_gemm_workspace_stride + m_idx * quant_a_stride;
+                        sqnbitgemm_spacemit_ime::ime1::quantize_a_row_i8(QK4_0, a_row_ptr, gemm_k, quant_a_row_ptr);
+                        rows_tobe_handled -= 1;
+                        m_idx += 1;
+                    }
+                }
+            }
+        }
+
+        ggml_barrier(params->threadpool);
+
+        if (ith >= ggml::cpu::riscv64_spacemit::num_ai_cores) {
+            return;
+        }
+        nth = std::min(nth, int{ ggml::cpu::riscv64_spacemit::num_ai_cores });
+
+        size_t           threads_per_gemm = nth / batch_feature;
+        constexpr size_t gemm_m_stride    = 128;
+        size_t           nc               = gemm_n;
+        const size_t     gemm_m_blocked   = div_round_up(gemm_m, gemm_m_stride);
+        const size_t     max_nc           = div_round_up(gemm_n * gemm_m_blocked, threads_per_gemm);
+        if (max_nc < nc) {
+            nc = std::min(nc, div_round_up(max_nc, QGEMM_STRIDEN_THREAD_ALIGN) * QGEMM_STRIDEN_THREAD_ALIGN);
+        }
+        const size_t gemm_n_stride  = nc;
+        const size_t thread_count_m = div_round_up(gemm_m, gemm_m_stride);
+        const size_t thread_count_n = div_round_up(gemm_n, gemm_n_stride);
+        threads_per_gemm            = thread_count_m * thread_count_n;
+
+        {
+            int task_count      = batch_feature * threads_per_gemm;
+            int task_per_thread = (task_count + nth - 1) / nth;
+            int start           = ith * task_per_thread;
+            int end             = std::min((ith + 1) * task_per_thread, task_count);
+            for (int compute_idx = start; compute_idx < end; compute_idx++) {
+                const auto   gemm_i = compute_idx / threads_per_gemm;
+                const auto   blk_i  = compute_idx % threads_per_gemm;
+                const auto * data   = &qnbitgemm_args[gemm_i];
+
+                const auto tid_n = blk_i / thread_count_m;
+                const auto tid_m = blk_i % thread_count_m;
+
+                const size_t m_start = tid_m * gemm_m_stride;
+                const size_t m_count = std::min(gemm_m - m_start, (size_t) gemm_m_stride);
+
+                const size_t n_start = tid_n * gemm_n_stride;
+                const size_t n_count = std::min(gemm_n - n_start, (size_t) gemm_n_stride);
+
+                void * per_gemm_ws = reinterpret_cast<std::byte *>(ws) + gemm_i * per_gemm_workspace_stride;
+
+                sqnbitgemm_spacemit_ime_i8i4(QK4_0, gemm_k, data, per_gemm_ws, m_start, m_count, n_start, n_count);
+            }
+        }
+    }
+
+    int repack(struct ggml_tensor * t, const void * data, size_t data_size) override {
+        GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, ggml_type_name(t->type),
+                       (int) NB_COLS, (int) INTER_SIZE);
+        return ggml::cpu::riscv64_spacemit::repack<BLOC_TYPE, INTER_SIZE, NB_COLS>(t, data, data_size);
+    }
+};
+
+class tensor_traits_common : public tensor_traits_base {
+    bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
+        switch (op->op) {
+            case GGML_OP_NORM:
+            case GGML_OP_RMS_NORM:
+                size = 0;
+                return true;
+            default:
+                // GGML_ABORT("fatal error");
+                break;
+        }
+        return false;
+    }
+
+    bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override {
+        switch (op->op) {
+            case GGML_OP_NORM:
+                forward_norm_f32(params, op);
+                return true;
+            case GGML_OP_RMS_NORM:
+                forward_rms_norm_f32(params, op);
+                return true;
+            default:
+                // GGML_ABORT("fatal error");
+                break;
+        }
+        return false;
+    }
+
+    void forward_norm_f32(ggml_compute_params * params, ggml_tensor * op) {
+        const ggml_tensor * src0 = op->src[0];
+        ggml_tensor *       dst  = op;
+        GGML_ASSERT(ggml_are_same_shape(src0, dst));
+        GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+        const int ith = params->ith;
+        const int nth = params->nth;
+
+        GGML_TENSOR_UNARY_OP_LOCALS
+
+        float epsilon;
+        memcpy(&epsilon, dst->op_params, sizeof(float));
+
+        GGML_ASSERT(epsilon > 0.0f);
+
+        auto * input  = (float *) src0->data;
+        auto * output = (float *) dst->data;
+
+        const auto hidden_size     = ne00;
+        const auto task_count      = ne01 * ne02 * ne03;
+        const auto task_per_thread = (task_count + nth - 1) / nth;
+
+        const auto task_begin = ith * task_per_thread;
+        const auto task_end   = std::min((ith + 1) * task_per_thread, task_count);
+
+        for (auto task_idx = task_begin; task_idx < task_end; task_idx++) {
+            auto   offset  = task_idx * hidden_size;
+            auto * p_input = const_cast<float *>(input + offset);
+
+            auto *       p_output      = output + offset;
+            auto *       p_temp_output = p_output;
+            auto *       p_gamma_data  = (const float *) nullptr;
+            auto *       p_beta_data   = (const float *) nullptr;
+            size_t       gvl           = __riscv_vsetvlmax_e32m4();
+            vfloat32m4_t sum           = __riscv_vfmv_v_f_f32m4(0.f, gvl);
+            vfloat32m4_t sum_sq        = __riscv_vfmv_v_f_f32m4(0.f, gvl);
+            int64_t      length        = hidden_size;
+            while (length > 0) {
+                gvl                   = __riscv_vsetvl_e32m4(length);
+                // load data
+                vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl);
+
+                sum    = __riscv_vfadd_vv_f32m4(sum, src_data, gvl);
+                sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl);
+
+                __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl);
+
+                p_input += gvl;
+                p_temp_output += gvl;
+                length -= gvl;
+            }
+
+            gvl = __riscv_vsetvlmax_e32m1();
+
+            float        mean   = 0.f;
+            vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl);
+            vfloat32m1_t mean_v =
+                __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum, 0), __riscv_vget_v_f32m4_f32m1(sum, 1), gvl);
+            mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 2), gvl);
+            mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 3), gvl);
+            mean_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_v, zero_v, gvl);
+            mean   = __riscv_vfmv_f_s_f32m1_f32(mean_v);
+            mean /= hidden_size;
+
+            vfloat32m1_t mean_square_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0),
+                                                                __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl);
+            mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl);
+            mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl);
+            mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl);
+
+            float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v);
+            mean_square /= hidden_size;
+            mean_square = sqrt(mean_square - mean * mean + epsilon);
+
+            mean_square   = 1.0f / mean_square;
+            length        = hidden_size;
+            p_temp_output = p_output;
+
+            if (p_gamma_data == nullptr && p_beta_data == nullptr) {
+                while (length > 0) {
+                    gvl                   = __riscv_vsetvl_e32m4(length);
+                    vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl);
+                    src_data              = __riscv_vfsub_vf_f32m4(src_data, mean, gvl);
+                    src_data              = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);
+                    __riscv_vse32_v_f32m4(p_output, src_data, gvl);
+                    p_temp_output += gvl;
+                    p_output += gvl;
+                    length -= gvl;
+                }
+            } else if (p_beta_data == nullptr) {
+                while (length > 0) {
+                    gvl                       = __riscv_vsetvl_e32m4(length);
+                    vfloat32m4_t src_data     = __riscv_vle32_v_f32m4(p_temp_output, gvl);
+                    vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl);
+                    src_data                  = __riscv_vfsub_vf_f32m4(src_data, mean, gvl);
+                    src_data                  = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);
+                    src_data                  = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl);
+                    __riscv_vse32_v_f32m4(p_output, src_data, gvl);
+                    p_temp_output += gvl;
+                    p_output += gvl;
+                    p_gamma_data += gvl;
+                    length -= gvl;
+                }
+            } else if (p_gamma_data != nullptr) {
+                while (length > 0) {
+                    gvl                       = __riscv_vsetvl_e32m4(length);
+                    vfloat32m4_t src_data     = __riscv_vle32_v_f32m4(p_temp_output, gvl);
+                    vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl);
+                    src_data                  = __riscv_vfsub_vf_f32m4(src_data, mean, gvl);
+                    src_data                  = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);
+                    src_data                  = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl);
+                    vfloat32m4_t beta_data_v  = __riscv_vle32_v_f32m4(p_beta_data, gvl);
+                    src_data                  = __riscv_vfadd_vv_f32m4(src_data, beta_data_v, gvl);
+                    p_beta_data += gvl;
+                    __riscv_vse32_v_f32m4(p_output, src_data, gvl);
+                    p_temp_output += gvl;
+                    p_output += gvl;
+                    p_gamma_data += gvl;
+                    length -= gvl;
+                }
+            }
+        }
+    }
+
+    void forward_rms_norm_f32(ggml_compute_params * params, ggml_tensor * op) {
+        const ggml_tensor * src0 = op->src[0];
+        ggml_tensor *       dst  = op;
+        GGML_ASSERT(ggml_are_same_shape(src0, dst));
+        GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+        const int ith = params->ith;
+        const int nth = params->nth;
+
+        GGML_TENSOR_UNARY_OP_LOCALS
+
+        float epsilon;
+        memcpy(&epsilon, dst->op_params, sizeof(float));
+
+        GGML_ASSERT(epsilon > 0.0f);
+
+        auto * input  = (float *) src0->data;
+        auto * output = (float *) dst->data;
+
+        const auto hidden_size     = ne00;
+        const auto task_count      = ne01 * ne02 * ne03;
+        const auto task_per_thread = (task_count + nth - 1) / nth;
+
+        const auto task_begin = ith * task_per_thread;
+        const auto task_end   = std::min((ith + 1) * task_per_thread, task_count);
+
+        for (auto task_idx = task_begin; task_idx < task_end; task_idx++) {
+            auto   offset        = task_idx * hidden_size;
+            auto * p_input       = const_cast<float *>(input + offset);
+            auto * p_output      = output + offset;
+            auto * p_temp_output = p_output;
+            auto * p_gamma_data  = (const float *) nullptr;
+            auto * p_beta_data   = (const float *) nullptr;
+
+            size_t       gvl    = __riscv_vsetvlmax_e32m4();
+            // vfloat32m4_t sum = __riscv_vfmv_v_f_f32m4(0.f, gvl);
+            vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl);
+            int64_t      length = hidden_size;
+            while (length > 0) {
+                gvl                   = __riscv_vsetvl_e32m4(length);
+                // load data
+                vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl);
+
+                sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl);
+
+                __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl);
+
+                p_input += gvl;
+                p_temp_output += gvl;
+                length -= gvl;
+            }
+
+            gvl = __riscv_vsetvlmax_e32m1();
+
+            // float mean = 0.f;
+            vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl);
+
+            vfloat32m1_t mean_square_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0),
+                                                                __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl);
+            mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl);
+            mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl);
+            mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl);
+
+            float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v);
+            mean_square /= hidden_size;
+
+            mean_square = sqrt(mean_square + epsilon);
+
+            mean_square   = 1.0f / mean_square;
+            length        = hidden_size;
+            p_temp_output = p_output;
+
+            if (p_gamma_data == nullptr && p_beta_data == nullptr) {
+                while (length > 0) {
+                    gvl                   = __riscv_vsetvl_e32m4(length);
+                    vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl);
+                    src_data              = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);
+                    __riscv_vse32_v_f32m4(p_output, src_data, gvl);
+                    p_temp_output += gvl;
+                    p_output += gvl;
+                    length -= gvl;
+                }
+            } else if (p_beta_data == nullptr) {
+                while (length > 0) {
+                    gvl                       = __riscv_vsetvl_e32m4(length);
+                    vfloat32m4_t src_data     = __riscv_vle32_v_f32m4(p_temp_output, gvl);
+                    vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl);
+                    src_data                  = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);
+                    src_data                  = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl);
+                    __riscv_vse32_v_f32m4(p_output, src_data, gvl);
+                    p_temp_output += gvl;
+                    p_output += gvl;
+                    p_gamma_data += gvl;
+                    length -= gvl;
+                }
+            } else if (p_gamma_data != nullptr) {
+                while (length > 0) {
+                    gvl                       = __riscv_vsetvl_e32m4(length);
+                    vfloat32m4_t src_data     = __riscv_vle32_v_f32m4(p_temp_output, gvl);
+                    vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl);
+                    src_data                  = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);
+                    src_data                  = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl);
+                    vfloat32m4_t beta_data_v  = __riscv_vle32_v_f32m4(p_beta_data, gvl);
+                    src_data                  = __riscv_vfadd_vv_f32m4(src_data, beta_data_v, gvl);
+                    p_beta_data += gvl;
+                    __riscv_vse32_v_f32m4(p_output, src_data, gvl);
+                    p_temp_output += gvl;
+                    p_output += gvl;
+                    p_gamma_data += gvl;
+                    length -= gvl;
+                }
+            }
+        }
+    }
+
+    int repack(struct ggml_tensor * t, const void * data, size_t data_size) override {
+        memcpy(t->data, data, data_size);
+        return 0;
+    }
+};
+
+static const tensor_traits<block_q4_0, 8, 16> q4_0_16x8_q8_0;
+static const tensor_traits<block_q4_1, 8, 16> q4_1_16x8_q8_0;
+static const tensor_traits<block_q4_K, 8, 16> q4_k_16x8_q8_0;
+static const tensor_traits_common             rvv_impl;
+
+}  // namespace ggml::cpu::riscv64_spacemit
+
+static const ggml::cpu::tensor_traits * ggml_riscv64_spacemit_get_optimal_repack_type(const struct ggml_tensor * cur) {
+    if (cur->type == GGML_TYPE_Q4_0) {
+        if (cur->ne[1] % 16 == 0) {
+            return &ggml::cpu::riscv64_spacemit::q4_0_16x8_q8_0;
+        }
+    } else if (cur->type == GGML_TYPE_Q4_1) {
+        if (cur->ne[1] % 16 == 0) {
+            return &ggml::cpu::riscv64_spacemit::q4_1_16x8_q8_0;
+        }
+    } else if (cur->type == GGML_TYPE_Q4_K) {
+        if (cur->ne[1] % 16 == 0) {
+            return &ggml::cpu::riscv64_spacemit::q4_k_16x8_q8_0;
+        }
+    } else if (cur->type == GGML_TYPE_F32) {
+        return &ggml::cpu::riscv64_spacemit::rvv_impl;
+    }
+
+    return nullptr;
+}
+
+static enum ggml_status ggml_backend_riscv64_spacemit_buffer_init_tensor(ggml_backend_buffer_t buffer,
+                                                                         struct ggml_tensor *  tensor) {
+    tensor->extra =
+        (void *) const_cast<ggml::cpu::tensor_traits *>(ggml_riscv64_spacemit_get_optimal_repack_type(tensor));
+
+    GGML_UNUSED(buffer);
+
+    return GGML_STATUS_SUCCESS;
+}
+
+static void ggml_backend_riscv64_spacemit_buffer_set_tensor(ggml_backend_buffer_t buffer,
+                                                            struct ggml_tensor *  tensor,
+                                                            const void *          data,
+                                                            size_t                offset,
+                                                            size_t                size) {
+    GGML_ASSERT(offset == 0);
+    GGML_ASSERT(size == ggml_nbytes(tensor));
+
+    auto tensor_traits = (ggml::cpu::riscv64_spacemit::tensor_traits_base *) tensor->extra;
+    if (tensor_traits) {
+        auto OK = tensor_traits->repack(tensor, data, size);
+        GGML_ASSERT(OK == 0);
+    }
+
+    GGML_UNUSED(buffer);
+}
+
+static const char * ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
+    return "CPU_RISCV64_SPACEMIT";
+
+    GGML_UNUSED(buft);
+}
+
+static ggml_backend_buffer_t ggml_backend_cpu_riscv64_spacemit_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
+                                                                                        size_t size) {
+    ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
+
+    if (buffer == nullptr) {
+        return nullptr;
+    }
+
+    buffer->buft              = buft;
+    buffer->iface.init_tensor = ggml_backend_riscv64_spacemit_buffer_init_tensor;
+    buffer->iface.set_tensor  = ggml_backend_riscv64_spacemit_buffer_set_tensor;
+    buffer->iface.get_tensor  = nullptr;
+    buffer->iface.cpy_tensor  = nullptr;
+    return buffer;
+}
+
+static size_t ggml_backend_cpu_riscv64_spacemit_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
+    return 64;
+
+    GGML_UNUSED(buft);
+}
+
+static size_t ggml_backend_cpu_riscv64_spacemit_nbytes(ggml_backend_buffer_type_t buft,
+                                                       const struct ggml_tensor * tensor) {
+    for (int i = 0; i < GGML_MAX_DIMS; ++i) {
+        if (tensor->ne[i] <= 0) {
+            return 0;
+        }
+    }
+
+    size_t       nbytes;
+    const size_t blck_size = ggml_blck_size(tensor->type);
+    if (blck_size == 1) {
+        nbytes = ggml_type_size(tensor->type);
+        for (int i = 0; i < GGML_MAX_DIMS; ++i) {
+            nbytes += (tensor->ne[i] - 1) * tensor->nb[i];
+        }
+    } else {
+        nbytes = tensor->ne[0] * tensor->nb[0] / blck_size;
+        if (tensor->type == GGML_TYPE_Q4_K) {
+            GGML_ASSERT(nbytes % sizeof(block_q4_K) == 0);
+            nbytes = (nbytes / sizeof(block_q4_K)) * sizeof(block_q4_1) * 8;
+            for (int i = 1; i < GGML_MAX_DIMS; ++i) {
+                nbytes += (tensor->ne[i] - 1) * (tensor->nb[i] / sizeof(block_q4_K)) * sizeof(block_q4_1) * 8;
+            }
+        } else {
+            for (int i = 1; i < GGML_MAX_DIMS; ++i) {
+                nbytes += (tensor->ne[i] - 1) * tensor->nb[i];
+            }
+        }
+    }
+
+    GGML_UNUSED(buft);
+    return nbytes;
+}
+
+namespace ggml::cpu::riscv64_spacemit {
+
+class extra_buffer_type : ggml::cpu::extra_buffer_type {
+    bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
+        switch (op->op) {
+            case GGML_OP_MUL_MAT:
+                if (op->src[0]->buffer && (ggml_n_dims(op->src[0]) == 2) &&
+                    op->src[0]->buffer->buft == ggml_backend_cpu_riscv64_spacemit_buffer_type() &&
+                    ggml_riscv64_spacemit_get_optimal_repack_type(op->src[0])) {
+                    if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
+                        return false;
+                    }
+                    if (op->src[1]->type == GGML_TYPE_F32) {
+                        return true;
+                    }
+                }
+                break;
+            case GGML_OP_NORM:
+            case GGML_OP_RMS_NORM:
+                if (op->src[0]->type == GGML_TYPE_F32) {
+                    return true;
+                }
+                break;
+            default:
+                // GGML_ABORT("fatal error");
+                break;
+        }
+        return false;
+    }
+
+    ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
+        switch (op->op) {
+            case GGML_OP_MUL_MAT:
+                if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_riscv64_spacemit_buffer_type()) {
+                    return (ggml::cpu::tensor_traits *) op->src[0]->extra;
+                }
+                break;
+            case GGML_OP_NORM:
+            case GGML_OP_RMS_NORM:
+                return (ggml::cpu::tensor_traits *) (&ggml::cpu::riscv64_spacemit::rvv_impl);
+            default:
+                // GGML_ABORT("fatal error");
+                break;
+        }
+
+        return nullptr;
+    }
+};
+
+}  // namespace ggml::cpu::riscv64_spacemit
+
+ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void) {
+    static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_riscv64_spacemit = {
+  /* .iface    = */
+        {
+         /* .get_name         = */ ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name,
+         /* .alloc_buffer     = */ ggml_backend_cpu_riscv64_spacemit_buffer_type_alloc_buffer,
+         /* .get_alignment    = */ ggml_backend_cpu_riscv64_spacemit_buffer_type_get_alignment,
+         /* .get_max_size     = */ nullptr,
+         /* .get_alloc_size   = */ ggml_backend_cpu_riscv64_spacemit_nbytes,
+         /* .is_host          = */ nullptr,
+         },
+ /* .device  = */
+        ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
+ /* .context = */
+        new ggml::cpu::riscv64_spacemit::extra_buffer_type(),
+    };
+
+    return &ggml_backend_cpu_buffer_type_riscv64_spacemit;
+}
diff --git a/src/ggml-cpu/spacemit/ime.h b/src/ggml-cpu/spacemit/ime.h
new file mode 100644 (file)
index 0000000..800d91a
--- /dev/null
@@ -0,0 +1,13 @@
+#pragma once
+
+#include "ggml-alloc.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void);
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/src/ggml-cpu/spacemit/ime1_kernels.cpp b/src/ggml-cpu/spacemit/ime1_kernels.cpp
new file mode 100644 (file)
index 0000000..cbbb6cd
--- /dev/null
@@ -0,0 +1,3196 @@
+#include "ggml.h"
+#include "ime_kernels.h"
+
+#include <algorithm>
+#include <cmath>
+
+// clang-format off
+#if defined(__GNUC__)
+#pragma GCC diagnostic ignored "-Woverlength-strings"
+#pragma GCC diagnostic ignored "-Wcast-qual"
+#pragma GCC diagnostic ignored "-Wunused-parameter"
+#endif
+// clang-format on
+namespace sqnbitgemm_spacemit_ime {
+
+#define QUANTIZEM4ROW_KERNEL                           \
+    "vmv.s.x            v16, zero                \n\t" \
+    "vfabs.v            v8, v0                   \n\t" \
+    "vfredmax.vs        v16, v8, v16             \n\t" \
+    "vfmv.f.s           f10, v16                 \n\t" \
+    "fmul.s             f10, f10, %[RMAXREC]     \n\t" \
+    "fsw                f10, (a1)                \n\t" \
+    "fdiv.s             f11, %[FONE], f10        \n\t" \
+    "vfmul.vf           v16, v0, f11             \n\t" \
+    "vfcvt.x.f.v        v16, v16                 \n\t" \
+    "vsetvli            t0, zero, e16, mf2       \n\t" \
+    "vnclip.wx          v16, v16, zero           \n\t" \
+    "vnclip.wx          v17, v17, zero           \n\t" \
+    "vnclip.wx          v18, v18, zero           \n\t" \
+    "vnclip.wx          v19, v19, zero           \n\t" \
+    "vnclip.wx          v20, v20, zero           \n\t" \
+    "vnclip.wx          v21, v21, zero           \n\t" \
+    "vnclip.wx          v22, v22, zero           \n\t" \
+    "vnclip.wx          v23, v23, zero           \n\t" \
+    "vsetvli            t0, zero, e8, mf4        \n\t" \
+    "vnclip.wx          v24, v16, zero           \n\t" \
+    "vnclip.wx          v25, v17, zero           \n\t" \
+    "vnclip.wx          v26, v18, zero           \n\t" \
+    "vnclip.wx          v27, v19, zero           \n\t" \
+    "vnclip.wx          v28, v20, zero           \n\t" \
+    "vnclip.wx          v29, v21, zero           \n\t" \
+    "vnclip.wx          v30, v22, zero           \n\t" \
+    "vnclip.wx          v31, v23, zero           \n\t"
+
+#define QUANTIZEM4ROW_STORE                            \
+    "addi               t1, %[BlkLen], 0         \n\t" \
+    "vsetvli            t0, t1, e8, mf4          \n\t" \
+    "vse8.v             v24, (s1)                \n\t" \
+    "addi               s1, s1, 32               \n\t" \
+    "sub                t1, t1, t0               \n\t" \
+    "vsetvli            t0, t1, e8, mf4          \n\t" \
+    "vse8.v             v25, (s1)                \n\t" \
+    "addi               s1, s1, 32               \n\t" \
+    "sub                t1, t1, t0               \n\t" \
+    "vsetvli            t0, t1, e8, mf4          \n\t" \
+    "vse8.v             v26, (s1)                \n\t" \
+    "addi               s1, s1, 32               \n\t" \
+    "sub                t1, t1, t0               \n\t" \
+    "vsetvli            t0, t1, e8, mf4          \n\t" \
+    "vse8.v             v27, (s1)                \n\t" \
+    "addi               s1, s1, 32               \n\t" \
+    "sub                t1, t1, t0               \n\t" \
+    "vsetvli            t0, t1, e8, mf4          \n\t" \
+    "vse8.v             v28, (s1)                \n\t" \
+    "addi               s1, s1, 32               \n\t" \
+    "sub                t1, t1, t0               \n\t" \
+    "vsetvli            t0, t1, e8, mf4          \n\t" \
+    "vse8.v             v29, (s1)                \n\t" \
+    "addi               s1, s1, 32               \n\t" \
+    "sub                t1, t1, t0               \n\t" \
+    "vsetvli            t0, t1, e8, mf4          \n\t" \
+    "vse8.v             v30, (s1)                \n\t" \
+    "addi               s1, s1, 32               \n\t" \
+    "sub                t1, t1, t0               \n\t" \
+    "vsetvli            t0, t1, e8, mf4          \n\t" \
+    "vse8.v             v31, (s1)                \n\t"
+
+namespace ime1 {
+void quantize_a_4row_i8(size_t BlkLen, const float * A, size_t CountK, std::byte * QuantA) {
+    constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1);
+    const float     fone                 = 1.0f;
+
+    if (BlkLen == 16 || BlkLen == 32 || BlkLen == 64) {
+        for (size_t row_index = 0; row_index < 4; ++row_index) {
+            const float * SRC = A + row_index * CountK;
+            std::byte *   DST = QuantA + row_index * sizeof(float);
+
+            const size_t offset = (4 - row_index) * 4 + row_index * 8;
+            const size_t stride = 4 * (sizeof(float) + BlkLen);
+            __asm__ volatile(
+                "vsetvli            t0, zero, e32, m8        \n\t"
+                "addi               t2, %[CountK], 0         \n\t"
+                "addi               a1, %[DST], 0            \n\t"
+                "blt                t2, %[BlkLen], TAIL%=    \n\t"
+
+                "LOOP%=:                                     \n\t"
+                "vsetvli            t0, %[BlkLen], e32, m8   \n\t"
+                "vle32.v            v0, (%[SRC])             \n\t"
+                "sub                t2, t2, t0               \n\t"
+                "slli               t1, t0, 2                \n\t"
+                "add                %[SRC], %[SRC], t1       \n\t"
+                "add                s1, a1, %[OFFSET]        \n\t"
+
+                QUANTIZEM4ROW_KERNEL QUANTIZEM4ROW_STORE
+
+                "add                a1, a1, %[STRIDE]        \n\t"
+                "bge                t2, %[BlkLen], LOOP%=    \n\t"
+
+                "TAIL%=:                                     \n\t"
+                "blez               t2, QUIT%=               \n\t"
+                "vsetvli            t0, zero, e32, m8        \n\t"
+                "vxor.vv            v16, v16, v16            \n\t"
+                "vxor.vv            v24, v24, v24            \n\t"
+                "vsetvli            t0, t2, e32, m8          \n\t"
+                "vle32.v            v0, (%[SRC])             \n\t"
+                "add                s1, a1, %[OFFSET]        \n\t"
+
+                QUANTIZEM4ROW_KERNEL
+
+                "addi               t3, %[BlkLen], 0         \n\t"
+                "addi               s2, s1, 0                \n\t"
+                "vsetvli            t0, zero, e8, mf4        \n\t"
+                "vxor.vv            v8, v8, v8               \n\t"
+                "SET_ZERO%=:                                 \n\t"
+                "vse8.v             v8, (s2)                 \n\t"
+                "addi               s2, s2, 32               \n\t"
+                "addi               t3, t3, -8               \n\t"
+                "bnez               t3, SET_ZERO%=           \n\t"
+
+                QUANTIZEM4ROW_STORE
+
+                "QUIT%=:                                     \n\t"
+                : [SRC] "+r"(SRC)
+                : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride),
+                  [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal)
+                : "cc", "t0", "t1", "t2", "t3", "a1", "s1", "s2", "f10", "f11");
+        }
+    } else if (BlkLen == 128) {
+        for (size_t row_index = 0; row_index < 4; ++row_index) {
+            const float * SRC = A + row_index * CountK;
+            std::byte *   DST = QuantA + row_index * sizeof(float);
+
+            const size_t offset = (4 - row_index) * 4 + row_index * 8;
+            const size_t stride = 4 * (sizeof(float) + BlkLen);
+            __asm__ volatile(
+                "vsetvli            t0, zero, e32, m8        \n\t"
+                "li                 t6, 32                   \n\t"
+                "addi               t2, %[CountK], 0         \n\t"
+                "addi               a1, %[DST], 0            \n\t"
+                "add                s1, a1, %[OFFSET]        \n\t"
+                "blt                t2, %[BlkLen], TAIL%=    \n\t"
+
+                "LOOP%=:                                     \n\t"
+                "vsetvli            t0, zero, e32, m8        \n\t"
+                "vle32.v            v0, (%[SRC])             \n\t"
+                "addi               %[SRC], %[SRC], 256      \n\t"
+                "vle32.v            v8, (%[SRC])             \n\t"
+                "addi               %[SRC], %[SRC], 256      \n\t"
+                "addi               t2, t2, -128             \n\t"
+
+                "QUANTIZE%=:                                 \n\t"
+                "add                s1, a1, %[OFFSET]        \n\t"
+                "vfabs.v            v16, v0                  \n\t"
+                "vfabs.v            v24, v8                  \n\t"
+                "vfmax.vv           v16, v24, v16            \n\t"
+                "vfredmax.vs        v24, v16, v24            \n\t"
+                "vfmv.f.s           f10, v24                 \n\t"
+                "fmul.s             f10, f10, %[RMAXREC]     \n\t"
+                "fsw                f10, (a1)                \n\t"
+                "fdiv.s             f11, %[FONE], f10        \n\t"
+                "vfmul.vf           v16, v0, f11             \n\t"
+                "vfmul.vf           v24, v8, f11             \n\t"
+                "vfcvt.x.f.v        v16, v16                 \n\t"
+                "vfcvt.x.f.v        v24, v24                 \n\t"
+                "vsetvli            t0, zero, e16, m4        \n\t"
+                "vnclip.wx          v16, v16, zero           \n\t"
+                "vnclip.wx          v20, v24, zero           \n\t"
+                "vsetvli            t0, zero, e8, m4         \n\t"
+                "vnclip.wx          v16, v16, zero           \n\t"
+                "vsetvli            t0, zero, e64, m4        \n\t"
+                "vsse64.v           v16, (s1), t6            \n\t"
+                "add                a1, a1, %[STRIDE]        \n\t"
+                "bge                t2, %[BlkLen], LOOP%=    \n\t"
+
+                "TAIL%=:                                     \n\t"
+                "blez               t2, QUIT%=               \n\t"
+                "vsetvli            t0, zero, e32, m8        \n\t"
+                "vxor.vv             v0, v0, v0              \n\t"
+                "vxor.vv             v8, v8, v8              \n\t"
+                "vxor.vv             v16, v16, v16           \n\t"
+                "vxor.vv             v24, v24, v24           \n\t"
+                "vsetvli            t0, t2, e32, m8          \n\t"
+                "sub                t2, t2, t0               \n\t"
+                "vle32.v            v0, (%[SRC])             \n\t"
+                "addi               %[SRC], %[SRC], 256      \n\t"
+                "vsetvli            t0, t2, e32, m8          \n\t"
+                "vle32.v            v8, (%[SRC])             \n\t"
+                "sub                t2, t2, t2               \n\t"
+                "vsetvli            t0, zero, e32, m8        \n\t"
+                "jal                x0, QUANTIZE%=           \n\t"
+
+                "QUIT%=:                                     \n\t"
+                : [SRC] "+r"(SRC)
+                : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride),
+                  [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal)
+                : "cc", "t0", "t1", "t2", "t6", "a1", "s1", "s2", "f10", "f11");
+        }
+    } else if (BlkLen == 256) {
+        for (size_t row_index = 0; row_index < 4; ++row_index) {
+            const float * SRC    = A + row_index * CountK;
+            std::byte *   DST    = QuantA + row_index * sizeof(float);
+            const size_t  offset = (4 - row_index) * 4 + row_index * 8;
+            const size_t  stride = 4 * (sizeof(float) + BlkLen);
+            __asm__ volatile(
+                "vsetvli            t0, zero, e32, m8        \n\t"
+                "li                 t6, 32                   \n\t"
+                "addi               t2, %[CountK], 0         \n\t"
+                "addi               a1, %[DST], 0            \n\t"
+                "add                s1, a1, %[OFFSET]        \n\t"
+                "blt                t2, %[BlkLen], TAIL%=    \n\t"
+
+                "LOOP%=:                                     \n\t"
+                "vsetvli            t0, zero, e32, m8        \n\t"
+                "vle32.v            v0, (%[SRC])             \n\t"
+                "addi               %[SRC], %[SRC], 256      \n\t"
+                "vle32.v            v8, (%[SRC])             \n\t"
+                "addi               %[SRC], %[SRC], 256      \n\t"
+                "vle32.v            v16, (%[SRC])            \n\t"
+                "addi               %[SRC], %[SRC], 256      \n\t"
+                "vle32.v            v24, (%[SRC])            \n\t"
+                "addi               %[SRC], %[SRC], -768     \n\t"
+                "addi               t2, t2, -256             \n\t"
+                "vfabs.v            v0, v0                   \n\t"
+                "vfabs.v            v8, v8                   \n\t"
+                "vfabs.v            v16, v16                 \n\t"
+                "vfabs.v            v24, v24                 \n\t"
+                "vfmax.vv           v8, v0, v8               \n\t"
+                "vfmax.vv           v24, v24, v16            \n\t"
+                "vfmax.vv           v8, v8, v24              \n\t"
+                "vfredmax.vs        v24, v8, v24             \n\t"
+                "vfmv.f.s           f10, v24                 \n\t"
+                "vle32.v            v0, (%[SRC])             \n\t"
+                "addi               %[SRC], %[SRC], 256      \n\t"
+                "vle32.v            v8, (%[SRC])             \n\t"
+                "addi               %[SRC], %[SRC], 256      \n\t"
+                "vle32.v            v16, (%[SRC])            \n\t"
+                "addi               %[SRC], %[SRC], 256      \n\t"
+                "vle32.v            v24, (%[SRC])            \n\t"
+                "addi               %[SRC], %[SRC], 256      \n\t"
+
+                "QUANTIZE%=:                                 \n\t"
+                "add                s1, a1, %[OFFSET]        \n\t"
+                "fmul.s             f10, f10, %[RMAXREC]     \n\t"
+                "fsw                f10, (a1)                \n\t"
+                "fdiv.s             f11, %[FONE], f10        \n\t"
+                "vfmul.vf           v0, v0, f11              \n\t"
+                "vfmul.vf           v8, v8, f11              \n\t"
+                "vfmul.vf           v16, v16, f11            \n\t"
+                "vfmul.vf           v24, v24, f11            \n\t"
+                "vfcvt.x.f.v        v0, v0                   \n\t"
+                "vfcvt.x.f.v        v8, v8                   \n\t"
+                "vfcvt.x.f.v        v16, v16                 \n\t"
+                "vfcvt.x.f.v        v24, v24                 \n\t"
+                "vsetvli            t0, zero, e16, m4        \n\t"
+                "vnclip.wx          v0, v0, zero             \n\t"
+                "vnclip.wx          v4, v8, zero             \n\t"
+                "vnclip.wx          v8, v16, zero            \n\t"
+                "vnclip.wx          v12, v24, zero           \n\t"
+                "vsetvli            t0, zero, e8, m4         \n\t"
+                "vnclip.wx          v0, v0, zero             \n\t"
+                "vnclip.wx          v4, v8, zero             \n\t"
+                "vsetvli            t0, zero, e64, m8        \n\t"
+                "vsse64.v           v0, (s1), t6             \n\t"
+                "add                a1, a1, %[STRIDE]        \n\t"
+                "bge                t2, %[BlkLen], LOOP%=    \n\t"
+
+                "TAIL%=:                                     \n\t"
+                "blez               t2, QUIT%=               \n\t"
+                "vsetvli            t0, zero, e32, m8        \n\t"
+                "vxor.vv            v0, v0, v0               \n\t"
+                "vxor.vv            v8, v8, v8               \n\t"
+                "vxor.vv            v16, v16, v16            \n\t"
+                "vxor.vv            v24, v24, v24            \n\t"
+                "addi               t1, t2, 0                \n\t"
+                "vsetvli            t0, t1, e32, m8          \n\t"
+                "sub                t1, t1, t0               \n\t"
+                "vle32.v            v0, (%[SRC])             \n\t"
+                "addi               %[SRC], %[SRC], 256      \n\t"
+                "vsetvli            t0, t1, e32, m8          \n\t"
+                "sub                t1, t1, t0               \n\t"
+                "vle32.v            v8, (%[SRC])             \n\t"
+                "addi               %[SRC], %[SRC], 256      \n\t"
+                "vsetvli            t0, t1, e32, m8          \n\t"
+                "sub                t1, t1, t0               \n\t"
+                "vle32.v            v16, (%[SRC])            \n\t"
+                "addi               %[SRC], %[SRC], 256      \n\t"
+                "vsetvli            t0, t1, e32, m8          \n\t"
+                "vle32.v            v24, (%[SRC])            \n\t"
+                "addi               %[SRC], %[SRC], -768     \n\t"
+                "vsetvli            t0, zero, e32, m8        \n\t"
+                "vfabs.v            v0, v0                   \n\t"
+                "vfabs.v            v8, v8                   \n\t"
+                "vfabs.v            v16, v16                 \n\t"
+                "vfabs.v            v24, v24                 \n\t"
+                "vfmax.vv           v8, v0, v8               \n\t"
+                "vfmax.vv           v24, v16, v24            \n\t"
+                "vfmax.vv           v8, v8, v24              \n\t"
+                "vfredmax.vs        v24, v8, v24             \n\t"
+                "vfmv.f.s           f10, v24                 \n\t"
+                "add                s1, a1, %[OFFSET]        \n\t"
+                "fmul.s             f10, f10, %[RMAXREC]     \n\t"
+                "fsw                f10, (a1)                \n\t"
+                "fdiv.s             f11, %[FONE], f10        \n\t"
+                "vsetvli            t0, zero, e64, m8        \n\t"
+                "vxor.vv            v0, v0, v0               \n\t"
+                "vsse64.v           v0, (s1), t6             \n\t"
+
+                "TAIL_LOOP%=:                                \n\t"
+                "vsetvli            t0, zero, e32, m4        \n\t"
+                "vxor.vv            v0, v0, v0               \n\t"
+                "vsetvli            t0, t2, e32, m1          \n\t"
+                "sub                t2, t2, t0               \n\t"
+                "vle32.v            v0, (%[SRC])             \n\t"
+                "addi               %[SRC], %[SRC], 32       \n\t"
+                "vfmul.vf           v1, v0, f11              \n\t"
+                "vfcvt.x.f.v        v2, v1                   \n\t"
+                "vsetvli            t0, zero, e16, mf2       \n\t"
+                "vnclip.wx          v3, v2, zero             \n\t"
+                "vsetvli            t0, zero, e8, mf4        \n\t"
+                "vnclip.wx          v3, v3, zero             \n\t"
+                "vse8.v             v3, (s1)                 \n\t"
+                "addi               s1, s1, 32               \n\t"
+                "bnez               t2, TAIL_LOOP%=          \n\t"
+
+                "QUIT%=:                                     \n\t"
+                : [SRC] "+r"(SRC)
+                : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride),
+                  [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal)
+                : "cc", "t0", "t1", "t2", "t6", "a1", "s1", "s2", "f10", "f11");
+        }
+    }
+}
+
+void quantize_a_row_i8(size_t BlkLen, const float * A, size_t CountK, std::byte * QuantA) {
+    const float *   SRC                  = A;
+    std::byte *     DST                  = QuantA;
+    constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1);
+    const float     fone                 = 1.0f;
+    std::byte *     QuantA_offset        = QuantA + CountK + 4 * ((CountK + BlkLen - 1) / BlkLen);
+    size_t          offset               = (CountK + BlkLen - 1) / BlkLen * BlkLen - CountK;
+
+    if (CountK <= BlkLen) {
+        float max_abs_A = 0.0f;
+        for (size_t k = 0; k < CountK; k++) {
+            max_abs_A = std::max(max_abs_A, fabsf(A[k]));
+        }
+        float scale_A = max_abs_A * range_max_reciprocal;
+
+        ((float *) QuantA)[0] = scale_A;
+
+        auto * QuantAData_offset = (int8_t *) (QuantA + sizeof(float));
+
+        for (size_t k = 0; k < CountK; k++) {
+            QuantAData_offset[k] =
+                (int8_t) std::clamp(roundf(A[k] / scale_A), (float) std::numeric_limits<int8_t>::lowest(),
+                                    (float) std::numeric_limits<int8_t>::max());
+        }
+        for (size_t k = CountK; k < BlkLen; k++) {
+            QuantAData_offset[k] = 0;
+        }
+
+        return;
+    }
+
+    if (BlkLen != 32 || BlkLen != 64 || BlkLen != 128) {
+        __asm__ volatile(
+            "vsetvli      t0, zero, e8, m8        \n\t"
+            "vxor.vv      v24, v24, v24           \n\t"
+            "LOOP%=:                              \n\t"
+            "vsetvli      t0, %[CNT], e8, m8      \n\t"
+            "vse8.v       v24, (%[DST])           \n\t"
+            "addi         %[DST], %[DST], 128     \n\t"
+            "sub          %[CNT], %[CNT], t0      \n\t"
+            "bnez         %[CNT], LOOP%=          \n\t"
+            : [DST] "+r"(QuantA_offset), [CNT] "+r"(offset)
+            :
+            : "cc", "t0");
+    }
+    if (BlkLen == 16) {
+        float buffer[64] = { 0.0f };
+        __asm__ volatile(
+            "addi         t3, zero, 16*8          \n\t"
+            "addi         t2, zero, 16            \n\t"
+            "blt          %[K], t3, LOOP_K%=      \n\t"
+            "blt          %[K], t2, TAIL%=        \n\t"
+            "LOOP_MAIN%=:                         \n\t"
+            "vsetvli      t1, zero, e32, m2       \n\t"
+            "addi         %[K], %[K], -128        \n\t"
+            "vle32.v      v0, (%[SRC])            \n\t"
+            "addi         %[SRC], %[SRC], 64      \n\t"
+            "vle32.v      v2, (%[SRC])            \n\t"
+            "addi         %[SRC], %[SRC], 64      \n\t"
+            "vle32.v      v4, (%[SRC])            \n\t"
+            "addi         %[SRC], %[SRC], 64      \n\t"
+            "vle32.v      v6, (%[SRC])            \n\t"
+            "addi         %[SRC], %[SRC], 64      \n\t"
+            "vle32.v      v8, (%[SRC])            \n\t"
+            "addi         %[SRC], %[SRC], 64      \n\t"
+            "vle32.v      v10, (%[SRC])           \n\t"
+            "addi         %[SRC], %[SRC], 64      \n\t"
+            "vle32.v      v12, (%[SRC])           \n\t"
+            "addi         %[SRC], %[SRC], 64      \n\t"
+            "vle32.v      v14, (%[SRC])           \n\t"
+            "addi         %[SRC], %[SRC], 64      \n\t"
+            "addi         a1, %[BUFFER], 0        \n\t"
+            "vfabs.v      v16, v0                 \n\t"
+            "vfabs.v      v18, v2                 \n\t"
+            "vfabs.v      v20, v4                 \n\t"
+            "vfabs.v      v22, v6                 \n\t"
+            "vfabs.v      v24, v8                 \n\t"
+            "vfabs.v      v26, v10                \n\t"
+            "vfabs.v      v28, v12                \n\t"
+            "vfabs.v      v30, v14                \n\t"
+            "vsetvli      t0, zero, e32, m1       \n\t"
+            "vfmax.vv     v16, v16, v17           \n\t"
+            "vfmax.vv     v18, v18, v19           \n\t"
+            "vfmax.vv     v20, v20, v21           \n\t"
+            "vfmax.vv     v22, v22, v23           \n\t"
+            "vfmax.vv     v24, v24, v25           \n\t"
+            "vfmax.vv     v26, v26, v27           \n\t"
+            "vfmax.vv     v28, v28, v29           \n\t"
+            "vfmax.vv     v30, v30, v31           \n\t"
+            "vse32.v      v16, (a1)               \n\t"
+            "addi         a1, a1, 32              \n\t"
+            "vse32.v      v18, (a1)               \n\t"
+            "addi         a1, a1, 32              \n\t"
+            "vse32.v      v20, (a1)               \n\t"
+            "addi         a1, a1, 32              \n\t"
+            "vse32.v      v22, (a1)               \n\t"
+            "addi         a1, a1, 32              \n\t"
+            "vse32.v      v24, (a1)               \n\t"
+            "addi         a1, a1, 32              \n\t"
+            "vse32.v      v26, (a1)               \n\t"
+            "addi         a1, a1, 32              \n\t"
+            "vse32.v      v28, (a1)               \n\t"
+            "addi         a1, a1, 32              \n\t"
+            "vse32.v      v30, (a1)               \n\t"
+            "addi         a1, %[BUFFER], 0        \n\t"
+            "flw          f0, (a1)                \n\t"
+            "flw          f1, 4(a1)               \n\t"
+            "flw          f2, 8(a1)               \n\t"
+            "flw          f3, 12(a1)              \n\t"
+            "flw          f4, 16(a1)              \n\t"
+            "flw          f5, 20(a1)              \n\t"
+            "flw          f6, 24(a1)              \n\t"
+            "flw          f7, 28(a1)              \n\t"
+            "addi         a1, a1, 32              \n\t"
+            "fmax.s       f1, f0, f1              \n\t"
+            "fmax.s       f3, f2, f3              \n\t"
+            "fmax.s       f5, f4, f5              \n\t"
+            "fmax.s       f7, f6, f7              \n\t"
+            "fmax.s       f3, f1, f3              \n\t"
+            "fmax.s       f7, f5, f7              \n\t"
+            "fmax.s       f10, f3, f7             \n\t"
+            "fmul.s       f10, f10, %[RMAXREC]    \n\t"
+            "fsw          f10, (%[DST])           \n\t"
+            "addi         %[DST], %[DST], 20      \n\t"
+            "fdiv.s       f10, %[FONE], f10       \n\t"
+            "flw          f0, (a1)                \n\t"
+            "flw          f1, 4(a1)               \n\t"
+            "flw          f2, 8(a1)               \n\t"
+            "flw          f3, 12(a1)              \n\t"
+            "flw          f4, 16(a1)              \n\t"
+            "flw          f5, 20(a1)              \n\t"
+            "flw          f6, 24(a1)              \n\t"
+            "flw          f7, 28(a1)              \n\t"
+            "addi         a1, a1, 32              \n\t"
+            "fmax.s       f1, f0, f1              \n\t"
+            "fmax.s       f3, f2, f3              \n\t"
+            "fmax.s       f5, f4, f5              \n\t"
+            "fmax.s       f7, f6, f7              \n\t"
+            "fmax.s       f3, f1, f3              \n\t"
+            "fmax.s       f7, f5, f7              \n\t"
+            "fmax.s       f11, f3, f7             \n\t"
+            "fmul.s       f11, f11, %[RMAXREC]    \n\t"
+            "fsw          f11, (%[DST])           \n\t"
+            "addi         %[DST], %[DST], 20      \n\t"
+            "fdiv.s       f11, %[FONE], f11       \n\t"
+            "flw          f0, (a1)                \n\t"
+            "flw          f1, 4(a1)               \n\t"
+            "flw          f2, 8(a1)               \n\t"
+            "flw          f3, 12(a1)              \n\t"
+            "flw          f4, 16(a1)              \n\t"
+            "flw          f5, 20(a1)              \n\t"
+            "flw          f6, 24(a1)              \n\t"
+            "flw          f7, 28(a1)              \n\t"
+            "addi         a1, a1, 32              \n\t"
+            "fmax.s       f1, f0, f1              \n\t"
+            "fmax.s       f3, f2, f3              \n\t"
+            "fmax.s       f5, f4, f5              \n\t"
+            "fmax.s       f7, f6, f7              \n\t"
+            "fmax.s       f3, f1, f3              \n\t"
+            "fmax.s       f7, f5, f7              \n\t"
+            "fmax.s       f12, f3, f7             \n\t"
+            "fmul.s       f12, f12, %[RMAXREC]    \n\t"
+            "fsw          f12, (%[DST])           \n\t"
+            "addi         %[DST], %[DST], 20      \n\t"
+            "fdiv.s       f12, %[FONE], f12       \n\t"
+            "flw          f0, (a1)                \n\t"
+            "flw          f1, 4(a1)               \n\t"
+            "flw          f2, 8(a1)               \n\t"
+            "flw          f3, 12(a1)              \n\t"
+            "flw          f4, 16(a1)              \n\t"
+            "flw          f5, 20(a1)              \n\t"
+            "flw          f6, 24(a1)              \n\t"
+            "flw          f7, 28(a1)              \n\t"
+            "addi         a1, a1, 32              \n\t"
+            "fmax.s       f1, f0, f1              \n\t"
+            "fmax.s       f3, f2, f3              \n\t"
+            "fmax.s       f5, f4, f5              \n\t"
+            "fmax.s       f7, f6, f7              \n\t"
+            "fmax.s       f3, f1, f3              \n\t"
+            "fmax.s       f7, f5, f7              \n\t"
+            "fmax.s       f13, f3, f7             \n\t"
+            "fmul.s       f13, f13, %[RMAXREC]    \n\t"
+            "fsw          f13, (%[DST])           \n\t"
+            "addi         %[DST], %[DST], 20      \n\t"
+            "fdiv.s       f13, %[FONE], f13       \n\t"
+            "flw          f0, (a1)                \n\t"
+            "flw          f1, 4(a1)               \n\t"
+            "flw          f2, 8(a1)               \n\t"
+            "flw          f3, 12(a1)              \n\t"
+            "flw          f4, 16(a1)              \n\t"
+            "flw          f5, 20(a1)              \n\t"
+            "flw          f6, 24(a1)              \n\t"
+            "flw          f7, 28(a1)              \n\t"
+            "addi         a1, a1, 32              \n\t"
+            "fmax.s       f1, f0, f1              \n\t"
+            "fmax.s       f3, f2, f3              \n\t"
+            "fmax.s       f5, f4, f5              \n\t"
+            "fmax.s       f7, f6, f7              \n\t"
+            "fmax.s       f3, f1, f3              \n\t"
+            "fmax.s       f7, f5, f7              \n\t"
+            "fmax.s       f14, f3, f7             \n\t"
+            "fmul.s       f14, f14, %[RMAXREC]    \n\t"
+            "fsw          f14, (%[DST])           \n\t"
+            "addi         %[DST], %[DST], 20      \n\t"
+            "fdiv.s       f14, %[FONE], f14       \n\t"
+            "flw          f0, (a1)                \n\t"
+            "flw          f1, 4(a1)               \n\t"
+            "flw          f2, 8(a1)               \n\t"
+            "flw          f3, 12(a1)              \n\t"
+            "flw          f4, 16(a1)              \n\t"
+            "flw          f5, 20(a1)              \n\t"
+            "flw          f6, 24(a1)              \n\t"
+            "flw          f7, 28(a1)              \n\t"
+            "addi         a1, a1, 32              \n\t"
+            "fmax.s       f1, f0, f1              \n\t"
+            "fmax.s       f3, f2, f3              \n\t"
+            "fmax.s       f5, f4, f5              \n\t"
+            "fmax.s       f7, f6, f7              \n\t"
+            "fmax.s       f3, f1, f3              \n\t"
+            "fmax.s       f7, f5, f7              \n\t"
+            "fmax.s       f15, f3, f7             \n\t"
+            "fmul.s       f15, f15, %[RMAXREC]    \n\t"
+            "fsw          f15, (%[DST])           \n\t"
+            "addi         %[DST], %[DST], 20      \n\t"
+            "fdiv.s       f15, %[FONE], f15       \n\t"
+            "flw          f0, (a1)                \n\t"
+            "flw          f1, 4(a1)               \n\t"
+            "flw          f2, 8(a1)               \n\t"
+            "flw          f3, 12(a1)              \n\t"
+            "flw          f4, 16(a1)              \n\t"
+            "flw          f5, 20(a1)              \n\t"
+            "flw          f6, 24(a1)              \n\t"
+            "flw          f7, 28(a1)              \n\t"
+            "addi         a1, a1, 32              \n\t"
+            "fmax.s       f1, f0, f1              \n\t"
+            "fmax.s       f3, f2, f3              \n\t"
+            "fmax.s       f5, f4, f5              \n\t"
+            "fmax.s       f7, f6, f7              \n\t"
+            "fmax.s       f3, f1, f3              \n\t"
+            "fmax.s       f7, f5, f7              \n\t"
+            "fmax.s       f16, f3, f7             \n\t"
+            "fmul.s       f16, f16, %[RMAXREC]    \n\t"
+            "fsw          f16, (%[DST])           \n\t"
+            "addi         %[DST], %[DST], 20      \n\t"
+            "fdiv.s       f16, %[FONE], f16       \n\t"
+            "flw          f0, (a1)                \n\t"
+            "flw          f1, 4(a1)               \n\t"
+            "flw          f2, 8(a1)               \n\t"
+            "flw          f3, 12(a1)              \n\t"
+            "flw          f4, 16(a1)              \n\t"
+            "flw          f5, 20(a1)              \n\t"
+            "flw          f6, 24(a1)              \n\t"
+            "flw          f7, 28(a1)              \n\t"
+            "addi         a1, a1, 32              \n\t"
+            "fmax.s       f1, f0, f1              \n\t"
+            "fmax.s       f3, f2, f3              \n\t"
+            "fmax.s       f5, f4, f5              \n\t"
+            "fmax.s       f7, f6, f7              \n\t"
+            "fmax.s       f3, f1, f3              \n\t"
+            "fmax.s       f7, f5, f7              \n\t"
+            "fmax.s       f17, f3, f7             \n\t"
+            "fmul.s       f17, f17, %[RMAXREC]    \n\t"
+            "fsw          f17, (%[DST])           \n\t"
+            "addi         %[DST], %[DST], -136    \n\t"
+            "fdiv.s       f17, %[FONE], f17       \n\t"
+            "vsetvli      t0, zero, e32, m2       \n\t"
+            "vfmul.vf     v16, v0, f10            \n\t"
+            "vfmul.vf     v18, v2, f11            \n\t"
+            "vfmul.vf     v20, v4, f12            \n\t"
+            "vfmul.vf     v22, v6, f13            \n\t"
+            "vfmul.vf     v24, v8, f14            \n\t"
+            "vfmul.vf     v26, v10, f15           \n\t"
+            "vfmul.vf     v28, v12, f16           \n\t"
+            "vfmul.vf     v30, v14, f17           \n\t"
+            "vfcvt.x.f.v  v16, v16                \n\t"
+            "vfcvt.x.f.v  v18, v18                \n\t"
+            "vfcvt.x.f.v  v20, v20                \n\t"
+            "vfcvt.x.f.v  v22, v22                \n\t"
+            "vfcvt.x.f.v  v24, v24                \n\t"
+            "vfcvt.x.f.v  v26, v26                \n\t"
+            "vfcvt.x.f.v  v28, v28                \n\t"
+            "vfcvt.x.f.v  v30, v30                \n\t"
+            "vsetvli      t0, zero, e16, m1       \n\t"
+            "vnclip.wx    v16, v16, zero          \n\t"
+            "vnclip.wx    v18, v18, zero          \n\t"
+            "vnclip.wx    v20, v20, zero          \n\t"
+            "vnclip.wx    v22, v22, zero          \n\t"
+            "vnclip.wx    v24, v24, zero          \n\t"
+            "vnclip.wx    v26, v26, zero          \n\t"
+            "vnclip.wx    v28, v28, zero          \n\t"
+            "vnclip.wx    v30, v30, zero          \n\t"
+            "vsetvli      t0, t1, e8, mf2         \n\t"
+            "vnclip.wx    v16, v16, zero          \n\t"
+            "vnclip.wx    v18, v18, zero          \n\t"
+            "vnclip.wx    v20, v20, zero          \n\t"
+            "vnclip.wx    v22, v22, zero          \n\t"
+            "vnclip.wx    v24, v24, zero          \n\t"
+            "vnclip.wx    v26, v26, zero          \n\t"
+            "vnclip.wx    v28, v28, zero          \n\t"
+            "vnclip.wx    v30, v30, zero          \n\t"
+            "vse8.v       v16, (%[DST])           \n\t"
+            "addi         %[DST], %[DST], 20      \n\t"
+            "vse8.v       v18, (%[DST])           \n\t"
+            "addi         %[DST], %[DST], 20      \n\t"
+            "vse8.v       v20, (%[DST])           \n\t"
+            "addi         %[DST], %[DST], 20      \n\t"
+            "vse8.v       v22, (%[DST])           \n\t"
+            "addi         %[DST], %[DST], 20      \n\t"
+            "vse8.v       v24, (%[DST])           \n\t"
+            "addi         %[DST], %[DST], 20      \n\t"
+            "vse8.v       v26, (%[DST])           \n\t"
+            "addi         %[DST], %[DST], 20      \n\t"
+            "vse8.v       v28, (%[DST])           \n\t"
+            "addi         %[DST], %[DST], 20      \n\t"
+            "vse8.v       v30, (%[DST])           \n\t"
+            "addi         %[DST], %[DST], 16      \n\t"
+            "bge          %[K], t3, LOOP_MAIN%=   \n\t"
+            "blt          %[K], t2, TAIL%=        \n\t"
+            "LOOP_K%=:                            \n\t"
+            "vsetvli      t1, %[K], e32, m2       \n\t"
+            "vle32.v      v0, (%[SRC])            \n\t"
+            "addi         %[SRC], %[SRC], 64      \n\t"
+            "sub          %[K], %[K], t1          \n\t"
+            "vfabs.v      v16, v0                 \n\t"
+            "vsetvli      t0, zero, e32, m1       \n\t"
+            "vfmax.vv     v16, v16, v17           \n\t"
+            "vse32.v      v16, (%[BUFFER])        \n\t"
+            "flw          f0, (%[BUFFER])         \n\t"
+            "flw          f1, 4(%[BUFFER])        \n\t"
+            "flw          f2, 8(%[BUFFER])        \n\t"
+            "flw          f3, 12(%[BUFFER])       \n\t"
+            "flw          f4, 16(%[BUFFER])       \n\t"
+            "flw          f5, 20(%[BUFFER])       \n\t"
+            "flw          f6, 24(%[BUFFER])       \n\t"
+            "flw          f7, 28(%[BUFFER])       \n\t"
+            "fmax.s       f1, f0, f1              \n\t"
+            "fmax.s       f3, f2, f3              \n\t"
+            "fmax.s       f5, f4, f5              \n\t"
+            "fmax.s       f7, f6, f7              \n\t"
+            "fmax.s       f3, f1, f3              \n\t"
+            "fmax.s       f7, f5, f7              \n\t"
+            "fmax.s       f10, f3, f7             \n\t"
+            "fmul.s       f10, f10, %[RMAXREC]    \n\t"
+            "fsw          f10, (%[DST])           \n\t"
+            "addi         %[DST], %[DST], 4       \n\t"
+            "fdiv.s       f11, %[FONE], f10       \n\t"
+            "vsetvli      t0, zero, e32, m2       \n\t"
+            "vfmul.vf     v16, v0, f11            \n\t"
+            "vfcvt.x.f.v  v16, v16                \n\t"
+            "vsetvli      t0, zero, e16, m1       \n\t"
+            "vnclip.wx    v16, v16, zero          \n\t"
+            "vsetvli      t0, t1, e8, mf2         \n\t"
+            "vnclip.wx    v16, v16, zero          \n\t"
+            "vse8.v       v16, (%[DST])           \n\t"
+            "addi         %[DST], %[DST], 16      \n\t"
+            "bge          %[K], t2, LOOP_K%=      \n\t"
+            "TAIL%=:                              \n\t"
+            "blez         %[K], END%=             \n\t"
+            "vsetvli      t0, t3, e32, m2         \n\t"
+            "vxor.vv      v16, v16, v16           \n\t"
+            "jal          x0, LOOP_K%=            \n\t"
+            "END%=:                               \n\t"
+            : [SRC] "+r"(SRC), [DST] "+r"(DST), [K] "+r"(CountK)
+            : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [BUFFER] "r"(buffer)
+            : "cc", "t3", "t2", "t1", "t0", "a1", "f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7", "f10", "f11", "f12",
+              "f13", "f14", "f15", "f16", "f17");
+    } else if (BlkLen == 32) {
+        __asm__ volatile(
+            "addi         t3, zero, 32*4          \n\t"
+            "addi         t2, zero, 32            \n\t"
+
+            "addi         a1, %[SRC], 0           \n\t"
+            "addi         a2, %[SRC], 128         \n\t"
+            "addi         a3, %[SRC], 256         \n\t"
+            "addi         a4, %[SRC], 384         \n\t"
+
+            "addi         s1, %[DST], 0           \n\t"
+            "addi         s2, %[DST], 36          \n\t"
+            "addi         s3, %[DST], 72          \n\t"
+            "addi         s4, %[DST], 108         \n\t"
+            "blt          %[K], t3, LOOP_K%=      \n\t"
+            "blt          %[K], t2, TAIL%=        \n\t"
+
+            "LOOP_MAIN%=:                         \n\t"
+            "vsetvli      t1, zero, e32, m4       \n\t"
+            "addi         %[K], %[K], -128        \n\t"
+            "vle32.v      v0, (a1)                \n\t"
+            "addi         a1, a1, 512             \n\t"
+            "vle32.v      v4, (a2)                \n\t"
+            "addi         a2, a2, 512             \n\t"
+            "vle32.v      v8, (a3)                \n\t"
+            "addi         a3, a3, 512             \n\t"
+            "vle32.v      v12, (a4)               \n\t"
+            "addi         a4, a4, 512             \n\t"
+            "vfabs.v      v16, v0                 \n\t"
+            "vfabs.v      v20, v4                 \n\t"
+            "vfabs.v      v24, v8                 \n\t"
+            "vfabs.v      v28, v12                \n\t"
+            "vsetvli      t0, zero, e32, m2       \n\t"
+            "vfmax.vv     v16, v16, v18           \n\t"
+            "vfmax.vv     v20, v20, v22           \n\t"
+            "vfmax.vv     v24, v24, v26           \n\t"
+            "vfmax.vv     v28, v28, v30           \n\t"
+            "vsetvli      t0, zero, e32, m1       \n\t"
+            "vfmax.vv     v16, v16, v17           \n\t"
+            "vfmax.vv     v20, v20, v21           \n\t"
+            "vfmax.vv     v24, v24, v25           \n\t"
+            "vfmax.vv     v28, v28, v29           \n\t"
+
+            "vfredmax.vs  v17, v16, v17           \n\t"
+            "vfredmax.vs  v21, v20, v21           \n\t"
+            "vfredmax.vs  v25, v24, v25           \n\t"
+            "vfredmax.vs  v29, v28, v29           \n\t"
+            "vfmv.f.s     f10,  v17               \n\t"
+            "vfmv.f.s     f11,  v21               \n\t"
+            "vfmv.f.s     f12,  v25               \n\t"
+            "vfmv.f.s     f13,  v29               \n\t"
+
+            "fmul.s       f10, f10, %[RMAXREC]    \n\t"
+            "fmul.s       f11, f11, %[RMAXREC]    \n\t"
+            "fmul.s       f12, f12, %[RMAXREC]    \n\t"
+            "fmul.s       f13, f13, %[RMAXREC]    \n\t"
+            "fsw          f10, (s1)               \n\t"
+            "addi         s1, s1, 4               \n\t"
+
+            "fsw          f11, (s2)               \n\t"
+            "addi         s2, s2, 4               \n\t"
+            "fsw          f12, (s3)               \n\t"
+            "addi         s3, s3, 4               \n\t"
+            "fsw          f13, (s4)               \n\t"
+            "addi         s4, s4, 4               \n\t"
+            "fdiv.s       f10, %[FONE], f10       \n\t"
+            "fdiv.s       f11, %[FONE], f11       \n\t"
+            "fdiv.s       f12, %[FONE], f12       \n\t"
+            "fdiv.s       f13, %[FONE], f13       \n\t"
+            "vsetvli      t0, zero, e32, m4       \n\t"
+            "vfmul.vf     v16, v0, f10            \n\t"
+            "vfmul.vf     v20, v4, f11            \n\t"
+            "vfmul.vf     v24, v8, f12            \n\t"
+            "vfmul.vf     v28, v12, f13           \n\t"
+            "vfcvt.x.f.v  v16, v16                \n\t"
+            "vfcvt.x.f.v  v20, v20                \n\t"
+            "vfcvt.x.f.v  v24, v24                \n\t"
+            "vfcvt.x.f.v  v28, v28                \n\t"
+            "vsetvli      t0, zero, e16, m2       \n\t"
+            "vnclip.wx    v16, v16, zero          \n\t"
+            "vnclip.wx    v20, v20, zero          \n\t"
+            "vnclip.wx    v24, v24, zero          \n\t"
+            "vnclip.wx    v28, v28, zero          \n\t"
+            "vsetvli      t0, t1, e8, m1          \n\t"
+            "vnclip.wx    v16, v16, zero          \n\t"
+            "vnclip.wx    v20, v20, zero          \n\t"
+            "vnclip.wx    v24, v24, zero          \n\t"
+            "vnclip.wx    v28, v28, zero          \n\t"
+            "vse8.v       v16, (s1)               \n\t"
+            "addi         s1, s1, 140             \n\t"
+            "vse8.v       v20, (s2)               \n\t"
+            "addi         s2, s2, 140             \n\t"
+            "vse8.v       v24, (s3)               \n\t"
+            "addi         s3, s3, 140             \n\t"
+            "vse8.v       v28, (s4)               \n\t"
+            "addi         s4, s4, 140             \n\t"
+            "bge          %[K], t3, LOOP_MAIN%=   \n\t"
+            "blt          %[K], t2, TAIL%=        \n\t"
+            "LOOP_K%=:                            \n\t"
+            "vsetvli      t1, %[K], e32, m4       \n\t"
+            "vle32.v      v0, (a1)                \n\t"
+            "addi         a1, a1, 128             \n\t"
+            "sub          %[K], %[K], t1          \n\t"
+            "vfabs.v      v16, v0                 \n\t"
+            "vsetvli      t0, zero, e32, m2       \n\t"
+            "vfmax.vv     v16, v16, v18           \n\t"
+            "vsetvli      t0, zero, e32, m1       \n\t"
+            "vfmax.vv     v16, v16, v17           \n\t"
+            "vfredmax.vs  v17, v16, v17           \n\t"
+            "vfmv.f.s     f10,  v17               \n\t"
+
+            "fmul.s       f10, f10, %[RMAXREC]    \n\t"
+            "fsw          f10, (s1)               \n\t"
+            "addi         s1, s1, 4               \n\t"
+            "fdiv.s       f11, %[FONE], f10       \n\t"
+            "vsetvli      t0, zero, e32, m4       \n\t"
+            "vfmul.vf     v16, v0, f11            \n\t"
+            "vfcvt.x.f.v  v16, v16                \n\t"
+            "vsetvli      t0, zero, e16, m2       \n\t"
+            "vnclip.wx    v16, v16, zero          \n\t"
+            "vsetvli      t0, zero, e8, m1        \n\t"
+            "vnclip.wx    v16, v16, zero          \n\t"
+            "vse8.v       v16, (s1)               \n\t"
+            "addi         s1, s1, 32              \n\t"
+            "bge          %[K], t2, LOOP_K%=      \n\t"
+            "TAIL%=:                              \n\t"
+            "blez         %[K], END%=             \n\t"
+            "vsetvli      t0, t3, e32, m4         \n\t"
+            "vxor.vv      v0, v0, v0              \n\t"
+            "vxor.vv      v16, v16, v16           \n\t"
+            "jal          x0, LOOP_K%=            \n\t"
+            "END%=:                               \n\t"
+            : [K] "+r"(CountK)
+            : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC), [DST] "r"(DST)
+            : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "a3", "a4", "s1", "s2", "s3", "s4", "f10", "f11", "f12", "f13");
+    } else if (BlkLen == 64) {
+        __asm__ volatile(
+            "addi         t3, zero, 64*2          \n\t"
+            "addi         t2, zero, 64            \n\t"
+            "addi         a1, %[SRC], 0           \n\t"
+            "addi         a2, %[SRC], 256         \n\t"
+            "addi         s1, %[DST], 0           \n\t"
+            "addi         s2, %[DST], 68          \n\t"
+            "blt          %[K], t3, LOOP_K%=      \n\t"
+            "blt          %[K], t2, TAIL%=        \n\t"
+            "LOOP_MAIN%=:                         \n\t"
+            "vsetvli      t1, zero, e32, m8       \n\t"
+            "addi         %[K], %[K], -128        \n\t"
+            "vle32.v      v0, (a1)                \n\t"
+            "addi         a1, a1, 512             \n\t"
+            "vle32.v      v8, (a2)                \n\t"
+            "addi         a2, a2, 512             \n\t"
+            "vfabs.v      v16, v0                 \n\t"
+            "vfabs.v      v24, v8                 \n\t"
+            "vsetvli      t0, zero, e32, m4       \n\t"
+            "vfmax.vv     v16, v16, v20           \n\t"
+            "vfmax.vv     v24, v24, v28           \n\t"
+            "vsetvli      t0, zero, e32, m2       \n\t"
+            "vfmax.vv     v16, v16, v18           \n\t"
+            "vfmax.vv     v24, v24, v26           \n\t"
+            "vsetvli      t0, zero, e32, m1       \n\t"
+            "vfmax.vv     v16, v16, v17           \n\t"
+            "vfmax.vv     v24, v24, v25           \n\t"
+            "vfredmax.vs  v17, v16, v17           \n\t"
+            "vfredmax.vs  v25, v24, v25           \n\t"
+            "vfmv.f.s     f10,  v17               \n\t"
+            "vfmv.f.s     f11,  v25               \n\t"
+            "fmul.s       f10, f10, %[RMAXREC]    \n\t"
+            "fmul.s       f11, f11, %[RMAXREC]    \n\t"
+            "fsw          f10, (s1)               \n\t"
+            "addi         s1, s1, 4               \n\t"
+            "fsw          f11, (s2)               \n\t"
+            "addi         s2, s2, 4               \n\t"
+            "fdiv.s       f10, %[FONE], f10       \n\t"
+            "fdiv.s       f11, %[FONE], f11       \n\t"
+            "vsetvli      t0, zero, e32, m8       \n\t"
+            "vfmul.vf     v16, v0, f10            \n\t"
+            "vfmul.vf     v24, v8, f11            \n\t"
+            "vfcvt.x.f.v  v16, v16                \n\t"
+            "vfcvt.x.f.v  v24, v24                \n\t"
+            "vsetvli      t0, zero, e16, m4       \n\t"
+            "vnclip.wx    v16, v16, zero          \n\t"
+            "vnclip.wx    v24, v24, zero          \n\t"
+            "vsetvli      t0, t1, e8, m2          \n\t"
+            "vnclip.wx    v16, v16, zero          \n\t"
+            "vnclip.wx    v24, v24, zero          \n\t"
+            "vse8.v       v16, (s1)               \n\t"
+            "addi         s1, s1, 132             \n\t"
+            "vse8.v       v24, (s2)               \n\t"
+            "addi         s2, s2, 132             \n\t"
+            "bge          %[K], t3, LOOP_MAIN%=   \n\t"
+            "blt          %[K], t2, TAIL%=        \n\t"
+            "LOOP_K%=:                            \n\t"
+            "vsetvli      t1, %[K], e32, m8       \n\t"
+            "vle32.v      v0, (a1)                \n\t"
+            "addi         a1, a1, 256             \n\t"
+            "sub          %[K], %[K], t1          \n\t"
+            "vfabs.v      v16, v0                 \n\t"
+            "vsetvli      t0, zero, e32, m4       \n\t"
+            "vfmax.vv     v16, v16, v20           \n\t"
+            "vsetvli      t0, zero, e32, m2       \n\t"
+            "vfmax.vv     v16, v16, v18           \n\t"
+            "vsetvli      t0, zero, e32, m1       \n\t"
+            "vfmax.vv     v16, v16, v17           \n\t"
+            "vfredmax.vs  v17, v16, v17           \n\t"
+            "vfmv.f.s     f10,  v17               \n\t"
+            "fmul.s       f10, f10, %[RMAXREC]    \n\t"
+            "fsw          f10, (s1)               \n\t"
+            "addi         s1, s1, 4               \n\t"
+            "fdiv.s       f11, %[FONE], f10       \n\t"
+            "vsetvli      t0, zero, e32, m8       \n\t"
+            "vfmul.vf     v16, v0, f11            \n\t"
+            "vfcvt.x.f.v  v16, v16                \n\t"
+            "vsetvli      t0, zero, e16, m4       \n\t"
+            "vnclip.wx    v16, v16, zero          \n\t"
+            "vsetvli      t0, zero, e8, m2        \n\t"
+            "vnclip.wx    v16, v16, zero          \n\t"
+            "vse8.v       v16, (s1)               \n\t"
+            "addi         s1, s1, 64              \n\t"
+            "bge          %[K], t2, LOOP_K%=      \n\t"
+            "TAIL%=:                              \n\t"
+            "blez         %[K], END%=             \n\t"
+            "vsetvli      t0, t3, e32, m8         \n\t"
+            "vxor.vv      v0, v0, v0              \n\t"
+            "vxor.vv      v16, v16, v16           \n\t"
+            "jal          x0, LOOP_K%=            \n\t"
+            "END%=:                               \n\t"
+            : [K] "+r"(CountK)
+            : [SRC] "r"(SRC), [DST] "r"(DST), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal)
+            : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "s1", "s2", "f10", "f11");
+    } else if (BlkLen == 128) {
+        __asm__ volatile(
+            "addi         t2, zero, 128           \n\t"
+            "addi         a1, %[SRC], 0           \n\t"
+            "addi         a2, %[SRC], 256         \n\t"
+            "blt          %[K], t2, TAIL%=        \n\t"
+            "LOOP_K%=:                            \n\t"
+            "vsetvli      t1, zero, e32, m8       \n\t"
+            "vle32.v      v0, (a1)                \n\t"
+            "addi         a1, a1, 512             \n\t"
+            "vle32.v      v8, (a2)                \n\t"
+            "addi         a2, a2, 512             \n\t"
+            "sub          %[K], %[K], t2          \n\t"
+            "QUANT%=:                             \n\t"
+            "vfabs.v      v16, v0                 \n\t"
+            "vfabs.v      v24, v8                 \n\t"
+            "vfmax.vv     v24, v16, v24           \n\t"
+            "vsetvli      t1, zero, e32, m4       \n\t"
+            "vfmax.vv     v28, v24, v28           \n\t"
+            "vsetvli      t0, zero, e32, m2       \n\t"
+            "vfmax.vv     v30, v28, v30           \n\t"
+            "vsetvli      t0, zero, e32, m1       \n\t"
+            "vfmax.vv     v30, v30, v31           \n\t"
+            "vfredmax.vs  v31, v30, v31           \n\t"
+            "vfmv.f.s     f10, v31                \n\t"
+            "fmul.s       f10, f10, %[RMAXREC]    \n\t"
+            "fsw          f10, (%[DST])           \n\t"
+            "addi         %[DST], %[DST], 4       \n\t"
+            "fdiv.s       f11, %[FONE], f10       \n\t"
+            "vsetvli      t0, zero, e32, m8       \n\t"
+            "vfmul.vf     v16, v0, f11            \n\t"
+            "vfmul.vf     v24, v8, f11            \n\t"
+            "vfcvt.x.f.v  v16, v16                \n\t"
+            "vfcvt.x.f.v  v24, v24                \n\t"
+            "vsetvli      t0, zero, e16, m4       \n\t"
+            "vnclip.wx    v16, v16, zero          \n\t"
+            "vnclip.wx    v20, v24, zero          \n\t"
+            "vsetvli      t0, zero, e8, m4        \n\t"
+            "vnclip.wx    v16, v16, zero          \n\t"
+            "vse8.v       v16, (%[DST])           \n\t"
+            "addi         %[DST], %[DST], 128     \n\t"
+            "bge          %[K], t2, LOOP_K%=      \n\t"
+            "TAIL%=:                              \n\t"
+            "blez         %[K], END%=             \n\t"
+            "vsetvli      t1, zero, e32, m8       \n\t"
+            "vxor.vv      v0, v0, v0              \n\t"
+            "vxor.vv      v8, v8, v8              \n\t"
+            "vsetvli      t0, %[K], e32, m8       \n\t"
+            "vle32.v      v0, (a1)                \n\t"
+            "sub          %[K], %[K], t0          \n\t"
+            "vsetvli      t0, %[K], e32, m8       \n\t"
+            "vle32.v      v8, (a2)                \n\t"
+            "sub          %[K], %[K], t0          \n\t"
+            "vsetvli      t1, zero, e32, m8       \n\t"
+            "jal          x0, QUANT%=             \n\t"
+            "END%=:                               \n\t"
+
+            : [DST] "+r"(DST), [K] "+r"(CountK)
+            : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC)
+            : "cc", "t2", "t1", "t0", "a1", "a2", "f10", "f11");
+    } else {
+        float  buffer[8] = { 0.0f };
+        size_t cnt       = BlkLen / 256;
+
+        __asm__ volatile(
+            "slli         t3, %[BLK], 2           \n\t"
+            "blt       %[K], %[BLK], LOOP_TAIL%=  \n\t"
+            "LOOP_MAIN%=:                         \n\t"
+            "vsetvli      t0, zero, e32, m1       \n\t"
+            "vxor.vv      v31, v31, v31           \n\t"
+            "vse32.v      v31, (%[BUFFER])        \n\t"
+            "addi         t6, %[CNT], 0           \n\t"
+            "LOOP_CMP%=:                          \n\t"
+            "addi         t6, t6, -1              \n\t"
+            "vsetvli      t0, zero, e32, m8       \n\t"
+            "vle32.v      v0, (%[SRC])            \n\t"
+            "addi         %[SRC], %[SRC], 256     \n\t"
+            "vle32.v      v8, (%[SRC])            \n\t"
+            "addi         %[SRC], %[SRC], 256     \n\t"
+            "vle32.v      v16, (%[SRC])           \n\t"
+            "addi         %[SRC], %[SRC], 256     \n\t"
+            "vle32.v      v24, (%[SRC])           \n\t"
+            "addi         %[SRC], %[SRC], 256     \n\t"
+            "vfabs.v      v0, v0                  \n\t"
+            "vfabs.v      v8, v8                  \n\t"
+            "vfabs.v      v16, v16                \n\t"
+            "vfabs.v      v24, v24                \n\t"
+            "vfmax.vv     v8, v0, v8              \n\t"
+            "vfmax.vv     v16, v16, v24           \n\t"
+            "vfmax.vv     v0, v0, v16             \n\t"
+            "vsetvli      t0, zero, e32, m4       \n\t"
+            "vfmax.vv     v0, v0, v4              \n\t"
+            "vsetvli      t0, zero, e32, m2       \n\t"
+            "vfmax.vv     v0, v0, v2              \n\t"
+            "vsetvli      t0, zero, e32, m1       \n\t"
+            "vfmax.vv     v0, v0, v1              \n\t"
+            "vle32.v      v30, (%[BUFFER])        \n\t"
+            "vfmax.vv     v31, v30,  v0           \n\t"
+            "vse32.v      v31, (%[BUFFER])        \n\t"
+            "bnez         t6, LOOP_CMP%=          \n\t"
+            "sub          %[SRC], %[SRC], t3      \n\t"
+            "addi         t6, %[CNT], 0           \n\t"
+            "flw          f0, (%[BUFFER])         \n\t"
+            "flw          f1, 4(%[BUFFER])        \n\t"
+            "flw          f2, 8(%[BUFFER])        \n\t"
+            "flw          f3, 12(%[BUFFER])       \n\t"
+            "flw          f4, 16(%[BUFFER])       \n\t"
+            "flw          f5, 20(%[BUFFER])       \n\t"
+            "flw          f6, 24(%[BUFFER])       \n\t"
+            "flw          f7, 28(%[BUFFER])       \n\t"
+            "fmax.s       f1, f0, f1              \n\t"
+            "fmax.s       f3, f2, f3              \n\t"
+            "fmax.s       f5, f4, f5              \n\t"
+            "fmax.s       f7, f6, f7              \n\t"
+            "fmax.s       f3, f1, f3              \n\t"
+            "fmax.s       f7, f5, f7              \n\t"
+            "fmax.s       f10, f3, f7             \n\t"
+            "fmul.s       f10, f10, %[RMAXREC]    \n\t"
+            "fsw          f10,  (%[DST])          \n\t"
+            "addi         %[DST], %[DST], 4       \n\t"
+            "fdiv.s       f11, %[FONE], f10       \n\t"
+            "addi         t6,  %[CNT], 0          \n\t"
+            "LOOP_QUANT%=:                        \n\t"
+            "addi         t6, t6, -1              \n\t"
+            "vsetvli      t0, zero, e32, m8       \n\t"
+            "vle32.v      v0, (%[SRC])            \n\t"
+            "addi         %[SRC], %[SRC], 256     \n\t"
+            "vle32.v      v8, (%[SRC])            \n\t"
+            "addi         %[SRC], %[SRC], 256     \n\t"
+            "vle32.v      v16, (%[SRC])           \n\t"
+            "addi         %[SRC], %[SRC], 256     \n\t"
+            "vle32.v      v24, (%[SRC])           \n\t"
+            "addi         %[SRC], %[SRC], 256     \n\t"
+            "vsetvli      t0, zero, e32, m8       \n\t"
+            "vfmul.vf     v0, v0, f11             \n\t"
+            "vfmul.vf     v8, v8, f11             \n\t"
+            "vfmul.vf     v16, v16, f11           \n\t"
+            "vfmul.vf     v24, v24, f11           \n\t"
+            "vfcvt.x.f.v  v0, v0                  \n\t"
+            "vfcvt.x.f.v  v8, v8                  \n\t"
+            "vfcvt.x.f.v  v16, v16                \n\t"
+            "vfcvt.x.f.v  v24, v24                \n\t"
+            "vsetvli      t0, zero, e16, m4       \n\t"
+            "vnclip.wx    v0, v0, zero            \n\t"
+            "vnclip.wx    v4, v8, zero            \n\t"
+            "vnclip.wx    v8, v16, zero           \n\t"
+            "vnclip.wx    v12, v24, zero          \n\t"
+            "vsetvli      t0, zero, e8, m4        \n\t"
+            "vnclip.wx    v0, v0, zero            \n\t"
+            "vnclip.wx    v4, v8, zero            \n\t"
+            "vse8.v       v0, (%[DST])            \n\t"
+            "addi         %[DST], %[DST], 128     \n\t"
+            "vse8.v       v4, (%[DST])            \n\t"
+            "addi         %[DST], %[DST], 128     \n\t"
+            "bnez         t6, LOOP_QUANT%=        \n\t"
+            "sub           %[K], %[K], %[BLK]     \n\t"
+            "bge        %[K], %[BLK], LOOP_MAIN%= \n\t"
+            "blez         %[K], END%=             \n\t"
+            "LOOP_TAIL%=:                         \n\t"
+            "vsetvli      t0, zero, e32, m1       \n\t"
+            "vxor.vv      v31, v31, v31           \n\t"
+            "vse32.v      v31, (%[BUFFER])        \n\t"
+            "addi         t6, %[K], 0             \n\t"
+            "addi         s1, %[SRC], 0           \n\t"
+            "TAIL_CMP%=:                          \n\t"
+            "vsetvli      t0, zero, e32, m8       \n\t"
+            "vxor.vv       v0, v0, v0             \n\t"
+            "vsetvli      t0, t6, e32, m8         \n\t"
+            "vle32.v      v0, (%[SRC])            \n\t"
+            "addi         %[SRC], %[SRC], 256     \n\t"
+            "sub          t6, t6, t0              \n\t"
+            "vfabs.v      v0, v0                  \n\t"
+            "vsetvli      t0, zero, e32, m4       \n\t"
+            "vfmax.vv     v0, v0, v4              \n\t"
+            "vsetvli      t0, zero, e32, m2       \n\t"
+            "vfmax.vv     v0, v0, v2              \n\t"
+            "vsetvli      t0, zero, e32, m1       \n\t"
+            "vfmax.vv     v0, v0, v1              \n\t"
+            "vle32.v      v30, (%[BUFFER])        \n\t"
+            "vfmax.vv     v31, v30,  v0           \n\t"
+            "vse32.v      v31, (%[BUFFER])        \n\t"
+            "bnez         t6, TAIL_CMP%=          \n\t"
+            "addi         t6, %[K], 0             \n\t"
+            "flw          f0, (%[BUFFER])         \n\t"
+            "flw          f1, 4(%[BUFFER])        \n\t"
+            "flw          f2, 8(%[BUFFER])        \n\t"
+            "flw          f3, 12(%[BUFFER])       \n\t"
+            "flw          f4, 16(%[BUFFER])       \n\t"
+            "flw          f5, 20(%[BUFFER])       \n\t"
+            "flw          f6, 24(%[BUFFER])       \n\t"
+            "flw          f7, 28(%[BUFFER])       \n\t"
+            "fmax.s       f1, f0, f1              \n\t"
+            "fmax.s       f3, f2, f3              \n\t"
+            "fmax.s       f5, f4, f5              \n\t"
+            "fmax.s       f7, f6, f7              \n\t"
+            "fmax.s       f3, f1, f3              \n\t"
+            "fmax.s       f7, f5, f7              \n\t"
+            "fmax.s       f10, f3, f7             \n\t"
+            "fmul.s       f10, f10, %[RMAXREC]    \n\t"
+            "fsw          f10,  (%[DST])          \n\t"
+            "addi         %[DST], %[DST], 4       \n\t"
+            "fdiv.s       f11, %[FONE], f10       \n\t"
+            "addi         t6,  %[K], 0            \n\t"
+            "TAIL_QUANT%=:                        \n\t"
+            "vsetvli      t0, zero, e32, m8       \n\t"
+            "vxor.vv       v0, v0, v0             \n\t"
+            "vsetvli      t1, t6, e32, m8         \n\t"
+            "vle32.v      v0, (s1)                \n\t"
+            "addi         s1, s1, 256             \n\t"
+            "sub          t6, t6, t1              \n\t"
+            "vsetvli      t0, zero, e32, m8       \n\t"
+            "vfmul.vf     v0, v0, f11             \n\t"
+            "vfcvt.x.f.v  v0, v0                  \n\t"
+            "vsetvli      t0, zero, e16, m4       \n\t"
+            "vnclip.wx    v0, v0, zero            \n\t"
+            "vsetvli      t0, t1, e8, m2          \n\t"
+            "vnclip.wx    v0, v0, zero            \n\t"
+            "vse8.v       v0, (%[DST])            \n\t"
+            "addi         %[DST], %[DST], 64      \n\t"
+            "bnez         t6, TAIL_QUANT%=        \n\t"
+            "END%=:                               \n\t"
+            : [SRC] "+r"(SRC), [DST] "+r"(DST), [K] "+r"(CountK)
+            : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [BLK] "r"(BlkLen), [BUFFER] "r"(buffer),
+              [CNT] "r"(cnt)
+            : "cc", "t1", "t0", "t6", "s1", "f0", "f1", "f2", "f3", "f4", "f5", "f6");
+    }
+}
+
+}  // namespace ime1
+
+namespace {
+#define SQ4BIT_KERNEL_COMP_1x8x2_4X8X4          \
+    "vmadot       v16, v14, v0            \n\t" \
+    "vmadot       v18, v14, v1            \n\t" \
+    "vmadot       v20, v14, v2            \n\t" \
+    "vmadot       v22, v14, v3            \n\t" \
+    "vmadot       v16, v15, v4            \n\t" \
+    "vmadot       v18, v15, v5            \n\t" \
+    "vmadot       v20, v15, v6            \n\t" \
+    "vmadot       v22, v15, v7            \n\t"
+
+#define SQ4BIT_KERNEL_ACC_1X4X4                 \
+    "vfcvt.f.x.v  v16,  v16               \n\t" \
+    "vfcvt.f.x.v  v18,  v18               \n\t" \
+    "vfcvt.f.x.v  v20,  v20               \n\t" \
+    "vfcvt.f.x.v  v22,  v22               \n\t" \
+    "addi         s2, s1, 16              \n\t" \
+    "addi         s3, s1, 32              \n\t" \
+    "addi         s4, s1, 48              \n\t" \
+    "addi         s6, s5, 12              \n\t" \
+    "vfmacc.vv    v28, v16, v24           \n\t" \
+    "vfmacc.vv    v29, v18, v25           \n\t" \
+    "vfmacc.vv    v30, v20, v26           \n\t" \
+    "vfmacc.vv    v31, v22, v27           \n\t"
+
+#define SQ4BIT_KERNEL_ACC_F16_1X4X4             \
+    "vfcvt.f.x.v  v16,  v16               \n\t" \
+    "vfcvt.f.x.v  v18,  v18               \n\t" \
+    "vfcvt.f.x.v  v20,  v20               \n\t" \
+    "vfcvt.f.x.v  v22,  v22               \n\t" \
+    "addi         s2, s1, 8               \n\t" \
+    "addi         s3, s1, 16              \n\t" \
+    "addi         s4, s1, 24              \n\t" \
+    "addi         s6, s5, 12              \n\t" \
+    "vfmacc.vv    v28, v16, v24           \n\t" \
+    "vfmacc.vv    v29, v18, v25           \n\t" \
+    "vfmacc.vv    v30, v20, v26           \n\t" \
+    "vfmacc.vv    v31, v22, v27           \n\t"
+
+#define SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4          \
+    "vle8.v       v4, (s1)                \n\t" \
+    "addi         s1, s1, 128             \n\t" \
+    "vle8.v       v5, (s2)                \n\t" \
+    "addi         s2, s2, 128             \n\t" \
+    "vle8.v       v6, (s3)                \n\t" \
+    "addi         s3, s3, 128             \n\t" \
+    "vle8.v       v7, (s4)                \n\t" \
+    "addi         s4, s4, 128             \n\t" \
+    "vsetvli      t0, zero, e8, mf4       \n\t" \
+    "vle8.v       v14, (s5)               \n\t" \
+    "addi         s5, s5, 16              \n\t" \
+    "vle8.v       v15, (s6)               \n\t" \
+    "addi         s6, s6, 16              \n\t" \
+    "addi         t5, t5, -1              \n\t" \
+    "vsetvli      t0, zero, e8, m1        \n\t" \
+    "vand.vi      v0, v4, 15              \n\t" \
+    "vand.vi      v1, v5, 15              \n\t" \
+    "vand.vi      v2, v6, 15              \n\t" \
+    "vand.vi      v3, v7, 15              \n\t" \
+    "vsrl.vi      v4, v4, 4               \n\t" \
+    "vsrl.vi      v5, v5, 4               \n\t" \
+    "vsrl.vi      v6, v6, 4               \n\t" \
+    "vsrl.vi      v7, v7, 4               \n\t"
+
+#define SQ4BIT_KERNEL_LOAD_ZP_16X1              \
+    "vsetvli      t0, zero, e8, mf2       \n\t" \
+    "vle8.v       v1, (s7)                \n\t" \
+    "vsetvli      t0, zero, e8, m1        \n\t" \
+    "vrgather.vv  v8, v1, v13             \n\t" \
+    "vadd.vi      v13, v13, 4             \n\t" \
+    "vrgather.vv  v9, v1, v13             \n\t" \
+    "vadd.vi      v13, v13, 4             \n\t" \
+    "vrgather.vv  v10, v1, v13            \n\t" \
+    "vadd.vi      v13, v13, 4             \n\t" \
+    "vrgather.vv  v11, v1, v13            \n\t" \
+    "vadd.vi      v13, v13, -12           \n\t"
+
+// using for M4Kernel
+#define LOAD_B_16x8x2                           \
+    "vsetvli      t0, zero, e8, m1        \n\t" \
+    "vle8.v       v6, (s1)                \n\t" \
+    "addi         s1, s1, 32*4            \n\t" \
+    "vle8.v       v7, (s2)                \n\t" \
+    "addi         s2, s2, 32*4            \n\t" \
+    "vle8.v       v8, (s3)                \n\t" \
+    "addi         s3, s3, 32*4            \n\t" \
+    "vle8.v       v9, (s4)                \n\t" \
+    "addi         s4, s4, 32*4            \n\t" \
+                                                \
+    "vand.vi      v2, v6, 15              \n\t" \
+    "vand.vi      v3, v7, 15              \n\t" \
+    "vand.vi      v4, v8, 15              \n\t" \
+    "vand.vi      v5, v9, 15              \n\t" \
+                                                \
+    "vsrl.vi      v6, v6, 4               \n\t" \
+    "vsrl.vi      v7, v7, 4               \n\t" \
+    "vsrl.vi      v8, v8, 4               \n\t" \
+    "vsrl.vi      v9, v9, 4               \n\t"
+
+// [s2|s5, s3, s4, s6]
+#define LOAD_SCALE_4x16_FP16                    \
+    "addi         s2, s5, -8              \n\t" \
+    "addi         s3, s5, 8               \n\t" \
+    "addi         s4, s5, 16              \n\t" \
+    "addi         s6, s5, 24              \n\t" \
+    "li           t1, 0xf0                \n\t" \
+    "vmv.s.x      v0, t1                  \n\t" \
+    "vsetvli      t0, zero, e16, mf4      \n\t" \
+    "vle16.v      v9, (s5)                \n\t" \
+    "vle16.v      v11, (s3)               \n\t" \
+    "vle16.v      v13, (s4)               \n\t" \
+    "vle16.v      v15, (s6)               \n\t" \
+    "vsetvli      t0, zero, e16, mf2      \n\t" \
+    "vle16.v      v9, (s2), v0.t          \n\t" \
+    "vle16.v      v11, (s5), v0.t         \n\t" \
+    "vle16.v      v13, (s3), v0.t         \n\t" \
+    "vle16.v      v15, (s4), v0.t         \n\t" \
+    "vfwcvt.f.f.v v8, v9                  \n\t" \
+    "vfwcvt.f.f.v v10, v11                \n\t" \
+    "vfwcvt.f.f.v v12, v13                \n\t" \
+    "vfwcvt.f.f.v v14, v15                \n\t" \
+    "vsetvli      t0, zero, e32, m1       \n\t" \
+    "vmv.v.v      v9, v8                  \n\t" \
+    "vmv.v.v      v11, v10                \n\t" \
+    "vmv.v.v      v13, v12                \n\t" \
+    "vmv.v.v      v15, v14                \n\t" \
+    "li           t1, 0xf0                \n\t" \
+    "vmv.s.x      v0, t1                  \n\t" \
+    "vsetvli      t0, zero, e32, mf2      \n\t" \
+    "vfmul.vf     v8, v8, f1              \n\t" \
+    "vfmul.vf     v10, v10, f1            \n\t" \
+    "vfmul.vf     v12, v12, f1            \n\t" \
+    "vfmul.vf     v14, v14, f1            \n\t" \
+    "vfmul.vf     v9, v9, f3              \n\t" \
+    "vfmul.vf     v11, v11, f3            \n\t" \
+    "vfmul.vf     v13, v13, f3            \n\t" \
+    "vfmul.vf     v15, v15, f3            \n\t" \
+    "vsetvli      t0, zero, e32, m1       \n\t" \
+    "vfmul.vf     v8, v8, f2, v0.t        \n\t" \
+    "vfmul.vf     v10, v10, f2, v0.t      \n\t" \
+    "vfmul.vf     v12, v12, f2, v0.t      \n\t" \
+    "vfmul.vf     v14, v14, f2, v0.t      \n\t" \
+    "vfmul.vf     v9, v9, f4, v0.t        \n\t" \
+    "vfmul.vf     v11, v11, f4, v0.t      \n\t" \
+    "vfmul.vf     v13, v13, f4, v0.t      \n\t" \
+    "vfmul.vf     v15, v15, f4, v0.t      \n\t"
+
+// [s2|s5, s3, s4, s6]
+#define LOAD_SCALE_4x16                         \
+    "addi         s2, s5, -16             \n\t" \
+    "addi         s3, s5, 16              \n\t" \
+    "addi         s4, s5, 32              \n\t" \
+    "addi         s6, s5, 48              \n\t" \
+    "li           t1, 0xf0                \n\t" \
+    "vmv.s.x      v0, t1                  \n\t" \
+    "vsetvli      t0, zero, e32, mf2      \n\t" \
+    "vle32.v      v8, (s5)                \n\t" \
+    "vle32.v      v10, (s3)               \n\t" \
+    "vle32.v      v12, (s4)               \n\t" \
+    "vle32.v      v14, (s6)               \n\t" \
+    "vsetvli      t0, zero, e32, m1       \n\t" \
+    "vle32.v      v8, (s2), v0.t          \n\t" \
+    "vle32.v      v10, (s5), v0.t         \n\t" \
+    "vle32.v      v12, (s3), v0.t         \n\t" \
+    "vle32.v      v14, (s4), v0.t         \n\t" \
+    "vmv.v.v      v9, v8                  \n\t" \
+    "vmv.v.v      v11, v10                \n\t" \
+    "vmv.v.v      v13, v12                \n\t" \
+    "vmv.v.v      v15, v14                \n\t" \
+    "vsetvli      t0, zero, e32, mf2      \n\t" \
+    "vfmul.vf     v8, v8, f1              \n\t" \
+    "vfmul.vf     v10, v10, f1            \n\t" \
+    "vfmul.vf     v12, v12, f1            \n\t" \
+    "vfmul.vf     v14, v14, f1            \n\t" \
+    "vfmul.vf     v9, v9, f3              \n\t" \
+    "vfmul.vf     v11, v11, f3            \n\t" \
+    "vfmul.vf     v13, v13, f3            \n\t" \
+    "vfmul.vf     v15, v15, f3            \n\t" \
+    "vsetvli      t0, zero, e32, m1       \n\t" \
+    "vfmul.vf     v8, v8, f2, v0.t        \n\t" \
+    "vfmul.vf     v10, v10, f2, v0.t      \n\t" \
+    "vfmul.vf     v12, v12, f2, v0.t      \n\t" \
+    "vfmul.vf     v14, v14, f2, v0.t      \n\t" \
+    "vfmul.vf     v9, v9, f4, v0.t        \n\t" \
+    "vfmul.vf     v11, v11, f4, v0.t      \n\t" \
+    "vfmul.vf     v13, v13, f4, v0.t      \n\t" \
+    "vfmul.vf     v15, v15, f4, v0.t      \n\t"
+
+//[s1| BIAS, s2, s3, s4]
+#define LOAD_BIAS                               \
+    "vsetvli      t0, zero, e32, mf2      \n\t" \
+    "li           t1, 0xf0                \n\t" \
+    "vmv.s.x      v0, t1                  \n\t" \
+    "addi         s1, %[BIAS], -16        \n\t" \
+    "addi         s2, %[BIAS], 16         \n\t" \
+    "addi         s3, %[BIAS], 32         \n\t" \
+    "addi         s4, %[BIAS], 48         \n\t" \
+                                                \
+    "vle32.v      v24, (%[BIAS])          \n\t" \
+    "vle32.v      v26, (s2)               \n\t" \
+    "vle32.v      v28, (s3)               \n\t" \
+    "vle32.v      v30, (s4)               \n\t" \
+    "vsetvli      t0, zero, e32, m1       \n\t" \
+    "vle32.v      v24, (s1), v0.t         \n\t" \
+    "vle32.v      v26, (%[BIAS]), v0.t    \n\t" \
+    "vle32.v      v28, (s2), v0.t         \n\t" \
+    "vle32.v      v30, (s3), v0.t         \n\t" \
+    "vmv.v.v      v25, v24                \n\t" \
+    "vmv.v.v      v27, v26                \n\t" \
+    "vmv.v.v      v29, v28                \n\t" \
+    "vmv.v.v      v31, v30                \n\t"
+
+#define SQ4BIT_KERNEL_COMP_4x16x16              \
+    "vmadot       v16, v10, v2            \n\t" \
+    "vmadot       v18, v10, v3            \n\t" \
+    "vmadot       v20, v10, v4            \n\t" \
+    "vmadot       v22, v10, v5            \n\t" \
+    "vmadot       v16, v11, v6            \n\t" \
+    "vmadot       v18, v11, v7            \n\t" \
+    "vmadot       v20, v11, v8            \n\t" \
+    "vmadot       v22, v11, v9            \n\t"
+
+#define SAVE_RESULT_4x16                        \
+    "addi         a1, %[C], 0             \n\t" \
+    "add          a2, %[C], %[LDC]        \n\t" \
+    "add          a3, a2, %[LDC]          \n\t" \
+    "add          a4, a3, %[LDC]          \n\t" \
+    "addi         a2, a2, -16             \n\t" \
+    "addi         a4, a4, -16             \n\t" \
+    "li           t1, 0xf0                \n\t" \
+    "vmv.s.x      v0, t1                  \n\t" \
+    "vsetvli      t0, zero, e32, mf2      \n\t" \
+                                                \
+    "vse32.v      v24, (a1)               \n\t" \
+    "addi         a1, a1, 16              \n\t" \
+    "vse32.v      v25, (a3)               \n\t" \
+    "addi         a3, a3, 16              \n\t" \
+                                                \
+    "vse32.v      v26, (a1)               \n\t" \
+    "addi         a1, a1, 16              \n\t" \
+    "vse32.v      v27, (a3)               \n\t" \
+    "addi         a3, a3, 16              \n\t" \
+                                                \
+    "vse32.v      v28, (a1)               \n\t" \
+    "addi         a1, a1, 16              \n\t" \
+    "vse32.v      v29, (a3)               \n\t" \
+    "addi         a3, a3, 16              \n\t" \
+                                                \
+    "vse32.v      v30, (a1)               \n\t" \
+    "vse32.v      v31, (a3)               \n\t" \
+    "vsetvli      t0, zero, e32, m1       \n\t" \
+                                                \
+    "vse32.v      v24, (a2), v0.t         \n\t" \
+    "addi         a2, a2, 16              \n\t" \
+    "vse32.v      v25, (a4), v0.t         \n\t" \
+    "addi         a4, a4, 16              \n\t" \
+                                                \
+    "vse32.v      v26, (a2), v0.t         \n\t" \
+    "addi         a2, a2, 16              \n\t" \
+    "vse32.v      v27, (a4), v0.t         \n\t" \
+    "addi         a4, a4, 16              \n\t" \
+                                                \
+    "vse32.v      v28, (a2), v0.t         \n\t" \
+    "addi         a2, a2, 16              \n\t" \
+    "vse32.v      v29, (a4), v0.t         \n\t" \
+    "addi         a4, a4, 16              \n\t" \
+                                                \
+    "vse32.v      v30, (a2), v0.t         \n\t" \
+    "vse32.v      v31, (a4), v0.t         \n\t"
+
+#define SQ4BIT_KERNEL_LOAD_ZP_16X1_v2           \
+    "vsetvli      t0, zero, e8, mf2       \n\t" \
+    "vle8.v       v11, (s6)               \n\t" \
+    "vsetvli      t0, zero, e8, m1        \n\t" \
+    "vrgather.vv  v12, v11, v1            \n\t" \
+    "vadd.vi      v1, v1, 4               \n\t" \
+    "vrgather.vv  v13, v11, v1            \n\t" \
+    "vadd.vi      v1, v1, 4               \n\t" \
+    "vrgather.vv  v14, v11, v1            \n\t" \
+    "vadd.vi      v1, v1, 4               \n\t" \
+    "vrgather.vv  v15, v11, v1            \n\t" \
+    "vadd.vi      v1, v1, -12             \n\t"
+
+template <bool HasZeroPoint>
+void SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(size_t            BlkLen,
+                                                const std::byte * QuantA,
+                                                const std::byte * QuantBData,
+                                                const float *     QuantBScale,
+                                                const std::byte * QuantBZeroPoint,
+                                                float *           C,
+                                                size_t            CountN,
+                                                size_t            BlockCountK,
+                                                const float *     Bias,
+                                                const size_t      ldc) {
+    GGML_UNUSED(QuantBScale);
+    GGML_UNUSED(QuantBZeroPoint);
+    size_t       LDC   = ldc * sizeof(float);
+    const size_t INNER = BlkLen / 16;
+    float        tmp[4 * 16];
+
+    if constexpr (HasZeroPoint) {
+        for (size_t n = 0; n < CountN; n += 16) {
+            size_t      NBLKS         = (CountN - n) > 16 ? 16 : CountN - n;
+            std::byte * QuantBDataPtr = (std::byte *) QuantBData +           //
+                                        n * BlockCountK * BlkLen / 2 +       // b data
+                                        n * BlockCountK * sizeof(uint8_t) +  // zp
+                                        n * BlockCountK * sizeof(_Float16);    // scale
+            float * CPtr = C + n;
+            if (NBLKS < 16) {
+                CPtr = tmp;
+                LDC  = 16 * sizeof(float);
+            }
+            if (Bias != nullptr) {
+                const float * bias = Bias + n;
+                if (NBLKS < 16) {
+                    __asm__ volatile(
+                        "vsetvli        t0, %[N], e32, m2     \n\t"
+                        "vle32.v        v0, (%[SRC])          \n\t"
+                        "vse32.v        v0, (%[DST])          \n\t"
+                        :
+                        : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS)
+                        : "cc", "t0");
+                    bias = tmp;
+                }
+                __asm__ volatile(LOAD_BIAS
+
+                                 "addi               t3, %[BlockCountK], 0       \n\t"
+
+                                 "vsetvli            t0, zero, e8, m1            \n\t"
+                                 "li                 s1, 24                      \n\t"
+                                 "vmv.v.i            v1, 3                       \n\t"
+                                 "vsetvli            t0, s1, e8, m1              \n\t"
+                                 "vmv.v.i            v1, 2                       \n\t"
+                                 "vsetvli            t0, zero, e8, mf2           \n\t"
+                                 "vmv.v.i            v1, 1                       \n\t"
+                                 "vsetvli            t0, zero, e8, mf4           \n\t"
+                                 "vmv.v.i            v1, 0                       \n\t"
+
+                                 "addi               a1, %[A], 0                 \n\t"
+                                 "addi               s1, %[B], 0                 \n\t"
+
+                                 "BLOCK_COUNTK_LOOP%=:                           \n\t"
+                                 // scale offset
+                                 "addi               s5, s1, 0                   \n\t"
+                                 // zp offset
+                                 "addi               s6, s1, 32                  \n\t"
+                                 "addi               s1, s6, 16                  \n\t"
+                                 "addi               s2, s1, 32                  \n\t"
+                                 "addi               s3, s1, 32*2                \n\t"
+                                 "addi               s4, s1, 32*3                \n\t"
+
+                                 "vsetvli            t0, zero, e32, m8           \n\t"
+                                 "vxor.vv            v16, v16, v16               \n\t"
+                                 // load a scale
+                                 "flw                f1, (a1)                    \n\t"
+                                 "flw                f2, 4(a1)                   \n\t"
+                                 "flw                f3, 8(a1)                   \n\t"
+                                 "flw                f4, 12(a1)                  \n\t"
+                                 "addi               a1, a1, 16                  \n\t"
+                                 "addi               t2, %[INNER], 0             \n\t"
+
+                                 SQ4BIT_KERNEL_LOAD_ZP_16X1_v2
+
+                                 "BLOCK_INNER_LOOP%=:                            \n\t"
+
+                                 LOAD_B_16x8x2
+
+                                 "vle8.v             v10, (a1)                   \n\t"
+                                 "addi               a1, a1, 32                  \n\t"
+                                 "vle8.v             v11, (a1)                   \n\t"
+                                 "addi               a1, a1, 32                  \n\t"
+                                 "vsub.vv            v2, v2, v12                 \n\t"
+                                 "vsub.vv            v6, v6, v12                 \n\t"
+                                 "vsub.vv            v3, v3, v13                 \n\t"
+                                 "vsub.vv            v7, v7, v13                 \n\t"
+                                 "vsub.vv            v4, v4, v14                 \n\t"
+                                 "vsub.vv            v8, v8, v14                 \n\t"
+                                 "vsub.vv            v5, v5, v15                 \n\t"
+                                 "vsub.vv            v9, v9, v15                 \n\t"
+
+                                 SQ4BIT_KERNEL_COMP_4x16x16
+
+                                 "addi               t2, t2, -1                  \n\t"
+                                 "bnez               t2, BLOCK_INNER_LOOP%=      \n\t"
+
+                                 LOAD_SCALE_4x16_FP16
+
+                                 "vsetvli            t0, zero, e32, m8           \n\t"
+                                 "vfcvt.f.x.v        v16, v16                    \n\t"
+                                 "vfmacc.vv          v24, v16, v8                \n\t"
+                                 "addi               t3, t3, -1                  \n\t"
+                                 "bnez               t3, BLOCK_COUNTK_LOOP%=     \n\t"
+
+                                 "RESULT_SAVE%=:                                 \n\t"
+
+                                 SAVE_RESULT_4x16
+
+                                 :
+                                 : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
+                                   [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias)
+                                 : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1",
+                                   "s2", "s3", "s4", "s5", "s6");
+
+            } else {
+                __asm__ volatile(
+                    "vsetvli            t0, zero, e32, m8           \n\t"
+                    "vxor.vv            v24, v24, v24               \n\t"
+                    "addi               t3, %[BlockCountK], 0       \n\t"
+                    "vsetvli            t0, zero, e8, m1            \n\t"
+                    "li                 s1, 24                      \n\t"
+                    "vmv.v.i            v1, 3                       \n\t"
+                    "vsetvli            t0, s1, e8, m1              \n\t"
+                    "vmv.v.i            v1, 2                       \n\t"
+                    "vsetvli            t0, zero, e8, mf2           \n\t"
+                    "vmv.v.i            v1, 1                       \n\t"
+                    "vsetvli            t0, zero, e8, mf4           \n\t"
+                    "vmv.v.i            v1, 0                       \n\t"
+                    "addi               a1, %[A], 0                 \n\t"
+                    "addi               s1, %[B], 0                 \n\t"
+                    "BLOCK_COUNTK_LOOP%=:                           \n\t"
+                    // scale offset
+                    "addi               s5, s1, 0                   \n\t"
+                    // zp offset
+                    "addi               s6, s1, 32                  \n\t"
+                    "addi               s1, s6, 16                  \n\t"
+                    "addi               s2, s1, 32                  \n\t"
+                    "addi               s3, s1, 32*2                \n\t"
+                    "addi               s4, s1, 32*3                \n\t"
+
+                    "vsetvli            t0, zero, e32, m8           \n\t"
+                    "vxor.vv            v16, v16, v16               \n\t"
+                    // load a scale
+                    "flw                f1, (a1)                    \n\t"
+                    "flw                f2, 4(a1)                   \n\t"
+                    "flw                f3, 8(a1)                   \n\t"
+                    "flw                f4, 12(a1)                  \n\t"
+                    "addi               a1, a1, 16                  \n\t"
+                    "addi               t2, %[INNER], 0             \n\t"
+
+                    SQ4BIT_KERNEL_LOAD_ZP_16X1_v2
+
+                    "BLOCK_INNER_LOOP%=:                            \n\t"
+
+                    LOAD_B_16x8x2
+
+                    "vle8.v             v10, (a1)                   \n\t"
+                    "addi               a1, a1, 32                  \n\t"
+                    "vle8.v             v11, (a1)                   \n\t"
+                    "addi               a1, a1, 32                  \n\t"
+                    "vsub.vv            v2, v2, v12                 \n\t"
+                    "vsub.vv            v6, v6, v12                 \n\t"
+                    "vsub.vv            v3, v3, v13                 \n\t"
+                    "vsub.vv            v7, v7, v13                 \n\t"
+                    "vsub.vv            v4, v4, v14                 \n\t"
+                    "vsub.vv            v8, v8, v14                 \n\t"
+                    "vsub.vv            v5, v5, v15                 \n\t"
+                    "vsub.vv            v9, v9, v15                 \n\t"
+
+                    SQ4BIT_KERNEL_COMP_4x16x16
+
+                    "addi               t2, t2, -1                  \n\t"
+                    "bnez               t2, BLOCK_INNER_LOOP%=      \n\t"
+
+                    LOAD_SCALE_4x16_FP16
+
+                    "vsetvli            t0, zero, e32, m8           \n\t"
+                    "vfcvt.f.x.v        v16, v16                    \n\t"
+                    "vfmacc.vv          v24, v16, v8                \n\t"
+                    "addi               t3, t3, -1                  \n\t"
+                    "bnez               t3, BLOCK_COUNTK_LOOP%=     \n\t"
+
+                    "RESULT_SAVE%=:                                 \n\t"
+
+                    SAVE_RESULT_4x16
+
+                    :
+                    : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
+                      [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
+                    : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3",
+                      "s4", "s5", "s6");
+            }
+        }
+    } else {
+        for (size_t n = 0; n < CountN; n += 16) {
+            size_t      NBLKS         = (CountN - n) > 16 ? 16 : CountN - n;
+            std::byte * QuantBDataPtr = (std::byte *) QuantBData +         //
+                                        n * BlockCountK * BlkLen / 2 +     // b data
+                                        n * BlockCountK * sizeof(_Float16);  // scale
+            float * CPtr = C + n;
+            if (NBLKS < 16) {
+                CPtr = tmp;
+                LDC  = 16 * sizeof(float);
+            }
+            if (Bias != nullptr) {
+                const float * bias = Bias + n;
+                if (NBLKS < 16) {
+                    __asm__ volatile(
+                        "vsetvli        t0, %[N], e32, m2     \n\t"
+                        "vle32.v        v0, (%[SRC])          \n\t"
+                        "vse32.v        v0, (%[DST])          \n\t"
+                        :
+                        : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS)
+                        : "cc", "t0");
+                    bias = tmp;
+                }
+                __asm__ volatile(LOAD_BIAS
+
+                                 "addi               t3, %[BlockCountK], 0       \n\t"
+                                 "addi               a1, %[A], 0                 \n\t"
+                                 "addi               s1, %[B], 0                 \n\t"
+                                 "BLOCK_COUNTK_LOOP%=:                           \n\t"
+                                 "addi               s5, s1, 0                   \n\t"
+                                 "addi               s1, s5, 32                  \n\t"
+                                 "addi               s2, s1, 32                  \n\t"
+                                 "addi               s3, s1, 32*2                \n\t"
+                                 "addi               s4, s1, 32*3                \n\t"
+                                 "vsetvli            t0, zero, e32, m8           \n\t"
+                                 "vxor.vv            v16, v16, v16               \n\t"
+                                 // load a scale
+                                 "flw                f1, (a1)                    \n\t"
+                                 "flw                f2, 4(a1)                   \n\t"
+                                 "flw                f3, 8(a1)                   \n\t"
+                                 "flw                f4, 12(a1)                  \n\t"
+                                 "addi               a1, a1, 16                  \n\t"
+                                 "addi               t2, %[INNER], 0             \n\t"
+                                 "BLOCK_INNER_LOOP%=:                            \n\t"
+
+                                 LOAD_B_16x8x2
+
+                                 "vsetvli            t0, zero, e8, m1            \n\t"
+                                 "vle8.v             v10, (a1)                   \n\t"
+                                 "addi               a1, a1, 32                  \n\t"
+                                 "vle8.v             v11, (a1)                   \n\t"
+                                 "addi               a1, a1, 32                  \n\t"
+                                 "vadd.vi            v2, v2, -8                  \n\t"
+                                 "vadd.vi            v3, v3, -8                  \n\t"
+                                 "vadd.vi            v4, v4, -8                  \n\t"
+                                 "vadd.vi            v5, v5, -8                  \n\t"
+                                 "vadd.vi            v6, v6, -8                  \n\t"
+                                 "vadd.vi            v7, v7, -8                  \n\t"
+                                 "vadd.vi            v8, v8, -8                  \n\t"
+                                 "vadd.vi            v9, v9, -8                  \n\t"
+
+                                 SQ4BIT_KERNEL_COMP_4x16x16
+
+                                 "addi               t2, t2, -1                  \n\t"
+                                 "bnez               t2, BLOCK_INNER_LOOP%=      \n\t"
+
+                                 LOAD_SCALE_4x16_FP16
+
+                                 "vsetvli            t0, zero, e32, m8           \n\t"
+                                 "vfcvt.f.x.v        v16, v16                    \n\t"
+                                 "vfmacc.vv          v24, v16, v8                \n\t"
+                                 "addi               t3, t3, -1                  \n\t"
+                                 "bnez               t3, BLOCK_COUNTK_LOOP%=     \n\t"
+                                 "RESULT_SAVE%=:                                 \n\t"
+
+                                 SAVE_RESULT_4x16
+
+                                 :
+                                 : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
+                                   [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias)
+                                 : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1",
+                                   "s2", "s3", "s4", "s5", "s6");
+
+            } else {
+                __asm__ volatile(
+                    "vsetvli            t0, zero, e32, m8           \n\t"
+                    "vxor.vv            v24, v24, v24               \n\t"
+                    "addi               t3, %[BlockCountK], 0       \n\t"
+                    "addi               a1, %[A], 0                 \n\t"
+                    "addi               s1, %[B], 0                 \n\t"
+                    "BLOCK_COUNTK_LOOP%=:                           \n\t"
+                    "addi               s5, s1, 0                   \n\t"
+                    "addi               s1, s5, 32                  \n\t"
+                    "addi               s2, s1, 32                  \n\t"
+                    "addi               s3, s1, 32*2                \n\t"
+                    "addi               s4, s1, 32*3                \n\t"
+                    "vsetvli            t0, zero, e32, m8           \n\t"
+                    "vxor.vv            v16, v16, v16               \n\t"
+                    // load a scale
+                    "flw                f1, (a1)                    \n\t"
+                    "flw                f2, 4(a1)                   \n\t"
+                    "flw                f3, 8(a1)                   \n\t"
+                    "flw                f4, 12(a1)                  \n\t"
+                    "addi               a1, a1, 16                  \n\t"
+                    "addi               t2, %[INNER], 0             \n\t"
+                    "BLOCK_INNER_LOOP%=:                            \n\t"
+
+                    LOAD_B_16x8x2
+
+                    "vsetvli            t0, zero, e8, m1            \n\t"
+                    "vle8.v             v10, (a1)                   \n\t"
+                    "addi               a1, a1, 32                  \n\t"
+                    "vle8.v             v11, (a1)                   \n\t"
+                    "addi               a1, a1, 32                  \n\t"
+                    "vadd.vi            v2, v2, -8                  \n\t"
+                    "vadd.vi            v3, v3, -8                  \n\t"
+                    "vadd.vi            v4, v4, -8                  \n\t"
+                    "vadd.vi            v5, v5, -8                  \n\t"
+                    "vadd.vi            v6, v6, -8                  \n\t"
+                    "vadd.vi            v7, v7, -8                  \n\t"
+                    "vadd.vi            v8, v8, -8                  \n\t"
+                    "vadd.vi            v9, v9, -8                  \n\t"
+
+                    SQ4BIT_KERNEL_COMP_4x16x16
+
+                    "addi               t2, t2, -1                  \n\t"
+                    "bnez               t2, BLOCK_INNER_LOOP%=      \n\t"
+
+                    LOAD_SCALE_4x16_FP16
+
+                    "vsetvli            t0, zero, e32, m8           \n\t"
+                    "vfcvt.f.x.v        v16, v16                    \n\t"
+                    "vfmacc.vv          v24, v16, v8                \n\t"
+                    "addi               t3, t3, -1                  \n\t"
+                    "bnez               t3, BLOCK_COUNTK_LOOP%=     \n\t"
+                    "RESULT_SAVE%=:                                 \n\t"
+
+                    SAVE_RESULT_4x16
+
+                    :
+                    : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
+                      [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
+                    : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3",
+                      "s4", "s5", "s6");
+            }
+        }
+    }
+    if (CountN % 16 != 0) {
+        // stroe output from tmp to C when NBLKS less than 16.
+        float *      CPtr = C + CountN / 16 * 16;
+        const size_t N    = CountN % 16;
+        LDC               = ldc * sizeof(float);
+        __asm__ volatile(
+            "vsetvli            t0, %[N], e32, m2       \n\t"
+            "vle32.v            v0, (%[SRC])            \n\t"
+            "addi               s2, %[SRC], 64          \n\t"
+            "addi               s3, %[SRC], 64*2        \n\t"
+            "addi               s4, %[SRC], 64*3        \n\t"
+            "vle32.v            v2, (s2)                \n\t"
+            "vle32.v            v4, (s3)                \n\t"
+            "vle32.v            v6, (s4)                \n\t"
+            "add                t2, %[DST], %[LDC]      \n\t"
+            "add                t3, t2, %[LDC]          \n\t"
+            "add                t4, t3, %[LDC]          \n\t"
+            "vse32.v            v0, (%[DST])            \n\t"
+            "vse32.v            v2, (t2)                \n\t"
+            "vse32.v            v4, (t3)                \n\t"
+            "vse32.v            v6, (t4)                \n\t"
+            :
+            : [N] "r"(N), [SRC] "r"(tmp), [DST] "r"(CPtr), [LDC] "r"(LDC)
+            : "cc", "t0", "t2", "t3", "t4", "s2", "s3", "s4");
+    }
+}
+
+template <bool HasZeroPoint>
+void SQ4BitGemmM4Kernel_CompInt8_Impl(size_t            BlkLen,
+                                      const std::byte * QuantA,
+                                      const std::byte * QuantBData,
+                                      const float *     QuantBScale,
+                                      const std::byte * QuantBZeroPoint,
+                                      float *           C,
+                                      size_t            CountN,
+                                      size_t            BlockCountK,
+                                      const float *     Bias,
+                                      const size_t      ldc) {
+    GGML_UNUSED(QuantBScale);
+    GGML_UNUSED(QuantBZeroPoint);
+    size_t       LDC   = ldc * sizeof(float);
+    const size_t INNER = BlkLen / 16;
+    float        tmp[4 * 16];
+
+    if constexpr (HasZeroPoint) {
+        for (size_t n = 0; n < CountN; n += 16) {
+            size_t      NBLKS         = (CountN - n) > 16 ? 16 : CountN - n;
+            std::byte * QuantBDataPtr = (std::byte *) QuantBData +           //
+                                        n * BlockCountK * BlkLen / 2 +       // b data
+                                        n * BlockCountK * sizeof(uint8_t) +  // zp
+                                        n * BlockCountK * sizeof(float);     // scale
+            float * CPtr = C + n;
+            if (NBLKS < 16) {
+                CPtr = tmp;
+                LDC  = 16 * sizeof(float);
+            }
+            if (Bias != nullptr) {
+                const float * bias = Bias + n;
+                if (NBLKS < 16) {
+                    __asm__ volatile(
+                        "vsetvli        t0, %[N], e32, m2     \n\t"
+                        "vle32.v        v0, (%[SRC])          \n\t"
+                        "vse32.v        v0, (%[DST])          \n\t"
+                        :
+                        : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS)
+                        : "cc", "t0");
+                    bias = tmp;
+                }
+
+                __asm__ volatile(LOAD_BIAS
+                                 "addi               t3, %[BlockCountK], 0       \n\t"
+                                 "vsetvli            t0, zero, e8, m1            \n\t"
+                                 "li                 s1, 24                      \n\t"
+                                 "vmv.v.i            v1, 3                       \n\t"
+                                 "vsetvli            t0, s1, e8, m1              \n\t"
+                                 "vmv.v.i            v1, 2                       \n\t"
+                                 "vsetvli            t0, zero, e8, mf2           \n\t"
+                                 "vmv.v.i            v1, 1                       \n\t"
+                                 "vsetvli            t0, zero, e8, mf4           \n\t"
+                                 "vmv.v.i            v1, 0                       \n\t"
+                                 "addi               a1, %[A], 0                 \n\t"
+                                 "addi               s1, %[B], 0                 \n\t"
+                                 "BLOCK_COUNTK_LOOP%=:                           \n\t"
+                                 // scale offset
+                                 "addi               s5, s1, 0                   \n\t"
+                                 // zp offset
+                                 "addi               s6, s1, 64                  \n\t"
+                                 "addi               s1, s6, 16                  \n\t"
+                                 "addi               s2, s1, 32                  \n\t"
+                                 "addi               s3, s1, 32*2                \n\t"
+                                 "addi               s4, s1, 32*3                \n\t"
+                                 "vsetvli            t0, zero, e32, m8           \n\t"
+                                 "vxor.vv            v16, v16, v16               \n\t"
+                                 // load a scale
+                                 "flw                f1, (a1)                    \n\t"
+                                 "flw                f2, 4(a1)                   \n\t"
+                                 "flw                f3, 8(a1)                   \n\t"
+                                 "flw                f4, 12(a1)                  \n\t"
+                                 "addi               a1, a1, 16                  \n\t"
+                                 "addi               t2, %[INNER], 0             \n\t"
+
+                                 SQ4BIT_KERNEL_LOAD_ZP_16X1_v2
+
+                                 "BLOCK_INNER_LOOP%=:                            \n\t"
+
+                                 LOAD_B_16x8x2
+
+                                 "vle8.v             v10, (a1)                   \n\t"
+                                 "addi               a1, a1, 32                  \n\t"
+                                 "vle8.v             v11, (a1)                   \n\t"
+                                 "addi               a1, a1, 32                  \n\t"
+                                 "vsub.vv            v2, v2, v12                 \n\t"
+                                 "vsub.vv            v6, v6, v12                 \n\t"
+                                 "vsub.vv            v3, v3, v13                 \n\t"
+                                 "vsub.vv            v7, v7, v13                 \n\t"
+                                 "vsub.vv            v4, v4, v14                 \n\t"
+                                 "vsub.vv            v8, v8, v14                 \n\t"
+                                 "vsub.vv            v5, v5, v15                 \n\t"
+                                 "vsub.vv            v9, v9, v15                 \n\t"
+
+                                 SQ4BIT_KERNEL_COMP_4x16x16
+
+                                 "addi               t2, t2, -1                  \n\t"
+                                 "bnez               t2, BLOCK_INNER_LOOP%=      \n\t"
+
+                                 LOAD_SCALE_4x16
+
+                                 "vsetvli            t0, zero, e32, m8           \n\t"
+                                 "vfcvt.f.x.v        v16, v16                    \n\t"
+                                 "vfmacc.vv          v24, v16, v8                \n\t"
+                                 "addi               t3, t3, -1                  \n\t"
+                                 "bnez               t3, BLOCK_COUNTK_LOOP%=     \n\t"
+
+                                 "RESULT_SAVE%=:                                 \n\t"
+
+                                 SAVE_RESULT_4x16
+
+                                 :
+                                 : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
+                                   [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias)
+                                 : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1",
+                                   "s2", "s3", "s4", "s5", "s6");
+
+            } else {
+                __asm__ volatile(
+                    "vsetvli            t0, zero, e32, m8           \n\t"
+                    "vxor.vv            v24, v24, v24               \n\t"
+                    "addi               t3, %[BlockCountK], 0       \n\t"
+                    "vsetvli            t0, zero, e8, m1            \n\t"
+                    "li                 s1, 24                      \n\t"
+                    "vmv.v.i            v1, 3                       \n\t"
+                    "vsetvli            t0, s1, e8, m1              \n\t"
+                    "vmv.v.i            v1, 2                       \n\t"
+                    "vsetvli            t0, zero, e8, mf2           \n\t"
+                    "vmv.v.i            v1, 1                       \n\t"
+                    "vsetvli            t0, zero, e8, mf4           \n\t"
+                    "vmv.v.i            v1, 0                       \n\t"
+                    "addi               a1, %[A], 0                 \n\t"
+                    "addi               s1, %[B], 0                 \n\t"
+                    "BLOCK_COUNTK_LOOP%=:                           \n\t"
+                    // scale offset
+                    "addi               s5, s1, 0                   \n\t"
+                    // zp offset
+                    "addi               s6, s1, 64                  \n\t"
+                    "addi               s1, s6, 16                  \n\t"
+                    "addi               s2, s1, 32                  \n\t"
+                    "addi               s3, s1, 32*2                \n\t"
+                    "addi               s4, s1, 32*3                \n\t"
+                    "vsetvli            t0, zero, e32, m8           \n\t"
+                    "vxor.vv            v16, v16, v16               \n\t"
+                    // load a scale
+                    // load a scale
+                    "flw                f1, (a1)                    \n\t"
+                    "flw                f2, 4(a1)                   \n\t"
+                    "flw                f3, 8(a1)                   \n\t"
+                    "flw                f4, 12(a1)                  \n\t"
+                    "addi               a1, a1, 16                  \n\t"
+                    "addi               t2, %[INNER], 0             \n\t"
+
+                    SQ4BIT_KERNEL_LOAD_ZP_16X1_v2
+
+                    "BLOCK_INNER_LOOP%=:                            \n\t"
+
+                    LOAD_B_16x8x2
+
+                    "vle8.v             v10, (a1)                   \n\t"
+                    "addi               a1, a1, 32                  \n\t"
+                    "vle8.v             v11, (a1)                   \n\t"
+                    "addi               a1, a1, 32                  \n\t"
+                    "vsub.vv            v2, v2, v12                 \n\t"
+                    "vsub.vv            v6, v6, v12                 \n\t"
+                    "vsub.vv            v3, v3, v13                 \n\t"
+                    "vsub.vv            v7, v7, v13                 \n\t"
+                    "vsub.vv            v4, v4, v14                 \n\t"
+                    "vsub.vv            v8, v8, v14                 \n\t"
+                    "vsub.vv            v5, v5, v15                 \n\t"
+                    "vsub.vv            v9, v9, v15                 \n\t"
+
+                    SQ4BIT_KERNEL_COMP_4x16x16
+
+                    "addi               t2, t2, -1                  \n\t"
+                    "bnez               t2, BLOCK_INNER_LOOP%=      \n\t"
+
+                    LOAD_SCALE_4x16
+
+                    "vsetvli            t0, zero, e32, m8           \n\t"
+                    "vfcvt.f.x.v        v16, v16                    \n\t"
+                    "vfmacc.vv          v24, v16, v8                \n\t"
+                    "addi               t3, t3, -1                  \n\t"
+                    "bnez               t3, BLOCK_COUNTK_LOOP%=     \n\t"
+
+                    "RESULT_SAVE%=:                                 \n\t"
+
+                    SAVE_RESULT_4x16
+
+                    :
+                    : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
+                      [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
+                    : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3",
+                      "s4", "s5", "s6");
+            }
+        }
+    } else {
+        for (size_t n = 0; n < CountN; n += 16) {
+            size_t      NBLKS         = (CountN - n) > 16 ? 16 : CountN - n;
+            std::byte * QuantBDataPtr = (std::byte *) QuantBData +        //
+                                        n * BlockCountK * BlkLen / 2 +    // b data
+                                        n * BlockCountK * sizeof(float);  // scale
+            float * CPtr = C + n;
+            if (NBLKS < 16) {
+                CPtr = tmp;
+                LDC  = 16 * sizeof(float);
+            }
+            if (Bias != nullptr) {
+                const float * bias = Bias + n;
+                if (NBLKS < 16) {
+                    __asm__ volatile(
+                        "vsetvli        t0, %[N], e32, m2     \n\t"
+                        "vle32.v        v0, (%[SRC])          \n\t"
+                        "vse32.v        v0, (%[DST])          \n\t"
+                        :
+                        : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS)
+                        : "cc", "t0");
+                    bias = tmp;
+                }
+                __asm__ volatile(LOAD_BIAS
+                                 "addi               t3, %[BlockCountK], 0       \n\t"
+                                 "addi               a1, %[A], 0                 \n\t"
+                                 "addi               s1, %[B], 0                 \n\t"
+                                 "BLOCK_COUNTK_LOOP%=:                           \n\t"
+                                 "addi               s5, s1, 0                   \n\t"
+                                 "addi               s1, s5, 64                  \n\t"
+                                 "addi               s2, s1, 32                  \n\t"
+                                 "addi               s3, s1, 32*2                \n\t"
+                                 "addi               s4, s1, 32*3                \n\t"
+                                 "vsetvli            t0, zero, e32, m8           \n\t"
+                                 "vxor.vv            v16, v16, v16               \n\t"
+                                 // load a scale
+                                 "flw                f1, (a1)                    \n\t"
+                                 "flw                f2, 4(a1)                   \n\t"
+                                 "flw                f3, 8(a1)                   \n\t"
+                                 "flw                f4, 12(a1)                  \n\t"
+                                 "addi               a1, a1, 16                  \n\t"
+                                 "addi               t2, %[INNER], 0             \n\t"
+                                 "BLOCK_INNER_LOOP%=:                            \n\t"
+
+                                 LOAD_B_16x8x2
+
+                                 "vsetvli            t0, zero, e8, m1            \n\t"
+                                 "vle8.v             v10, (a1)                   \n\t"
+                                 "addi               a1, a1, 32                  \n\t"
+                                 "vle8.v             v11, (a1)                   \n\t"
+                                 "addi               a1, a1, 32                  \n\t"
+                                 "vadd.vi            v2, v2, -8                  \n\t"
+                                 "vadd.vi            v3, v3, -8                  \n\t"
+                                 "vadd.vi            v4, v4, -8                  \n\t"
+                                 "vadd.vi            v5, v5, -8                  \n\t"
+                                 "vadd.vi            v6, v6, -8                  \n\t"
+                                 "vadd.vi            v7, v7, -8                  \n\t"
+                                 "vadd.vi            v8, v8, -8                  \n\t"
+                                 "vadd.vi            v9, v9, -8                  \n\t"
+
+                                 SQ4BIT_KERNEL_COMP_4x16x16
+
+                                 "addi               t2, t2, -1                  \n\t"
+                                 "bnez               t2, BLOCK_INNER_LOOP%=      \n\t"
+
+                                 LOAD_SCALE_4x16
+
+                                 "vsetvli            t0, zero, e32, m8           \n\t"
+                                 "vfcvt.f.x.v        v16, v16                    \n\t"
+                                 "vfmacc.vv          v24, v16, v8                \n\t"
+                                 "addi               t3, t3, -1                  \n\t"
+                                 "bnez               t3, BLOCK_COUNTK_LOOP%=     \n\t"
+
+                                 "RESULT_SAVE%=:                                 \n\t"
+
+                                 SAVE_RESULT_4x16
+
+                                 :
+                                 : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
+                                   [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias)
+                                 : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1",
+                                   "s2", "s3", "s4", "s5", "s6");
+
+            } else {
+                __asm__ volatile(
+                    "vsetvli            t0, zero, e32, m8           \n\t"
+                    "vxor.vv            v24, v24, v24               \n\t"
+                    "addi               t3, %[BlockCountK], 0       \n\t"
+                    "addi               a1, %[A], 0                 \n\t"
+                    "addi               s1, %[B], 0                 \n\t"
+                    "BLOCK_COUNTK_LOOP%=:                           \n\t"
+                    "addi               s5, s1, 0                   \n\t"
+                    "addi               s1, s5, 64                  \n\t"
+                    "addi               s2, s1, 32                  \n\t"
+                    "addi               s3, s1, 32*2                \n\t"
+                    "addi               s4, s1, 32*3                \n\t"
+                    "vsetvli            t0, zero, e32, m8           \n\t"
+                    "vxor.vv            v16, v16, v16               \n\t"
+                    // load a scale
+                    "flw                f1, (a1)                    \n\t"
+                    "flw                f2, 4(a1)                   \n\t"
+                    "flw                f3, 8(a1)                   \n\t"
+                    "flw                f4, 12(a1)                  \n\t"
+                    "addi               a1, a1, 16                  \n\t"
+                    "addi               t2, %[INNER], 0             \n\t"
+                    "BLOCK_INNER_LOOP%=:                            \n\t"
+
+                    LOAD_B_16x8x2
+
+                    "vsetvli            t0, zero, e8, m1            \n\t"
+                    "vle8.v             v10, (a1)                   \n\t"
+
+                    "addi               a1, a1, 32                  \n\t"
+                    "vle8.v             v11, (a1)                   \n\t"
+                    "addi               a1, a1, 32                  \n\t"
+                    "vadd.vi            v2, v2, -8                  \n\t"
+                    "vadd.vi            v3, v3, -8                  \n\t"
+                    "vadd.vi            v4, v4, -8                  \n\t"
+                    "vadd.vi            v5, v5, -8                  \n\t"
+                    "vadd.vi            v6, v6, -8                  \n\t"
+                    "vadd.vi            v7, v7, -8                  \n\t"
+                    "vadd.vi            v8, v8, -8                  \n\t"
+                    "vadd.vi            v9, v9, -8                  \n\t"
+
+                    SQ4BIT_KERNEL_COMP_4x16x16
+
+                    "addi               t2, t2, -1                  \n\t"
+                    "bnez               t2, BLOCK_INNER_LOOP%=      \n\t"
+
+                    LOAD_SCALE_4x16
+
+                    "vsetvli            t0, zero, e32, m8           \n\t"
+                    "vfcvt.f.x.v        v16, v16                    \n\t"
+                    "vfmacc.vv          v24, v16, v8                \n\t"
+                    "addi               t3, t3, -1                  \n\t"
+                    "bnez               t3, BLOCK_COUNTK_LOOP%=     \n\t"
+
+                    "RESULT_SAVE%=:                                 \n\t"
+
+                    SAVE_RESULT_4x16
+
+                    :
+                    : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
+                      [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
+                    : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3",
+                      "s4", "s5", "s6");
+            }
+        }
+    }
+    if (CountN % 16 != 0) {
+        // stroe output from tmp to C when NBLKS less than 16.
+        float *      CPtr = C + CountN / 16 * 16;
+        const size_t N    = CountN % 16;
+        LDC               = ldc * sizeof(float);
+        __asm__ volatile(
+            "vsetvli            t0, %[N], e32, m2       \n\t"
+            "vle32.v            v0, (%[SRC])            \n\t"
+            "addi               s2, %[SRC], 64          \n\t"
+            "addi               s3, %[SRC], 64*2        \n\t"
+            "addi               s4, %[SRC], 64*3        \n\t"
+            "vle32.v            v2, (s2)                \n\t"
+            "vle32.v            v4, (s3)                \n\t"
+            "vle32.v            v6, (s4)                \n\t"
+            "add                t2, %[DST], %[LDC]      \n\t"
+            "add                t3, t2, %[LDC]          \n\t"
+            "add                t4, t3, %[LDC]          \n\t"
+            "vse32.v            v0, (%[DST])            \n\t"
+            "vse32.v            v2, (t2)                \n\t"
+            "vse32.v            v4, (t3)                \n\t"
+            "vse32.v            v6, (t4)                \n\t"
+            :
+            : [N] "r"(N), [SRC] "r"(tmp), [DST] "r"(CPtr), [LDC] "r"(LDC)
+            : "cc", "t0", "t2", "t3", "t4", "s2", "s3", "s4");
+    }
+}
+
+template <bool HasZeroPoint>
+void SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(size_t            BlkLen,
+                                                const std::byte * QuantA,
+                                                const std::byte * QuantBData,
+                                                const float *     QuantBScale,
+                                                const std::byte * QuantBZeroPoint,
+                                                float *           C,
+                                                size_t            CountN,
+                                                size_t            BlockCountK,
+                                                const float *     Bias) {
+    GGML_UNUSED(QuantBScale);
+    GGML_UNUSED(QuantBZeroPoint);
+    size_t INNER = BlkLen / 16;
+
+    if constexpr (HasZeroPoint) {
+        for (size_t n = 0; n < CountN; n += 16) {
+            size_t      nblks         = (CountN - n) > 16 ? 16 : CountN - n;
+            std::byte * QuantBDataPtr = (std::byte *) QuantBData +           //
+                                        n * BlockCountK * BlkLen / 2 +       // b data
+                                        n * BlockCountK * sizeof(uint8_t) +  // zp
+                                        n * BlockCountK * sizeof(_Float16);    // scale
+            float * CPtr = C + n;
+            size_t  cnt  = BlockCountK;
+            if (Bias != nullptr) {
+                const float * bias = Bias + n;
+                __asm__ volatile(
+                    "addi         t3, %[NBLKS], 0         \n\t"
+                    "vsetvli      t0, zero, e8, m1        \n\t"
+
+                    "vmv.v.i      v13, 3                  \n\t"
+                    "li           s1, 24                  \n\t"
+                    "vsetvli      t0, s1, e8, m1          \n\t"
+                    "vmv.v.i      v13, 2                  \n\t"
+                    "vsetvli      t0, zero, e8, mf2       \n\t"
+                    "vmv.v.i      v13, 1                  \n\t"
+                    "vsetvli      t0, zero, e8, mf4       \n\t"
+                    "vmv.v.i      v13, 0                  \n\t"
+                    "addi         s1, %[B], 0             \n\t"
+                    "addi         s2, %[B], 8             \n\t"
+                    "addi         s3, %[B], 16            \n\t"
+                    "addi         s4, %[B], 24            \n\t"
+                    // zp offset
+                    "addi         s7, %[B], 32            \n\t"
+                    // a offset
+                    "addi         s5, %[A], 0             \n\t"
+                    "addi         s6, %[A], 12            \n\t"
+
+                    "vsetvli      t0, t3, e32, mf2        \n\t"
+                    "vle32.v      v28, (%[BIAS])          \n\t"
+                    "sub          t3, t3, t0              \n\t"
+                    "addi         %[BIAS], %[BIAS], 16    \n\t"
+                    "vsetvli      t0, t3, e32, mf2        \n\t"
+                    "vle32.v      v29, (%[BIAS])          \n\t"
+                    "sub          t3, t3, t0              \n\t"
+                    "addi         %[BIAS], %[BIAS], 16    \n\t"
+                    "vsetvli      t0, t3, e32, mf2        \n\t"
+                    "vle32.v      v30, (%[BIAS])          \n\t"
+                    "sub          t3, t3, t0              \n\t"
+                    "addi         %[BIAS], %[BIAS], 16    \n\t"
+                    "vsetvli      t0, t3, e32, mf2        \n\t"
+                    "vle32.v      v31, (%[BIAS])          \n\t"
+
+                    "LOOP_K%=:                            \n\t"
+                    "vsetvli      t0, zero, e16, mf4      \n\t"
+
+                    "vle16.v      v4, (s1)                \n\t"
+                    "addi         s1, s1, 48              \n\t"
+                    "vle16.v      v5, (s2)                \n\t"
+                    "addi         s2, s2, 72              \n\t"
+                    "vle16.v      v6, (s3)                \n\t"
+                    "addi         s3, s3, 96              \n\t"
+                    "vle16.v      v7, (s4)                \n\t"
+                    "addi         s4, s4, 120             \n\t"
+                    "flw          f1, (s5)                \n\t"
+                    "addi         s5, s5, 4               \n\t"
+                    "vfwcvt.f.f.v v8, v4                  \n\t"
+                    "vfwcvt.f.f.v v9, v5                  \n\t"
+                    "vfwcvt.f.f.v v10, v6                 \n\t"
+                    "vfwcvt.f.f.v v11, v7                 \n\t"
+
+                    "vsetvli      t0, zero, e32, mf2      \n\t"
+                    "addi         t5, %[INNER], 0         \n\t"
+                    "vxor.vv      v16, v16, v16           \n\t"
+                    "vxor.vv      v18, v18, v18           \n\t"
+                    "vxor.vv      v20, v20, v20           \n\t"
+                    "vxor.vv      v22, v22, v22           \n\t"
+                    "vfmul.vf     v24, v8, f1             \n\t"
+                    "vfmul.vf     v25, v9, f1             \n\t"
+                    "vfmul.vf     v26, v10, f1            \n\t"
+                    "vfmul.vf     v27, v11, f1            \n\t"
+                    "addi         %[CNT], %[CNT], -1      \n\t"
+
+                    SQ4BIT_KERNEL_LOAD_ZP_16X1
+
+                    "LOOP_INNER%=:                        \n\t"
+
+                    SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
+
+                    "vsub.vv      v0, v0, v8              \n\t"
+                    "vsub.vv      v4, v4, v8              \n\t"
+                    "vsub.vv      v1, v1, v9              \n\t"
+                    "vsub.vv      v5, v5, v9              \n\t"
+                    "vsub.vv      v2, v2, v10             \n\t"
+                    "vsub.vv      v6, v6, v10             \n\t"
+                    "vsub.vv      v3, v3, v11             \n\t"
+                    "vsub.vv      v7, v7, v11             \n\t"
+
+                    SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
+
+                    "bnez         t5, LOOP_INNER%=        \n\t"
+                    "vsetvli      t0, zero, e32, mf2      \n\t"
+
+                    SQ4BIT_KERNEL_ACC_F16_1X4X4
+                    "addi         s7, s1, 32              \n\t"
+
+                    "bnez         %[CNT], LOOP_K%=        \n\t"
+                    "addi         t3, zero, 16            \n\t"
+                    "addi         s1, %[C], 16            \n\t"
+                    "addi         s2, %[C], 32            \n\t"
+                    "addi         s3, %[C], 48            \n\t"
+                    "blt          %[NBLKS], t3, ST_TAIL%= \n\t"
+                    "vse32.v      v28, (%[C])             \n\t"
+                    "vse32.v      v29, (s1)               \n\t"
+                    "vse32.v      v30, (s2)               \n\t"
+                    "vse32.v      v31, (s3)               \n\t"
+                    "jal          x0, END%=               \n\t"
+
+                    "ST_TAIL%=:                           \n\t"
+                    "vsetvli      t0, %[NBLKS], e32, mf2  \n\t"
+                    "sub          %[NBLKS], %[NBLKS], t0  \n\t"
+                    "vse32.v      v28, (%[C])             \n\t"
+                    "vsetvli      t0, %[NBLKS], e32, mf2  \n\t"
+                    "sub          %[NBLKS], %[NBLKS], t0  \n\t"
+                    "vse32.v      v29, (s1)               \n\t"
+                    "vsetvli      t0, %[NBLKS], e32, mf2  \n\t"
+                    "sub          %[NBLKS], %[NBLKS], t0  \n\t"
+                    "vse32.v      v30, (s2)               \n\t"
+                    "vsetvli      t0, %[NBLKS], e32, mf2  \n\t"
+                    "sub          %[NBLKS], %[NBLKS], t0  \n\t"
+                    "vse32.v      v31, (s3)               \n\t"
+                    "END%=:                               \n\t"
+
+                    : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias)
+                    : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
+                    : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7");
+            } else {
+                __asm__ volatile(
+                    "vsetvli      t0, zero, e32, m4       \n\t"
+                    "vxor.vv      v28, v28, v28           \n\t"
+
+                    "vsetvli      t0, zero, e8, m1        \n\t"
+                    "vmv.v.i      v13, 3                  \n\t"
+                    "li           s1, 24                  \n\t"
+                    "vsetvli      t0, s1, e8, m1          \n\t"
+                    "vmv.v.i      v13, 2                  \n\t"
+                    "vsetvli      t0, zero, e8, mf2       \n\t"
+                    "vmv.v.i      v13, 1                  \n\t"
+                    "vsetvli      t0, zero, e8, mf4       \n\t"
+                    "vmv.v.i      v13, 0                  \n\t"
+
+                    "addi         s1, %[B], 0             \n\t"
+                    "addi         s2, %[B], 8             \n\t"
+                    "addi         s3, %[B], 16            \n\t"
+                    "addi         s4, %[B], 24            \n\t"
+
+                    "addi         s7, %[B], 32            \n\t"
+
+                    "addi         s5, %[A], 0             \n\t"
+                    "addi         s6, %[A], 12            \n\t"
+                    "LOOP_K%=:                            \n\t"
+                    "vsetvli      t0, zero, e16, mf4      \n\t"
+                    "vle16.v      v4, (s1)                \n\t"
+                    "addi         s1, s1, 48              \n\t"
+                    "vle16.v      v5, (s2)                \n\t"
+                    "addi         s2, s2, 72              \n\t"
+                    "vle16.v      v6, (s3)                \n\t"
+                    "addi         s3, s3, 96              \n\t"
+                    "vle16.v      v7, (s4)                \n\t"
+                    "addi         s4, s4, 120             \n\t"
+                    "flw          f1, (s5)                \n\t"
+                    "addi         s5, s5, 4               \n\t"
+
+                    "vfwcvt.f.f.v v8, v4                  \n\t"
+                    "vfwcvt.f.f.v v9, v5                  \n\t"
+                    "vfwcvt.f.f.v v10, v6                 \n\t"
+                    "vfwcvt.f.f.v v11, v7                 \n\t"
+                    "vsetvli      t0, zero, e32, mf2      \n\t"
+
+                    "addi         t5, %[INNER], 0         \n\t"
+                    "vxor.vv      v16, v16, v16           \n\t"
+                    "vxor.vv      v18, v18, v18           \n\t"
+                    "vxor.vv      v20, v20, v20           \n\t"
+                    "vxor.vv      v22, v22, v22           \n\t"
+                    "vfmul.vf     v24, v8, f1             \n\t"
+                    "vfmul.vf     v25, v9, f1             \n\t"
+                    "vfmul.vf     v26, v10, f1            \n\t"
+                    "vfmul.vf     v27, v11, f1            \n\t"
+                    "addi         %[CNT], %[CNT], -1      \n\t"
+
+                    SQ4BIT_KERNEL_LOAD_ZP_16X1
+
+                    "LOOP_INNER%=:                        \n\t"
+
+                    SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
+
+                    "vsub.vv      v0, v0, v8              \n\t"
+                    "vsub.vv      v4, v4, v8              \n\t"
+                    "vsub.vv      v1, v1, v9              \n\t"
+                    "vsub.vv      v5, v5, v9              \n\t"
+                    "vsub.vv      v2, v2, v10             \n\t"
+                    "vsub.vv      v6, v6, v10             \n\t"
+                    "vsub.vv      v3, v3, v11             \n\t"
+                    "vsub.vv      v7, v7, v11             \n\t"
+
+                    SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
+
+                    "bnez         t5, LOOP_INNER%=        \n\t"
+                    "vsetvli      t0, zero, e32, mf2      \n\t"
+
+                    SQ4BIT_KERNEL_ACC_F16_1X4X4
+                    "addi         s7, s1, 32              \n\t"
+
+                    "bnez         %[CNT], LOOP_K%=        \n\t"
+                    "addi         t3, zero, 16            \n\t"
+                    "addi         s1, %[C], 16            \n\t"
+                    "addi         s2, %[C], 32            \n\t"
+                    "addi         s3, %[C], 48            \n\t"
+                    "blt          %[NBLKS], t3, ST_TAIL%= \n\t"
+                    "vse32.v      v28, (%[C])             \n\t"
+                    "vse32.v      v29, (s1)               \n\t"
+                    "vse32.v      v30, (s2)               \n\t"
+                    "vse32.v      v31, (s3)               \n\t"
+                    "jal          x0, END%=               \n\t"
+
+                    "ST_TAIL%=:                           \n\t"
+                    "vsetvli      t0, %[NBLKS], e32, mf2  \n\t"
+                    "sub          %[NBLKS], %[NBLKS], t0  \n\t"
+                    "vse32.v      v28, (%[C])             \n\t"
+                    "vsetvli      t0, %[NBLKS], e32, mf2  \n\t"
+                    "sub          %[NBLKS], %[NBLKS], t0  \n\t"
+                    "vse32.v      v29, (s1)               \n\t"
+                    "vsetvli      t0, %[NBLKS], e32, mf2  \n\t"
+                    "sub          %[NBLKS], %[NBLKS], t0  \n\t"
+                    "vse32.v      v30, (s2)               \n\t"
+                    "vsetvli      t0, %[NBLKS], e32, mf2  \n\t"
+                    "sub          %[NBLKS], %[NBLKS], t0  \n\t"
+                    "vse32.v      v31, (s3)               \n\t"
+                    "END%=:                               \n\t"
+
+                    : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
+                    : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
+                    : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7");
+            }
+        }
+    } else {
+        for (size_t n = 0; n < CountN; n += 16) {
+            size_t      nblks         = (CountN - n) > 16 ? 16 : CountN - n;
+            std::byte * QuantBDataPtr = (std::byte *) QuantBData +         //
+                                        n * BlockCountK * BlkLen / 2 +     // b data
+                                        n * BlockCountK * sizeof(_Float16);  // scale
+            float * CPtr = C + n;
+            size_t  cnt  = BlockCountK;
+            if (Bias != nullptr) {
+                const float * bias = Bias + n;
+                __asm__ volatile(
+                    "addi         t3, %[NBLKS], 0         \n\t"
+                    "addi         s1, %[B], 0             \n\t"
+                    "addi         s2, %[B], 8             \n\t"
+                    "addi         s3, %[B], 16            \n\t"
+                    "addi         s4, %[B], 24            \n\t"
+                    "addi         s5, %[A], 0             \n\t"
+                    "addi         s6, %[A], 12            \n\t"
+                    "vsetvli      t0, t3, e32, mf2        \n\t"
+                    "vle32.v      v28, (%[BIAS])          \n\t"
+                    "sub          t3, t3, t0              \n\t"
+                    "addi         %[BIAS], %[BIAS], 16    \n\t"
+                    "vsetvli      t0, t3, e32, mf2        \n\t"
+                    "vle32.v      v29, (%[BIAS])          \n\t"
+                    "sub          t3, t3, t0              \n\t"
+                    "addi         %[BIAS], %[BIAS], 16    \n\t"
+                    "vsetvli      t0, t3, e32, mf2        \n\t"
+                    "vle32.v      v30, (%[BIAS])          \n\t"
+                    "sub          t3, t3, t0              \n\t"
+                    "addi         %[BIAS], %[BIAS], 16    \n\t"
+                    "vsetvli      t0, t3, e32, mf2        \n\t"
+                    "vle32.v      v31, (%[BIAS])          \n\t"
+
+                    "LOOP_K%=:                            \n\t"
+                    "vsetvli      t0, zero, e16, mf4      \n\t"
+
+                    "vle16.v      v4, (s1)                \n\t"
+                    "addi         s1, s1, 32              \n\t"
+                    "vle16.v      v5, (s2)                \n\t"
+                    "addi         s2, s2, 56              \n\t"
+                    "vle16.v      v6, (s3)                \n\t"
+                    "addi         s3, s3, 80              \n\t"
+                    "vle16.v      v7, (s4)                \n\t"
+                    "addi         s4, s4, 104             \n\t"
+                    "flw          f1, (s5)                \n\t"
+                    "addi         s5, s5, 4               \n\t"
+                    "vfwcvt.f.f.v v8, v4                  \n\t"
+                    "vfwcvt.f.f.v v9, v5                  \n\t"
+                    "vfwcvt.f.f.v v10, v6                 \n\t"
+                    "vfwcvt.f.f.v v11, v7                 \n\t"
+
+                    "vsetvli      t0, zero, e32, mf2      \n\t"
+                    "addi         t5, %[INNER], 0         \n\t"
+                    "vxor.vv      v16, v16, v16           \n\t"
+                    "vxor.vv      v18, v18, v18           \n\t"
+                    "vxor.vv      v20, v20, v20           \n\t"
+                    "vxor.vv      v22, v22, v22           \n\t"
+                    "vfmul.vf     v24, v8, f1             \n\t"
+                    "vfmul.vf     v25, v9, f1             \n\t"
+                    "vfmul.vf     v26, v10, f1            \n\t"
+                    "vfmul.vf     v27, v11, f1            \n\t"
+                    "addi         %[CNT], %[CNT], -1      \n\t"
+                    "vsetvli      t0, zero, e8, m1        \n\t"
+                    "LOOP_INNER%=:                        \n\t"
+
+                    SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
+
+                    "vadd.vi      v0, v0, -8              \n\t"
+                    "vadd.vi      v1, v1, -8              \n\t"
+                    "vadd.vi      v2, v2, -8              \n\t"
+                    "vadd.vi      v3, v3, -8              \n\t"
+                    "vadd.vi      v4, v4, -8              \n\t"
+                    "vadd.vi      v5, v5, -8              \n\t"
+                    "vadd.vi      v6, v6, -8              \n\t"
+                    "vadd.vi      v7, v7, -8              \n\t"
+
+                    SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
+
+                    "bnez         t5, LOOP_INNER%=        \n\t"
+                    "vsetvli      t0, zero, e32, mf2      \n\t"
+
+                    SQ4BIT_KERNEL_ACC_F16_1X4X4
+
+                    "bnez         %[CNT], LOOP_K%=        \n\t"
+                    "addi         t3, zero, 16            \n\t"
+                    "addi         s1, %[C], 16            \n\t"
+                    "addi         s2, %[C], 32            \n\t"
+                    "addi         s3, %[C], 48            \n\t"
+                    "blt          %[NBLKS], t3, ST_TAIL%= \n\t"
+                    "vse32.v      v28, (%[C])             \n\t"
+                    "vse32.v      v29, (s1)               \n\t"
+                    "vse32.v      v30, (s2)               \n\t"
+                    "vse32.v      v31, (s3)               \n\t"
+                    "jal          x0, END%=               \n\t"
+
+                    "ST_TAIL%=:                           \n\t"
+                    "vsetvli      t0, %[NBLKS], e32, mf2  \n\t"
+                    "sub          %[NBLKS], %[NBLKS], t0  \n\t"
+                    "vse32.v      v28, (%[C])             \n\t"
+                    "vsetvli      t0, %[NBLKS], e32, mf2  \n\t"
+                    "sub          %[NBLKS], %[NBLKS], t0  \n\t"
+                    "vse32.v      v29, (s1)               \n\t"
+                    "vsetvli      t0, %[NBLKS], e32, mf2  \n\t"
+                    "sub          %[NBLKS], %[NBLKS], t0  \n\t"
+                    "vse32.v      v30, (s2)               \n\t"
+                    "vsetvli      t0, %[NBLKS], e32, mf2  \n\t"
+                    "sub          %[NBLKS], %[NBLKS], t0  \n\t"
+                    "vse32.v      v31, (s3)               \n\t"
+                    "END%=:                               \n\t"
+
+                    : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias)
+                    : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
+                    : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6");
+            } else {
+                __asm__ volatile(
+                    "vsetvli      t0, zero, e32, m4       \n\t"
+                    "vxor.vv      v28, v28, v28           \n\t"
+                    "addi         s1, %[B], 0             \n\t"
+                    "addi         s2, %[B], 8             \n\t"
+                    "addi         s3, %[B], 16            \n\t"
+                    "addi         s4, %[B], 24            \n\t"
+
+                    "addi         s5, %[A], 0             \n\t"
+                    "addi         s6, %[A], 12            \n\t"
+                    "LOOP_K%=:                            \n\t"
+                    "vsetvli      t0, zero, e16, mf4      \n\t"
+                    "vle16.v      v4, (s1)                \n\t"
+                    "addi         s1, s1, 32              \n\t"
+                    "vle16.v      v5, (s2)                \n\t"
+                    "addi         s2, s2, 56              \n\t"
+                    "vle16.v      v6, (s3)                \n\t"
+                    "addi         s3, s3, 80              \n\t"
+                    "vle16.v      v7, (s4)                \n\t"
+                    "addi         s4, s4, 104             \n\t"
+                    "flw          f1, (s5)                \n\t"
+                    "addi         s5, s5, 4               \n\t"
+
+                    "vfwcvt.f.f.v v8, v4                  \n\t"
+                    "vfwcvt.f.f.v v9, v5                  \n\t"
+                    "vfwcvt.f.f.v v10, v6                 \n\t"
+                    "vfwcvt.f.f.v v11, v7                 \n\t"
+                    "vsetvli      t0, zero, e32, mf2      \n\t"
+
+                    "addi         t5, %[INNER], 0         \n\t"
+                    "vxor.vv      v16, v16, v16           \n\t"
+                    "vxor.vv      v18, v18, v18           \n\t"
+                    "vxor.vv      v20, v20, v20           \n\t"
+                    "vxor.vv      v22, v22, v22           \n\t"
+                    "vfmul.vf     v24, v8, f1             \n\t"
+                    "vfmul.vf     v25, v9, f1             \n\t"
+                    "vfmul.vf     v26, v10, f1            \n\t"
+                    "vfmul.vf     v27, v11, f1            \n\t"
+                    "addi         %[CNT], %[CNT], -1      \n\t"
+                    "vsetvli      t0, zero, e8, m1        \n\t"
+                    "LOOP_INNER%=:                        \n\t"
+
+                    SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
+
+                    "vadd.vi      v0, v0, -8              \n\t"
+                    "vadd.vi      v1, v1, -8              \n\t"
+                    "vadd.vi      v2, v2, -8              \n\t"
+                    "vadd.vi      v3, v3, -8              \n\t"
+                    "vadd.vi      v4, v4, -8              \n\t"
+                    "vadd.vi      v5, v5, -8              \n\t"
+                    "vadd.vi      v6, v6, -8              \n\t"
+                    "vadd.vi      v7, v7, -8              \n\t"
+
+                    SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
+
+                    "bnez         t5, LOOP_INNER%=        \n\t"
+                    "vsetvli      t0, zero, e32, mf2      \n\t"
+
+                    SQ4BIT_KERNEL_ACC_F16_1X4X4
+
+                    "bnez         %[CNT], LOOP_K%=        \n\t"
+                    "addi         t3, zero, 16            \n\t"
+                    "addi         s1, %[C], 16            \n\t"
+                    "addi         s2, %[C], 32            \n\t"
+                    "addi         s3, %[C], 48            \n\t"
+                    "blt          %[NBLKS], t3, ST_TAIL%= \n\t"
+                    "vse32.v      v28, (%[C])             \n\t"
+                    "vse32.v      v29, (s1)               \n\t"
+                    "vse32.v      v30, (s2)               \n\t"
+                    "vse32.v      v31, (s3)               \n\t"
+                    "jal          x0, END%=               \n\t"
+
+                    "ST_TAIL%=:                           \n\t"
+                    "vsetvli      t0, %[NBLKS], e32, mf2  \n\t"
+                    "sub          %[NBLKS], %[NBLKS], t0  \n\t"
+                    "vse32.v      v28, (%[C])             \n\t"
+                    "vsetvli      t0, %[NBLKS], e32, mf2  \n\t"
+                    "sub          %[NBLKS], %[NBLKS], t0  \n\t"
+                    "vse32.v      v29, (s1)               \n\t"
+                    "vsetvli      t0, %[NBLKS], e32, mf2  \n\t"
+                    "sub          %[NBLKS], %[NBLKS], t0  \n\t"
+                    "vse32.v      v30, (s2)               \n\t"
+                    "vsetvli      t0, %[NBLKS], e32, mf2  \n\t"
+                    "sub          %[NBLKS], %[NBLKS], t0  \n\t"
+                    "vse32.v      v31, (s3)               \n\t"
+                    "END%=:                               \n\t"
+
+                    : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
+                    : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
+                    : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6");
+            }
+        }
+    }
+}
+
+template <bool HasZeroPoint>
+void SQ4BitGemmM1Kernel_CompInt8_Impl(size_t            BlkLen,
+                                      const std::byte * QuantA,
+                                      const std::byte * QuantBData,
+                                      const float *     QuantBScale,
+                                      const std::byte * QuantBZeroPoint,
+                                      float *           C,
+                                      size_t            CountN,
+                                      size_t            BlockCountK,
+                                      const float *     Bias) {
+    GGML_UNUSED(QuantBScale);
+    GGML_UNUSED(QuantBZeroPoint);
+    const size_t INNER = BlkLen / 16;
+    if constexpr (HasZeroPoint) {
+        for (size_t n = 0; n < CountN; n += 16) {
+            size_t      nblks         = (CountN - n) > 16 ? 16 : CountN - n;
+            std::byte * QuantBDataPtr = (std::byte *) QuantBData +           //
+                                        n * BlockCountK * BlkLen / 2 +       // b data
+                                        n * BlockCountK * sizeof(uint8_t) +  // zp
+                                        n * BlockCountK * sizeof(float);     // scale
+            float * CPtr = C + n;
+            size_t  cnt  = BlockCountK;
+            if (Bias != nullptr) {
+                const float * bias = Bias + n;
+                __asm__ volatile(
+                    "addi         t3, %[NBLKS], 0         \n\t"
+                    "vsetvli      t0, zero, e8, m1        \n\t"
+                    "vmv.v.i      v13, 3                  \n\t"
+                    "li           s1, 24                  \n\t"
+                    "vsetvli      t0, s1, e8, m1          \n\t"
+                    "vmv.v.i      v13, 2                  \n\t"
+                    "vsetvli      t0, zero, e8, mf2       \n\t"
+                    "vmv.v.i      v13, 1                  \n\t"
+                    "vsetvli      t0, zero, e8, mf4       \n\t"
+                    "vmv.v.i      v13, 0                  \n\t"
+                    "vsetvli      t0, zero, e32, m4       \n\t"
+                    "vxor.vv      v28, v28, v28           \n\t"
+
+                    // scale offset, scale0.0, scale1.0, scale2.0, scale3.0....scale15.0
+                    "addi         s1, %[B], 0             \n\t"
+                    "addi         s2, %[B], 16            \n\t"
+                    "addi         s3, %[B], 32            \n\t"
+                    "addi         s4, %[B], 48            \n\t"
+                    // zp offset
+                    "addi         s7, %[B], 64            \n\t"
+                    // a offset
+                    "addi         s5, %[A], 0             \n\t"
+                    "addi         s6, %[A], 12            \n\t"
+
+                    "vsetvli      t0, t3, e32, mf2        \n\t"
+                    "vle32.v      v28, (%[BIAS])          \n\t"
+                    "sub          t3, t3, t0              \n\t"
+                    "addi         %[BIAS], %[BIAS], 16    \n\t"
+                    "vsetvli      t0, t3, e32, mf2        \n\t"
+                    "vle32.v      v29, (%[BIAS])          \n\t"
+                    "sub          t3, t3, t0              \n\t"
+                    "addi         %[BIAS], %[BIAS], 16    \n\t"
+                    "vsetvli      t0, t3, e32, mf2        \n\t"
+                    "vle32.v      v30, (%[BIAS])          \n\t"
+                    "sub          t3, t3, t0              \n\t"
+                    "addi         %[BIAS], %[BIAS], 16    \n\t"
+                    "vsetvli      t0, t3, e32, mf2        \n\t"
+                    "vle32.v      v31, (%[BIAS])          \n\t"
+                    "vsetvli      t0, zero, e32, mf2      \n\t"
+                    "LOOP_K%=:                            \n\t"
+
+                    // load scale
+                    "vle32.v      v8, (s1)                \n\t"
+                    "addi         s1, s1, 80              \n\t"
+                    "vle32.v      v9, (s2)                \n\t"
+                    "addi         s2, s2, 96              \n\t"
+                    "vle32.v      v10, (s3)               \n\t"
+                    "addi         s3, s3, 112             \n\t"
+                    "vle32.v      v11, (s4)               \n\t"
+                    "addi         s4, s4, 128             \n\t"
+
+                    // load a scale
+                    "flw          f1, (s5)                \n\t"
+                    "addi         s5, s5, 4               \n\t"
+
+                    "addi         t5, %[INNER], 0         \n\t"
+                    "vxor.vv      v16, v16, v16           \n\t"
+                    "vxor.vv      v18, v18, v18           \n\t"
+                    "vxor.vv      v20, v20, v20           \n\t"
+                    "vxor.vv      v22, v22, v22           \n\t"
+
+                    // a scale * b scale
+                    "vfmul.vf     v24, v8, f1             \n\t"
+                    "vfmul.vf     v25, v9, f1             \n\t"
+                    "vfmul.vf     v26, v10, f1            \n\t"
+                    "vfmul.vf     v27, v11, f1            \n\t"
+                    "addi         %[CNT], %[CNT], -1      \n\t"
+
+                    SQ4BIT_KERNEL_LOAD_ZP_16X1
+
+                    "LOOP_INNER%=:                        \n\t"
+
+                    SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
+
+                    "vsub.vv      v0, v0, v8              \n\t"
+                    "vsub.vv      v4, v4, v8              \n\t"
+                    "vsub.vv      v1, v1, v9              \n\t"
+                    "vsub.vv      v5, v5, v9              \n\t"
+                    "vsub.vv      v2, v2, v10             \n\t"
+                    "vsub.vv      v6, v6, v10             \n\t"
+                    "vsub.vv      v3, v3, v11             \n\t"
+                    "vsub.vv      v7, v7, v11             \n\t"
+
+                    SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
+
+                    "bnez         t5, LOOP_INNER%=        \n\t"
+                    "vsetvli      t0, zero, e32, mf2      \n\t"
+
+                    SQ4BIT_KERNEL_ACC_1X4X4
+                    "addi         s7, s1, 64              \n\t"
+
+                    "bnez         %[CNT], LOOP_K%=        \n\t"
+
+                    "addi         t3, zero, 16            \n\t"
+                    "addi         s1, %[C], 16            \n\t"
+                    "addi         s2, %[C], 32            \n\t"
+                    "addi         s3, %[C], 48            \n\t"
+                    "blt          %[NBLKS], t3, ST_TAIL%= \n\t"
+                    "vse32.v      v28, (%[C])             \n\t"
+                    "vse32.v      v29, (s1)               \n\t"
+                    "vse32.v      v30, (s2)               \n\t"
+                    "vse32.v      v31, (s3)               \n\t"
+                    "jal          x0, END%=               \n\t"
+
+                    "ST_TAIL%=:                           \n\t"
+                    "vsetvli      t0, %[NBLKS], e32, mf2  \n\t"
+                    "sub          %[NBLKS], %[NBLKS], t0  \n\t"
+                    "vse32.v      v28, (%[C])             \n\t"
+                    "vsetvli      t0, %[NBLKS], e32, mf2  \n\t"
+                    "sub          %[NBLKS], %[NBLKS], t0  \n\t"
+                    "vse32.v      v29, (s1)               \n\t"
+                    "vsetvli      t0, %[NBLKS], e32, mf2  \n\t"
+                    "sub          %[NBLKS], %[NBLKS], t0  \n\t"
+                    "vse32.v      v30, (s2)               \n\t"
+                    "vsetvli      t0, %[NBLKS], e32, mf2  \n\t"
+                    "sub          %[NBLKS], %[NBLKS], t0  \n\t"
+                    "vse32.v      v31, (s3)               \n\t"
+                    "END%=:                               \n\t"
+
+                    : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias)
+                    : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
+                    : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7");
+            } else {
+                __asm__ volatile(
+                    "vsetvli      t0, zero, e32, m4       \n\t"
+                    "vxor.vv      v28, v28, v28           \n\t"
+
+                    "vsetvli      t0, zero, e8, m1        \n\t"
+                    "vmv.v.i      v13, 3                  \n\t"
+                    "li           s1, 24                  \n\t"
+                    "vsetvli      t0, s1, e8, m1          \n\t"
+                    "vmv.v.i      v13, 2                  \n\t"
+                    "vsetvli      t0, zero, e8, mf2       \n\t"
+                    "vmv.v.i      v13, 1                  \n\t"
+                    "vsetvli      t0, zero, e8, mf4       \n\t"
+                    "vmv.v.i      v13, 0                  \n\t"
+                    "addi         s1, %[B], 0             \n\t"
+                    "addi         s2, %[B], 16            \n\t"
+                    "addi         s3, %[B], 32            \n\t"
+                    "addi         s4, %[B], 48            \n\t"
+
+                    "addi         s7, %[B], 64            \n\t"
+
+                    "addi         s5, %[A], 0             \n\t"
+                    "addi         s6, %[A], 12            \n\t"
+                    "vsetvli      t0, zero, e32, mf2      \n\t"
+
+                    "LOOP_K%=:                            \n\t"
+                    "vle32.v      v8, (s1)                \n\t"
+                    "addi         s1, s1, 80              \n\t"
+                    "vle32.v      v9, (s2)                \n\t"
+                    "addi         s2, s2, 96              \n\t"
+                    "vle32.v      v10, (s3)               \n\t"
+                    "addi         s3, s3, 112             \n\t"
+                    "vle32.v      v11, (s4)               \n\t"
+                    "addi         s4, s4, 128             \n\t"
+
+                    "flw          f1, (s5)                \n\t"
+                    "addi         s5, s5, 4               \n\t"
+
+                    "addi         t5, %[INNER], 0         \n\t"
+                    "vxor.vv      v16, v16, v16           \n\t"
+                    "vxor.vv      v18, v18, v18           \n\t"
+                    "vxor.vv      v20, v20, v20           \n\t"
+                    "vxor.vv      v22, v22, v22           \n\t"
+
+                    "vfmul.vf     v24, v8, f1             \n\t"
+                    "vfmul.vf     v25, v9, f1             \n\t"
+                    "vfmul.vf     v26, v10, f1            \n\t"
+                    "vfmul.vf     v27, v11, f1            \n\t"
+                    "addi         %[CNT], %[CNT], -1      \n\t"
+
+                    SQ4BIT_KERNEL_LOAD_ZP_16X1
+
+                    "LOOP_INNER%=:                        \n\t"
+
+                    SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
+
+                    "vsub.vv      v0, v0, v8              \n\t"
+                    "vsub.vv      v4, v4, v8              \n\t"
+                    "vsub.vv      v1, v1, v9              \n\t"
+                    "vsub.vv      v5, v5, v9              \n\t"
+                    "vsub.vv      v2, v2, v10             \n\t"
+                    "vsub.vv      v6, v6, v10             \n\t"
+                    "vsub.vv      v3, v3, v11             \n\t"
+                    "vsub.vv      v7, v7, v11             \n\t"
+
+                    SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
+
+                    "bnez         t5, LOOP_INNER%=        \n\t"
+                    "vsetvli      t0, zero, e32, mf2      \n\t"
+
+                    SQ4BIT_KERNEL_ACC_1X4X4
+                    "addi         s7, s1, 64              \n\t"
+
+                    "bnez         %[CNT], LOOP_K%=        \n\t"
+
+                    "addi         t3, zero, 16            \n\t"
+                    "addi         s1, %[C], 16            \n\t"
+                    "addi         s2, %[C], 32            \n\t"
+                    "addi         s3, %[C], 48            \n\t"
+                    "blt          %[NBLKS], t3, ST_TAIL%= \n\t"
+                    "vse32.v      v28, (%[C])             \n\t"
+                    "vse32.v      v29, (s1)               \n\t"
+                    "vse32.v      v30, (s2)               \n\t"
+                    "vse32.v      v31, (s3)               \n\t"
+                    "jal          x0, END%=               \n\t"
+
+                    "ST_TAIL%=:                           \n\t"
+                    "vsetvli      t0, %[NBLKS], e32, mf2  \n\t"
+                    "sub          %[NBLKS], %[NBLKS], t0  \n\t"
+                    "vse32.v      v28, (%[C])             \n\t"
+                    "vsetvli      t0, %[NBLKS], e32, mf2  \n\t"
+                    "sub          %[NBLKS], %[NBLKS], t0  \n\t"
+                    "vse32.v      v29, (s1)               \n\t"
+                    "vsetvli      t0, %[NBLKS], e32, mf2  \n\t"
+                    "sub          %[NBLKS], %[NBLKS], t0  \n\t"
+                    "vse32.v      v30, (s2)               \n\t"
+                    "vsetvli      t0, %[NBLKS], e32, mf2  \n\t"
+                    "sub          %[NBLKS], %[NBLKS], t0  \n\t"
+                    "vse32.v      v31, (s3)               \n\t"
+                    "END%=:                               \n\t"
+
+                    : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
+                    : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
+                    : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7");
+            }
+        }
+    } else {
+        for (size_t n = 0; n < CountN; n += 16) {
+            size_t      nblks         = (CountN - n) > 16 ? 16 : CountN - n;
+            std::byte * QuantBDataPtr = (std::byte *) QuantBData +        //
+                                        n * BlockCountK * BlkLen / 2 +    // b data
+                                        n * BlockCountK * sizeof(float);  // scale
+            float * CPtr = C + n;
+            size_t  cnt  = BlockCountK;
+            if (Bias != nullptr) {
+                const float * bias = Bias + n;
+                __asm__ volatile(
+                    "addi         t3, %[NBLKS], 0         \n\t"
+                    "addi         s1, %[B], 0             \n\t"
+                    "addi         s2, %[B], 16            \n\t"
+                    "addi         s3, %[B], 32            \n\t"
+                    "addi         s4, %[B], 48            \n\t"
+                    "addi         s5, %[A], 0             \n\t"
+                    "addi         s6, %[A], 12            \n\t"
+                    "vsetvli      t0, t3, e32, mf2        \n\t"
+                    "vle32.v      v28, (%[BIAS])          \n\t"
+                    "sub          t3, t3, t0              \n\t"
+                    "addi         %[BIAS], %[BIAS], 16    \n\t"
+                    "vsetvli      t0, t3, e32, mf2        \n\t"
+                    "vle32.v      v29, (%[BIAS])          \n\t"
+                    "sub          t3, t3, t0              \n\t"
+                    "addi         %[BIAS], %[BIAS], 16    \n\t"
+                    "vsetvli      t0, t3, e32, mf2        \n\t"
+                    "vle32.v      v30, (%[BIAS])          \n\t"
+                    "sub          t3, t3, t0              \n\t"
+                    "addi         %[BIAS], %[BIAS], 16    \n\t"
+                    "vsetvli      t0, t3, e32, mf2        \n\t"
+                    "vle32.v      v31, (%[BIAS])          \n\t"
+                    "vsetvli      t0, zero, e32, mf2      \n\t"
+                    "LOOP_K%=:                            \n\t"
+                    "vle32.v      v8, (s1)                \n\t"
+                    "addi         s1, s1, 64              \n\t"
+                    "vle32.v      v9, (s2)                \n\t"
+                    "addi         s2, s2, 80              \n\t"
+                    "vle32.v      v10, (s3)               \n\t"
+                    "addi         s3, s3, 96              \n\t"
+                    "vle32.v      v11, (s4)               \n\t"
+                    "addi         s4, s4, 112             \n\t"
+                    "flw          f1, (s5)                \n\t"
+                    "addi         s5, s5, 4               \n\t"
+
+                    "addi         t5, %[INNER], 0         \n\t"
+                    "vxor.vv      v16, v16, v16           \n\t"
+                    "vxor.vv      v18, v18, v18           \n\t"
+                    "vxor.vv      v20, v20, v20           \n\t"
+                    "vxor.vv      v22, v22, v22           \n\t"
+                    "vfmul.vf     v24, v8, f1             \n\t"
+                    "vfmul.vf     v25, v9, f1             \n\t"
+                    "vfmul.vf     v26, v10, f1            \n\t"
+                    "vfmul.vf     v27, v11, f1            \n\t"
+                    "addi         %[CNT], %[CNT], -1      \n\t"
+                    "vsetvli      t0, zero, e8, m1        \n\t"
+                    "LOOP_INNER%=:                        \n\t"
+
+                    SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
+
+                    "vadd.vi      v0, v0, -8              \n\t"
+                    "vadd.vi      v1, v1, -8              \n\t"
+                    "vadd.vi      v2, v2, -8              \n\t"
+                    "vadd.vi      v3, v3, -8              \n\t"
+                    "vadd.vi      v4, v4, -8              \n\t"
+                    "vadd.vi      v5, v5, -8              \n\t"
+                    "vadd.vi      v6, v6, -8              \n\t"
+                    "vadd.vi      v7, v7, -8              \n\t"
+
+                    SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
+
+                    "bnez         t5, LOOP_INNER%=        \n\t"
+                    "vsetvli      t0, zero, e32, mf2      \n\t"
+
+                    SQ4BIT_KERNEL_ACC_1X4X4
+
+                    "bnez         %[CNT], LOOP_K%=        \n\t"
+                    "addi         t3, zero, 16            \n\t"
+                    "addi         s1, %[C], 16            \n\t"
+                    "addi         s2, %[C], 32            \n\t"
+                    "addi         s3, %[C], 48            \n\t"
+                    "blt          %[NBLKS], t3, ST_TAIL%= \n\t"
+                    "vse32.v      v28, (%[C])             \n\t"
+                    "vse32.v      v29, (s1)               \n\t"
+                    "vse32.v      v30, (s2)               \n\t"
+                    "vse32.v      v31, (s3)               \n\t"
+                    "jal          x0, END%=               \n\t"
+
+                    "ST_TAIL%=:                           \n\t"
+                    "vsetvli      t0, %[NBLKS], e32, mf2  \n\t"
+                    "sub          %[NBLKS], %[NBLKS], t0  \n\t"
+                    "vse32.v      v28, (%[C])             \n\t"
+                    "vsetvli      t0, %[NBLKS], e32, mf2  \n\t"
+                    "sub          %[NBLKS], %[NBLKS], t0  \n\t"
+                    "vse32.v      v29, (s1)               \n\t"
+                    "vsetvli      t0, %[NBLKS], e32, mf2  \n\t"
+                    "sub          %[NBLKS], %[NBLKS], t0  \n\t"
+                    "vse32.v      v30, (s2)               \n\t"
+                    "vsetvli      t0, %[NBLKS], e32, mf2  \n\t"
+                    "sub          %[NBLKS], %[NBLKS], t0  \n\t"
+                    "vse32.v      v31, (s3)               \n\t"
+                    "END%=:                               \n\t"
+
+                    : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias)
+                    : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
+                    : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6");
+            } else {
+                __asm__ volatile(
+                    "vsetvli      t0, zero, e32, m4       \n\t"
+                    "vxor.vv      v28, v28, v28           \n\t"
+                    "addi         s1, %[B], 0             \n\t"
+                    "addi         s2, %[B], 16            \n\t"
+                    "addi         s3, %[B], 32            \n\t"
+                    "addi         s4, %[B], 48            \n\t"
+
+                    "addi         s5, %[A], 0             \n\t"
+                    "addi         s6, %[A], 12            \n\t"
+                    "vsetvli      t0, zero, e32, mf2      \n\t"
+                    "LOOP_K%=:                            \n\t"
+                    "vle32.v      v8, (s1)                \n\t"
+                    "addi         s1, s1, 64              \n\t"
+                    "vle32.v      v9, (s2)                \n\t"
+                    "addi         s2, s2, 80              \n\t"
+                    "vle32.v      v10, (s3)               \n\t"
+                    "addi         s3, s3, 96              \n\t"
+                    "vle32.v      v11, (s4)               \n\t"
+                    "addi         s4, s4, 112             \n\t"
+                    "flw          f1, (s5)                \n\t"
+                    "addi         s5, s5, 4               \n\t"
+
+                    "addi         t5, %[INNER], 0         \n\t"
+                    "vxor.vv      v16, v16, v16           \n\t"
+                    "vxor.vv      v18, v18, v18           \n\t"
+                    "vxor.vv      v20, v20, v20           \n\t"
+                    "vxor.vv      v22, v22, v22           \n\t"
+                    "vfmul.vf     v24, v8, f1             \n\t"
+                    "vfmul.vf     v25, v9, f1             \n\t"
+                    "vfmul.vf     v26, v10, f1            \n\t"
+                    "vfmul.vf     v27, v11, f1            \n\t"
+                    "addi         %[CNT], %[CNT], -1      \n\t"
+                    "vsetvli      t0, zero, e8, m1        \n\t"
+                    "LOOP_INNER%=:                        \n\t"
+
+                    SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
+
+                    "vadd.vi      v0, v0, -8              \n\t"
+                    "vadd.vi      v1, v1, -8              \n\t"
+                    "vadd.vi      v2, v2, -8              \n\t"
+                    "vadd.vi      v3, v3, -8              \n\t"
+                    "vadd.vi      v4, v4, -8              \n\t"
+                    "vadd.vi      v5, v5, -8              \n\t"
+                    "vadd.vi      v6, v6, -8              \n\t"
+                    "vadd.vi      v7, v7, -8              \n\t"
+
+                    SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
+
+                    "bnez         t5, LOOP_INNER%=        \n\t"
+                    "vsetvli      t0, zero, e32, mf2      \n\t"
+
+                    SQ4BIT_KERNEL_ACC_1X4X4
+
+                    "bnez         %[CNT], LOOP_K%=        \n\t"
+                    "addi         t3, zero, 16            \n\t"
+                    "addi         s1, %[C], 16            \n\t"
+                    "addi         s2, %[C], 32            \n\t"
+                    "addi         s3, %[C], 48            \n\t"
+                    "blt          %[NBLKS], t3, ST_TAIL%= \n\t"
+                    "vse32.v      v28, (%[C])             \n\t"
+                    "vse32.v      v29, (s1)               \n\t"
+                    "vse32.v      v30, (s2)               \n\t"
+                    "vse32.v      v31, (s3)               \n\t"
+                    "jal          x0, END%=               \n\t"
+
+                    "ST_TAIL%=:                           \n\t"
+                    "vsetvli      t0, %[NBLKS], e32, mf2  \n\t"
+                    "sub          %[NBLKS], %[NBLKS], t0  \n\t"
+                    "vse32.v      v28, (%[C])             \n\t"
+                    "vsetvli      t0, %[NBLKS], e32, mf2  \n\t"
+                    "sub          %[NBLKS], %[NBLKS], t0  \n\t"
+                    "vse32.v      v29, (s1)               \n\t"
+                    "vsetvli      t0, %[NBLKS], e32, mf2  \n\t"
+                    "sub          %[NBLKS], %[NBLKS], t0  \n\t"
+                    "vse32.v      v30, (s2)               \n\t"
+                    "vsetvli      t0, %[NBLKS], e32, mf2  \n\t"
+                    "sub          %[NBLKS], %[NBLKS], t0  \n\t"
+                    "vse32.v      v31, (s3)               \n\t"
+                    "END%=:                               \n\t"
+
+                    : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
+                    : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
+                    : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6");
+            }
+        }
+    }
+}
+
+template <bool HasZeroPoint>
+inline void SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen(size_t            BlkLen,
+                                                         const std::byte * QuantA,
+                                                         const std::byte * QuantBData,
+                                                         const float *     QuantBScale,
+                                                         const std::byte * QuantBZeroPoint,
+                                                         float *           C,
+                                                         size_t            CountM,
+                                                         size_t            CountN,
+                                                         size_t            BlockStrideQuantB,
+                                                         const float *     Bias,
+                                                         const size_t      ldc,
+                                                         const size_t      scalestride) {
+    if (scalestride == 4) {
+        SQ4BitGemmM4Kernel_CompInt8_Impl<HasZeroPoint>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C,
+                                                       CountN, BlockStrideQuantB, Bias, ldc);
+
+    } else if (scalestride == 2) {
+        SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl<HasZeroPoint>(
+            BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias, ldc);
+    }
+}
+
+template <bool HasZeroPoint>
+inline void SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(size_t            BlkLen,
+                                                         const std::byte * QuantA,
+                                                         const std::byte * QuantBData,
+                                                         const float *     QuantBScale,
+                                                         const std::byte * QuantBZeroPoint,
+                                                         float *           C,
+                                                         size_t            CountM,
+                                                         size_t            CountN,
+                                                         size_t            BlockStrideQuantB,
+                                                         const float *     Bias,
+                                                         const size_t      ldc,
+                                                         const size_t      scalestride) {
+    if (scalestride == 4) {
+        SQ4BitGemmM1Kernel_CompInt8_Impl<HasZeroPoint>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C,
+                                                       CountN, BlockStrideQuantB, Bias);
+    } else if (scalestride == 2) {
+        SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl<HasZeroPoint>(BlkLen, QuantA, QuantBData, QuantBScale,
+                                                                 QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias);
+    }
+}
+
+}  // namespace
+
+namespace ime1 {
+size_t gemm_kernel_i8i4(size_t            BlkLen,
+                        const std::byte * QuantA,
+                        const std::byte * QuantBData,
+                        const float *     QuantBScale,
+                        const std::byte * QuantBZeroPoint,
+                        float *           C,
+                        size_t            CountM,
+                        size_t            CountN,
+                        size_t            CountK,
+                        size_t            BlockCountK,
+                        size_t            ldc,
+                        const float *     Bias,
+                        const size_t      ScaleStride) {
+    GGML_UNUSED(CountM);
+    GGML_UNUSED(CountK);
+    GGML_UNUSED(ldc);
+    if (CountM >= 4) {
+        if (QuantBZeroPoint != nullptr) {
+            SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen<true>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint,
+                                                               C, CountM, CountN, BlockCountK, Bias, ldc, ScaleStride);
+        } else {
+            SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen<false>(BlkLen, QuantA, QuantBData, QuantBScale,
+                                                                QuantBZeroPoint, C, CountM, CountN, BlockCountK, Bias,
+                                                                ldc, ScaleStride);
+        }
+        return 4;
+    } else {
+        if (QuantBZeroPoint != nullptr) {
+            SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen<true>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint,
+                                                               C, CountM, CountN, BlockCountK, Bias, ldc, ScaleStride);
+        } else {
+            SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen<false>(BlkLen, QuantA, QuantBData, QuantBScale,
+                                                                QuantBZeroPoint, C, CountM, CountN, BlockCountK, Bias,
+                                                                ldc, ScaleStride);
+        }
+        return 1;
+    }
+}
+}  // namespace ime1
+}  // namespace sqnbitgemm_spacemit_ime
diff --git a/src/ggml-cpu/spacemit/ime_kernels.h b/src/ggml-cpu/spacemit/ime_kernels.h
new file mode 100644 (file)
index 0000000..7570634
--- /dev/null
@@ -0,0 +1,26 @@
+#pragma once
+
+#include <cstddef>
+
+namespace sqnbitgemm_spacemit_ime {
+namespace ime1 {
+size_t gemm_kernel_i8i4(size_t            blk_len,
+                        const std::byte * quant_a_ptr,
+                        const std::byte * quant_b_data,
+                        const float *     quant_b_scale,
+                        const std::byte * quant_b_zp,
+                        float *           c_ptr,
+                        size_t            count_m,
+                        size_t            count_n,
+                        size_t            count_k,
+                        size_t            block_count_k,
+                        size_t            ldc,
+                        const float *     bias,
+                        const size_t      scale_stride);
+
+void quantize_a_row_i8(size_t blk_len, const float * a_ptr, size_t count_k, std::byte * quant_a_ptr);
+
+void quantize_a_4row_i8(size_t blk_len, const float * a_ptr, size_t count_k, std::byte * quant_a_ptr);
+
+}  // namespace ime1
+}  // namespace sqnbitgemm_spacemit_ime