]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : move sgemm sources to llamafile subfolder (#8394)
authorGeorgi Gerganov <redacted>
Wed, 10 Jul 2024 12:23:29 +0000 (15:23 +0300)
committerGitHub <redacted>
Wed, 10 Jul 2024 12:23:29 +0000 (15:23 +0300)
ggml-ci

Makefile
ggml/CMakeLists.txt
ggml/src/CMakeLists.txt
ggml/src/ggml.c
ggml/src/llamafile/sgemm.cpp [new file with mode: 0644]
ggml/src/llamafile/sgemm.h [new file with mode: 0644]
ggml/src/sgemm.cpp [deleted file]
ggml/src/sgemm.h [deleted file]

index b70ebaed546386b0c456526ada394cbb2d9c35ad..68197fef80019910ae9b54aa1a54f0e6b5c74a59 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -554,7 +554,7 @@ endif # GGML_BLIS
 
 ifndef GGML_NO_LLAMAFILE
        MK_CPPFLAGS += -DGGML_USE_LLAMAFILE
-       OBJ_GGML    += ggml/src/sgemm.o
+       OBJ_GGML    += ggml/src/llamafile/sgemm.o
 endif
 
 ifdef GGML_RPC
@@ -983,9 +983,9 @@ ggml/src/ggml-blas.o: \
        $(CXX) $(CXXFLAGS) -c $< -o $@
 
 ifndef GGML_NO_LLAMAFILE
-ggml/src/sgemm.o: \
-       ggml/src/sgemm.cpp \
-       ggml/src/sgemm.h \
+ggml/src/llamafile/sgemm.o: \
+       ggml/src/llamafile/sgemm.cpp \
+       ggml/src/llamafile/sgemm.h \
        ggml/include/ggml.h
        $(CXX) $(CXXFLAGS) -c $< -o $@
 endif # GGML_NO_LLAMAFILE
index 0d0d52d57597100250e2caceb106865ed62522a2..649ac3dcc4f63aae1c907231200810433e9f2c77 100644 (file)
@@ -104,7 +104,7 @@ option(GGML_ACCELERATE                      "ggml: enable Accelerate framework"
 option(GGML_BLAS                            "ggml: use BLAS"                                  ${GGML_BLAS_DEFAULT})
 set(GGML_BLAS_VENDOR ${GGML_BLAS_VENDOR_DEFAULT} CACHE STRING
                                             "ggml: BLAS library vendor")
-option(GGML_LLAMAFILE                       "ggml: use ggml SGEMM"                            OFF)
+option(GGML_LLAMAFILE                       "ggml: use LLAMAFILE"                             OFF)
 
 option(GGML_CUDA                            "ggml: use CUDA"                                  OFF)
 option(GGML_CUDA_FORCE_DMMV                 "ggml: use dmmv instead of mmvq CUDA kernels"     OFF)
index aae5b8e9fe35c4c7f621c2722d8bbbe284f74dda..c5ee7e4255ee5ff91b9230d9c43268d1a9fb5a66 100644 (file)
@@ -238,12 +238,12 @@ if (GGML_BLAS)
 endif()
 
 if (GGML_LLAMAFILE)
-    message(STATUS "Using ggml SGEMM")
+    message(STATUS "Using llamafile")
 
     add_compile_definitions(GGML_USE_LLAMAFILE)
 
-    set(GGML_HEADERS_LLAMAFILE sgemm.h)
-    set(GGML_SOURCES_LLAMAFILE sgemm.cpp)
+    set(GGML_HEADERS_LLAMAFILE llamafile/sgemm.h)
+    set(GGML_SOURCES_LLAMAFILE llamafile/sgemm.cpp)
 endif()
 
 if (GGML_CUDA)
index c0aced3d2d06928546e3a5c5ccc54098bee555ef..1bb731e16d51dd1f5e249173bea238a90194d287 100644 (file)
@@ -6,7 +6,6 @@
 #include "ggml.h"
 #include "ggml-aarch64.h"
 
-
 #if defined(_MSC_VER) || defined(__MINGW32__)
 #include <malloc.h> // using malloc.h with MSC/MINGW
 #elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
@@ -43,7 +42,7 @@
 #endif
 
 #ifdef GGML_USE_LLAMAFILE
-#include "sgemm.h"
+#include <llamafile/sgemm.h>
 #endif
 
 #if defined(_MSC_VER)
diff --git a/ggml/src/llamafile/sgemm.cpp b/ggml/src/llamafile/sgemm.cpp
new file mode 100644 (file)
index 0000000..6626ceb
--- /dev/null
@@ -0,0 +1,1027 @@
+// Copyright 2024 Mozilla Foundation
+//
+// Permission is hereby granted, free of charge, to any person obtaining
+// a copy of this software and associated documentation files (the
+// "Software"), to deal in the Software without restriction, including
+// without limitation the rights to use, copy, modify, merge, publish,
+// distribute, sublicense, and/or sell copies of the Software, and to
+// permit persons to whom the Software is furnished to do so, subject to
+// the following conditions:
+//
+// The above copyright notice and this permission notice shall be
+// included in all copies or substantial portions of the Software.
+//
+// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+// BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+// ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+// SOFTWARE.
+
+//
+//                   _   _          ___ _      _   ___
+//                  | |_(_)_ _ _  _| _ ) |    /_\ / __|
+//                  |  _| | ' \ || | _ \ |__ / _ \\__ \.
+//                   \__|_|_||_\_, |___/____/_/ \_\___/
+//                             |__/
+//
+//                    BASIC LINEAR ALGEBRA SUBPROGRAMS
+//
+//
+// This file implements multithreaded CPU matrix multiplication for the
+// common contiguous use case C = Aᵀ * B. These kernels are designed to
+// have excellent performance[1] for matrices that fit in the CPU cache
+// without imposing any overhead such as cache filling or malloc calls.
+//
+// This implementation does not guarantee any upper bound with rounding
+// errors, which grow along with k. Our goal's to maximally exploit the
+// hardware for performance, and then use whatever resources remain for
+// improving numerical accuracy.
+//
+// [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online].
+//     Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].
+
+#if defined(__GNUC__)
+#pragma GCC diagnostic ignored "-Wpedantic"
+#pragma GCC diagnostic ignored "-Wignored-attributes"
+#endif
+
+#include "sgemm.h"
+#include "ggml-impl.h"
+#include "ggml-quants.h"
+
+#ifdef _MSC_VER
+#define NOINLINE __declspec(noinline)
+#else
+#define NOINLINE __attribute__((__noinline__))
+#endif
+
+#if defined(__ARM_NEON) || defined(__AVX512F__)
+#define VECTOR_REGISTERS 32
+#else
+#define VECTOR_REGISTERS 16
+#endif
+
+#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
+
+namespace {
+
+inline float unhalf(ggml_fp16_t d) {
+    return GGML_FP16_TO_FP32(d);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+// VECTORIZED ARITHMETIC OPERATIONS
+
+#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
+inline __m128 add(__m128 x, __m128 y) { return _mm_add_ps(x, y); }
+inline __m128 sub(__m128 x, __m128 y) { return _mm_sub_ps(x, y); }
+inline __m128 mul(__m128 x, __m128 y) { return _mm_mul_ps(x, y); }
+#endif  // __SSE__
+
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
+inline __m256 add(__m256 x, __m256 y) { return _mm256_add_ps(x, y); }
+inline __m256 sub(__m256 x, __m256 y) { return _mm256_sub_ps(x, y); }
+inline __m256 mul(__m256 x, __m256 y) { return _mm256_mul_ps(x, y); }
+#endif // __AVX__
+
+#if defined(__AVX512F__)
+inline __m512 add(__m512 x, __m512 y) { return _mm512_add_ps(x, y); }
+inline __m512 sub(__m512 x, __m512 y) { return _mm512_sub_ps(x, y); }
+inline __m512 mul(__m512 x, __m512 y) { return _mm512_mul_ps(x, y); }
+#endif // __AVX512F__
+
+#if defined(__ARM_NEON)
+inline float32x4_t add(float32x4_t x, float32x4_t y) { return vaddq_f32(x, y); }
+inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vsubq_f32(x, y); }
+inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vmulq_f32(x, y); }
+#endif // __ARM_NEON
+
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
+inline float16x8_t add(float16x8_t x, float16x8_t y) { return vaddq_f16(x, y); }
+inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
+inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+// VECTORIZED FUSED MULTIPLY ADD
+
+/**
+ * Computes a * b + c.
+ */
+template <typename T, typename U>
+inline U madd(T a, T b, U c) {
+    return add(mul(a, b), c);
+}
+
+#if defined(__FMA__)
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
+template <>
+inline __m256 madd(__m256 a, __m256 b, __m256 c) {
+    return _mm256_fmadd_ps(a, b, c);
+}
+#endif
+#if defined(__AVX512F__)
+template <>
+inline __m512 madd(__m512 a, __m512 b, __m512 c) {
+    return _mm512_fmadd_ps(a, b, c);
+}
+#endif
+#endif
+
+#if defined(__ARM_FEATURE_FMA)
+template <>
+inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
+    return vfmaq_f32(c, b, a);
+}
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
+template <>
+inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
+    return vfmaq_f16(c, b, a);
+}
+#endif
+#endif
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+// VECTORIZED HORIZONTAL SUM
+
+#if defined(__ARM_NEON)
+inline float hsum(float32x4_t x) {
+    return vaddvq_f32(x);
+}
+#endif // __ARM_NEON
+
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
+inline float hsum(float16x8_t x) {
+    return vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(x)),
+                                vcvt_f32_f16(vget_high_f16(x))));
+}
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+
+#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
+inline float hsum(__m128 x) {
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
+    x = _mm_add_ps(x, _mm_movehl_ps(x, x));
+    x = _mm_add_ss(x, _mm_movehdup_ps(x));
+#else
+    __m128 t;
+    t = _mm_shuffle_ps(x, x, _MM_SHUFFLE(2, 3, 0, 1));
+    x = _mm_add_ps(x, t);
+    t = _mm_movehl_ps(t, x);
+    x = _mm_add_ss(x, t);
+#endif
+    return _mm_cvtss_f32(x);
+}
+#endif
+
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
+inline float hsum(__m256 x) {
+    return hsum(_mm_add_ps(_mm256_extractf128_ps(x, 1),
+                           _mm256_castps256_ps128(x)));
+}
+#endif // __AVX__
+
+#if defined(__AVX512F__)
+inline float hsum(__m512 x) {
+    return _mm512_reduce_add_ps(x);
+}
+#endif // __AVX512F__
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+// VECTORIZED MEMORY LOADING
+
+template <typename T, typename U> T load(const U *);
+
+#if defined(__ARM_NEON)
+template <> inline float32x4_t load(const float *p) {
+    return vld1q_f32(p);
+}
+#if !defined(_MSC_VER)
+template <> inline float16x8_t load(const ggml_fp16_t *p) {
+    return vld1q_f16((const float16_t *)p);
+}
+template <> inline float32x4_t load(const ggml_fp16_t *p) {
+    return vcvt_f32_f16(vld1_f16((const float16_t *)p));
+}
+#endif // _MSC_VER
+#endif // __ARM_NEON
+
+#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
+template <> inline __m128 load(const float *p) {
+    return _mm_loadu_ps(p);
+}
+#endif  // __SSE__
+
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
+template <> inline __m256 load(const float *p) {
+    return _mm256_loadu_ps(p);
+}
+#endif // __AVX__
+
+#if defined(__F16C__)
+template <> inline __m256 load(const ggml_fp16_t *p) {
+    return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)p));
+}
+#endif // __F16C__
+
+#if defined(__AVX512F__)
+template <> inline __m512 load(const float *p) {
+    return _mm512_loadu_ps(p);
+}
+template <> inline __m512 load(const ggml_fp16_t *p) {
+    return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)p));
+}
+#endif // __AVX512F__
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+// FLOATING POINT MATRIX MULTIPLICATION
+
+template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
+class tinyBLAS {
+  public:
+    tinyBLAS(int64_t k,
+             const TA *A, int64_t lda,
+             const TB *B, int64_t ldb,
+             TC *C, int64_t ldc,
+             int ith, int nth)
+        : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
+    }
+
+    void matmul(int64_t m, int64_t n) {
+        mnpack(0, m, 0, n);
+    }
+
+  private:
+    NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+        int64_t mc, nc, mp, np;
+        switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) {
+#if VECTOR_REGISTERS == 32
+        case 0x55:
+            mc = 5;
+            nc = 5;
+            gemm<5, 5>(m0, m, n0, n);
+            break;
+        case 0x45:
+            mc = 4;
+            nc = 5;
+            gemm<4, 5>(m0, m, n0, n);
+            break;
+        case 0x54:
+            mc = 5;
+            nc = 4;
+            gemm<5, 4>(m0, m, n0, n);
+            break;
+        case 0x44:
+            mc = 4;
+            nc = 4;
+            gemm<4, 4>(m0, m, n0, n);
+            break;
+        case 0x53:
+            mc = 5;
+            nc = 3;
+            gemm<5, 3>(m0, m, n0, n);
+            break;
+        case 0x35:
+            mc = 3;
+            nc = 5;
+            gemm<3, 5>(m0, m, n0, n);
+            break;
+        case 0x43:
+            mc = 4;
+            nc = 3;
+            gemm<4, 3>(m0, m, n0, n);
+            break;
+#else
+        case 0x55:
+        case 0x54:
+        case 0x53:
+        case 0x45:
+        case 0x44:
+        case 0x43:
+            mc = 4;
+            nc = 3;
+            gemm<4, 3>(m0, m, n0, n);
+            break;
+        case 0x35:
+#endif
+        case 0x34:
+            mc = 3;
+            nc = 4;
+            gemm<3, 4>(m0, m, n0, n);
+            break;
+        case 0x52:
+            mc = 5;
+            nc = 2;
+            gemm<5, 2>(m0, m, n0, n);
+            break;
+        case 0x33:
+            mc = 3;
+            nc = 3;
+            gemm<3, 3>(m0, m, n0, n);
+            break;
+        case 0x25:
+            mc = 2;
+            nc = 5;
+            gemm<2, 5>(m0, m, n0, n);
+            break;
+        case 0x42:
+            mc = 4;
+            nc = 2;
+            gemm<4, 2>(m0, m, n0, n);
+            break;
+        case 0x24:
+            mc = 2;
+            nc = 4;
+            gemm<2, 4>(m0, m, n0, n);
+            break;
+        case 0x32:
+            mc = 3;
+            nc = 2;
+            gemm<3, 2>(m0, m, n0, n);
+            break;
+        case 0x23:
+            mc = 2;
+            nc = 3;
+            gemm<2, 3>(m0, m, n0, n);
+            break;
+        case 0x51:
+            mc = 5;
+            nc = 1;
+            gemm<5, 1>(m0, m, n0, n);
+            break;
+        case 0x41:
+            mc = 4;
+            nc = 1;
+            gemm<4, 1>(m0, m, n0, n);
+            break;
+        case 0x22:
+            mc = 2;
+            nc = 2;
+            gemm<2, 2>(m0, m, n0, n);
+            break;
+        case 0x15:
+            mc = 1;
+            nc = 5;
+            gemm<1, 5>(m0, m, n0, n);
+            break;
+        case 0x14:
+            mc = 1;
+            nc = 4;
+            gemm<1, 4>(m0, m, n0, n);
+            break;
+        case 0x31:
+            mc = 3;
+            nc = 1;
+            gemm<3, 1>(m0, m, n0, n);
+            break;
+        case 0x13:
+            mc = 1;
+            nc = 3;
+            gemm<1, 3>(m0, m, n0, n);
+            break;
+        case 0x21:
+            mc = 2;
+            nc = 1;
+            gemm<2, 1>(m0, m, n0, n);
+            break;
+        case 0x12:
+            mc = 1;
+            nc = 2;
+            gemm<1, 2>(m0, m, n0, n);
+            break;
+        case 0x11:
+            mc = 1;
+            nc = 1;
+            gemm<1, 1>(m0, m, n0, n);
+            break;
+        default:
+            return;
+        }
+        mp = m0 + (m - m0) / mc * mc;
+        np = n0 + (n - n0) / nc * nc;
+        mnpack(mp, m, n0, np);
+        mnpack(m0, m, np, n);
+    }
+
+    template <int RM, int RN>
+    NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+        int64_t ytiles = (m - m0) / RM;
+        int64_t xtiles = (n - n0) / RN;
+        int64_t tiles = xtiles * ytiles;
+        int64_t duty = (tiles + nth - 1) / nth;
+        int64_t start = duty * ith;
+        int64_t end = start + duty;
+        if (end > tiles)
+            end = tiles;
+        for (int64_t job = start; job < end; ++job) {
+            int64_t ii = m0 + job / xtiles * RM;
+            int64_t jj = n0 + job % xtiles * RN;
+            D Cv[RN][RM] = {};
+            for (int64_t l = 0; l < k; l += KN)
+                for (int64_t j = 0; j < RN; ++j)
+                    for (int64_t i = 0; i < RM; ++i)
+                        Cv[j][i] = madd(load<V>(A + lda * (ii + i) + l),
+                                        load<V>(B + ldb * (jj + j) + l),
+                                        Cv[j][i]);
+            for (int64_t j = 0; j < RN; ++j)
+                for (int64_t i = 0; i < RM; ++i)
+                    C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
+        }
+    }
+
+    const TA *const A;
+    const TB *const B;
+    TC *const C;
+    const int64_t k;
+    const int64_t lda;
+    const int64_t ldb;
+    const int64_t ldc;
+    const int ith;
+    const int nth;
+};
+
+//////////////////////////////////////////////////////////////////////////////////////////
+// QUANT ZERO MATRIX MULTIPLICATION
+
+#if defined(__ARM_FEATURE_DOTPROD)
+template <typename TA>
+class tinyBLAS_Q0_ARM {
+  public:
+    tinyBLAS_Q0_ARM(int64_t k,
+                    const TA *A, int64_t lda,
+                    const block_q8_0 *B, int64_t ldb,
+                    float *C, int64_t ldc,
+                    int ith, int nth)
+        : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
+    }
+
+    void matmul(int64_t m, int64_t n) {
+        mnpack(0, m, 0, n);
+    }
+
+  private:
+    NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+        int64_t mc, nc, mp, np;
+        switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) {
+        case 0x33:
+            mc = 3;
+            nc = 3;
+            gemm<3, 3>(m0, m, n0, n);
+            break;
+        case 0x32:
+            mc = 3;
+            nc = 2;
+            gemm<3, 2>(m0, m, n0, n);
+            break;
+        case 0x23:
+            mc = 2;
+            nc = 3;
+            gemm<2, 3>(m0, m, n0, n);
+            break;
+        case 0x22:
+            mc = 2;
+            nc = 2;
+            gemm<2, 2>(m0, m, n0, n);
+            break;
+        case 0x31:
+            mc = 3;
+            nc = 1;
+            gemm<3, 1>(m0, m, n0, n);
+            break;
+        case 0x13:
+            mc = 1;
+            nc = 3;
+            gemm<1, 3>(m0, m, n0, n);
+            break;
+        case 0x21:
+            mc = 2;
+            nc = 1;
+            gemm<2, 1>(m0, m, n0, n);
+            break;
+        case 0x12:
+            mc = 1;
+            nc = 2;
+            gemm<1, 2>(m0, m, n0, n);
+            break;
+        case 0x11:
+            mc = 1;
+            nc = 1;
+            gemm<1, 1>(m0, m, n0, n);
+            break;
+        default:
+            return;
+        }
+        mp = m0 + (m - m0) / mc * mc;
+        np = n0 + (n - n0) / nc * nc;
+        mnpack(mp, m, n0, np);
+        mnpack(m0, m, np, n);
+    }
+
+    template <int RM, int RN>
+    NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+        int64_t ytiles = (m - m0) / RM;
+        int64_t xtiles = (n - n0) / RN;
+        int64_t tiles = xtiles * ytiles;
+        int64_t duty = (tiles + nth - 1) / nth;
+        int64_t start = duty * ith;
+        int64_t end = start + duty;
+        if (end > tiles)
+            end = tiles;
+        for (int64_t job = start; job < end; ++job) {
+            int64_t ii = m0 + job / xtiles * RM;
+            int64_t jj = n0 + job % xtiles * RN;
+            float32x4_t Cv[RN][RM] = {};
+            for (int64_t l = 0; l < k; ++l)
+                for (int64_t j = 0; j < RN; ++j)
+                    for (int64_t i = 0; i < RM; ++i)
+                        Cv[j][i] = vmlaq_n_f32(Cv[j][i],
+                                               vcvtq_f32_s32(vdotq_s32(
+                                                   vdotq_s32(vdupq_n_s32(0),
+                                                             load_lo(A + lda * (ii + i) + l),
+                                                             load_lo(B + ldb * (jj + j) + l)),
+                                                   load_hi(A + lda * (ii + i) + l),
+                                                   load_hi(B + ldb * (jj + j) + l))),
+                                               unhalf(A[lda * (ii + i) + l].d) *
+                                               unhalf(B[ldb * (jj + j) + l].d));
+            for (int64_t j = 0; j < RN; ++j)
+                for (int64_t i = 0; i < RM; ++i)
+                    C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
+        }
+    }
+
+    inline int8x16_t load_lo(const block_q8_0 *b) {
+        return vld1q_s8(b->qs);
+    }
+
+    inline int8x16_t load_hi(const block_q8_0 *b) {
+        return vld1q_s8(b->qs + 16);
+    }
+
+    inline int8x16_t load_lo(const block_q4_0 *b) {
+        return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs),
+                                                     vdupq_n_u8(0x0f))),
+                        vdupq_n_s8(0x8));
+    }
+
+    inline int8x16_t load_hi(const block_q4_0 *b) {
+        return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)),
+                        vdupq_n_s8(0x8));
+    }
+
+    const TA *const A;
+    const block_q8_0 *const B;
+    float *const C;
+    const int64_t k;
+    const int64_t lda;
+    const int64_t ldb;
+    const int64_t ldc;
+    const int ith;
+    const int nth;
+};
+#endif // __ARM_FEATURE_DOTPROD
+
+#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
+template <typename TA, typename TB, typename TC>
+class tinyBLAS_Q0_AVX {
+  public:
+    tinyBLAS_Q0_AVX(int64_t k,
+                    const TA *A, int64_t lda,
+                    const TB *B, int64_t ldb,
+                    TC *C, int64_t ldc,
+                    int ith, int nth)
+        : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
+    }
+
+    void matmul(int64_t m, int64_t n) {
+        mnpack(0, m, 0, n);
+    }
+
+  private:
+    void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+        int64_t mc, nc, mp, np;
+        switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) {
+#if VECTOR_REGISTERS == 32
+        case 0x44:
+            mc = 4;
+            nc = 4;
+            gemm<4, 4>(m0, m, n0, n);
+            break;
+        case 0x43:
+            mc = 4;
+            nc = 3;
+            gemm<4, 3>(m0, m, n0, n);
+            break;
+        case 0x34:
+            mc = 3;
+            nc = 4;
+            gemm<3, 4>(m0, m, n0, n);
+            break;
+        case 0x33:
+            mc = 3;
+            nc = 3;
+            gemm<3, 3>(m0, m, n0, n);
+            break;
+        case 0x42:
+            mc = 4;
+            nc = 2;
+            gemm<4, 2>(m0, m, n0, n);
+            break;
+        case 0x24:
+            mc = 2;
+            nc = 4;
+            gemm<2, 4>(m0, m, n0, n);
+            break;
+#else
+        case 0x44:
+        case 0x43:
+        case 0x42:
+            mc = 4;
+            nc = 2;
+            gemm<4, 2>(m0, m, n0, n);
+            break;
+        case 0x34:
+        case 0x24:
+            mc = 2;
+            nc = 4;
+            gemm<2, 4>(m0, m, n0, n);
+            break;
+        case 0x33:
+#endif
+        case 0x32:
+            mc = 3;
+            nc = 2;
+            gemm<3, 2>(m0, m, n0, n);
+            break;
+        case 0x23:
+            mc = 2;
+            nc = 3;
+            gemm<2, 3>(m0, m, n0, n);
+            break;
+        case 0x41:
+            mc = 4;
+            nc = 1;
+            gemm<4, 1>(m0, m, n0, n);
+            break;
+        case 0x22:
+            mc = 2;
+            nc = 2;
+            gemm<2, 2>(m0, m, n0, n);
+            break;
+        case 0x14:
+            mc = 1;
+            nc = 4;
+            gemm<1, 4>(m0, m, n0, n);
+            break;
+        case 0x31:
+            mc = 3;
+            nc = 1;
+            gemm<3, 1>(m0, m, n0, n);
+            break;
+        case 0x13:
+            mc = 1;
+            nc = 3;
+            gemm<1, 3>(m0, m, n0, n);
+            break;
+        case 0x21:
+            mc = 2;
+            nc = 1;
+            gemm<2, 1>(m0, m, n0, n);
+            break;
+        case 0x12:
+            mc = 1;
+            nc = 2;
+            gemm<1, 2>(m0, m, n0, n);
+            break;
+        case 0x11:
+            mc = 1;
+            nc = 1;
+            gemm<1, 1>(m0, m, n0, n);
+            break;
+        default:
+            return;
+        }
+        mp = m0 + (m - m0) / mc * mc;
+        np = n0 + (n - n0) / nc * nc;
+        mnpack(mp, m, n0, np);
+        mnpack(m0, m, np, n);
+    }
+
+    template <int RM, int RN>
+    NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+        int64_t ytiles = (m - m0) / RM;
+        int64_t xtiles = (n - n0) / RN;
+        int64_t tiles = xtiles * ytiles;
+        int64_t duty = (tiles + nth - 1) / nth;
+        int64_t start = duty * ith;
+        int64_t end = start + duty;
+        if (end > tiles)
+            end = tiles;
+        for (int64_t job = start; job < end; ++job) {
+            int64_t ii = m0 + job / xtiles * RM;
+            int64_t jj = n0 + job % xtiles * RN;
+            __m256 Cv[RN][RM] = {};
+            for (int64_t l = 0; l < k; ++l)
+                for (int64_t j = 0; j < RN; ++j)
+                    for (int64_t i = 0; i < RM; ++i) {
+#if defined(__AVX2__)
+                        __m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
+                                                              load(A + lda * (ii + i) + l)),
+                                             _mm256_sign_epi8(load(B + ldb * (jj + j) + l),
+                                                              load(A + lda * (ii + i) + l)));
+#else
+                        __m128i ali0 = load0(A + lda * (ii + i) + l);
+                        __m128i ali1 = load1(A + lda * (ii + i) + l);
+                        __m128i blj0 = load0(B + ldb * (jj + j) + l);
+                        __m128i blj1 = load1(B + ldb * (jj + j) + l);
+
+                        __m128i sepAA0 = _mm_sign_epi8(ali0, ali0);
+                        __m128i sepAA1 = _mm_sign_epi8(ali1, ali1);
+                        __m128i sepBA0 = _mm_sign_epi8(blj0, ali0);
+                        __m128i sepBA1 = _mm_sign_epi8(blj1, ali1);
+
+                        // updot
+                        const __m128i oneFill = _mm_set1_epi16(1);
+                        __m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0);
+                        __m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1);
+                        __m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0)));
+#endif
+                        Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
+                                                       unhalf(B[ldb * (jj + j) + l].d)),
+                                                       udTmp,
+                                                       Cv[j][i]);
+                    }
+            for (int64_t j = 0; j < RN; ++j)
+                for (int64_t i = 0; i < RM; ++i)
+                    C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
+        }
+    }
+
+    inline __m256i load(const block_q8_0 *b) {
+        return _mm256_loadu_si256((const __m256i *)b->qs);
+    }
+
+    inline __m128i load0(const block_q8_0 *b) {
+        return _mm_loadu_si128((const __m128i *)b->qs);
+    }
+
+    inline __m128i load1(const block_q8_0 *b) {
+        return _mm_loadu_si128(((const __m128i *)b->qs) + 1);
+    }
+
+    inline __m256i load(const block_q4_0 *b) {
+        return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8));
+    }
+
+    inline __m128i load0(const block_q4_0 *b) {
+        const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
+        return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8));
+    }
+
+    inline __m128i load1(const block_q4_0 *b) {
+        const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
+        return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
+    }
+
+    inline __m256 updot(__m256i u, __m256i s) {
+        __m256i res;
+#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
+        res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);
+#else
+        res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));
+#endif
+        return _mm256_cvtepi32_ps(res);
+    }
+
+    static inline __m256i denibble(const uint8_t *p) {
+        __m128i x = _mm_loadu_si128((const __m128i *)p);
+        return _mm256_and_si256(_mm256_set1_epi8(15),
+                                _mm256_insertf128_si256(_mm256_castsi128_si256(x),
+                                                        _mm_srli_epi16(x, 4), 1));
+    }
+
+    const TA *const A;
+    const TB *const B;
+    TC *const C;
+    const int64_t k;
+    const int64_t lda;
+    const int64_t ldb;
+    const int64_t ldc;
+    const int ith;
+    const int nth;
+};
+#endif // __AVX__
+
+} // namespace
+
+/**
+ * Performs optimized matrix multiplication on CPU.
+ *
+ * This subroutine may compute C = Aᵀ * B with column major ordering.
+ * Despite its name, this isn't a generalized implementation. Work is
+ * only performed when a handwritten kernel is written and available.
+ * Otherwise the caller should fall back to a general matmul routine.
+ *
+ * For example, for single-threaded single-precision GEMM you can say
+ *
+ *     llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc,
+ *                     0, 1,
+ *                     GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32);
+ *
+ * @param m is rows in `A` and `C`
+ * @param n is cols in `B` and `C`
+ * @param k is cols in `A` and rows in `B`
+ * @param A is first input matrix (always transposed)
+ * @param lda is row stride of `A`
+ * @param B is second input matrix (never transposed)
+ * @param ldb is row stride of `B`
+ * @param C is input/output array of output matrices
+ * @param ldc is row stride of `C`
+ * @param ith is thread id (must be less than `nth`)
+ * @param nth is number of threads (must be greater than zero)
+ * @param Atype is GGML data type of `A`
+ * @param Btype is GGML data type of `B`
+ * @param Ctype is GGML data type of `C`
+ * @return true if this function was able to service the matmul request
+ */
+bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
+                     int64_t ldc, int ith, int nth, int Atype, int Btype, int Ctype) {
+
+    assert(m >= 0);
+    assert(n >= 0);
+    assert(k >= 0);
+    assert(lda >= k);
+    assert(ldb >= k);
+    assert(ldc >= m);
+    assert(nth > 0);
+    assert(ith < nth);
+
+    if (Ctype != GGML_TYPE_F32)
+        return false;
+
+    switch (Atype) {
+
+    case GGML_TYPE_F32: {
+        if (Btype != GGML_TYPE_F32)
+            return false;
+#if defined(__AVX512F__)
+        if (k % 16)
+            return false;
+        tinyBLAS<16, __m512, __m512, float, float, float> tb{
+            k, (const float *)A, lda,
+            (const float *)B, ldb,
+            (float *)C, ldc,
+            ith, nth};
+        tb.matmul(m, n);
+        return true;
+#elif defined(__AVX__) || defined(__AVX2__)
+        if (k % 8)
+            return false;
+        tinyBLAS<8, __m256, __m256, float, float, float> tb{
+            k, (const float *)A, lda,
+            (const float *)B, ldb,
+            (float *)C, ldc,
+            ith, nth};
+        tb.matmul(m, n);
+        return true;
+#elif defined(__ARM_NEON)
+        if (n < 4)
+            return false;
+        if (k % 4)
+            return false;
+        tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{
+            k, (const float *)A, lda,
+            (const float *)B, ldb,
+            (float *)C, ldc,
+            ith, nth};
+        tb.matmul(m, n);
+        return true;
+#else
+        return false;
+#endif
+    }
+
+    case GGML_TYPE_F16: {
+#if defined(__AVX512F__)
+        if (k % 16)
+            return false;
+        if (Btype != GGML_TYPE_F32)
+            return false;
+        tinyBLAS<16, __m512, __m512, ggml_fp16_t, float, float> tb{
+            k, (const ggml_fp16_t *)A, lda,
+            (const float *)B, ldb,
+            (float *)C, ldc,
+            ith, nth};
+        tb.matmul(m, n);
+        return true;
+#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
+        if (k % 8)
+            return false;
+        if (Btype != GGML_TYPE_F32)
+            return false;
+        tinyBLAS<8, __m256, __m256, ggml_fp16_t, float, float> tb{
+            k, (const ggml_fp16_t *)A, lda,
+            (const float *)B, ldb,
+            (float *)C, ldc,
+            ith, nth};
+        tb.matmul(m, n);
+        return true;
+#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
+        if (n < 8)
+            return false;
+        if (k % 8)
+            return false;
+        if (Btype != GGML_TYPE_F16)
+            return false;
+        tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{
+            k, (const ggml_fp16_t *)A, lda,
+            (const ggml_fp16_t *)B, ldb,
+            (float *)C, ldc,
+            ith, nth};
+        tb.matmul(m, n);
+        return true;
+#elif defined(__ARM_NEON) && !defined(_MSC_VER)
+        if (k % 4)
+            return false;
+        if (Btype != GGML_TYPE_F32)
+            return false;
+        tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{
+            k, (const ggml_fp16_t *)A, lda,
+            (const float *)B, ldb,
+            (float *)C, ldc,
+            ith, nth};
+        tb.matmul(m, n);
+        return true;
+#else
+        return false;
+#endif
+    }
+
+    case GGML_TYPE_Q8_0: {
+        if (Btype != GGML_TYPE_Q8_0)
+           return false;
+#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
+        tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float> tb{
+            k, (const block_q8_0 *)A, lda,
+            (const block_q8_0 *)B, ldb,
+            (float *)C, ldc,
+            ith, nth};
+        tb.matmul(m, n);
+        return true;
+#elif defined(__ARM_FEATURE_DOTPROD)
+        tinyBLAS_Q0_ARM<block_q8_0> tb{
+            k, (const block_q8_0 *)A, lda,
+            (const block_q8_0 *)B, ldb,
+            (float *)C, ldc,
+            ith, nth};
+        tb.matmul(m, n);
+        return true;
+#else
+        return false;
+#endif
+    }
+
+    case GGML_TYPE_Q4_0: {
+        if (Btype != GGML_TYPE_Q8_0)
+            return false;
+#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
+        tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float> tb{
+            k, (const block_q4_0 *)A, lda,
+            (const block_q8_0 *)B, ldb,
+            (float *)C, ldc,
+            ith, nth};
+        tb.matmul(m, n);
+        return true;
+#elif defined(__ARM_FEATURE_DOTPROD)
+        tinyBLAS_Q0_ARM<block_q4_0> tb{
+            k, (const block_q4_0 *)A, lda,
+            (const block_q8_0 *)B, ldb,
+            (float *)C, ldc,
+            ith, nth};
+        tb.matmul(m, n);
+        return true;
+#else
+        return false;
+#endif
+    }
+
+    default:
+        return false;
+    }
+
+    (void)m;
+    (void)n;
+    (void)k;
+    (void)A;
+    (void)lda;
+    (void)B;
+    (void)ldb;
+    (void)C;
+    (void)ldc;
+    (void)ith;
+    (void)nth;
+    (void)Atype;
+    (void)Btype;
+    (void)Ctype;
+}
diff --git a/ggml/src/llamafile/sgemm.h b/ggml/src/llamafile/sgemm.h
new file mode 100644 (file)
index 0000000..caf6dd5
--- /dev/null
@@ -0,0 +1,14 @@
+#pragma once
+#include <stdint.h>
+#include <stdbool.h>
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+bool llamafile_sgemm(int64_t, int64_t, int64_t, const void *, int64_t,
+                     const void *, int64_t, void *, int64_t, int, int,
+                     int, int, int);
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/ggml/src/sgemm.cpp b/ggml/src/sgemm.cpp
deleted file mode 100644 (file)
index 6626ceb..0000000
+++ /dev/null
@@ -1,1027 +0,0 @@
-// Copyright 2024 Mozilla Foundation
-//
-// Permission is hereby granted, free of charge, to any person obtaining
-// a copy of this software and associated documentation files (the
-// "Software"), to deal in the Software without restriction, including
-// without limitation the rights to use, copy, modify, merge, publish,
-// distribute, sublicense, and/or sell copies of the Software, and to
-// permit persons to whom the Software is furnished to do so, subject to
-// the following conditions:
-//
-// The above copyright notice and this permission notice shall be
-// included in all copies or substantial portions of the Software.
-//
-// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
-// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
-// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
-// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
-// BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
-// ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
-// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-// SOFTWARE.
-
-//
-//                   _   _          ___ _      _   ___
-//                  | |_(_)_ _ _  _| _ ) |    /_\ / __|
-//                  |  _| | ' \ || | _ \ |__ / _ \\__ \.
-//                   \__|_|_||_\_, |___/____/_/ \_\___/
-//                             |__/
-//
-//                    BASIC LINEAR ALGEBRA SUBPROGRAMS
-//
-//
-// This file implements multithreaded CPU matrix multiplication for the
-// common contiguous use case C = Aᵀ * B. These kernels are designed to
-// have excellent performance[1] for matrices that fit in the CPU cache
-// without imposing any overhead such as cache filling or malloc calls.
-//
-// This implementation does not guarantee any upper bound with rounding
-// errors, which grow along with k. Our goal's to maximally exploit the
-// hardware for performance, and then use whatever resources remain for
-// improving numerical accuracy.
-//
-// [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online].
-//     Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].
-
-#if defined(__GNUC__)
-#pragma GCC diagnostic ignored "-Wpedantic"
-#pragma GCC diagnostic ignored "-Wignored-attributes"
-#endif
-
-#include "sgemm.h"
-#include "ggml-impl.h"
-#include "ggml-quants.h"
-
-#ifdef _MSC_VER
-#define NOINLINE __declspec(noinline)
-#else
-#define NOINLINE __attribute__((__noinline__))
-#endif
-
-#if defined(__ARM_NEON) || defined(__AVX512F__)
-#define VECTOR_REGISTERS 32
-#else
-#define VECTOR_REGISTERS 16
-#endif
-
-#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
-
-namespace {
-
-inline float unhalf(ggml_fp16_t d) {
-    return GGML_FP16_TO_FP32(d);
-}
-
-////////////////////////////////////////////////////////////////////////////////////////////////////
-// VECTORIZED ARITHMETIC OPERATIONS
-
-#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
-inline __m128 add(__m128 x, __m128 y) { return _mm_add_ps(x, y); }
-inline __m128 sub(__m128 x, __m128 y) { return _mm_sub_ps(x, y); }
-inline __m128 mul(__m128 x, __m128 y) { return _mm_mul_ps(x, y); }
-#endif  // __SSE__
-
-#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
-inline __m256 add(__m256 x, __m256 y) { return _mm256_add_ps(x, y); }
-inline __m256 sub(__m256 x, __m256 y) { return _mm256_sub_ps(x, y); }
-inline __m256 mul(__m256 x, __m256 y) { return _mm256_mul_ps(x, y); }
-#endif // __AVX__
-
-#if defined(__AVX512F__)
-inline __m512 add(__m512 x, __m512 y) { return _mm512_add_ps(x, y); }
-inline __m512 sub(__m512 x, __m512 y) { return _mm512_sub_ps(x, y); }
-inline __m512 mul(__m512 x, __m512 y) { return _mm512_mul_ps(x, y); }
-#endif // __AVX512F__
-
-#if defined(__ARM_NEON)
-inline float32x4_t add(float32x4_t x, float32x4_t y) { return vaddq_f32(x, y); }
-inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vsubq_f32(x, y); }
-inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vmulq_f32(x, y); }
-#endif // __ARM_NEON
-
-#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
-inline float16x8_t add(float16x8_t x, float16x8_t y) { return vaddq_f16(x, y); }
-inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
-inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
-#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-
-////////////////////////////////////////////////////////////////////////////////////////////////////
-// VECTORIZED FUSED MULTIPLY ADD
-
-/**
- * Computes a * b + c.
- */
-template <typename T, typename U>
-inline U madd(T a, T b, U c) {
-    return add(mul(a, b), c);
-}
-
-#if defined(__FMA__)
-#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
-template <>
-inline __m256 madd(__m256 a, __m256 b, __m256 c) {
-    return _mm256_fmadd_ps(a, b, c);
-}
-#endif
-#if defined(__AVX512F__)
-template <>
-inline __m512 madd(__m512 a, __m512 b, __m512 c) {
-    return _mm512_fmadd_ps(a, b, c);
-}
-#endif
-#endif
-
-#if defined(__ARM_FEATURE_FMA)
-template <>
-inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
-    return vfmaq_f32(c, b, a);
-}
-#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
-template <>
-inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
-    return vfmaq_f16(c, b, a);
-}
-#endif
-#endif
-
-////////////////////////////////////////////////////////////////////////////////////////////////////
-// VECTORIZED HORIZONTAL SUM
-
-#if defined(__ARM_NEON)
-inline float hsum(float32x4_t x) {
-    return vaddvq_f32(x);
-}
-#endif // __ARM_NEON
-
-#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
-inline float hsum(float16x8_t x) {
-    return vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(x)),
-                                vcvt_f32_f16(vget_high_f16(x))));
-}
-#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-
-#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
-inline float hsum(__m128 x) {
-#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
-    x = _mm_add_ps(x, _mm_movehl_ps(x, x));
-    x = _mm_add_ss(x, _mm_movehdup_ps(x));
-#else
-    __m128 t;
-    t = _mm_shuffle_ps(x, x, _MM_SHUFFLE(2, 3, 0, 1));
-    x = _mm_add_ps(x, t);
-    t = _mm_movehl_ps(t, x);
-    x = _mm_add_ss(x, t);
-#endif
-    return _mm_cvtss_f32(x);
-}
-#endif
-
-#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
-inline float hsum(__m256 x) {
-    return hsum(_mm_add_ps(_mm256_extractf128_ps(x, 1),
-                           _mm256_castps256_ps128(x)));
-}
-#endif // __AVX__
-
-#if defined(__AVX512F__)
-inline float hsum(__m512 x) {
-    return _mm512_reduce_add_ps(x);
-}
-#endif // __AVX512F__
-
-////////////////////////////////////////////////////////////////////////////////////////////////////
-// VECTORIZED MEMORY LOADING
-
-template <typename T, typename U> T load(const U *);
-
-#if defined(__ARM_NEON)
-template <> inline float32x4_t load(const float *p) {
-    return vld1q_f32(p);
-}
-#if !defined(_MSC_VER)
-template <> inline float16x8_t load(const ggml_fp16_t *p) {
-    return vld1q_f16((const float16_t *)p);
-}
-template <> inline float32x4_t load(const ggml_fp16_t *p) {
-    return vcvt_f32_f16(vld1_f16((const float16_t *)p));
-}
-#endif // _MSC_VER
-#endif // __ARM_NEON
-
-#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
-template <> inline __m128 load(const float *p) {
-    return _mm_loadu_ps(p);
-}
-#endif  // __SSE__
-
-#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
-template <> inline __m256 load(const float *p) {
-    return _mm256_loadu_ps(p);
-}
-#endif // __AVX__
-
-#if defined(__F16C__)
-template <> inline __m256 load(const ggml_fp16_t *p) {
-    return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)p));
-}
-#endif // __F16C__
-
-#if defined(__AVX512F__)
-template <> inline __m512 load(const float *p) {
-    return _mm512_loadu_ps(p);
-}
-template <> inline __m512 load(const ggml_fp16_t *p) {
-    return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)p));
-}
-#endif // __AVX512F__
-
-////////////////////////////////////////////////////////////////////////////////////////////////////
-// FLOATING POINT MATRIX MULTIPLICATION
-
-template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
-class tinyBLAS {
-  public:
-    tinyBLAS(int64_t k,
-             const TA *A, int64_t lda,
-             const TB *B, int64_t ldb,
-             TC *C, int64_t ldc,
-             int ith, int nth)
-        : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
-    }
-
-    void matmul(int64_t m, int64_t n) {
-        mnpack(0, m, 0, n);
-    }
-
-  private:
-    NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
-        int64_t mc, nc, mp, np;
-        switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) {
-#if VECTOR_REGISTERS == 32
-        case 0x55:
-            mc = 5;
-            nc = 5;
-            gemm<5, 5>(m0, m, n0, n);
-            break;
-        case 0x45:
-            mc = 4;
-            nc = 5;
-            gemm<4, 5>(m0, m, n0, n);
-            break;
-        case 0x54:
-            mc = 5;
-            nc = 4;
-            gemm<5, 4>(m0, m, n0, n);
-            break;
-        case 0x44:
-            mc = 4;
-            nc = 4;
-            gemm<4, 4>(m0, m, n0, n);
-            break;
-        case 0x53:
-            mc = 5;
-            nc = 3;
-            gemm<5, 3>(m0, m, n0, n);
-            break;
-        case 0x35:
-            mc = 3;
-            nc = 5;
-            gemm<3, 5>(m0, m, n0, n);
-            break;
-        case 0x43:
-            mc = 4;
-            nc = 3;
-            gemm<4, 3>(m0, m, n0, n);
-            break;
-#else
-        case 0x55:
-        case 0x54:
-        case 0x53:
-        case 0x45:
-        case 0x44:
-        case 0x43:
-            mc = 4;
-            nc = 3;
-            gemm<4, 3>(m0, m, n0, n);
-            break;
-        case 0x35:
-#endif
-        case 0x34:
-            mc = 3;
-            nc = 4;
-            gemm<3, 4>(m0, m, n0, n);
-            break;
-        case 0x52:
-            mc = 5;
-            nc = 2;
-            gemm<5, 2>(m0, m, n0, n);
-            break;
-        case 0x33:
-            mc = 3;
-            nc = 3;
-            gemm<3, 3>(m0, m, n0, n);
-            break;
-        case 0x25:
-            mc = 2;
-            nc = 5;
-            gemm<2, 5>(m0, m, n0, n);
-            break;
-        case 0x42:
-            mc = 4;
-            nc = 2;
-            gemm<4, 2>(m0, m, n0, n);
-            break;
-        case 0x24:
-            mc = 2;
-            nc = 4;
-            gemm<2, 4>(m0, m, n0, n);
-            break;
-        case 0x32:
-            mc = 3;
-            nc = 2;
-            gemm<3, 2>(m0, m, n0, n);
-            break;
-        case 0x23:
-            mc = 2;
-            nc = 3;
-            gemm<2, 3>(m0, m, n0, n);
-            break;
-        case 0x51:
-            mc = 5;
-            nc = 1;
-            gemm<5, 1>(m0, m, n0, n);
-            break;
-        case 0x41:
-            mc = 4;
-            nc = 1;
-            gemm<4, 1>(m0, m, n0, n);
-            break;
-        case 0x22:
-            mc = 2;
-            nc = 2;
-            gemm<2, 2>(m0, m, n0, n);
-            break;
-        case 0x15:
-            mc = 1;
-            nc = 5;
-            gemm<1, 5>(m0, m, n0, n);
-            break;
-        case 0x14:
-            mc = 1;
-            nc = 4;
-            gemm<1, 4>(m0, m, n0, n);
-            break;
-        case 0x31:
-            mc = 3;
-            nc = 1;
-            gemm<3, 1>(m0, m, n0, n);
-            break;
-        case 0x13:
-            mc = 1;
-            nc = 3;
-            gemm<1, 3>(m0, m, n0, n);
-            break;
-        case 0x21:
-            mc = 2;
-            nc = 1;
-            gemm<2, 1>(m0, m, n0, n);
-            break;
-        case 0x12:
-            mc = 1;
-            nc = 2;
-            gemm<1, 2>(m0, m, n0, n);
-            break;
-        case 0x11:
-            mc = 1;
-            nc = 1;
-            gemm<1, 1>(m0, m, n0, n);
-            break;
-        default:
-            return;
-        }
-        mp = m0 + (m - m0) / mc * mc;
-        np = n0 + (n - n0) / nc * nc;
-        mnpack(mp, m, n0, np);
-        mnpack(m0, m, np, n);
-    }
-
-    template <int RM, int RN>
-    NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
-        int64_t ytiles = (m - m0) / RM;
-        int64_t xtiles = (n - n0) / RN;
-        int64_t tiles = xtiles * ytiles;
-        int64_t duty = (tiles + nth - 1) / nth;
-        int64_t start = duty * ith;
-        int64_t end = start + duty;
-        if (end > tiles)
-            end = tiles;
-        for (int64_t job = start; job < end; ++job) {
-            int64_t ii = m0 + job / xtiles * RM;
-            int64_t jj = n0 + job % xtiles * RN;
-            D Cv[RN][RM] = {};
-            for (int64_t l = 0; l < k; l += KN)
-                for (int64_t j = 0; j < RN; ++j)
-                    for (int64_t i = 0; i < RM; ++i)
-                        Cv[j][i] = madd(load<V>(A + lda * (ii + i) + l),
-                                        load<V>(B + ldb * (jj + j) + l),
-                                        Cv[j][i]);
-            for (int64_t j = 0; j < RN; ++j)
-                for (int64_t i = 0; i < RM; ++i)
-                    C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
-        }
-    }
-
-    const TA *const A;
-    const TB *const B;
-    TC *const C;
-    const int64_t k;
-    const int64_t lda;
-    const int64_t ldb;
-    const int64_t ldc;
-    const int ith;
-    const int nth;
-};
-
-//////////////////////////////////////////////////////////////////////////////////////////
-// QUANT ZERO MATRIX MULTIPLICATION
-
-#if defined(__ARM_FEATURE_DOTPROD)
-template <typename TA>
-class tinyBLAS_Q0_ARM {
-  public:
-    tinyBLAS_Q0_ARM(int64_t k,
-                    const TA *A, int64_t lda,
-                    const block_q8_0 *B, int64_t ldb,
-                    float *C, int64_t ldc,
-                    int ith, int nth)
-        : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
-    }
-
-    void matmul(int64_t m, int64_t n) {
-        mnpack(0, m, 0, n);
-    }
-
-  private:
-    NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
-        int64_t mc, nc, mp, np;
-        switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) {
-        case 0x33:
-            mc = 3;
-            nc = 3;
-            gemm<3, 3>(m0, m, n0, n);
-            break;
-        case 0x32:
-            mc = 3;
-            nc = 2;
-            gemm<3, 2>(m0, m, n0, n);
-            break;
-        case 0x23:
-            mc = 2;
-            nc = 3;
-            gemm<2, 3>(m0, m, n0, n);
-            break;
-        case 0x22:
-            mc = 2;
-            nc = 2;
-            gemm<2, 2>(m0, m, n0, n);
-            break;
-        case 0x31:
-            mc = 3;
-            nc = 1;
-            gemm<3, 1>(m0, m, n0, n);
-            break;
-        case 0x13:
-            mc = 1;
-            nc = 3;
-            gemm<1, 3>(m0, m, n0, n);
-            break;
-        case 0x21:
-            mc = 2;
-            nc = 1;
-            gemm<2, 1>(m0, m, n0, n);
-            break;
-        case 0x12:
-            mc = 1;
-            nc = 2;
-            gemm<1, 2>(m0, m, n0, n);
-            break;
-        case 0x11:
-            mc = 1;
-            nc = 1;
-            gemm<1, 1>(m0, m, n0, n);
-            break;
-        default:
-            return;
-        }
-        mp = m0 + (m - m0) / mc * mc;
-        np = n0 + (n - n0) / nc * nc;
-        mnpack(mp, m, n0, np);
-        mnpack(m0, m, np, n);
-    }
-
-    template <int RM, int RN>
-    NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
-        int64_t ytiles = (m - m0) / RM;
-        int64_t xtiles = (n - n0) / RN;
-        int64_t tiles = xtiles * ytiles;
-        int64_t duty = (tiles + nth - 1) / nth;
-        int64_t start = duty * ith;
-        int64_t end = start + duty;
-        if (end > tiles)
-            end = tiles;
-        for (int64_t job = start; job < end; ++job) {
-            int64_t ii = m0 + job / xtiles * RM;
-            int64_t jj = n0 + job % xtiles * RN;
-            float32x4_t Cv[RN][RM] = {};
-            for (int64_t l = 0; l < k; ++l)
-                for (int64_t j = 0; j < RN; ++j)
-                    for (int64_t i = 0; i < RM; ++i)
-                        Cv[j][i] = vmlaq_n_f32(Cv[j][i],
-                                               vcvtq_f32_s32(vdotq_s32(
-                                                   vdotq_s32(vdupq_n_s32(0),
-                                                             load_lo(A + lda * (ii + i) + l),
-                                                             load_lo(B + ldb * (jj + j) + l)),
-                                                   load_hi(A + lda * (ii + i) + l),
-                                                   load_hi(B + ldb * (jj + j) + l))),
-                                               unhalf(A[lda * (ii + i) + l].d) *
-                                               unhalf(B[ldb * (jj + j) + l].d));
-            for (int64_t j = 0; j < RN; ++j)
-                for (int64_t i = 0; i < RM; ++i)
-                    C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
-        }
-    }
-
-    inline int8x16_t load_lo(const block_q8_0 *b) {
-        return vld1q_s8(b->qs);
-    }
-
-    inline int8x16_t load_hi(const block_q8_0 *b) {
-        return vld1q_s8(b->qs + 16);
-    }
-
-    inline int8x16_t load_lo(const block_q4_0 *b) {
-        return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs),
-                                                     vdupq_n_u8(0x0f))),
-                        vdupq_n_s8(0x8));
-    }
-
-    inline int8x16_t load_hi(const block_q4_0 *b) {
-        return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)),
-                        vdupq_n_s8(0x8));
-    }
-
-    const TA *const A;
-    const block_q8_0 *const B;
-    float *const C;
-    const int64_t k;
-    const int64_t lda;
-    const int64_t ldb;
-    const int64_t ldc;
-    const int ith;
-    const int nth;
-};
-#endif // __ARM_FEATURE_DOTPROD
-
-#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
-template <typename TA, typename TB, typename TC>
-class tinyBLAS_Q0_AVX {
-  public:
-    tinyBLAS_Q0_AVX(int64_t k,
-                    const TA *A, int64_t lda,
-                    const TB *B, int64_t ldb,
-                    TC *C, int64_t ldc,
-                    int ith, int nth)
-        : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
-    }
-
-    void matmul(int64_t m, int64_t n) {
-        mnpack(0, m, 0, n);
-    }
-
-  private:
-    void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
-        int64_t mc, nc, mp, np;
-        switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) {
-#if VECTOR_REGISTERS == 32
-        case 0x44:
-            mc = 4;
-            nc = 4;
-            gemm<4, 4>(m0, m, n0, n);
-            break;
-        case 0x43:
-            mc = 4;
-            nc = 3;
-            gemm<4, 3>(m0, m, n0, n);
-            break;
-        case 0x34:
-            mc = 3;
-            nc = 4;
-            gemm<3, 4>(m0, m, n0, n);
-            break;
-        case 0x33:
-            mc = 3;
-            nc = 3;
-            gemm<3, 3>(m0, m, n0, n);
-            break;
-        case 0x42:
-            mc = 4;
-            nc = 2;
-            gemm<4, 2>(m0, m, n0, n);
-            break;
-        case 0x24:
-            mc = 2;
-            nc = 4;
-            gemm<2, 4>(m0, m, n0, n);
-            break;
-#else
-        case 0x44:
-        case 0x43:
-        case 0x42:
-            mc = 4;
-            nc = 2;
-            gemm<4, 2>(m0, m, n0, n);
-            break;
-        case 0x34:
-        case 0x24:
-            mc = 2;
-            nc = 4;
-            gemm<2, 4>(m0, m, n0, n);
-            break;
-        case 0x33:
-#endif
-        case 0x32:
-            mc = 3;
-            nc = 2;
-            gemm<3, 2>(m0, m, n0, n);
-            break;
-        case 0x23:
-            mc = 2;
-            nc = 3;
-            gemm<2, 3>(m0, m, n0, n);
-            break;
-        case 0x41:
-            mc = 4;
-            nc = 1;
-            gemm<4, 1>(m0, m, n0, n);
-            break;
-        case 0x22:
-            mc = 2;
-            nc = 2;
-            gemm<2, 2>(m0, m, n0, n);
-            break;
-        case 0x14:
-            mc = 1;
-            nc = 4;
-            gemm<1, 4>(m0, m, n0, n);
-            break;
-        case 0x31:
-            mc = 3;
-            nc = 1;
-            gemm<3, 1>(m0, m, n0, n);
-            break;
-        case 0x13:
-            mc = 1;
-            nc = 3;
-            gemm<1, 3>(m0, m, n0, n);
-            break;
-        case 0x21:
-            mc = 2;
-            nc = 1;
-            gemm<2, 1>(m0, m, n0, n);
-            break;
-        case 0x12:
-            mc = 1;
-            nc = 2;
-            gemm<1, 2>(m0, m, n0, n);
-            break;
-        case 0x11:
-            mc = 1;
-            nc = 1;
-            gemm<1, 1>(m0, m, n0, n);
-            break;
-        default:
-            return;
-        }
-        mp = m0 + (m - m0) / mc * mc;
-        np = n0 + (n - n0) / nc * nc;
-        mnpack(mp, m, n0, np);
-        mnpack(m0, m, np, n);
-    }
-
-    template <int RM, int RN>
-    NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
-        int64_t ytiles = (m - m0) / RM;
-        int64_t xtiles = (n - n0) / RN;
-        int64_t tiles = xtiles * ytiles;
-        int64_t duty = (tiles + nth - 1) / nth;
-        int64_t start = duty * ith;
-        int64_t end = start + duty;
-        if (end > tiles)
-            end = tiles;
-        for (int64_t job = start; job < end; ++job) {
-            int64_t ii = m0 + job / xtiles * RM;
-            int64_t jj = n0 + job % xtiles * RN;
-            __m256 Cv[RN][RM] = {};
-            for (int64_t l = 0; l < k; ++l)
-                for (int64_t j = 0; j < RN; ++j)
-                    for (int64_t i = 0; i < RM; ++i) {
-#if defined(__AVX2__)
-                        __m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
-                                                              load(A + lda * (ii + i) + l)),
-                                             _mm256_sign_epi8(load(B + ldb * (jj + j) + l),
-                                                              load(A + lda * (ii + i) + l)));
-#else
-                        __m128i ali0 = load0(A + lda * (ii + i) + l);
-                        __m128i ali1 = load1(A + lda * (ii + i) + l);
-                        __m128i blj0 = load0(B + ldb * (jj + j) + l);
-                        __m128i blj1 = load1(B + ldb * (jj + j) + l);
-
-                        __m128i sepAA0 = _mm_sign_epi8(ali0, ali0);
-                        __m128i sepAA1 = _mm_sign_epi8(ali1, ali1);
-                        __m128i sepBA0 = _mm_sign_epi8(blj0, ali0);
-                        __m128i sepBA1 = _mm_sign_epi8(blj1, ali1);
-
-                        // updot
-                        const __m128i oneFill = _mm_set1_epi16(1);
-                        __m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0);
-                        __m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1);
-                        __m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0)));
-#endif
-                        Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
-                                                       unhalf(B[ldb * (jj + j) + l].d)),
-                                                       udTmp,
-                                                       Cv[j][i]);
-                    }
-            for (int64_t j = 0; j < RN; ++j)
-                for (int64_t i = 0; i < RM; ++i)
-                    C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
-        }
-    }
-
-    inline __m256i load(const block_q8_0 *b) {
-        return _mm256_loadu_si256((const __m256i *)b->qs);
-    }
-
-    inline __m128i load0(const block_q8_0 *b) {
-        return _mm_loadu_si128((const __m128i *)b->qs);
-    }
-
-    inline __m128i load1(const block_q8_0 *b) {
-        return _mm_loadu_si128(((const __m128i *)b->qs) + 1);
-    }
-
-    inline __m256i load(const block_q4_0 *b) {
-        return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8));
-    }
-
-    inline __m128i load0(const block_q4_0 *b) {
-        const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
-        return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8));
-    }
-
-    inline __m128i load1(const block_q4_0 *b) {
-        const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
-        return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
-    }
-
-    inline __m256 updot(__m256i u, __m256i s) {
-        __m256i res;
-#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
-        res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);
-#else
-        res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));
-#endif
-        return _mm256_cvtepi32_ps(res);
-    }
-
-    static inline __m256i denibble(const uint8_t *p) {
-        __m128i x = _mm_loadu_si128((const __m128i *)p);
-        return _mm256_and_si256(_mm256_set1_epi8(15),
-                                _mm256_insertf128_si256(_mm256_castsi128_si256(x),
-                                                        _mm_srli_epi16(x, 4), 1));
-    }
-
-    const TA *const A;
-    const TB *const B;
-    TC *const C;
-    const int64_t k;
-    const int64_t lda;
-    const int64_t ldb;
-    const int64_t ldc;
-    const int ith;
-    const int nth;
-};
-#endif // __AVX__
-
-} // namespace
-
-/**
- * Performs optimized matrix multiplication on CPU.
- *
- * This subroutine may compute C = Aᵀ * B with column major ordering.
- * Despite its name, this isn't a generalized implementation. Work is
- * only performed when a handwritten kernel is written and available.
- * Otherwise the caller should fall back to a general matmul routine.
- *
- * For example, for single-threaded single-precision GEMM you can say
- *
- *     llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc,
- *                     0, 1,
- *                     GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32);
- *
- * @param m is rows in `A` and `C`
- * @param n is cols in `B` and `C`
- * @param k is cols in `A` and rows in `B`
- * @param A is first input matrix (always transposed)
- * @param lda is row stride of `A`
- * @param B is second input matrix (never transposed)
- * @param ldb is row stride of `B`
- * @param C is input/output array of output matrices
- * @param ldc is row stride of `C`
- * @param ith is thread id (must be less than `nth`)
- * @param nth is number of threads (must be greater than zero)
- * @param Atype is GGML data type of `A`
- * @param Btype is GGML data type of `B`
- * @param Ctype is GGML data type of `C`
- * @return true if this function was able to service the matmul request
- */
-bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
-                     int64_t ldc, int ith, int nth, int Atype, int Btype, int Ctype) {
-
-    assert(m >= 0);
-    assert(n >= 0);
-    assert(k >= 0);
-    assert(lda >= k);
-    assert(ldb >= k);
-    assert(ldc >= m);
-    assert(nth > 0);
-    assert(ith < nth);
-
-    if (Ctype != GGML_TYPE_F32)
-        return false;
-
-    switch (Atype) {
-
-    case GGML_TYPE_F32: {
-        if (Btype != GGML_TYPE_F32)
-            return false;
-#if defined(__AVX512F__)
-        if (k % 16)
-            return false;
-        tinyBLAS<16, __m512, __m512, float, float, float> tb{
-            k, (const float *)A, lda,
-            (const float *)B, ldb,
-            (float *)C, ldc,
-            ith, nth};
-        tb.matmul(m, n);
-        return true;
-#elif defined(__AVX__) || defined(__AVX2__)
-        if (k % 8)
-            return false;
-        tinyBLAS<8, __m256, __m256, float, float, float> tb{
-            k, (const float *)A, lda,
-            (const float *)B, ldb,
-            (float *)C, ldc,
-            ith, nth};
-        tb.matmul(m, n);
-        return true;
-#elif defined(__ARM_NEON)
-        if (n < 4)
-            return false;
-        if (k % 4)
-            return false;
-        tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{
-            k, (const float *)A, lda,
-            (const float *)B, ldb,
-            (float *)C, ldc,
-            ith, nth};
-        tb.matmul(m, n);
-        return true;
-#else
-        return false;
-#endif
-    }
-
-    case GGML_TYPE_F16: {
-#if defined(__AVX512F__)
-        if (k % 16)
-            return false;
-        if (Btype != GGML_TYPE_F32)
-            return false;
-        tinyBLAS<16, __m512, __m512, ggml_fp16_t, float, float> tb{
-            k, (const ggml_fp16_t *)A, lda,
-            (const float *)B, ldb,
-            (float *)C, ldc,
-            ith, nth};
-        tb.matmul(m, n);
-        return true;
-#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
-        if (k % 8)
-            return false;
-        if (Btype != GGML_TYPE_F32)
-            return false;
-        tinyBLAS<8, __m256, __m256, ggml_fp16_t, float, float> tb{
-            k, (const ggml_fp16_t *)A, lda,
-            (const float *)B, ldb,
-            (float *)C, ldc,
-            ith, nth};
-        tb.matmul(m, n);
-        return true;
-#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
-        if (n < 8)
-            return false;
-        if (k % 8)
-            return false;
-        if (Btype != GGML_TYPE_F16)
-            return false;
-        tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{
-            k, (const ggml_fp16_t *)A, lda,
-            (const ggml_fp16_t *)B, ldb,
-            (float *)C, ldc,
-            ith, nth};
-        tb.matmul(m, n);
-        return true;
-#elif defined(__ARM_NEON) && !defined(_MSC_VER)
-        if (k % 4)
-            return false;
-        if (Btype != GGML_TYPE_F32)
-            return false;
-        tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{
-            k, (const ggml_fp16_t *)A, lda,
-            (const float *)B, ldb,
-            (float *)C, ldc,
-            ith, nth};
-        tb.matmul(m, n);
-        return true;
-#else
-        return false;
-#endif
-    }
-
-    case GGML_TYPE_Q8_0: {
-        if (Btype != GGML_TYPE_Q8_0)
-           return false;
-#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
-        tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float> tb{
-            k, (const block_q8_0 *)A, lda,
-            (const block_q8_0 *)B, ldb,
-            (float *)C, ldc,
-            ith, nth};
-        tb.matmul(m, n);
-        return true;
-#elif defined(__ARM_FEATURE_DOTPROD)
-        tinyBLAS_Q0_ARM<block_q8_0> tb{
-            k, (const block_q8_0 *)A, lda,
-            (const block_q8_0 *)B, ldb,
-            (float *)C, ldc,
-            ith, nth};
-        tb.matmul(m, n);
-        return true;
-#else
-        return false;
-#endif
-    }
-
-    case GGML_TYPE_Q4_0: {
-        if (Btype != GGML_TYPE_Q8_0)
-            return false;
-#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
-        tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float> tb{
-            k, (const block_q4_0 *)A, lda,
-            (const block_q8_0 *)B, ldb,
-            (float *)C, ldc,
-            ith, nth};
-        tb.matmul(m, n);
-        return true;
-#elif defined(__ARM_FEATURE_DOTPROD)
-        tinyBLAS_Q0_ARM<block_q4_0> tb{
-            k, (const block_q4_0 *)A, lda,
-            (const block_q8_0 *)B, ldb,
-            (float *)C, ldc,
-            ith, nth};
-        tb.matmul(m, n);
-        return true;
-#else
-        return false;
-#endif
-    }
-
-    default:
-        return false;
-    }
-
-    (void)m;
-    (void)n;
-    (void)k;
-    (void)A;
-    (void)lda;
-    (void)B;
-    (void)ldb;
-    (void)C;
-    (void)ldc;
-    (void)ith;
-    (void)nth;
-    (void)Atype;
-    (void)Btype;
-    (void)Ctype;
-}
diff --git a/ggml/src/sgemm.h b/ggml/src/sgemm.h
deleted file mode 100644 (file)
index caf6dd5..0000000
+++ /dev/null
@@ -1,14 +0,0 @@
-#pragma once
-#include <stdint.h>
-#include <stdbool.h>
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-bool llamafile_sgemm(int64_t, int64_t, int64_t, const void *, int64_t,
-                     const void *, int64_t, void *, int64_t, int, int,
-                     int, int, int);
-
-#ifdef __cplusplus
-}
-#endif