ifndef GGML_NO_LLAMAFILE
MK_CPPFLAGS += -DGGML_USE_LLAMAFILE
- OBJ_GGML += ggml/src/sgemm.o
+ OBJ_GGML += ggml/src/llamafile/sgemm.o
endif
ifdef GGML_RPC
$(CXX) $(CXXFLAGS) -c $< -o $@
ifndef GGML_NO_LLAMAFILE
-ggml/src/sgemm.o: \
- ggml/src/sgemm.cpp \
- ggml/src/sgemm.h \
+ggml/src/llamafile/sgemm.o: \
+ ggml/src/llamafile/sgemm.cpp \
+ ggml/src/llamafile/sgemm.h \
ggml/include/ggml.h
$(CXX) $(CXXFLAGS) -c $< -o $@
endif # GGML_NO_LLAMAFILE
option(GGML_BLAS "ggml: use BLAS" ${GGML_BLAS_DEFAULT})
set(GGML_BLAS_VENDOR ${GGML_BLAS_VENDOR_DEFAULT} CACHE STRING
"ggml: BLAS library vendor")
-option(GGML_LLAMAFILE "ggml: use ggml SGEMM" OFF)
+option(GGML_LLAMAFILE "ggml: use LLAMAFILE" OFF)
option(GGML_CUDA "ggml: use CUDA" OFF)
option(GGML_CUDA_FORCE_DMMV "ggml: use dmmv instead of mmvq CUDA kernels" OFF)
endif()
if (GGML_LLAMAFILE)
- message(STATUS "Using ggml SGEMM")
+ message(STATUS "Using llamafile")
add_compile_definitions(GGML_USE_LLAMAFILE)
- set(GGML_HEADERS_LLAMAFILE sgemm.h)
- set(GGML_SOURCES_LLAMAFILE sgemm.cpp)
+ set(GGML_HEADERS_LLAMAFILE llamafile/sgemm.h)
+ set(GGML_SOURCES_LLAMAFILE llamafile/sgemm.cpp)
endif()
if (GGML_CUDA)
#include "ggml.h"
#include "ggml-aarch64.h"
-
#if defined(_MSC_VER) || defined(__MINGW32__)
#include <malloc.h> // using malloc.h with MSC/MINGW
#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
#endif
#ifdef GGML_USE_LLAMAFILE
-#include "sgemm.h"
+#include <llamafile/sgemm.h>
#endif
#if defined(_MSC_VER)
--- /dev/null
+// Copyright 2024 Mozilla Foundation
+//
+// Permission is hereby granted, free of charge, to any person obtaining
+// a copy of this software and associated documentation files (the
+// "Software"), to deal in the Software without restriction, including
+// without limitation the rights to use, copy, modify, merge, publish,
+// distribute, sublicense, and/or sell copies of the Software, and to
+// permit persons to whom the Software is furnished to do so, subject to
+// the following conditions:
+//
+// The above copyright notice and this permission notice shall be
+// included in all copies or substantial portions of the Software.
+//
+// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+// BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+// ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+// SOFTWARE.
+
+//
+// _ _ ___ _ _ ___
+// | |_(_)_ _ _ _| _ ) | /_\ / __|
+// | _| | ' \ || | _ \ |__ / _ \\__ \.
+// \__|_|_||_\_, |___/____/_/ \_\___/
+// |__/
+//
+// BASIC LINEAR ALGEBRA SUBPROGRAMS
+//
+//
+// This file implements multithreaded CPU matrix multiplication for the
+// common contiguous use case C = Aᵀ * B. These kernels are designed to
+// have excellent performance[1] for matrices that fit in the CPU cache
+// without imposing any overhead such as cache filling or malloc calls.
+//
+// This implementation does not guarantee any upper bound with rounding
+// errors, which grow along with k. Our goal's to maximally exploit the
+// hardware for performance, and then use whatever resources remain for
+// improving numerical accuracy.
+//
+// [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online].
+// Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].
+
+#if defined(__GNUC__)
+#pragma GCC diagnostic ignored "-Wpedantic"
+#pragma GCC diagnostic ignored "-Wignored-attributes"
+#endif
+
+#include "sgemm.h"
+#include "ggml-impl.h"
+#include "ggml-quants.h"
+
+#ifdef _MSC_VER
+#define NOINLINE __declspec(noinline)
+#else
+#define NOINLINE __attribute__((__noinline__))
+#endif
+
+#if defined(__ARM_NEON) || defined(__AVX512F__)
+#define VECTOR_REGISTERS 32
+#else
+#define VECTOR_REGISTERS 16
+#endif
+
+#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
+
+namespace {
+
+inline float unhalf(ggml_fp16_t d) {
+ return GGML_FP16_TO_FP32(d);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+// VECTORIZED ARITHMETIC OPERATIONS
+
+#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
+inline __m128 add(__m128 x, __m128 y) { return _mm_add_ps(x, y); }
+inline __m128 sub(__m128 x, __m128 y) { return _mm_sub_ps(x, y); }
+inline __m128 mul(__m128 x, __m128 y) { return _mm_mul_ps(x, y); }
+#endif // __SSE__
+
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
+inline __m256 add(__m256 x, __m256 y) { return _mm256_add_ps(x, y); }
+inline __m256 sub(__m256 x, __m256 y) { return _mm256_sub_ps(x, y); }
+inline __m256 mul(__m256 x, __m256 y) { return _mm256_mul_ps(x, y); }
+#endif // __AVX__
+
+#if defined(__AVX512F__)
+inline __m512 add(__m512 x, __m512 y) { return _mm512_add_ps(x, y); }
+inline __m512 sub(__m512 x, __m512 y) { return _mm512_sub_ps(x, y); }
+inline __m512 mul(__m512 x, __m512 y) { return _mm512_mul_ps(x, y); }
+#endif // __AVX512F__
+
+#if defined(__ARM_NEON)
+inline float32x4_t add(float32x4_t x, float32x4_t y) { return vaddq_f32(x, y); }
+inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vsubq_f32(x, y); }
+inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vmulq_f32(x, y); }
+#endif // __ARM_NEON
+
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
+inline float16x8_t add(float16x8_t x, float16x8_t y) { return vaddq_f16(x, y); }
+inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
+inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+// VECTORIZED FUSED MULTIPLY ADD
+
+/**
+ * Computes a * b + c.
+ */
+template <typename T, typename U>
+inline U madd(T a, T b, U c) {
+ return add(mul(a, b), c);
+}
+
+#if defined(__FMA__)
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
+template <>
+inline __m256 madd(__m256 a, __m256 b, __m256 c) {
+ return _mm256_fmadd_ps(a, b, c);
+}
+#endif
+#if defined(__AVX512F__)
+template <>
+inline __m512 madd(__m512 a, __m512 b, __m512 c) {
+ return _mm512_fmadd_ps(a, b, c);
+}
+#endif
+#endif
+
+#if defined(__ARM_FEATURE_FMA)
+template <>
+inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
+ return vfmaq_f32(c, b, a);
+}
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
+template <>
+inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
+ return vfmaq_f16(c, b, a);
+}
+#endif
+#endif
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+// VECTORIZED HORIZONTAL SUM
+
+#if defined(__ARM_NEON)
+inline float hsum(float32x4_t x) {
+ return vaddvq_f32(x);
+}
+#endif // __ARM_NEON
+
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
+inline float hsum(float16x8_t x) {
+ return vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(x)),
+ vcvt_f32_f16(vget_high_f16(x))));
+}
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+
+#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
+inline float hsum(__m128 x) {
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
+ x = _mm_add_ps(x, _mm_movehl_ps(x, x));
+ x = _mm_add_ss(x, _mm_movehdup_ps(x));
+#else
+ __m128 t;
+ t = _mm_shuffle_ps(x, x, _MM_SHUFFLE(2, 3, 0, 1));
+ x = _mm_add_ps(x, t);
+ t = _mm_movehl_ps(t, x);
+ x = _mm_add_ss(x, t);
+#endif
+ return _mm_cvtss_f32(x);
+}
+#endif
+
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
+inline float hsum(__m256 x) {
+ return hsum(_mm_add_ps(_mm256_extractf128_ps(x, 1),
+ _mm256_castps256_ps128(x)));
+}
+#endif // __AVX__
+
+#if defined(__AVX512F__)
+inline float hsum(__m512 x) {
+ return _mm512_reduce_add_ps(x);
+}
+#endif // __AVX512F__
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+// VECTORIZED MEMORY LOADING
+
+template <typename T, typename U> T load(const U *);
+
+#if defined(__ARM_NEON)
+template <> inline float32x4_t load(const float *p) {
+ return vld1q_f32(p);
+}
+#if !defined(_MSC_VER)
+template <> inline float16x8_t load(const ggml_fp16_t *p) {
+ return vld1q_f16((const float16_t *)p);
+}
+template <> inline float32x4_t load(const ggml_fp16_t *p) {
+ return vcvt_f32_f16(vld1_f16((const float16_t *)p));
+}
+#endif // _MSC_VER
+#endif // __ARM_NEON
+
+#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
+template <> inline __m128 load(const float *p) {
+ return _mm_loadu_ps(p);
+}
+#endif // __SSE__
+
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
+template <> inline __m256 load(const float *p) {
+ return _mm256_loadu_ps(p);
+}
+#endif // __AVX__
+
+#if defined(__F16C__)
+template <> inline __m256 load(const ggml_fp16_t *p) {
+ return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)p));
+}
+#endif // __F16C__
+
+#if defined(__AVX512F__)
+template <> inline __m512 load(const float *p) {
+ return _mm512_loadu_ps(p);
+}
+template <> inline __m512 load(const ggml_fp16_t *p) {
+ return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)p));
+}
+#endif // __AVX512F__
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+// FLOATING POINT MATRIX MULTIPLICATION
+
+template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
+class tinyBLAS {
+ public:
+ tinyBLAS(int64_t k,
+ const TA *A, int64_t lda,
+ const TB *B, int64_t ldb,
+ TC *C, int64_t ldc,
+ int ith, int nth)
+ : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
+ }
+
+ void matmul(int64_t m, int64_t n) {
+ mnpack(0, m, 0, n);
+ }
+
+ private:
+ NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+ int64_t mc, nc, mp, np;
+ switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) {
+#if VECTOR_REGISTERS == 32
+ case 0x55:
+ mc = 5;
+ nc = 5;
+ gemm<5, 5>(m0, m, n0, n);
+ break;
+ case 0x45:
+ mc = 4;
+ nc = 5;
+ gemm<4, 5>(m0, m, n0, n);
+ break;
+ case 0x54:
+ mc = 5;
+ nc = 4;
+ gemm<5, 4>(m0, m, n0, n);
+ break;
+ case 0x44:
+ mc = 4;
+ nc = 4;
+ gemm<4, 4>(m0, m, n0, n);
+ break;
+ case 0x53:
+ mc = 5;
+ nc = 3;
+ gemm<5, 3>(m0, m, n0, n);
+ break;
+ case 0x35:
+ mc = 3;
+ nc = 5;
+ gemm<3, 5>(m0, m, n0, n);
+ break;
+ case 0x43:
+ mc = 4;
+ nc = 3;
+ gemm<4, 3>(m0, m, n0, n);
+ break;
+#else
+ case 0x55:
+ case 0x54:
+ case 0x53:
+ case 0x45:
+ case 0x44:
+ case 0x43:
+ mc = 4;
+ nc = 3;
+ gemm<4, 3>(m0, m, n0, n);
+ break;
+ case 0x35:
+#endif
+ case 0x34:
+ mc = 3;
+ nc = 4;
+ gemm<3, 4>(m0, m, n0, n);
+ break;
+ case 0x52:
+ mc = 5;
+ nc = 2;
+ gemm<5, 2>(m0, m, n0, n);
+ break;
+ case 0x33:
+ mc = 3;
+ nc = 3;
+ gemm<3, 3>(m0, m, n0, n);
+ break;
+ case 0x25:
+ mc = 2;
+ nc = 5;
+ gemm<2, 5>(m0, m, n0, n);
+ break;
+ case 0x42:
+ mc = 4;
+ nc = 2;
+ gemm<4, 2>(m0, m, n0, n);
+ break;
+ case 0x24:
+ mc = 2;
+ nc = 4;
+ gemm<2, 4>(m0, m, n0, n);
+ break;
+ case 0x32:
+ mc = 3;
+ nc = 2;
+ gemm<3, 2>(m0, m, n0, n);
+ break;
+ case 0x23:
+ mc = 2;
+ nc = 3;
+ gemm<2, 3>(m0, m, n0, n);
+ break;
+ case 0x51:
+ mc = 5;
+ nc = 1;
+ gemm<5, 1>(m0, m, n0, n);
+ break;
+ case 0x41:
+ mc = 4;
+ nc = 1;
+ gemm<4, 1>(m0, m, n0, n);
+ break;
+ case 0x22:
+ mc = 2;
+ nc = 2;
+ gemm<2, 2>(m0, m, n0, n);
+ break;
+ case 0x15:
+ mc = 1;
+ nc = 5;
+ gemm<1, 5>(m0, m, n0, n);
+ break;
+ case 0x14:
+ mc = 1;
+ nc = 4;
+ gemm<1, 4>(m0, m, n0, n);
+ break;
+ case 0x31:
+ mc = 3;
+ nc = 1;
+ gemm<3, 1>(m0, m, n0, n);
+ break;
+ case 0x13:
+ mc = 1;
+ nc = 3;
+ gemm<1, 3>(m0, m, n0, n);
+ break;
+ case 0x21:
+ mc = 2;
+ nc = 1;
+ gemm<2, 1>(m0, m, n0, n);
+ break;
+ case 0x12:
+ mc = 1;
+ nc = 2;
+ gemm<1, 2>(m0, m, n0, n);
+ break;
+ case 0x11:
+ mc = 1;
+ nc = 1;
+ gemm<1, 1>(m0, m, n0, n);
+ break;
+ default:
+ return;
+ }
+ mp = m0 + (m - m0) / mc * mc;
+ np = n0 + (n - n0) / nc * nc;
+ mnpack(mp, m, n0, np);
+ mnpack(m0, m, np, n);
+ }
+
+ template <int RM, int RN>
+ NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+ int64_t ytiles = (m - m0) / RM;
+ int64_t xtiles = (n - n0) / RN;
+ int64_t tiles = xtiles * ytiles;
+ int64_t duty = (tiles + nth - 1) / nth;
+ int64_t start = duty * ith;
+ int64_t end = start + duty;
+ if (end > tiles)
+ end = tiles;
+ for (int64_t job = start; job < end; ++job) {
+ int64_t ii = m0 + job / xtiles * RM;
+ int64_t jj = n0 + job % xtiles * RN;
+ D Cv[RN][RM] = {};
+ for (int64_t l = 0; l < k; l += KN)
+ for (int64_t j = 0; j < RN; ++j)
+ for (int64_t i = 0; i < RM; ++i)
+ Cv[j][i] = madd(load<V>(A + lda * (ii + i) + l),
+ load<V>(B + ldb * (jj + j) + l),
+ Cv[j][i]);
+ for (int64_t j = 0; j < RN; ++j)
+ for (int64_t i = 0; i < RM; ++i)
+ C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
+ }
+ }
+
+ const TA *const A;
+ const TB *const B;
+ TC *const C;
+ const int64_t k;
+ const int64_t lda;
+ const int64_t ldb;
+ const int64_t ldc;
+ const int ith;
+ const int nth;
+};
+
+//////////////////////////////////////////////////////////////////////////////////////////
+// QUANT ZERO MATRIX MULTIPLICATION
+
+#if defined(__ARM_FEATURE_DOTPROD)
+template <typename TA>
+class tinyBLAS_Q0_ARM {
+ public:
+ tinyBLAS_Q0_ARM(int64_t k,
+ const TA *A, int64_t lda,
+ const block_q8_0 *B, int64_t ldb,
+ float *C, int64_t ldc,
+ int ith, int nth)
+ : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
+ }
+
+ void matmul(int64_t m, int64_t n) {
+ mnpack(0, m, 0, n);
+ }
+
+ private:
+ NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+ int64_t mc, nc, mp, np;
+ switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) {
+ case 0x33:
+ mc = 3;
+ nc = 3;
+ gemm<3, 3>(m0, m, n0, n);
+ break;
+ case 0x32:
+ mc = 3;
+ nc = 2;
+ gemm<3, 2>(m0, m, n0, n);
+ break;
+ case 0x23:
+ mc = 2;
+ nc = 3;
+ gemm<2, 3>(m0, m, n0, n);
+ break;
+ case 0x22:
+ mc = 2;
+ nc = 2;
+ gemm<2, 2>(m0, m, n0, n);
+ break;
+ case 0x31:
+ mc = 3;
+ nc = 1;
+ gemm<3, 1>(m0, m, n0, n);
+ break;
+ case 0x13:
+ mc = 1;
+ nc = 3;
+ gemm<1, 3>(m0, m, n0, n);
+ break;
+ case 0x21:
+ mc = 2;
+ nc = 1;
+ gemm<2, 1>(m0, m, n0, n);
+ break;
+ case 0x12:
+ mc = 1;
+ nc = 2;
+ gemm<1, 2>(m0, m, n0, n);
+ break;
+ case 0x11:
+ mc = 1;
+ nc = 1;
+ gemm<1, 1>(m0, m, n0, n);
+ break;
+ default:
+ return;
+ }
+ mp = m0 + (m - m0) / mc * mc;
+ np = n0 + (n - n0) / nc * nc;
+ mnpack(mp, m, n0, np);
+ mnpack(m0, m, np, n);
+ }
+
+ template <int RM, int RN>
+ NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+ int64_t ytiles = (m - m0) / RM;
+ int64_t xtiles = (n - n0) / RN;
+ int64_t tiles = xtiles * ytiles;
+ int64_t duty = (tiles + nth - 1) / nth;
+ int64_t start = duty * ith;
+ int64_t end = start + duty;
+ if (end > tiles)
+ end = tiles;
+ for (int64_t job = start; job < end; ++job) {
+ int64_t ii = m0 + job / xtiles * RM;
+ int64_t jj = n0 + job % xtiles * RN;
+ float32x4_t Cv[RN][RM] = {};
+ for (int64_t l = 0; l < k; ++l)
+ for (int64_t j = 0; j < RN; ++j)
+ for (int64_t i = 0; i < RM; ++i)
+ Cv[j][i] = vmlaq_n_f32(Cv[j][i],
+ vcvtq_f32_s32(vdotq_s32(
+ vdotq_s32(vdupq_n_s32(0),
+ load_lo(A + lda * (ii + i) + l),
+ load_lo(B + ldb * (jj + j) + l)),
+ load_hi(A + lda * (ii + i) + l),
+ load_hi(B + ldb * (jj + j) + l))),
+ unhalf(A[lda * (ii + i) + l].d) *
+ unhalf(B[ldb * (jj + j) + l].d));
+ for (int64_t j = 0; j < RN; ++j)
+ for (int64_t i = 0; i < RM; ++i)
+ C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
+ }
+ }
+
+ inline int8x16_t load_lo(const block_q8_0 *b) {
+ return vld1q_s8(b->qs);
+ }
+
+ inline int8x16_t load_hi(const block_q8_0 *b) {
+ return vld1q_s8(b->qs + 16);
+ }
+
+ inline int8x16_t load_lo(const block_q4_0 *b) {
+ return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs),
+ vdupq_n_u8(0x0f))),
+ vdupq_n_s8(0x8));
+ }
+
+ inline int8x16_t load_hi(const block_q4_0 *b) {
+ return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)),
+ vdupq_n_s8(0x8));
+ }
+
+ const TA *const A;
+ const block_q8_0 *const B;
+ float *const C;
+ const int64_t k;
+ const int64_t lda;
+ const int64_t ldb;
+ const int64_t ldc;
+ const int ith;
+ const int nth;
+};
+#endif // __ARM_FEATURE_DOTPROD
+
+#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
+template <typename TA, typename TB, typename TC>
+class tinyBLAS_Q0_AVX {
+ public:
+ tinyBLAS_Q0_AVX(int64_t k,
+ const TA *A, int64_t lda,
+ const TB *B, int64_t ldb,
+ TC *C, int64_t ldc,
+ int ith, int nth)
+ : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
+ }
+
+ void matmul(int64_t m, int64_t n) {
+ mnpack(0, m, 0, n);
+ }
+
+ private:
+ void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+ int64_t mc, nc, mp, np;
+ switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) {
+#if VECTOR_REGISTERS == 32
+ case 0x44:
+ mc = 4;
+ nc = 4;
+ gemm<4, 4>(m0, m, n0, n);
+ break;
+ case 0x43:
+ mc = 4;
+ nc = 3;
+ gemm<4, 3>(m0, m, n0, n);
+ break;
+ case 0x34:
+ mc = 3;
+ nc = 4;
+ gemm<3, 4>(m0, m, n0, n);
+ break;
+ case 0x33:
+ mc = 3;
+ nc = 3;
+ gemm<3, 3>(m0, m, n0, n);
+ break;
+ case 0x42:
+ mc = 4;
+ nc = 2;
+ gemm<4, 2>(m0, m, n0, n);
+ break;
+ case 0x24:
+ mc = 2;
+ nc = 4;
+ gemm<2, 4>(m0, m, n0, n);
+ break;
+#else
+ case 0x44:
+ case 0x43:
+ case 0x42:
+ mc = 4;
+ nc = 2;
+ gemm<4, 2>(m0, m, n0, n);
+ break;
+ case 0x34:
+ case 0x24:
+ mc = 2;
+ nc = 4;
+ gemm<2, 4>(m0, m, n0, n);
+ break;
+ case 0x33:
+#endif
+ case 0x32:
+ mc = 3;
+ nc = 2;
+ gemm<3, 2>(m0, m, n0, n);
+ break;
+ case 0x23:
+ mc = 2;
+ nc = 3;
+ gemm<2, 3>(m0, m, n0, n);
+ break;
+ case 0x41:
+ mc = 4;
+ nc = 1;
+ gemm<4, 1>(m0, m, n0, n);
+ break;
+ case 0x22:
+ mc = 2;
+ nc = 2;
+ gemm<2, 2>(m0, m, n0, n);
+ break;
+ case 0x14:
+ mc = 1;
+ nc = 4;
+ gemm<1, 4>(m0, m, n0, n);
+ break;
+ case 0x31:
+ mc = 3;
+ nc = 1;
+ gemm<3, 1>(m0, m, n0, n);
+ break;
+ case 0x13:
+ mc = 1;
+ nc = 3;
+ gemm<1, 3>(m0, m, n0, n);
+ break;
+ case 0x21:
+ mc = 2;
+ nc = 1;
+ gemm<2, 1>(m0, m, n0, n);
+ break;
+ case 0x12:
+ mc = 1;
+ nc = 2;
+ gemm<1, 2>(m0, m, n0, n);
+ break;
+ case 0x11:
+ mc = 1;
+ nc = 1;
+ gemm<1, 1>(m0, m, n0, n);
+ break;
+ default:
+ return;
+ }
+ mp = m0 + (m - m0) / mc * mc;
+ np = n0 + (n - n0) / nc * nc;
+ mnpack(mp, m, n0, np);
+ mnpack(m0, m, np, n);
+ }
+
+ template <int RM, int RN>
+ NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+ int64_t ytiles = (m - m0) / RM;
+ int64_t xtiles = (n - n0) / RN;
+ int64_t tiles = xtiles * ytiles;
+ int64_t duty = (tiles + nth - 1) / nth;
+ int64_t start = duty * ith;
+ int64_t end = start + duty;
+ if (end > tiles)
+ end = tiles;
+ for (int64_t job = start; job < end; ++job) {
+ int64_t ii = m0 + job / xtiles * RM;
+ int64_t jj = n0 + job % xtiles * RN;
+ __m256 Cv[RN][RM] = {};
+ for (int64_t l = 0; l < k; ++l)
+ for (int64_t j = 0; j < RN; ++j)
+ for (int64_t i = 0; i < RM; ++i) {
+#if defined(__AVX2__)
+ __m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
+ load(A + lda * (ii + i) + l)),
+ _mm256_sign_epi8(load(B + ldb * (jj + j) + l),
+ load(A + lda * (ii + i) + l)));
+#else
+ __m128i ali0 = load0(A + lda * (ii + i) + l);
+ __m128i ali1 = load1(A + lda * (ii + i) + l);
+ __m128i blj0 = load0(B + ldb * (jj + j) + l);
+ __m128i blj1 = load1(B + ldb * (jj + j) + l);
+
+ __m128i sepAA0 = _mm_sign_epi8(ali0, ali0);
+ __m128i sepAA1 = _mm_sign_epi8(ali1, ali1);
+ __m128i sepBA0 = _mm_sign_epi8(blj0, ali0);
+ __m128i sepBA1 = _mm_sign_epi8(blj1, ali1);
+
+ // updot
+ const __m128i oneFill = _mm_set1_epi16(1);
+ __m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0);
+ __m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1);
+ __m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0)));
+#endif
+ Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
+ unhalf(B[ldb * (jj + j) + l].d)),
+ udTmp,
+ Cv[j][i]);
+ }
+ for (int64_t j = 0; j < RN; ++j)
+ for (int64_t i = 0; i < RM; ++i)
+ C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
+ }
+ }
+
+ inline __m256i load(const block_q8_0 *b) {
+ return _mm256_loadu_si256((const __m256i *)b->qs);
+ }
+
+ inline __m128i load0(const block_q8_0 *b) {
+ return _mm_loadu_si128((const __m128i *)b->qs);
+ }
+
+ inline __m128i load1(const block_q8_0 *b) {
+ return _mm_loadu_si128(((const __m128i *)b->qs) + 1);
+ }
+
+ inline __m256i load(const block_q4_0 *b) {
+ return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8));
+ }
+
+ inline __m128i load0(const block_q4_0 *b) {
+ const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
+ return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8));
+ }
+
+ inline __m128i load1(const block_q4_0 *b) {
+ const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
+ return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
+ }
+
+ inline __m256 updot(__m256i u, __m256i s) {
+ __m256i res;
+#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
+ res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);
+#else
+ res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));
+#endif
+ return _mm256_cvtepi32_ps(res);
+ }
+
+ static inline __m256i denibble(const uint8_t *p) {
+ __m128i x = _mm_loadu_si128((const __m128i *)p);
+ return _mm256_and_si256(_mm256_set1_epi8(15),
+ _mm256_insertf128_si256(_mm256_castsi128_si256(x),
+ _mm_srli_epi16(x, 4), 1));
+ }
+
+ const TA *const A;
+ const TB *const B;
+ TC *const C;
+ const int64_t k;
+ const int64_t lda;
+ const int64_t ldb;
+ const int64_t ldc;
+ const int ith;
+ const int nth;
+};
+#endif // __AVX__
+
+} // namespace
+
+/**
+ * Performs optimized matrix multiplication on CPU.
+ *
+ * This subroutine may compute C = Aᵀ * B with column major ordering.
+ * Despite its name, this isn't a generalized implementation. Work is
+ * only performed when a handwritten kernel is written and available.
+ * Otherwise the caller should fall back to a general matmul routine.
+ *
+ * For example, for single-threaded single-precision GEMM you can say
+ *
+ * llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc,
+ * 0, 1,
+ * GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32);
+ *
+ * @param m is rows in `A` and `C`
+ * @param n is cols in `B` and `C`
+ * @param k is cols in `A` and rows in `B`
+ * @param A is first input matrix (always transposed)
+ * @param lda is row stride of `A`
+ * @param B is second input matrix (never transposed)
+ * @param ldb is row stride of `B`
+ * @param C is input/output array of output matrices
+ * @param ldc is row stride of `C`
+ * @param ith is thread id (must be less than `nth`)
+ * @param nth is number of threads (must be greater than zero)
+ * @param Atype is GGML data type of `A`
+ * @param Btype is GGML data type of `B`
+ * @param Ctype is GGML data type of `C`
+ * @return true if this function was able to service the matmul request
+ */
+bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
+ int64_t ldc, int ith, int nth, int Atype, int Btype, int Ctype) {
+
+ assert(m >= 0);
+ assert(n >= 0);
+ assert(k >= 0);
+ assert(lda >= k);
+ assert(ldb >= k);
+ assert(ldc >= m);
+ assert(nth > 0);
+ assert(ith < nth);
+
+ if (Ctype != GGML_TYPE_F32)
+ return false;
+
+ switch (Atype) {
+
+ case GGML_TYPE_F32: {
+ if (Btype != GGML_TYPE_F32)
+ return false;
+#if defined(__AVX512F__)
+ if (k % 16)
+ return false;
+ tinyBLAS<16, __m512, __m512, float, float, float> tb{
+ k, (const float *)A, lda,
+ (const float *)B, ldb,
+ (float *)C, ldc,
+ ith, nth};
+ tb.matmul(m, n);
+ return true;
+#elif defined(__AVX__) || defined(__AVX2__)
+ if (k % 8)
+ return false;
+ tinyBLAS<8, __m256, __m256, float, float, float> tb{
+ k, (const float *)A, lda,
+ (const float *)B, ldb,
+ (float *)C, ldc,
+ ith, nth};
+ tb.matmul(m, n);
+ return true;
+#elif defined(__ARM_NEON)
+ if (n < 4)
+ return false;
+ if (k % 4)
+ return false;
+ tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{
+ k, (const float *)A, lda,
+ (const float *)B, ldb,
+ (float *)C, ldc,
+ ith, nth};
+ tb.matmul(m, n);
+ return true;
+#else
+ return false;
+#endif
+ }
+
+ case GGML_TYPE_F16: {
+#if defined(__AVX512F__)
+ if (k % 16)
+ return false;
+ if (Btype != GGML_TYPE_F32)
+ return false;
+ tinyBLAS<16, __m512, __m512, ggml_fp16_t, float, float> tb{
+ k, (const ggml_fp16_t *)A, lda,
+ (const float *)B, ldb,
+ (float *)C, ldc,
+ ith, nth};
+ tb.matmul(m, n);
+ return true;
+#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
+ if (k % 8)
+ return false;
+ if (Btype != GGML_TYPE_F32)
+ return false;
+ tinyBLAS<8, __m256, __m256, ggml_fp16_t, float, float> tb{
+ k, (const ggml_fp16_t *)A, lda,
+ (const float *)B, ldb,
+ (float *)C, ldc,
+ ith, nth};
+ tb.matmul(m, n);
+ return true;
+#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
+ if (n < 8)
+ return false;
+ if (k % 8)
+ return false;
+ if (Btype != GGML_TYPE_F16)
+ return false;
+ tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{
+ k, (const ggml_fp16_t *)A, lda,
+ (const ggml_fp16_t *)B, ldb,
+ (float *)C, ldc,
+ ith, nth};
+ tb.matmul(m, n);
+ return true;
+#elif defined(__ARM_NEON) && !defined(_MSC_VER)
+ if (k % 4)
+ return false;
+ if (Btype != GGML_TYPE_F32)
+ return false;
+ tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{
+ k, (const ggml_fp16_t *)A, lda,
+ (const float *)B, ldb,
+ (float *)C, ldc,
+ ith, nth};
+ tb.matmul(m, n);
+ return true;
+#else
+ return false;
+#endif
+ }
+
+ case GGML_TYPE_Q8_0: {
+ if (Btype != GGML_TYPE_Q8_0)
+ return false;
+#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
+ tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float> tb{
+ k, (const block_q8_0 *)A, lda,
+ (const block_q8_0 *)B, ldb,
+ (float *)C, ldc,
+ ith, nth};
+ tb.matmul(m, n);
+ return true;
+#elif defined(__ARM_FEATURE_DOTPROD)
+ tinyBLAS_Q0_ARM<block_q8_0> tb{
+ k, (const block_q8_0 *)A, lda,
+ (const block_q8_0 *)B, ldb,
+ (float *)C, ldc,
+ ith, nth};
+ tb.matmul(m, n);
+ return true;
+#else
+ return false;
+#endif
+ }
+
+ case GGML_TYPE_Q4_0: {
+ if (Btype != GGML_TYPE_Q8_0)
+ return false;
+#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
+ tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float> tb{
+ k, (const block_q4_0 *)A, lda,
+ (const block_q8_0 *)B, ldb,
+ (float *)C, ldc,
+ ith, nth};
+ tb.matmul(m, n);
+ return true;
+#elif defined(__ARM_FEATURE_DOTPROD)
+ tinyBLAS_Q0_ARM<block_q4_0> tb{
+ k, (const block_q4_0 *)A, lda,
+ (const block_q8_0 *)B, ldb,
+ (float *)C, ldc,
+ ith, nth};
+ tb.matmul(m, n);
+ return true;
+#else
+ return false;
+#endif
+ }
+
+ default:
+ return false;
+ }
+
+ (void)m;
+ (void)n;
+ (void)k;
+ (void)A;
+ (void)lda;
+ (void)B;
+ (void)ldb;
+ (void)C;
+ (void)ldc;
+ (void)ith;
+ (void)nth;
+ (void)Atype;
+ (void)Btype;
+ (void)Ctype;
+}
--- /dev/null
+#pragma once
+#include <stdint.h>
+#include <stdbool.h>
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+bool llamafile_sgemm(int64_t, int64_t, int64_t, const void *, int64_t,
+ const void *, int64_t, void *, int64_t, int, int,
+ int, int, int);
+
+#ifdef __cplusplus
+}
+#endif
+++ /dev/null
-// Copyright 2024 Mozilla Foundation
-//
-// Permission is hereby granted, free of charge, to any person obtaining
-// a copy of this software and associated documentation files (the
-// "Software"), to deal in the Software without restriction, including
-// without limitation the rights to use, copy, modify, merge, publish,
-// distribute, sublicense, and/or sell copies of the Software, and to
-// permit persons to whom the Software is furnished to do so, subject to
-// the following conditions:
-//
-// The above copyright notice and this permission notice shall be
-// included in all copies or substantial portions of the Software.
-//
-// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
-// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
-// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
-// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
-// BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
-// ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
-// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-// SOFTWARE.
-
-//
-// _ _ ___ _ _ ___
-// | |_(_)_ _ _ _| _ ) | /_\ / __|
-// | _| | ' \ || | _ \ |__ / _ \\__ \.
-// \__|_|_||_\_, |___/____/_/ \_\___/
-// |__/
-//
-// BASIC LINEAR ALGEBRA SUBPROGRAMS
-//
-//
-// This file implements multithreaded CPU matrix multiplication for the
-// common contiguous use case C = Aᵀ * B. These kernels are designed to
-// have excellent performance[1] for matrices that fit in the CPU cache
-// without imposing any overhead such as cache filling or malloc calls.
-//
-// This implementation does not guarantee any upper bound with rounding
-// errors, which grow along with k. Our goal's to maximally exploit the
-// hardware for performance, and then use whatever resources remain for
-// improving numerical accuracy.
-//
-// [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online].
-// Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].
-
-#if defined(__GNUC__)
-#pragma GCC diagnostic ignored "-Wpedantic"
-#pragma GCC diagnostic ignored "-Wignored-attributes"
-#endif
-
-#include "sgemm.h"
-#include "ggml-impl.h"
-#include "ggml-quants.h"
-
-#ifdef _MSC_VER
-#define NOINLINE __declspec(noinline)
-#else
-#define NOINLINE __attribute__((__noinline__))
-#endif
-
-#if defined(__ARM_NEON) || defined(__AVX512F__)
-#define VECTOR_REGISTERS 32
-#else
-#define VECTOR_REGISTERS 16
-#endif
-
-#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
-
-namespace {
-
-inline float unhalf(ggml_fp16_t d) {
- return GGML_FP16_TO_FP32(d);
-}
-
-////////////////////////////////////////////////////////////////////////////////////////////////////
-// VECTORIZED ARITHMETIC OPERATIONS
-
-#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
-inline __m128 add(__m128 x, __m128 y) { return _mm_add_ps(x, y); }
-inline __m128 sub(__m128 x, __m128 y) { return _mm_sub_ps(x, y); }
-inline __m128 mul(__m128 x, __m128 y) { return _mm_mul_ps(x, y); }
-#endif // __SSE__
-
-#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
-inline __m256 add(__m256 x, __m256 y) { return _mm256_add_ps(x, y); }
-inline __m256 sub(__m256 x, __m256 y) { return _mm256_sub_ps(x, y); }
-inline __m256 mul(__m256 x, __m256 y) { return _mm256_mul_ps(x, y); }
-#endif // __AVX__
-
-#if defined(__AVX512F__)
-inline __m512 add(__m512 x, __m512 y) { return _mm512_add_ps(x, y); }
-inline __m512 sub(__m512 x, __m512 y) { return _mm512_sub_ps(x, y); }
-inline __m512 mul(__m512 x, __m512 y) { return _mm512_mul_ps(x, y); }
-#endif // __AVX512F__
-
-#if defined(__ARM_NEON)
-inline float32x4_t add(float32x4_t x, float32x4_t y) { return vaddq_f32(x, y); }
-inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vsubq_f32(x, y); }
-inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vmulq_f32(x, y); }
-#endif // __ARM_NEON
-
-#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
-inline float16x8_t add(float16x8_t x, float16x8_t y) { return vaddq_f16(x, y); }
-inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
-inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
-#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-
-////////////////////////////////////////////////////////////////////////////////////////////////////
-// VECTORIZED FUSED MULTIPLY ADD
-
-/**
- * Computes a * b + c.
- */
-template <typename T, typename U>
-inline U madd(T a, T b, U c) {
- return add(mul(a, b), c);
-}
-
-#if defined(__FMA__)
-#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
-template <>
-inline __m256 madd(__m256 a, __m256 b, __m256 c) {
- return _mm256_fmadd_ps(a, b, c);
-}
-#endif
-#if defined(__AVX512F__)
-template <>
-inline __m512 madd(__m512 a, __m512 b, __m512 c) {
- return _mm512_fmadd_ps(a, b, c);
-}
-#endif
-#endif
-
-#if defined(__ARM_FEATURE_FMA)
-template <>
-inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
- return vfmaq_f32(c, b, a);
-}
-#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
-template <>
-inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
- return vfmaq_f16(c, b, a);
-}
-#endif
-#endif
-
-////////////////////////////////////////////////////////////////////////////////////////////////////
-// VECTORIZED HORIZONTAL SUM
-
-#if defined(__ARM_NEON)
-inline float hsum(float32x4_t x) {
- return vaddvq_f32(x);
-}
-#endif // __ARM_NEON
-
-#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
-inline float hsum(float16x8_t x) {
- return vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(x)),
- vcvt_f32_f16(vget_high_f16(x))));
-}
-#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-
-#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
-inline float hsum(__m128 x) {
-#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
- x = _mm_add_ps(x, _mm_movehl_ps(x, x));
- x = _mm_add_ss(x, _mm_movehdup_ps(x));
-#else
- __m128 t;
- t = _mm_shuffle_ps(x, x, _MM_SHUFFLE(2, 3, 0, 1));
- x = _mm_add_ps(x, t);
- t = _mm_movehl_ps(t, x);
- x = _mm_add_ss(x, t);
-#endif
- return _mm_cvtss_f32(x);
-}
-#endif
-
-#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
-inline float hsum(__m256 x) {
- return hsum(_mm_add_ps(_mm256_extractf128_ps(x, 1),
- _mm256_castps256_ps128(x)));
-}
-#endif // __AVX__
-
-#if defined(__AVX512F__)
-inline float hsum(__m512 x) {
- return _mm512_reduce_add_ps(x);
-}
-#endif // __AVX512F__
-
-////////////////////////////////////////////////////////////////////////////////////////////////////
-// VECTORIZED MEMORY LOADING
-
-template <typename T, typename U> T load(const U *);
-
-#if defined(__ARM_NEON)
-template <> inline float32x4_t load(const float *p) {
- return vld1q_f32(p);
-}
-#if !defined(_MSC_VER)
-template <> inline float16x8_t load(const ggml_fp16_t *p) {
- return vld1q_f16((const float16_t *)p);
-}
-template <> inline float32x4_t load(const ggml_fp16_t *p) {
- return vcvt_f32_f16(vld1_f16((const float16_t *)p));
-}
-#endif // _MSC_VER
-#endif // __ARM_NEON
-
-#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
-template <> inline __m128 load(const float *p) {
- return _mm_loadu_ps(p);
-}
-#endif // __SSE__
-
-#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
-template <> inline __m256 load(const float *p) {
- return _mm256_loadu_ps(p);
-}
-#endif // __AVX__
-
-#if defined(__F16C__)
-template <> inline __m256 load(const ggml_fp16_t *p) {
- return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)p));
-}
-#endif // __F16C__
-
-#if defined(__AVX512F__)
-template <> inline __m512 load(const float *p) {
- return _mm512_loadu_ps(p);
-}
-template <> inline __m512 load(const ggml_fp16_t *p) {
- return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)p));
-}
-#endif // __AVX512F__
-
-////////////////////////////////////////////////////////////////////////////////////////////////////
-// FLOATING POINT MATRIX MULTIPLICATION
-
-template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
-class tinyBLAS {
- public:
- tinyBLAS(int64_t k,
- const TA *A, int64_t lda,
- const TB *B, int64_t ldb,
- TC *C, int64_t ldc,
- int ith, int nth)
- : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
- }
-
- void matmul(int64_t m, int64_t n) {
- mnpack(0, m, 0, n);
- }
-
- private:
- NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
- int64_t mc, nc, mp, np;
- switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) {
-#if VECTOR_REGISTERS == 32
- case 0x55:
- mc = 5;
- nc = 5;
- gemm<5, 5>(m0, m, n0, n);
- break;
- case 0x45:
- mc = 4;
- nc = 5;
- gemm<4, 5>(m0, m, n0, n);
- break;
- case 0x54:
- mc = 5;
- nc = 4;
- gemm<5, 4>(m0, m, n0, n);
- break;
- case 0x44:
- mc = 4;
- nc = 4;
- gemm<4, 4>(m0, m, n0, n);
- break;
- case 0x53:
- mc = 5;
- nc = 3;
- gemm<5, 3>(m0, m, n0, n);
- break;
- case 0x35:
- mc = 3;
- nc = 5;
- gemm<3, 5>(m0, m, n0, n);
- break;
- case 0x43:
- mc = 4;
- nc = 3;
- gemm<4, 3>(m0, m, n0, n);
- break;
-#else
- case 0x55:
- case 0x54:
- case 0x53:
- case 0x45:
- case 0x44:
- case 0x43:
- mc = 4;
- nc = 3;
- gemm<4, 3>(m0, m, n0, n);
- break;
- case 0x35:
-#endif
- case 0x34:
- mc = 3;
- nc = 4;
- gemm<3, 4>(m0, m, n0, n);
- break;
- case 0x52:
- mc = 5;
- nc = 2;
- gemm<5, 2>(m0, m, n0, n);
- break;
- case 0x33:
- mc = 3;
- nc = 3;
- gemm<3, 3>(m0, m, n0, n);
- break;
- case 0x25:
- mc = 2;
- nc = 5;
- gemm<2, 5>(m0, m, n0, n);
- break;
- case 0x42:
- mc = 4;
- nc = 2;
- gemm<4, 2>(m0, m, n0, n);
- break;
- case 0x24:
- mc = 2;
- nc = 4;
- gemm<2, 4>(m0, m, n0, n);
- break;
- case 0x32:
- mc = 3;
- nc = 2;
- gemm<3, 2>(m0, m, n0, n);
- break;
- case 0x23:
- mc = 2;
- nc = 3;
- gemm<2, 3>(m0, m, n0, n);
- break;
- case 0x51:
- mc = 5;
- nc = 1;
- gemm<5, 1>(m0, m, n0, n);
- break;
- case 0x41:
- mc = 4;
- nc = 1;
- gemm<4, 1>(m0, m, n0, n);
- break;
- case 0x22:
- mc = 2;
- nc = 2;
- gemm<2, 2>(m0, m, n0, n);
- break;
- case 0x15:
- mc = 1;
- nc = 5;
- gemm<1, 5>(m0, m, n0, n);
- break;
- case 0x14:
- mc = 1;
- nc = 4;
- gemm<1, 4>(m0, m, n0, n);
- break;
- case 0x31:
- mc = 3;
- nc = 1;
- gemm<3, 1>(m0, m, n0, n);
- break;
- case 0x13:
- mc = 1;
- nc = 3;
- gemm<1, 3>(m0, m, n0, n);
- break;
- case 0x21:
- mc = 2;
- nc = 1;
- gemm<2, 1>(m0, m, n0, n);
- break;
- case 0x12:
- mc = 1;
- nc = 2;
- gemm<1, 2>(m0, m, n0, n);
- break;
- case 0x11:
- mc = 1;
- nc = 1;
- gemm<1, 1>(m0, m, n0, n);
- break;
- default:
- return;
- }
- mp = m0 + (m - m0) / mc * mc;
- np = n0 + (n - n0) / nc * nc;
- mnpack(mp, m, n0, np);
- mnpack(m0, m, np, n);
- }
-
- template <int RM, int RN>
- NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
- int64_t ytiles = (m - m0) / RM;
- int64_t xtiles = (n - n0) / RN;
- int64_t tiles = xtiles * ytiles;
- int64_t duty = (tiles + nth - 1) / nth;
- int64_t start = duty * ith;
- int64_t end = start + duty;
- if (end > tiles)
- end = tiles;
- for (int64_t job = start; job < end; ++job) {
- int64_t ii = m0 + job / xtiles * RM;
- int64_t jj = n0 + job % xtiles * RN;
- D Cv[RN][RM] = {};
- for (int64_t l = 0; l < k; l += KN)
- for (int64_t j = 0; j < RN; ++j)
- for (int64_t i = 0; i < RM; ++i)
- Cv[j][i] = madd(load<V>(A + lda * (ii + i) + l),
- load<V>(B + ldb * (jj + j) + l),
- Cv[j][i]);
- for (int64_t j = 0; j < RN; ++j)
- for (int64_t i = 0; i < RM; ++i)
- C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
- }
- }
-
- const TA *const A;
- const TB *const B;
- TC *const C;
- const int64_t k;
- const int64_t lda;
- const int64_t ldb;
- const int64_t ldc;
- const int ith;
- const int nth;
-};
-
-//////////////////////////////////////////////////////////////////////////////////////////
-// QUANT ZERO MATRIX MULTIPLICATION
-
-#if defined(__ARM_FEATURE_DOTPROD)
-template <typename TA>
-class tinyBLAS_Q0_ARM {
- public:
- tinyBLAS_Q0_ARM(int64_t k,
- const TA *A, int64_t lda,
- const block_q8_0 *B, int64_t ldb,
- float *C, int64_t ldc,
- int ith, int nth)
- : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
- }
-
- void matmul(int64_t m, int64_t n) {
- mnpack(0, m, 0, n);
- }
-
- private:
- NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
- int64_t mc, nc, mp, np;
- switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) {
- case 0x33:
- mc = 3;
- nc = 3;
- gemm<3, 3>(m0, m, n0, n);
- break;
- case 0x32:
- mc = 3;
- nc = 2;
- gemm<3, 2>(m0, m, n0, n);
- break;
- case 0x23:
- mc = 2;
- nc = 3;
- gemm<2, 3>(m0, m, n0, n);
- break;
- case 0x22:
- mc = 2;
- nc = 2;
- gemm<2, 2>(m0, m, n0, n);
- break;
- case 0x31:
- mc = 3;
- nc = 1;
- gemm<3, 1>(m0, m, n0, n);
- break;
- case 0x13:
- mc = 1;
- nc = 3;
- gemm<1, 3>(m0, m, n0, n);
- break;
- case 0x21:
- mc = 2;
- nc = 1;
- gemm<2, 1>(m0, m, n0, n);
- break;
- case 0x12:
- mc = 1;
- nc = 2;
- gemm<1, 2>(m0, m, n0, n);
- break;
- case 0x11:
- mc = 1;
- nc = 1;
- gemm<1, 1>(m0, m, n0, n);
- break;
- default:
- return;
- }
- mp = m0 + (m - m0) / mc * mc;
- np = n0 + (n - n0) / nc * nc;
- mnpack(mp, m, n0, np);
- mnpack(m0, m, np, n);
- }
-
- template <int RM, int RN>
- NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
- int64_t ytiles = (m - m0) / RM;
- int64_t xtiles = (n - n0) / RN;
- int64_t tiles = xtiles * ytiles;
- int64_t duty = (tiles + nth - 1) / nth;
- int64_t start = duty * ith;
- int64_t end = start + duty;
- if (end > tiles)
- end = tiles;
- for (int64_t job = start; job < end; ++job) {
- int64_t ii = m0 + job / xtiles * RM;
- int64_t jj = n0 + job % xtiles * RN;
- float32x4_t Cv[RN][RM] = {};
- for (int64_t l = 0; l < k; ++l)
- for (int64_t j = 0; j < RN; ++j)
- for (int64_t i = 0; i < RM; ++i)
- Cv[j][i] = vmlaq_n_f32(Cv[j][i],
- vcvtq_f32_s32(vdotq_s32(
- vdotq_s32(vdupq_n_s32(0),
- load_lo(A + lda * (ii + i) + l),
- load_lo(B + ldb * (jj + j) + l)),
- load_hi(A + lda * (ii + i) + l),
- load_hi(B + ldb * (jj + j) + l))),
- unhalf(A[lda * (ii + i) + l].d) *
- unhalf(B[ldb * (jj + j) + l].d));
- for (int64_t j = 0; j < RN; ++j)
- for (int64_t i = 0; i < RM; ++i)
- C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
- }
- }
-
- inline int8x16_t load_lo(const block_q8_0 *b) {
- return vld1q_s8(b->qs);
- }
-
- inline int8x16_t load_hi(const block_q8_0 *b) {
- return vld1q_s8(b->qs + 16);
- }
-
- inline int8x16_t load_lo(const block_q4_0 *b) {
- return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs),
- vdupq_n_u8(0x0f))),
- vdupq_n_s8(0x8));
- }
-
- inline int8x16_t load_hi(const block_q4_0 *b) {
- return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)),
- vdupq_n_s8(0x8));
- }
-
- const TA *const A;
- const block_q8_0 *const B;
- float *const C;
- const int64_t k;
- const int64_t lda;
- const int64_t ldb;
- const int64_t ldc;
- const int ith;
- const int nth;
-};
-#endif // __ARM_FEATURE_DOTPROD
-
-#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
-template <typename TA, typename TB, typename TC>
-class tinyBLAS_Q0_AVX {
- public:
- tinyBLAS_Q0_AVX(int64_t k,
- const TA *A, int64_t lda,
- const TB *B, int64_t ldb,
- TC *C, int64_t ldc,
- int ith, int nth)
- : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
- }
-
- void matmul(int64_t m, int64_t n) {
- mnpack(0, m, 0, n);
- }
-
- private:
- void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
- int64_t mc, nc, mp, np;
- switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) {
-#if VECTOR_REGISTERS == 32
- case 0x44:
- mc = 4;
- nc = 4;
- gemm<4, 4>(m0, m, n0, n);
- break;
- case 0x43:
- mc = 4;
- nc = 3;
- gemm<4, 3>(m0, m, n0, n);
- break;
- case 0x34:
- mc = 3;
- nc = 4;
- gemm<3, 4>(m0, m, n0, n);
- break;
- case 0x33:
- mc = 3;
- nc = 3;
- gemm<3, 3>(m0, m, n0, n);
- break;
- case 0x42:
- mc = 4;
- nc = 2;
- gemm<4, 2>(m0, m, n0, n);
- break;
- case 0x24:
- mc = 2;
- nc = 4;
- gemm<2, 4>(m0, m, n0, n);
- break;
-#else
- case 0x44:
- case 0x43:
- case 0x42:
- mc = 4;
- nc = 2;
- gemm<4, 2>(m0, m, n0, n);
- break;
- case 0x34:
- case 0x24:
- mc = 2;
- nc = 4;
- gemm<2, 4>(m0, m, n0, n);
- break;
- case 0x33:
-#endif
- case 0x32:
- mc = 3;
- nc = 2;
- gemm<3, 2>(m0, m, n0, n);
- break;
- case 0x23:
- mc = 2;
- nc = 3;
- gemm<2, 3>(m0, m, n0, n);
- break;
- case 0x41:
- mc = 4;
- nc = 1;
- gemm<4, 1>(m0, m, n0, n);
- break;
- case 0x22:
- mc = 2;
- nc = 2;
- gemm<2, 2>(m0, m, n0, n);
- break;
- case 0x14:
- mc = 1;
- nc = 4;
- gemm<1, 4>(m0, m, n0, n);
- break;
- case 0x31:
- mc = 3;
- nc = 1;
- gemm<3, 1>(m0, m, n0, n);
- break;
- case 0x13:
- mc = 1;
- nc = 3;
- gemm<1, 3>(m0, m, n0, n);
- break;
- case 0x21:
- mc = 2;
- nc = 1;
- gemm<2, 1>(m0, m, n0, n);
- break;
- case 0x12:
- mc = 1;
- nc = 2;
- gemm<1, 2>(m0, m, n0, n);
- break;
- case 0x11:
- mc = 1;
- nc = 1;
- gemm<1, 1>(m0, m, n0, n);
- break;
- default:
- return;
- }
- mp = m0 + (m - m0) / mc * mc;
- np = n0 + (n - n0) / nc * nc;
- mnpack(mp, m, n0, np);
- mnpack(m0, m, np, n);
- }
-
- template <int RM, int RN>
- NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
- int64_t ytiles = (m - m0) / RM;
- int64_t xtiles = (n - n0) / RN;
- int64_t tiles = xtiles * ytiles;
- int64_t duty = (tiles + nth - 1) / nth;
- int64_t start = duty * ith;
- int64_t end = start + duty;
- if (end > tiles)
- end = tiles;
- for (int64_t job = start; job < end; ++job) {
- int64_t ii = m0 + job / xtiles * RM;
- int64_t jj = n0 + job % xtiles * RN;
- __m256 Cv[RN][RM] = {};
- for (int64_t l = 0; l < k; ++l)
- for (int64_t j = 0; j < RN; ++j)
- for (int64_t i = 0; i < RM; ++i) {
-#if defined(__AVX2__)
- __m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
- load(A + lda * (ii + i) + l)),
- _mm256_sign_epi8(load(B + ldb * (jj + j) + l),
- load(A + lda * (ii + i) + l)));
-#else
- __m128i ali0 = load0(A + lda * (ii + i) + l);
- __m128i ali1 = load1(A + lda * (ii + i) + l);
- __m128i blj0 = load0(B + ldb * (jj + j) + l);
- __m128i blj1 = load1(B + ldb * (jj + j) + l);
-
- __m128i sepAA0 = _mm_sign_epi8(ali0, ali0);
- __m128i sepAA1 = _mm_sign_epi8(ali1, ali1);
- __m128i sepBA0 = _mm_sign_epi8(blj0, ali0);
- __m128i sepBA1 = _mm_sign_epi8(blj1, ali1);
-
- // updot
- const __m128i oneFill = _mm_set1_epi16(1);
- __m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0);
- __m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1);
- __m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0)));
-#endif
- Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
- unhalf(B[ldb * (jj + j) + l].d)),
- udTmp,
- Cv[j][i]);
- }
- for (int64_t j = 0; j < RN; ++j)
- for (int64_t i = 0; i < RM; ++i)
- C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
- }
- }
-
- inline __m256i load(const block_q8_0 *b) {
- return _mm256_loadu_si256((const __m256i *)b->qs);
- }
-
- inline __m128i load0(const block_q8_0 *b) {
- return _mm_loadu_si128((const __m128i *)b->qs);
- }
-
- inline __m128i load1(const block_q8_0 *b) {
- return _mm_loadu_si128(((const __m128i *)b->qs) + 1);
- }
-
- inline __m256i load(const block_q4_0 *b) {
- return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8));
- }
-
- inline __m128i load0(const block_q4_0 *b) {
- const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
- return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8));
- }
-
- inline __m128i load1(const block_q4_0 *b) {
- const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
- return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
- }
-
- inline __m256 updot(__m256i u, __m256i s) {
- __m256i res;
-#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
- res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);
-#else
- res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));
-#endif
- return _mm256_cvtepi32_ps(res);
- }
-
- static inline __m256i denibble(const uint8_t *p) {
- __m128i x = _mm_loadu_si128((const __m128i *)p);
- return _mm256_and_si256(_mm256_set1_epi8(15),
- _mm256_insertf128_si256(_mm256_castsi128_si256(x),
- _mm_srli_epi16(x, 4), 1));
- }
-
- const TA *const A;
- const TB *const B;
- TC *const C;
- const int64_t k;
- const int64_t lda;
- const int64_t ldb;
- const int64_t ldc;
- const int ith;
- const int nth;
-};
-#endif // __AVX__
-
-} // namespace
-
-/**
- * Performs optimized matrix multiplication on CPU.
- *
- * This subroutine may compute C = Aᵀ * B with column major ordering.
- * Despite its name, this isn't a generalized implementation. Work is
- * only performed when a handwritten kernel is written and available.
- * Otherwise the caller should fall back to a general matmul routine.
- *
- * For example, for single-threaded single-precision GEMM you can say
- *
- * llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc,
- * 0, 1,
- * GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32);
- *
- * @param m is rows in `A` and `C`
- * @param n is cols in `B` and `C`
- * @param k is cols in `A` and rows in `B`
- * @param A is first input matrix (always transposed)
- * @param lda is row stride of `A`
- * @param B is second input matrix (never transposed)
- * @param ldb is row stride of `B`
- * @param C is input/output array of output matrices
- * @param ldc is row stride of `C`
- * @param ith is thread id (must be less than `nth`)
- * @param nth is number of threads (must be greater than zero)
- * @param Atype is GGML data type of `A`
- * @param Btype is GGML data type of `B`
- * @param Ctype is GGML data type of `C`
- * @return true if this function was able to service the matmul request
- */
-bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
- int64_t ldc, int ith, int nth, int Atype, int Btype, int Ctype) {
-
- assert(m >= 0);
- assert(n >= 0);
- assert(k >= 0);
- assert(lda >= k);
- assert(ldb >= k);
- assert(ldc >= m);
- assert(nth > 0);
- assert(ith < nth);
-
- if (Ctype != GGML_TYPE_F32)
- return false;
-
- switch (Atype) {
-
- case GGML_TYPE_F32: {
- if (Btype != GGML_TYPE_F32)
- return false;
-#if defined(__AVX512F__)
- if (k % 16)
- return false;
- tinyBLAS<16, __m512, __m512, float, float, float> tb{
- k, (const float *)A, lda,
- (const float *)B, ldb,
- (float *)C, ldc,
- ith, nth};
- tb.matmul(m, n);
- return true;
-#elif defined(__AVX__) || defined(__AVX2__)
- if (k % 8)
- return false;
- tinyBLAS<8, __m256, __m256, float, float, float> tb{
- k, (const float *)A, lda,
- (const float *)B, ldb,
- (float *)C, ldc,
- ith, nth};
- tb.matmul(m, n);
- return true;
-#elif defined(__ARM_NEON)
- if (n < 4)
- return false;
- if (k % 4)
- return false;
- tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{
- k, (const float *)A, lda,
- (const float *)B, ldb,
- (float *)C, ldc,
- ith, nth};
- tb.matmul(m, n);
- return true;
-#else
- return false;
-#endif
- }
-
- case GGML_TYPE_F16: {
-#if defined(__AVX512F__)
- if (k % 16)
- return false;
- if (Btype != GGML_TYPE_F32)
- return false;
- tinyBLAS<16, __m512, __m512, ggml_fp16_t, float, float> tb{
- k, (const ggml_fp16_t *)A, lda,
- (const float *)B, ldb,
- (float *)C, ldc,
- ith, nth};
- tb.matmul(m, n);
- return true;
-#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
- if (k % 8)
- return false;
- if (Btype != GGML_TYPE_F32)
- return false;
- tinyBLAS<8, __m256, __m256, ggml_fp16_t, float, float> tb{
- k, (const ggml_fp16_t *)A, lda,
- (const float *)B, ldb,
- (float *)C, ldc,
- ith, nth};
- tb.matmul(m, n);
- return true;
-#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
- if (n < 8)
- return false;
- if (k % 8)
- return false;
- if (Btype != GGML_TYPE_F16)
- return false;
- tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{
- k, (const ggml_fp16_t *)A, lda,
- (const ggml_fp16_t *)B, ldb,
- (float *)C, ldc,
- ith, nth};
- tb.matmul(m, n);
- return true;
-#elif defined(__ARM_NEON) && !defined(_MSC_VER)
- if (k % 4)
- return false;
- if (Btype != GGML_TYPE_F32)
- return false;
- tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{
- k, (const ggml_fp16_t *)A, lda,
- (const float *)B, ldb,
- (float *)C, ldc,
- ith, nth};
- tb.matmul(m, n);
- return true;
-#else
- return false;
-#endif
- }
-
- case GGML_TYPE_Q8_0: {
- if (Btype != GGML_TYPE_Q8_0)
- return false;
-#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
- tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float> tb{
- k, (const block_q8_0 *)A, lda,
- (const block_q8_0 *)B, ldb,
- (float *)C, ldc,
- ith, nth};
- tb.matmul(m, n);
- return true;
-#elif defined(__ARM_FEATURE_DOTPROD)
- tinyBLAS_Q0_ARM<block_q8_0> tb{
- k, (const block_q8_0 *)A, lda,
- (const block_q8_0 *)B, ldb,
- (float *)C, ldc,
- ith, nth};
- tb.matmul(m, n);
- return true;
-#else
- return false;
-#endif
- }
-
- case GGML_TYPE_Q4_0: {
- if (Btype != GGML_TYPE_Q8_0)
- return false;
-#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
- tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float> tb{
- k, (const block_q4_0 *)A, lda,
- (const block_q8_0 *)B, ldb,
- (float *)C, ldc,
- ith, nth};
- tb.matmul(m, n);
- return true;
-#elif defined(__ARM_FEATURE_DOTPROD)
- tinyBLAS_Q0_ARM<block_q4_0> tb{
- k, (const block_q4_0 *)A, lda,
- (const block_q8_0 *)B, ldb,
- (float *)C, ldc,
- ith, nth};
- tb.matmul(m, n);
- return true;
-#else
- return false;
-#endif
- }
-
- default:
- return false;
- }
-
- (void)m;
- (void)n;
- (void)k;
- (void)A;
- (void)lda;
- (void)B;
- (void)ldb;
- (void)C;
- (void)ldc;
- (void)ith;
- (void)nth;
- (void)Atype;
- (void)Btype;
- (void)Ctype;
-}
+++ /dev/null
-#pragma once
-#include <stdint.h>
-#include <stdbool.h>
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-bool llamafile_sgemm(int64_t, int64_t, int64_t, const void *, int64_t,
- const void *, int64_t, void *, int64_t, int, int,
- int, int, int);
-
-#ifdef __cplusplus
-}
-#endif