#define VECTOR_REGISTERS 16
#endif
+#if defined(__riscv_v_intrinsic)
+#define LMUL 4
+#endif
+
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
namespace {
}
#endif
+#if defined(__riscv_zvfh)
+template <>
+inline vfloat32m1_t madd(vfloat16mf2_t a, vfloat16mf2_t b, vfloat32m1_t c) {
+ return __riscv_vfwmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
+}
+inline vfloat32m2_t madd(vfloat16m1_t a, vfloat16m1_t b, vfloat32m2_t c) {
+ return __riscv_vfwmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
+}
+inline vfloat32m4_t madd(vfloat16m2_t a, vfloat16m2_t b, vfloat32m4_t c) {
+ return __riscv_vfwmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
+}
+inline vfloat32m8_t madd(vfloat16m4_t a, vfloat16m4_t b, vfloat32m8_t c) {
+ return __riscv_vfwmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
+}
+inline vfloat32m1_t madd(vfloat32m1_t a, vfloat32m1_t b, vfloat32m1_t c) {
+ return __riscv_vfmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
+}
+inline vfloat32m2_t madd(vfloat32m2_t a, vfloat32m2_t b, vfloat32m2_t c) {
+ return __riscv_vfmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
+}
+inline vfloat32m4_t madd(vfloat32m4_t a, vfloat32m4_t b, vfloat32m4_t c) {
+ return __riscv_vfmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
+}
+inline vfloat32m8_t madd(vfloat32m8_t a, vfloat32m8_t b, vfloat32m8_t c) {
+ return __riscv_vfmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
+}
+#endif
+
+#if defined(__riscv_zvfbfwma)
+inline vfloat32m1_t madd(vbfloat16mf2_t a, vbfloat16mf2_t b, vfloat32m1_t c) {
+ return __riscv_vfwmaccbf16_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
+}
+inline vfloat32m2_t madd(vbfloat16m1_t a, vbfloat16m1_t b, vfloat32m2_t c) {
+ return __riscv_vfwmaccbf16_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
+}
+inline vfloat32m4_t madd(vbfloat16m2_t a, vbfloat16m2_t b, vfloat32m4_t c) {
+ return __riscv_vfwmaccbf16_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
+}
+#endif
+
////////////////////////////////////////////////////////////////////////////////////////////////////
// VECTORIZED HORIZONTAL SUM
}
#endif // __AVX512F__
+#if defined(__riscv_zvfh)
+inline float hsum(vfloat32m1_t x) {
+ return __riscv_vfmv_f_s_f32m1_f32(
+ __riscv_vfredusum_vs_f32m1_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m1()));
+}
+inline float hsum(vfloat32m2_t x) {
+ return __riscv_vfmv_f_s_f32m1_f32(
+ __riscv_vfredusum_vs_f32m2_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m2()));
+}
+inline float hsum(vfloat32m4_t x) {
+ return __riscv_vfmv_f_s_f32m1_f32(
+ __riscv_vfredusum_vs_f32m4_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m4()));
+}
+inline float hsum(vfloat32m8_t x) {
+ return __riscv_vfmv_f_s_f32m1_f32(
+ __riscv_vfredusum_vs_f32m8_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m8()));
+}
+#endif
+
////////////////////////////////////////////////////////////////////////////////////////////////////
// VECTORIZED MEMORY LOADING
}
#endif
+#if defined(__riscv_zvfh)
+template <> inline vfloat16mf2_t load(const ggml_fp16_t *p) {
+ return __riscv_vle16_v_f16mf2(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16mf2());
+}
+template <> inline vfloat16m1_t load(const ggml_fp16_t *p) {
+ return __riscv_vle16_v_f16m1(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m1());
+}
+template <> inline vfloat16m2_t load(const ggml_fp16_t *p) {
+ return __riscv_vle16_v_f16m2(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m2());
+}
+template <> inline vfloat16m4_t load(const ggml_fp16_t *p) {
+ return __riscv_vle16_v_f16m4(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m4());
+}
+template <> inline vfloat32m1_t load(const float *p) {
+ return __riscv_vle32_v_f32m1(p, __riscv_vsetvlmax_e32m1());
+}
+template <> inline vfloat32m2_t load(const float *p) {
+ return __riscv_vle32_v_f32m2(p, __riscv_vsetvlmax_e32m2());
+}
+template <> inline vfloat32m4_t load(const float *p) {
+ return __riscv_vle32_v_f32m4(p, __riscv_vsetvlmax_e32m4());
+}
+template <> inline vfloat32m8_t load(const float *p) {
+ return __riscv_vle32_v_f32m8(p, __riscv_vsetvlmax_e32m8());
+}
+#endif
+
+#if defined(__riscv_zvfbfwma)
+template <> inline vbfloat16mf2_t load(const ggml_bf16_t *p) {
+ return __riscv_vle16_v_bf16mf2(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16mf2());
+}
+template <> inline vbfloat16m1_t load(const ggml_bf16_t *p) {
+ return __riscv_vle16_v_bf16m1(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16m1());
+}
+template <> inline vbfloat16m2_t load(const ggml_bf16_t *p) {
+ return __riscv_vle16_v_bf16m2(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16m2());
+}
+#endif
+
+#if defined(__riscv_zvfh)
+template <typename T> T set_zero();
+
+template <> inline vfloat16mf2_t set_zero() {
+ return __riscv_vfmv_v_f_f16mf2(0, __riscv_vsetvlmax_e16mf2());
+}
+template <> inline vfloat16m1_t set_zero() {
+ return __riscv_vfmv_v_f_f16m1(0, __riscv_vsetvlmax_e16m1());
+}
+template <> inline vfloat16m2_t set_zero() {
+ return __riscv_vfmv_v_f_f16m2(0, __riscv_vsetvlmax_e16m2());
+}
+template <> inline vfloat16m4_t set_zero() {
+ return __riscv_vfmv_v_f_f16m4(0, __riscv_vsetvlmax_e16m4());
+}
+template <> inline vfloat32m1_t set_zero() {
+ return __riscv_vfmv_v_f_f32m1(0.0f, __riscv_vsetvlmax_e32m1());
+}
+template <> inline vfloat32m2_t set_zero() {
+ return __riscv_vfmv_v_f_f32m2(0, __riscv_vsetvlmax_e32m2());
+}
+template <> inline vfloat32m4_t set_zero() {
+ return __riscv_vfmv_v_f_f32m4(0, __riscv_vsetvlmax_e32m4());
+}
+template <> inline vfloat32m8_t set_zero() {
+ return __riscv_vfmv_v_f_f32m8(0, __riscv_vsetvlmax_e32m8());
+}
+#endif
+
+#if defined(__riscv_v_intrinsic)
+template <typename T> size_t vlmax() {
+ if constexpr (std::is_same_v<T, vfloat16mf2_t>) { return __riscv_vsetvlmax_e16mf2(); }
+ else if constexpr (std::is_same_v<T, vfloat16m1_t>) { return __riscv_vsetvlmax_e16m1(); }
+ else if constexpr (std::is_same_v<T, vfloat16m2_t>) { return __riscv_vsetvlmax_e16m2(); }
+ else if constexpr (std::is_same_v<T, vfloat16m4_t>) { return __riscv_vsetvlmax_e16m4(); }
+ else if constexpr (std::is_same_v<T, vfloat32m1_t>) { return __riscv_vsetvlmax_e32m1(); }
+ else if constexpr (std::is_same_v<T, vfloat32m2_t>) { return __riscv_vsetvlmax_e32m2(); }
+ else if constexpr (std::is_same_v<T, vfloat32m4_t>) { return __riscv_vsetvlmax_e32m4(); }
+ else if constexpr (std::is_same_v<T, vfloat32m8_t>) { return __riscv_vsetvlmax_e32m8(); }
+ return 0;
+}
+#endif
+
////////////////////////////////////////////////////////////////////////////////////////////////////
// FLOATING POINT MATRIX MULTIPLICATION
const int64_t ldc;
};
+#if defined(__riscv_v_intrinsic)
+template <typename D, typename V, typename TA, typename TB, typename TC>
+class tinyBLAS_RVV {
+ public:
+ tinyBLAS_RVV(const ggml_compute_params * params, int64_t k,
+ const TA *A, int64_t lda,
+ const TB *B, int64_t ldb,
+ TC *C, int64_t ldc)
+ : params(params), A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc) {
+ }
+
+ bool matmul(int64_t m, int64_t n) {
+ if (k % vlmax<V>() != 0) {
+ return false;
+ }
+
+#if LMUL == 1
+ if (m % 16 == 0 && (m/16 >= params->nth)) {
+ const int64_t SIZE_N = BLOCK_SIZE<6>(n);
+ mnpack<4, 6, 4>(m, n, SIZE_N, 12);
+ return true;
+ }
+ if (m % 8 == 0 ) {
+ const int64_t SIZE_N = BLOCK_SIZE<6>(n);
+ mnpack<4, 6, 2>(m, n, SIZE_N, 12);
+ return true;
+ }
+ if (m % 4 == 0) {
+ const int64_t SIZE_N = BLOCK_SIZE<6>(n);
+ mnpack<4, 6, 1>(m, n, SIZE_N, 12);
+ return true;
+ }
+#elif LMUL == 2
+ if (m % 16 == 0 && (m/16 >= params->nth)) {
+ const int64_t SIZE_N = BLOCK_SIZE<3>(n);
+ mnpack<4, 3, 4>(m, n, SIZE_N, 24);
+ return true;
+ }
+ if (m % 8 == 0 ) {
+ const int64_t SIZE_N = BLOCK_SIZE<3>(n);
+ mnpack<4, 3, 2>(m, n, SIZE_N, 24);
+ return true;
+ }
+ if (m % 4 == 0) {
+ const int64_t SIZE_N = BLOCK_SIZE<3>(n);
+ mnpack<4, 3, 1>(m, n, SIZE_N, 24);
+ return true;
+ }
+#else // LMUL = 4
+ if (m % 16 == 0 && (m/16 >= params->nth)) {
+ const int64_t SIZE_N = BLOCK_SIZE<2>(n);
+ mnpack<2, 2, 8>(m, n, SIZE_N, 36);
+ return true;
+ }
+ if (m % 8 == 0 ) {
+ const int64_t SIZE_N = BLOCK_SIZE<2>(n);
+ mnpack<2, 2, 4>(m, n, SIZE_N, 36);
+ return true;
+ }
+ if (m % 4 == 0) {
+ const int64_t SIZE_N = BLOCK_SIZE<2>(n);
+ mnpack<2, 2, 2>(m, n, SIZE_N, 36);
+ return true;
+ }
+#endif
+ return false;
+ }
+
+ private:
+ template<int RM, int RN, int BM>
+ inline void mnpack(int64_t m, int64_t n, int64_t SIZE_N, int64_t BN) {
+ if (SIZE_N == RN) {
+ return gemm<RM, RN, BM>(m, n, BN);
+ }
+ if constexpr (RN > 1) {
+ return mnpack<RM, RN-1, BM>(m, n, SIZE_N, BN);
+ } else {
+ GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N);
+ GGML_ASSERT(false); // we have miss something.
+ }
+ }
+
+ inline void gemm_bloc_4x6(int64_t ii, int64_t jj) {
+ size_t vl = vlmax<V>();
+ D Cv00 = set_zero<D>();
+ D Cv01 = set_zero<D>();
+ D Cv02 = set_zero<D>();
+ D Cv03 = set_zero<D>();
+ D Cv10 = set_zero<D>();
+ D Cv11 = set_zero<D>();
+ D Cv12 = set_zero<D>();
+ D Cv13 = set_zero<D>();
+ D Cv20 = set_zero<D>();
+ D Cv21 = set_zero<D>();
+ D Cv22 = set_zero<D>();
+ D Cv23 = set_zero<D>();
+ D Cv30 = set_zero<D>();
+ D Cv31 = set_zero<D>();
+ D Cv32 = set_zero<D>();
+ D Cv33 = set_zero<D>();
+ D Cv40 = set_zero<D>();
+ D Cv41 = set_zero<D>();
+ D Cv42 = set_zero<D>();
+ D Cv43 = set_zero<D>();
+ D Cv50 = set_zero<D>();
+ D Cv51 = set_zero<D>();
+ D Cv52 = set_zero<D>();
+ D Cv53 = set_zero<D>();
+
+ for (int64_t l = 0; l < k; l += vl) {
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
+ V Bv1 = load<V>(B + ldb * (jj + 1) + l);
+ V Bv2 = load<V>(B + ldb * (jj + 2) + l);
+ V Bv3 = load<V>(B + ldb * (jj + 3) + l);
+ V Bv4 = load<V>(B + ldb * (jj + 4) + l);
+ V Bv5 = load<V>(B + ldb * (jj + 5) + l);
+
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
+ Cv00 = madd(Av0, Bv0, Cv00);
+ Cv10 = madd(Av0, Bv1, Cv10);
+ Cv20 = madd(Av0, Bv2, Cv20);
+ Cv30 = madd(Av0, Bv3, Cv30);
+ Cv40 = madd(Av0, Bv4, Cv40);
+ Cv50 = madd(Av0, Bv5, Cv50);
+
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
+ Cv01 = madd(Av1, Bv0, Cv01);
+ Cv11 = madd(Av1, Bv1, Cv11);
+ Cv21 = madd(Av1, Bv2, Cv21);
+ Cv31 = madd(Av1, Bv3, Cv31);
+ Cv41 = madd(Av1, Bv4, Cv41);
+ Cv51 = madd(Av1, Bv5, Cv51);
+
+ V Av2 = load<V>(A + lda * (ii + 2) + l);
+ Cv02 = madd(Av2, Bv0, Cv02);
+ Cv12 = madd(Av2, Bv1, Cv12);
+ Cv22 = madd(Av2, Bv2, Cv22);
+ Cv32 = madd(Av2, Bv3, Cv32);
+ Cv42 = madd(Av2, Bv4, Cv42);
+ Cv52 = madd(Av2, Bv5, Cv52);
+
+ V Av3 = load<V>(A + lda * (ii + 3) + l);
+ Cv03 = madd(Av3, Bv0, Cv03);
+ Cv13 = madd(Av3, Bv1, Cv13);
+ Cv23 = madd(Av3, Bv2, Cv23);
+ Cv33 = madd(Av3, Bv3, Cv33);
+ Cv43 = madd(Av3, Bv4, Cv43);
+ Cv53 = madd(Av3, Bv5, Cv53);
+ }
+
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
+ C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
+ C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
+ C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
+ C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
+ C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
+ C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
+ C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
+ C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
+ C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
+ C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
+ C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30);
+ C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31);
+ C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32);
+ C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33);
+ C[ldc * (jj + 4) + (ii + 0)] = hsum(Cv40);
+ C[ldc * (jj + 4) + (ii + 1)] = hsum(Cv41);
+ C[ldc * (jj + 4) + (ii + 2)] = hsum(Cv42);
+ C[ldc * (jj + 4) + (ii + 3)] = hsum(Cv43);
+ C[ldc * (jj + 5) + (ii + 0)] = hsum(Cv50);
+ C[ldc * (jj + 5) + (ii + 1)] = hsum(Cv51);
+ C[ldc * (jj + 5) + (ii + 2)] = hsum(Cv52);
+ C[ldc * (jj + 5) + (ii + 3)] = hsum(Cv53);
+ }
+
+ inline void gemm_bloc_4x5(int64_t ii, int64_t jj) {
+ size_t vl = vlmax<V>();
+ D Cv00 = set_zero<D>();
+ D Cv01 = set_zero<D>();
+ D Cv02 = set_zero<D>();
+ D Cv03 = set_zero<D>();
+ D Cv10 = set_zero<D>();
+ D Cv11 = set_zero<D>();
+ D Cv12 = set_zero<D>();
+ D Cv13 = set_zero<D>();
+ D Cv20 = set_zero<D>();
+ D Cv21 = set_zero<D>();
+ D Cv22 = set_zero<D>();
+ D Cv23 = set_zero<D>();
+ D Cv30 = set_zero<D>();
+ D Cv31 = set_zero<D>();
+ D Cv32 = set_zero<D>();
+ D Cv33 = set_zero<D>();
+ D Cv40 = set_zero<D>();
+ D Cv41 = set_zero<D>();
+ D Cv42 = set_zero<D>();
+ D Cv43 = set_zero<D>();
+
+ for (int64_t l = 0; l < k; l += vl) {
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
+ V Bv1 = load<V>(B + ldb * (jj + 1) + l);
+ V Bv2 = load<V>(B + ldb * (jj + 2) + l);
+ V Bv3 = load<V>(B + ldb * (jj + 3) + l);
+ V Bv4 = load<V>(B + ldb * (jj + 4) + l);
+
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
+ Cv00 = madd(Av0, Bv0, Cv00);
+ Cv10 = madd(Av0, Bv1, Cv10);
+ Cv20 = madd(Av0, Bv2, Cv20);
+ Cv30 = madd(Av0, Bv3, Cv30);
+ Cv40 = madd(Av0, Bv4, Cv40);
+
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
+ Cv01 = madd(Av1, Bv0, Cv01);
+ Cv11 = madd(Av1, Bv1, Cv11);
+ Cv21 = madd(Av1, Bv2, Cv21);
+ Cv31 = madd(Av1, Bv3, Cv31);
+ Cv41 = madd(Av1, Bv4, Cv41);
+
+ V Av2 = load<V>(A + lda * (ii + 2) + l);
+ Cv02 = madd(Av2, Bv0, Cv02);
+ Cv12 = madd(Av2, Bv1, Cv12);
+ Cv22 = madd(Av2, Bv2, Cv22);
+ Cv32 = madd(Av2, Bv3, Cv32);
+ Cv42 = madd(Av2, Bv4, Cv42);
+
+ V Av3 = load<V>(A + lda * (ii + 3) + l);
+ Cv03 = madd(Av3, Bv0, Cv03);
+ Cv13 = madd(Av3, Bv1, Cv13);
+ Cv23 = madd(Av3, Bv2, Cv23);
+ Cv33 = madd(Av3, Bv3, Cv33);
+ Cv43 = madd(Av3, Bv4, Cv43);
+ }
+
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
+ C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
+ C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
+ C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
+ C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
+ C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
+ C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
+ C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
+ C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
+ C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
+ C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
+ C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30);
+ C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31);
+ C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32);
+ C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33);
+ C[ldc * (jj + 4) + (ii + 0)] = hsum(Cv40);
+ C[ldc * (jj + 4) + (ii + 1)] = hsum(Cv41);
+ C[ldc * (jj + 4) + (ii + 2)] = hsum(Cv42);
+ C[ldc * (jj + 4) + (ii + 3)] = hsum(Cv43);
+ }
+
+ inline void gemm_bloc_4x4(int64_t ii, int64_t jj) {
+ size_t vl = vlmax<V>();
+ D Cv00 = set_zero<D>();
+ D Cv01 = set_zero<D>();
+ D Cv02 = set_zero<D>();
+ D Cv03 = set_zero<D>();
+ D Cv10 = set_zero<D>();
+ D Cv11 = set_zero<D>();
+ D Cv12 = set_zero<D>();
+ D Cv13 = set_zero<D>();
+ D Cv20 = set_zero<D>();
+ D Cv21 = set_zero<D>();
+ D Cv22 = set_zero<D>();
+ D Cv23 = set_zero<D>();
+ D Cv30 = set_zero<D>();
+ D Cv31 = set_zero<D>();
+ D Cv32 = set_zero<D>();
+ D Cv33 = set_zero<D>();
+
+ for (int64_t l = 0; l < k; l += vl) {
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
+ V Av2 = load<V>(A + lda * (ii + 2) + l);
+ V Av3 = load<V>(A + lda * (ii + 3) + l);
+
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
+ Cv00 = madd(Av0, Bv0, Cv00);
+ Cv01 = madd(Av1, Bv0, Cv01);
+ Cv02 = madd(Av2, Bv0, Cv02);
+ Cv03 = madd(Av3, Bv0, Cv03);
+
+ V Bv1 = load<V>(B + ldb * (jj + 1) + l);
+ Cv10 = madd(Av0, Bv1, Cv10);
+ Cv11 = madd(Av1, Bv1, Cv11);
+ Cv12 = madd(Av2, Bv1, Cv12);
+ Cv13 = madd(Av3, Bv1, Cv13);
+
+ V Bv2 = load<V>(B + ldb * (jj + 2) + l);
+ Cv20 = madd(Av0, Bv2, Cv20);
+ Cv21 = madd(Av1, Bv2, Cv21);
+ Cv22 = madd(Av2, Bv2, Cv22);
+ Cv23 = madd(Av3, Bv2, Cv23);
+
+ V Bv3 = load<V>(B + ldb * (jj + 3) + l);
+ Cv30 = madd(Av0, Bv3, Cv30);
+ Cv31 = madd(Av1, Bv3, Cv31);
+ Cv32 = madd(Av2, Bv3, Cv32);
+ Cv33 = madd(Av3, Bv3, Cv33);
+ }
+
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
+ C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
+ C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
+ C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
+ C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
+ C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
+ C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
+ C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
+ C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
+ C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
+ C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
+ C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30);
+ C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31);
+ C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32);
+ C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33);
+ }
+
+ inline void gemm_bloc_4x3(int64_t ii, int64_t jj) {
+ size_t vl = vlmax<V>();
+ D Cv00 = set_zero<D>();
+ D Cv01 = set_zero<D>();
+ D Cv02 = set_zero<D>();
+ D Cv03 = set_zero<D>();
+ D Cv10 = set_zero<D>();
+ D Cv11 = set_zero<D>();
+ D Cv12 = set_zero<D>();
+ D Cv13 = set_zero<D>();
+ D Cv20 = set_zero<D>();
+ D Cv21 = set_zero<D>();
+ D Cv22 = set_zero<D>();
+ D Cv23 = set_zero<D>();
+
+ for (int64_t l = 0; l < k; l += vl) {
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
+ V Av2 = load<V>(A + lda * (ii + 2) + l);
+ V Av3 = load<V>(A + lda * (ii + 3) + l);
+
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
+ Cv00 = madd(Av0, Bv0, Cv00);
+ Cv01 = madd(Av1, Bv0, Cv01);
+ Cv02 = madd(Av2, Bv0, Cv02);
+ Cv03 = madd(Av3, Bv0, Cv03);
+
+ V Bv1 = load<V>(B + ldb * (jj + 1) + l);
+ Cv10 = madd(Av0, Bv1, Cv10);
+ Cv11 = madd(Av1, Bv1, Cv11);
+ Cv12 = madd(Av2, Bv1, Cv12);
+ Cv13 = madd(Av3, Bv1, Cv13);
+
+ V Bv2 = load<V>(B + ldb * (jj + 2) + l);
+ Cv20 = madd(Av0, Bv2, Cv20);
+ Cv21 = madd(Av1, Bv2, Cv21);
+ Cv22 = madd(Av2, Bv2, Cv22);
+ Cv23 = madd(Av3, Bv2, Cv23);
+ }
+
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
+ C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
+ C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
+ C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
+ C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
+ C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
+ C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
+ C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
+ C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
+ C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
+ C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
+ }
+
+ inline void gemm_bloc_4x2(int64_t ii, int64_t jj) {
+ size_t vl = vlmax<V>();
+ D Cv00 = set_zero<D>();
+ D Cv01 = set_zero<D>();
+ D Cv02 = set_zero<D>();
+ D Cv03 = set_zero<D>();
+ D Cv10 = set_zero<D>();
+ D Cv11 = set_zero<D>();
+ D Cv12 = set_zero<D>();
+ D Cv13 = set_zero<D>();
+
+ for (int64_t l = 0; l < k; l += vl) {
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
+ V Av2 = load<V>(A + lda * (ii + 2) + l);
+ V Av3 = load<V>(A + lda * (ii + 3) + l);
+
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
+ Cv00 = madd(Av0, Bv0, Cv00);
+ Cv01 = madd(Av1, Bv0, Cv01);
+ Cv02 = madd(Av2, Bv0, Cv02);
+ Cv03 = madd(Av3, Bv0, Cv03);
+
+ V Bv1 = load<V>(B + ldb * (jj + 1) + l);
+ Cv10 = madd(Av0, Bv1, Cv10);
+ Cv11 = madd(Av1, Bv1, Cv11);
+ Cv12 = madd(Av2, Bv1, Cv12);
+ Cv13 = madd(Av3, Bv1, Cv13);
+ }
+
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
+ C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
+ C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
+ C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
+ C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
+ C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
+ C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
+ }
+
+ inline void gemm_bloc_4x1(int64_t ii, int64_t jj) {
+ size_t vl = vlmax<V>();
+ D Cv00 = set_zero<D>();
+ D Cv01 = set_zero<D>();
+ D Cv02 = set_zero<D>();
+ D Cv03 = set_zero<D>();
+
+ for (int64_t l = 0; l < k; l += vl) {
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
+ V Av2 = load<V>(A + lda * (ii + 2) + l);
+ V Av3 = load<V>(A + lda * (ii + 3) + l);
+
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
+ Cv00 = madd(Av0, Bv0, Cv00);
+ Cv01 = madd(Av1, Bv0, Cv01);
+ Cv02 = madd(Av2, Bv0, Cv02);
+ Cv03 = madd(Av3, Bv0, Cv03);
+ }
+
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
+ C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
+ C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
+ }
+
+ inline void gemm_bloc_2x2(int64_t ii, int64_t jj) {
+ size_t vl = vlmax<V>();
+ D Cv00 = set_zero<D>();
+ D Cv01 = set_zero<D>();
+ D Cv10 = set_zero<D>();
+ D Cv11 = set_zero<D>();
+
+ for (int64_t l = 0; l < k; l += vl) {
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
+
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
+ Cv00 = madd(Av0, Bv0, Cv00);
+ Cv01 = madd(Av1, Bv0, Cv01);
+
+ V Bv1 = load<V>(B + ldb * (jj + 1) + l);
+ Cv10 = madd(Av0, Bv1, Cv10);
+ Cv11 = madd(Av1, Bv1, Cv11);
+ }
+
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
+ C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
+ C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
+ }
+
+ inline void gemm_bloc_2x1(int64_t ii, int64_t jj) {
+ size_t vl = vlmax<V>();
+ D Cv00 = set_zero<D>();
+ D Cv01 = set_zero<D>();
+
+ for (int64_t l = 0; l < k; l += vl) {
+ V Av0 = load<V>(A + lda * (ii + 0) + l);
+ V Av1 = load<V>(A + lda * (ii + 1) + l);
+
+ V Bv0 = load<V>(B + ldb * (jj + 0) + l);
+ Cv00 = madd(Av0, Bv0, Cv00);
+ Cv01 = madd(Av1, Bv0, Cv01);
+ }
+
+ C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
+ C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
+ }
+
+ template <int RM, int RN>
+ inline void gemm_bloc(int64_t ii, int64_t jj) {
+ if constexpr (RM == 4) {
+ if constexpr (RN == 6) { return gemm_bloc_4x6(ii, jj); }
+ if constexpr (RN == 5) { return gemm_bloc_4x5(ii, jj); }
+ if constexpr (RN == 4) { return gemm_bloc_4x4(ii, jj); }
+ if constexpr (RN == 3) { return gemm_bloc_4x3(ii, jj); }
+ if constexpr (RN == 2) { return gemm_bloc_4x2(ii, jj); }
+ if constexpr (RN == 1) { return gemm_bloc_4x1(ii, jj); }
+ } else if constexpr (RM == 2) {
+ if constexpr (RN == 2) { return gemm_bloc_2x2(ii, jj); }
+ if constexpr (RN == 1) { return gemm_bloc_2x1(ii, jj); }
+ }
+ }
+
+ template <int RM, int RN, int BM>
+ NOINLINE void gemm(int64_t m, int64_t n, int64_t BN) {
+ GGML_ASSERT(m % (RM * BM) == 0);
+ const int64_t ytiles = m / (RM * BM);
+ const int64_t xtiles = (n + RN -1) / RN;
+ const int64_t jj_RN = (xtiles - (xtiles * RN - n));
+
+ // "round" bloc_size to "nearest" BN
+ const int64_t NB_BN = xtiles < BN ? 1 : (xtiles + BN / 2) / BN;
+ const int64_t SIZE_BN = xtiles % NB_BN == 0 ? xtiles / NB_BN : xtiles / NB_BN + 1;
+ const int64_t jj_BN = (NB_BN - (NB_BN * SIZE_BN - xtiles));
+ const int64_t nb_job = ytiles * NB_BN;
+
+ if (params->ith == 0) {
+ GGML_ASSERT( jj_BN * SIZE_BN + (NB_BN - jj_BN) * (SIZE_BN - 1) == xtiles);
+ // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
+ ggml_threadpool_chunk_set(params->threadpool, params->nth);
+ }
+
+ ggml_barrier(params->threadpool);
+
+ int64_t job = params->ith;
+ while (job < nb_job) {
+ const int64_t ii = (job % ytiles) * RM * BM;
+ const int64_t jb = job / ytiles;
+ const int64_t jr0 = BLOC_POS(jb , jj_BN, SIZE_BN);
+ const int64_t jrN = BLOC_POS(jb+1, jj_BN, SIZE_BN);
+
+ const int64_t jj0 = BLOC_POS(jr0, jj_RN, RN);
+ const int64_t jj2 = BLOC_POS(jrN, jj_RN, RN);
+ const int64_t jj1 = jj2 < jj_RN * RN ? jj2 : jj_RN * RN;
+
+ for (int64_t bi = 0; bi < BM * RM; bi += RM) {
+ int64_t jj = jj0;
+ for (; jj < jj1; jj += RN) {
+ gemm_bloc<RM, RN>(ii + bi, jj);
+ }
+ if constexpr (RN > 1) {
+ for (; jj < jj2; jj += RN - 1) {
+ gemm_bloc<RM, RN-1>(ii + bi, jj);
+ }
+ }
+ GGML_ASSERT(jj == jj2);
+ }
+
+ job = ggml_threadpool_chunk_add(params->threadpool, 1);
+ }
+
+ ggml_barrier(params->threadpool);
+ return;
+ }
+
+ const ggml_compute_params * params;
+ const TA *const A;
+ const TB *const B;
+ TC *const C;
+ const int64_t k;
+ const int64_t lda;
+ const int64_t ldb;
+ const int64_t ldc;
+};
+#endif
+
//////////////////////////////////////////////////////////////////////////////////////////
// QUANT ZERO MATRIX MULTIPLICATION
params->ith, params->nth};
tb.matmul(m, n);
return true;
+#elif defined(__riscv_zvfh)
+ #if LMUL == 1
+ tinyBLAS_RVV<vfloat32m1_t, vfloat32m1_t, float, float, float> tb{ params,
+ k, (const float *)A, lda,
+ (const float *)B, ldb,
+ (float *)C, ldc};
+ #elif LMUL == 2
+ tinyBLAS_RVV<vfloat32m2_t, vfloat32m2_t, float, float, float> tb{ params,
+ k, (const float *)A, lda,
+ (const float *)B, ldb,
+ (float *)C, ldc};
+ #else // LMUL = 4
+ tinyBLAS_RVV<vfloat32m4_t, vfloat32m4_t, float, float, float> tb{ params,
+ k, (const float *)A, lda,
+ (const float *)B, ldb,
+ (float *)C, ldc};
+ #endif
+ return tb.matmul(m, n);
#else
return false;
#endif
tb.matmul(m, n);
return true;
}
+#elif defined(__riscv_zvfbfwma)
+ #if LMUL == 1
+ tinyBLAS_RVV<vfloat32m1_t, vbfloat16mf2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
+ k, (const ggml_bf16_t *)A, lda,
+ (const ggml_bf16_t *)B, ldb,
+ (float *)C, ldc};
+ #elif LMUL == 2
+ tinyBLAS_RVV<vfloat32m2_t, vbfloat16m1_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
+ k, (const ggml_bf16_t *)A, lda,
+ (const ggml_bf16_t *)B, ldb,
+ (float *)C, ldc};
+ #else // LMUL = 4
+ tinyBLAS_RVV<vfloat32m4_t, vbfloat16m2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
+ k, (const ggml_bf16_t *)A, lda,
+ (const ggml_bf16_t *)B, ldb,
+ (float *)C, ldc};
+ #endif
+ return tb.matmul(m, n);
#endif
return false;
}
(float *)C, ldc};
return tb.matmul(m, n);
}
+#elif defined(__riscv_zvfh)
+ if (Btype == GGML_TYPE_F16) {
+ #if LMUL == 1
+ tinyBLAS_RVV<vfloat32m1_t, vfloat16mf2_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
+ k, (const ggml_fp16_t *)A, lda,
+ (const ggml_fp16_t *)B, ldb,
+ (float *)C, ldc};
+ #elif LMUL == 2
+ tinyBLAS_RVV<vfloat32m2_t, vfloat16m1_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
+ k, (const ggml_fp16_t *)A, lda,
+ (const ggml_fp16_t *)B, ldb,
+ (float *)C, ldc};
+ #else // LMUL = 4
+ tinyBLAS_RVV<vfloat32m4_t, vfloat16m2_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
+ k, (const ggml_fp16_t *)A, lda,
+ (const ggml_fp16_t *)B, ldb,
+ (float *)C, ldc};
+ #endif
+ return tb.matmul(m, n);
+ }
#endif
return false;
}