#include "ggml-cpu-impl.h"
#include "ggml-quants.h"
+#include <atomic>
+
#ifdef _MSC_VER
#define NOINLINE __declspec(noinline)
#else
return _mm512_fmadd_ps(a, b, c);
}
#endif
+#if defined(__AVX512BF16__)
+template <>
+inline __m512 madd(__m512bh a, __m512bh b, __m512 c) {
+ return _mm512_dpbf16_ps(c, a, b);
+}
+template <>
+inline __m256 madd(__m256bh a, __m256bh b, __m256 c) {
+ return _mm256_dpbf16_ps(c, a, b);
+}
+#endif
#endif
#if defined(__ARM_FEATURE_FMA)
}
#endif // __AVX__
+#if defined(__AVX2__) || defined(__AVX512F__)
+template <> inline __m256 load(const ggml_bf16_t *p) {
+ return _mm256_castsi256_ps(
+ _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)p)), 16));
+}
+#endif // __AVX2__
+
#if defined(__F16C__)
template <> inline __m256 load(const ggml_fp16_t *p) {
return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)p));
template <> inline __m512 load(const ggml_fp16_t *p) {
return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)p));
}
+template <> inline __m512 load(const ggml_bf16_t *p) {
+ return _mm512_castsi512_ps(
+ _mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)p)), 16));
+}
#endif // __AVX512F__
+#if defined(__AVX512BF16__)
+template <> inline __m512bh load(const ggml_bf16_t *p) {
+ return (__m512bh)_mm512_loadu_ps((const float *)p);
+}
+template <> inline __m256bh load(const ggml_bf16_t *p) {
+ return (__m256bh)_mm256_loadu_ps((const float *)p);
+}
+template <> inline __m512bh load(const float *p) {
+ return _mm512_cvtne2ps_pbh(_mm512_loadu_ps(p + 16), _mm512_loadu_ps(p));
+}
+template <> inline __m256bh load(const float *p) {
+ return _mm512_cvtneps_pbh(_mm512_loadu_ps(p));
+}
+#endif
+
////////////////////////////////////////////////////////////////////////////////////////////////////
// CONSTANTS
////////////////////////////////////////////////////////////////////////////////////////////////////
// FLOATING POINT MATRIX MULTIPLICATION
+template <int M>
+static inline int64_t BLOCK_SIZE(size_t m) {
+ const int64_t NB_BLOC_M = (m + M - 1) / M;
+ return (m % NB_BLOC_M == 0) ? m / NB_BLOC_M : (m / NB_BLOC_M) + 1;
+}
+
+static constexpr inline int64_t BLOC_POS(int64_t ib, int64_t ibN, int64_t bloc_size) {
+ return ib < ibN ? ib * bloc_size : ibN * bloc_size + (ib - ibN) * (bloc_size - 1);
+}
+
template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
class tinyBLAS {
public:
- tinyBLAS(int64_t k,
+ tinyBLAS(const ggml_compute_params * params, int64_t k,
const TA *A, int64_t lda,
const TB *B, int64_t ldb,
- TC *C, int64_t ldc,
- int ith, int nth)
- : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
+ TC *C, int64_t ldc)
+ : params(params), A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc) {
}
- void matmul(int64_t m, int64_t n) {
- mnpack(0, m, 0, n);
+ bool matmul(int64_t m, int64_t n) {
+ if (k % KN != 0)
+ return false;
+ // compute RM for only need tile with size RM&RM-1
+#if VECTOR_REGISTERS == 32
+ if (m % 16 == 0 && (m/16 >= params->nth)) {
+ const int64_t SIZE_N = BLOCK_SIZE<6>(n);
+ mnpack<4, 6, 4>(m, n, SIZE_N, 12);
+ return true;
+ }
+ if (m % 8 == 0 ) {
+ const int64_t SIZE_N = BLOCK_SIZE<6>(n);
+ mnpack<4, 6, 2>(m, n, SIZE_N, 12);
+ return true;
+ }
+ if (m % 4 == 0) {
+ const int64_t SIZE_N = BLOCK_SIZE<6>(n);
+ mnpack<4, 6, 1>(m, n, SIZE_N, 12);
+ return true;
+ }
+#else // VECTOR_REGISTERS == 16
+ if (m % 16 == 0 && (m/16 >= params->nth)) {
+ const int64_t SIZE_N = BLOCK_SIZE<3>(n);
+ mnpack<4, 3, 4>(m, n, SIZE_N, 24);
+ return true;
+ }
+ if (m % 8 == 0 ) {
+ const int64_t SIZE_N = BLOCK_SIZE<3>(n);
+ mnpack<4, 3, 2>(m, n, SIZE_N, 24);
+ return true;
+ }
+ if (m % 4 == 0) {
+ const int64_t SIZE_N = BLOCK_SIZE<3>(n);
+ mnpack<4, 3, 1>(m, n, SIZE_N, 24);
+ return true;
+ }
+#endif
+ return false;
}
private:
- NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
- int64_t mc, nc, mp, np;
- switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) {
-#if VECTOR_REGISTERS == 32
- case 0x55:
- mc = 5;
- nc = 5;
- gemm<5, 5>(m0, m, n0, n);
- break;
- case 0x45:
- mc = 4;
- nc = 5;
- gemm<4, 5>(m0, m, n0, n);
- break;
- case 0x54:
- mc = 5;
- nc = 4;
- gemm<5, 4>(m0, m, n0, n);
- break;
- case 0x44:
- mc = 4;
- nc = 4;
- gemm<4, 4>(m0, m, n0, n);
- break;
- case 0x53:
- mc = 5;
- nc = 3;
- gemm<5, 3>(m0, m, n0, n);
- break;
- case 0x35:
- mc = 3;
- nc = 5;
- gemm<3, 5>(m0, m, n0, n);
- break;
- case 0x43:
- mc = 4;
- nc = 3;
- gemm<4, 3>(m0, m, n0, n);
- break;
-#else
- case 0x55:
- case 0x54:
- case 0x53:
- case 0x45:
- case 0x44:
- case 0x43:
- mc = 4;
- nc = 3;
- gemm<4, 3>(m0, m, n0, n);
- break;
- case 0x35:
-#endif
- case 0x34:
- mc = 3;
- nc = 4;
- gemm<3, 4>(m0, m, n0, n);
- break;
- case 0x52:
- mc = 5;
- nc = 2;
- gemm<5, 2>(m0, m, n0, n);
- break;
- case 0x33:
- mc = 3;
- nc = 3;
- gemm<3, 3>(m0, m, n0, n);
- break;
- case 0x25:
- mc = 2;
- nc = 5;
- gemm<2, 5>(m0, m, n0, n);
- break;
- case 0x42:
- mc = 4;
- nc = 2;
- gemm<4, 2>(m0, m, n0, n);
- break;
- case 0x24:
- mc = 2;
- nc = 4;
- gemm<2, 4>(m0, m, n0, n);
- break;
- case 0x32:
- mc = 3;
- nc = 2;
- gemm<3, 2>(m0, m, n0, n);
- break;
- case 0x23:
- mc = 2;
- nc = 3;
- gemm<2, 3>(m0, m, n0, n);
- break;
- case 0x51:
- mc = 5;
- nc = 1;
- gemm<5, 1>(m0, m, n0, n);
- break;
- case 0x41:
- mc = 4;
- nc = 1;
- gemm<4, 1>(m0, m, n0, n);
- break;
- case 0x22:
- mc = 2;
- nc = 2;
- gemm<2, 2>(m0, m, n0, n);
- break;
- case 0x15:
- mc = 1;
- nc = 5;
- gemm<1, 5>(m0, m, n0, n);
- break;
- case 0x14:
- mc = 1;
- nc = 4;
- gemm<1, 4>(m0, m, n0, n);
- break;
- case 0x31:
- mc = 3;
- nc = 1;
- gemm<3, 1>(m0, m, n0, n);
- break;
- case 0x13:
- mc = 1;
- nc = 3;
- gemm<1, 3>(m0, m, n0, n);
- break;
- case 0x21:
- mc = 2;
- nc = 1;
- gemm<2, 1>(m0, m, n0, n);
- break;
- case 0x12:
- mc = 1;
- nc = 2;
- gemm<1, 2>(m0, m, n0, n);
- break;
- case 0x11:
- mc = 1;
- nc = 1;
- gemm<1, 1>(m0, m, n0, n);
- break;
- default:
- return;
+ template <int RM, int RN, int BM>
+ inline void mnpack(int64_t m, int64_t n, int64_t SIZE_N, int64_t BN) {
+ if (SIZE_N == RN) {
+ return gemm<RM, RN, BM>(m, n, BN);
+ }
+ if constexpr (RN > 1) {
+ return mnpack<RM, RN-1, BM>(m, n, SIZE_N, BN);
+ } else {
+ GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N);
+ GGML_ASSERT(false); // we have miss something.
}
- mp = m0 + (m - m0) / mc * mc;
- np = n0 + (n - n0) / nc * nc;
- mnpack(mp, m, n0, np);
- mnpack(m0, m, np, n);
}
template <int RM, int RN>
- NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
- int64_t ytiles = (m - m0) / RM;
- int64_t xtiles = (n - n0) / RN;
- int64_t tiles = xtiles * ytiles;
- int64_t duty = (tiles + nth - 1) / nth;
- int64_t start = duty * ith;
- int64_t end = start + duty;
- if (end > tiles)
- end = tiles;
- for (int64_t job = start; job < end; ++job) {
- int64_t ii = m0 + job / xtiles * RM;
- int64_t jj = n0 + job % xtiles * RN;
- D Cv[RN][RM] = {};
- for (int64_t l = 0; l < k; l += KN)
- for (int64_t j = 0; j < RN; ++j)
- for (int64_t i = 0; i < RM; ++i)
- Cv[j][i] = madd(load<V>(A + lda * (ii + i) + l),
- load<V>(B + ldb * (jj + j) + l),
- Cv[j][i]);
- for (int64_t j = 0; j < RN; ++j)
- for (int64_t i = 0; i < RM; ++i)
- C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
+ inline void gemm_bloc(int64_t ii, int64_t jj) {
+ D Cv[RN][RM] = {};
+ for (int64_t l = 0; l < k; l += KN) {
+ // help compiler for op order.
+ if constexpr (RM <= RN) {
+ V Av[RM];
+ for (int64_t i = 0; i < RM; ++i) {
+ Av[i] = load<V>(A + lda * (ii + i) + l);
+ }
+ for (int64_t j = 0; j < RN; ++j) {
+ V Bv = load<V>(B + ldb * (jj + j) + l);
+ for (int64_t i = 0; i < RM; ++i) {
+ Cv[j][i] = madd(Av[i], Bv, Cv[j][i]);
+ }
+ }
+ } else {
+ V Bv[RN];
+ for (int64_t j = 0; j < RN; ++j) {
+ Bv[j] = load<V>(B + ldb * (jj + j) + l);
+ }
+ for (int64_t i = 0; i < RM; ++i) {
+ V Av = load<V>(A + lda * (ii + i) + l);
+ for (int64_t j = 0; j < RN; ++j) {
+ Cv[j][i] = madd(Av, Bv[j], Cv[j][i]);
+ }
+ }
+ }
}
+ for (int64_t j = 0; j < RN; ++j)
+ for (int64_t i = 0; i < RM; ++i)
+ C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
}
+ template <int RM, int RN, int BM>
+ NOINLINE void gemm(int64_t m, int64_t n, int64_t BN) {
+ static std::atomic<int64_t> current_chunk;
+
+ GGML_ASSERT(m % (RM * BM) == 0);
+ const int64_t ytiles = m / (RM * BM);
+ const int64_t xtiles = (n + RN -1) / RN;
+ const int64_t jj_RN = (xtiles - (xtiles * RN - n));
+
+ // "round" bloc_size to "nearest" BN
+ const int64_t NB_BN = xtiles < BN ? 1 : (xtiles + BN / 2) / BN;
+ const int64_t SIZE_BN = xtiles % NB_BN == 0 ? xtiles / NB_BN : xtiles / NB_BN + 1;
+ const int64_t jj_BN = (NB_BN - (NB_BN * SIZE_BN - xtiles));
+ const int64_t nb_job = ytiles * NB_BN;
+
+ if (params->ith == 0) {
+ GGML_ASSERT( jj_BN * SIZE_BN + (NB_BN - jj_BN) * (SIZE_BN - 1) == xtiles);
+ // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
+ std::atomic_store_explicit(¤t_chunk, (int64_t)params->nth, std::memory_order_relaxed);
+ }
+
+ ggml_barrier(params->threadpool);
+
+ int64_t job = params->ith;
+ while (job < nb_job) {
+ const int64_t ii = (job % ytiles) * RM * BM;
+ const int64_t jb = job / ytiles;
+ const int64_t jr0 = BLOC_POS(jb , jj_BN, SIZE_BN);
+ const int64_t jrN = BLOC_POS(jb+1, jj_BN, SIZE_BN);
+
+ const int64_t jj0 = BLOC_POS(jr0, jj_RN, RN);
+ const int64_t jj2 = BLOC_POS(jrN, jj_RN, RN);
+ const int64_t jj1 = jj2 < jj_RN * RN ? jj2 : jj_RN * RN;
+
+ for (int64_t bi = 0; bi < BM * RM; bi += RM) {
+ int64_t jj = jj0;
+ for (; jj < jj1; jj += RN) {
+ gemm_bloc<RM, RN>(ii + bi, jj);
+ }
+ if constexpr (RN > 1) {
+ for (; jj < jj2; jj += RN - 1) {
+ gemm_bloc<RM, RN-1>(ii + bi, jj);
+ }
+ }
+ GGML_ASSERT(jj == jj2);
+ }
+
+ // next step.
+ job = std::atomic_fetch_add_explicit(¤t_chunk, (int64_t)1, std::memory_order_relaxed);
+ }
+
+ ggml_barrier(params->threadpool);
+ return;
+ }
+
+ const ggml_compute_params * params;
const TA *const A;
const TB *const B;
TC *const C;
const int64_t lda;
const int64_t ldb;
const int64_t ldc;
- const int ith;
- const int nth;
};
//////////////////////////////////////////////////////////////////////////////////////////
* @param Ctype is GGML data type of `C`
* @return true if this function was able to service the matmul request
*/
-bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
- int64_t ldc, int ith, int nth, int Atype, int Btype, int Ctype) {
+bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
+ const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
+ int64_t ldc, int Atype, int Btype, int Ctype) {
assert(m >= 0);
assert(n >= 0);
assert(lda >= k);
assert(ldb >= k);
assert(ldc >= m);
- assert(nth > 0);
- assert(ith < nth);
+ assert(params->nth > 0);
+ assert(params->ith < params->nth);
// only enable sgemm for prompt processing
if (n < 2)
if (Btype != GGML_TYPE_F32)
return false;
#if defined(__AVX512F__)
- if (k % 16)
- return false;
- tinyBLAS<16, __m512, __m512, float, float, float> tb{
+ tinyBLAS<16, __m512, __m512, float, float, float> tb{ params,
k, (const float *)A, lda,
(const float *)B, ldb,
- (float *)C, ldc,
- ith, nth};
- tb.matmul(m, n);
- return true;
+ (float *)C, ldc};
+ return tb.matmul(m, n);
#elif defined(__AVX__) || defined(__AVX2__)
- if (k % 8)
- return false;
- tinyBLAS<8, __m256, __m256, float, float, float> tb{
+ tinyBLAS<8, __m256, __m256, float, float, float> tb{ params,
k, (const float *)A, lda,
(const float *)B, ldb,
- (float *)C, ldc,
- ith, nth};
- tb.matmul(m, n);
- return true;
+ (float *)C, ldc};
+ return tb.matmul(m, n);
#elif defined(__ARM_NEON)
if (n < 4)
return false;
- if (k % 4)
- return false;
- tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{
+ tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ params,
k, (const float *)A, lda,
(const float *)B, ldb,
- (float *)C, ldc,
- ith, nth};
- tb.matmul(m, n);
- return true;
+ (float *)C, ldc};
+ return tb.matmul(m, n);
#elif defined(__MMA__)
if (k % 8)
return false;
k, (const float *)A, lda,
(const float *)B, ldb,
(float *)C, ldc,
- ith, nth};
+ params->ith, params->nth};
tb.matmul(m, n);
return true;
#else
#endif
}
+ case GGML_TYPE_BF16: {
+#if defined(__AVX512BF16__)
+ if (Btype == GGML_TYPE_BF16) {
+ tinyBLAS<32, __m512, __m512bh, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k,
+ (const ggml_bf16_t *)A, lda,
+ (const ggml_bf16_t *)B, ldb,
+ (float *)C, ldc};
+ return tb.matmul(m, n);
+ }
+#elif defined(__AVX512F__)
+ if (Btype == GGML_TYPE_BF16) {
+ tinyBLAS<16, __m512, __m512, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k,
+ (const ggml_bf16_t *)A, lda,
+ (const ggml_bf16_t *)B, ldb,
+ (float *)C, ldc};
+ return tb.matmul(m, n);
+ }
+#elif defined(__AVX2__)
+ if (Btype == GGML_TYPE_BF16) {
+ tinyBLAS<8, __m256, __m256, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k,
+ (const ggml_bf16_t *)A, lda,
+ (const ggml_bf16_t *)B, ldb,
+ (float *)C, ldc};
+ return tb.matmul(m, n);
+ }
+#endif
+ return false;
+ }
case GGML_TYPE_F16: {
#if defined(__AVX512F__)
- if (k % 16)
- return false;
- if (Btype != GGML_TYPE_F32)
- return false;
- tinyBLAS<16, __m512, __m512, ggml_fp16_t, float, float> tb{
- k, (const ggml_fp16_t *)A, lda,
- (const float *)B, ldb,
- (float *)C, ldc,
- ith, nth};
- tb.matmul(m, n);
- return true;
+ if (Btype == GGML_TYPE_F16) {
+ tinyBLAS<16, __m512, __m512, ggml_fp16_t, ggml_fp16_t, float> tb{ params, k,
+ (const ggml_fp16_t *)A, lda,
+ (const ggml_fp16_t *)B, ldb,
+ (float *)C, ldc};
+ return tb.matmul(m, n);
+ }
#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
- if (k % 8)
- return false;
- if (Btype != GGML_TYPE_F32)
- return false;
- tinyBLAS<8, __m256, __m256, ggml_fp16_t, float, float> tb{
- k, (const ggml_fp16_t *)A, lda,
- (const float *)B, ldb,
- (float *)C, ldc,
- ith, nth};
- tb.matmul(m, n);
- return true;
+ if (Btype == GGML_TYPE_F16) {
+ tinyBLAS<8, __m256, __m256, ggml_fp16_t, ggml_fp16_t, float> tb{ params, k,
+ (const ggml_fp16_t *)A, lda,
+ (const ggml_fp16_t *)B, ldb,
+ (float *)C, ldc};
+ return tb.matmul(m, n);
+ }
#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
if (n < 8)
return false;
- if (k % 8)
- return false;
- if (Btype != GGML_TYPE_F16)
- return false;
- tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{
- k, (const ggml_fp16_t *)A, lda,
- (const ggml_fp16_t *)B, ldb,
- (float *)C, ldc,
- ith, nth};
- tb.matmul(m, n);
- return true;
+ if (Btype == GGML_TYPE_F16) {
+ tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
+ k, (const ggml_fp16_t *)A, lda,
+ (const ggml_fp16_t *)B, ldb,
+ (float *)C, ldc};
+ return tb.matmul(m, n);
+ }
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
- if (k % 4)
- return false;
- if (Btype != GGML_TYPE_F32)
- return false;
- tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{
- k, (const ggml_fp16_t *)A, lda,
- (const float *)B, ldb,
- (float *)C, ldc,
- ith, nth};
- tb.matmul(m, n);
- return true;
-#else
- return false;
+ if (Btype == GGML_TYPE_F32) {
+ tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{ params,
+ k, (const ggml_fp16_t *)A, lda,
+ (const float *)B, ldb,
+ (float *)C, ldc};
+ return tb.matmul(m, n);
+ }
#endif
+ return false;
}
case GGML_TYPE_Q8_0: {
k, (const block_q8_0 *)A, lda,
(const block_q8_0 *)B, ldb,
(float *)C, ldc,
- ith, nth};
+ params->ith, params->nth};
tb.matmul(m, n);
return true;
#elif defined(__ARM_FEATURE_DOTPROD)
k, (const block_q8_0 *)A, lda,
(const block_q8_0 *)B, ldb,
(float *)C, ldc,
- ith, nth};
+ params->ith, params->nth};
tb.matmul(m, n);
return true;
#else
k, (const block_q4_0 *)A, lda,
(const block_q8_0 *)B, ldb,
(float *)C, ldc,
- ith, nth};
+ params->ith, params->nth};
tb.matmul(m, n);
return true;
#elif defined(__ARM_FEATURE_DOTPROD)
k, (const block_q4_0 *)A, lda,
(const block_q8_0 *)B, ldb,
(float *)C, ldc,
- ith, nth};
+ params->ith, params->nth};
tb.matmul(m, n);
return true;
#else
k, (const block_q5_0 *)A, lda,
(const block_q8_0 *)B, ldb,
(float *)C, ldc,
- ith, nth};
+ params->ith, params->nth};
tb.matmul(m, n);
return true;
#else
k, (const block_iq4_nl *)A, lda,
(const block_q8_0 *)B, ldb,
(float *)C, ldc,
- ith, nth};
+ params->ith, params->nth};
tb.matmul(m, n);
return true;
#else
return false;
}
+ (void)params;
(void)m;
(void)n;
(void)k;
(void)ldb;
(void)C;
(void)ldc;
- (void)ith;
- (void)nth;
(void)Atype;
(void)Btype;
(void)Ctype;