#pragma GCC diagnostic ignored "-Wignored-attributes"
#include "sgemm.h"
-#include <algorithm>
#include "ggml-impl.h"
#include "ggml-quants.h"
template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
class tinyBLAS {
public:
- tinyBLAS(int k,
- const TA *A, int lda,
- const TB *B, int ldb,
- TC *C, int ldc,
+ tinyBLAS(int64_t k,
+ const TA *A, int64_t lda,
+ const TB *B, int64_t ldb,
+ TC *C, int64_t ldc,
int ith, int nth)
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
}
- void matmul(int m, int n, int task) {
+ void matmul(int64_t m, int64_t n, int task) {
if (task == GGML_TASK_TYPE_COMPUTE)
mnpack(0, m, 0, n);
}
private:
- NOINLINE void mnpack(int m0, int m, int n0, int n) {
- int mc, nc, mp, np;
- switch ((std::min(m - m0, 5) << 4) | std::min(n - n0, 5)) {
+ NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+ int64_t mc, nc, mp, np;
+ switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) {
#if VECTOR_REGISTERS == 32
case 0x55:
mc = 5;
}
template <int RM, int RN>
- NOINLINE void gemm(int m0, int m, int n0, int n) {
- int ytiles = (m - m0) / RM;
- int xtiles = (n - n0) / RN;
- int tiles = xtiles * ytiles;
- int duty = (tiles + nth - 1) / nth;
- int start = duty * ith;
- int end = start + duty;
+ NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+ int64_t ytiles = (m - m0) / RM;
+ int64_t xtiles = (n - n0) / RN;
+ int64_t tiles = xtiles * ytiles;
+ int64_t duty = (tiles + nth - 1) / nth;
+ int64_t start = duty * ith;
+ int64_t end = start + duty;
if (end > tiles)
end = tiles;
- for (int job = start; job < end; ++job) {
- int ii = m0 + job / xtiles * RM;
- int jj = n0 + job % xtiles * RN;
+ for (int64_t job = start; job < end; ++job) {
+ int64_t ii = m0 + job / xtiles * RM;
+ int64_t jj = n0 + job % xtiles * RN;
D Cv[RN][RM] = {};
- for (int l = 0; l < k; l += KN)
- for (int j = 0; j < RN; ++j)
- for (int i = 0; i < RM; ++i)
+ for (int64_t l = 0; l < k; l += KN)
+ for (int64_t j = 0; j < RN; ++j)
+ for (int64_t i = 0; i < RM; ++i)
Cv[j][i] = madd(load<V>(A + lda * (ii + i) + l),
load<V>(B + ldb * (jj + j) + l),
Cv[j][i]);
- for (int j = 0; j < RN; ++j)
- for (int i = 0; i < RM; ++i)
+ for (int64_t j = 0; j < RN; ++j)
+ for (int64_t i = 0; i < RM; ++i)
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
}
}
const TA *const A;
const TB *const B;
TC *const C;
- const int k;
- const int lda;
- const int ldb;
- const int ldc;
+ const int64_t k;
+ const int64_t lda;
+ const int64_t ldb;
+ const int64_t ldc;
const int ith;
const int nth;
};
template <typename TA>
class tinyBLAS_Q0_ARM {
public:
- tinyBLAS_Q0_ARM(int k,
- const TA *A, int lda,
- const block_q8_0 *B, int ldb,
- float *C, int ldc,
+ tinyBLAS_Q0_ARM(int64_t k,
+ const TA *A, int64_t lda,
+ const block_q8_0 *B, int64_t ldb,
+ float *C, int64_t ldc,
int ith, int nth)
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
}
- void matmul(int m, int n, int task) {
+ void matmul(int64_t m, int64_t n, int task) {
if (task == GGML_TASK_TYPE_COMPUTE)
mnpack(0, m, 0, n);
}
private:
- NOINLINE void mnpack(int m0, int m, int n0, int n) {
- int mc, nc, mp, np;
- switch ((std::min(m - m0, 3) << 4) | std::min(n - n0, 3)) {
+ NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+ int64_t mc, nc, mp, np;
+ switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) {
case 0x33:
mc = 3;
nc = 3;
}
template <int RM, int RN>
- NOINLINE void gemm(int m0, int m, int n0, int n) {
- int ytiles = (m - m0) / RM;
- int xtiles = (n - n0) / RN;
- int tiles = xtiles * ytiles;
- int duty = (tiles + nth - 1) / nth;
- int start = duty * ith;
- int end = start + duty;
+ NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+ int64_t ytiles = (m - m0) / RM;
+ int64_t xtiles = (n - n0) / RN;
+ int64_t tiles = xtiles * ytiles;
+ int64_t duty = (tiles + nth - 1) / nth;
+ int64_t start = duty * ith;
+ int64_t end = start + duty;
if (end > tiles)
end = tiles;
- for (int job = start; job < end; ++job) {
- int ii = m0 + job / xtiles * RM;
- int jj = n0 + job % xtiles * RN;
+ for (int64_t job = start; job < end; ++job) {
+ int64_t ii = m0 + job / xtiles * RM;
+ int64_t jj = n0 + job % xtiles * RN;
float32x4_t Cv[RN][RM] = {};
- for (int l = 0; l < k; ++l)
- for (int j = 0; j < RN; ++j)
- for (int i = 0; i < RM; ++i)
+ for (int64_t l = 0; l < k; ++l)
+ for (int64_t j = 0; j < RN; ++j)
+ for (int64_t i = 0; i < RM; ++i)
Cv[j][i] = vmlaq_n_f32(Cv[j][i],
vcvtq_f32_s32(vdotq_s32(
vdotq_s32(vdupq_n_s32(0),
load_hi(B + ldb * (jj + j) + l))),
unhalf(A[lda * (ii + i) + l].d) *
unhalf(B[ldb * (jj + j) + l].d));
- for (int j = 0; j < RN; ++j)
- for (int i = 0; i < RM; ++i)
+ for (int64_t j = 0; j < RN; ++j)
+ for (int64_t i = 0; i < RM; ++i)
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
}
}
const TA *const A;
const block_q8_0 *const B;
float *const C;
- const int k;
- const int lda;
- const int ldb;
- const int ldc;
+ const int64_t k;
+ const int64_t lda;
+ const int64_t ldb;
+ const int64_t ldc;
const int ith;
const int nth;
};
template <typename TA, typename TB, typename TC>
class tinyBLAS_Q0_AVX2 {
public:
- tinyBLAS_Q0_AVX2(int k,
- const TA *A, int lda,
- const TB *B, int ldb,
- TC *C, int ldc,
+ tinyBLAS_Q0_AVX2(int64_t k,
+ const TA *A, int64_t lda,
+ const TB *B, int64_t ldb,
+ TC *C, int64_t ldc,
int ith, int nth)
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
}
- void matmul(int m, int n, int task) {
+ void matmul(int64_t m, int64_t n, int task) {
if (task == GGML_TASK_TYPE_COMPUTE)
mnpack(0, m, 0, n);
}
private:
- void mnpack(int m0, int m, int n0, int n) {
- int mc, nc, mp, np;
- switch ((std::min(m - m0, 4) << 4) | std::min(n - n0, 4)) {
+ void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+ int64_t mc, nc, mp, np;
+ switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) {
#if VECTOR_REGISTERS == 32
case 0x44:
mc = 4;
}
template <int RM, int RN>
- NOINLINE void gemm(int m0, int m, int n0, int n) {
- int ytiles = (m - m0) / RM;
- int xtiles = (n - n0) / RN;
- int tiles = xtiles * ytiles;
- int duty = (tiles + nth - 1) / nth;
- int start = duty * ith;
- int end = start + duty;
+ NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+ int64_t ytiles = (m - m0) / RM;
+ int64_t xtiles = (n - n0) / RN;
+ int64_t tiles = xtiles * ytiles;
+ int64_t duty = (tiles + nth - 1) / nth;
+ int64_t start = duty * ith;
+ int64_t end = start + duty;
if (end > tiles)
end = tiles;
- for (int job = start; job < end; ++job) {
- int ii = m0 + job / xtiles * RM;
- int jj = n0 + job % xtiles * RN;
+ for (int64_t job = start; job < end; ++job) {
+ int64_t ii = m0 + job / xtiles * RM;
+ int64_t jj = n0 + job % xtiles * RN;
__m256 Cv[RN][RM] = {};
- for (int l = 0; l < k; ++l)
- for (int j = 0; j < RN; ++j)
- for (int i = 0; i < RM; ++i)
+ for (int64_t l = 0; l < k; ++l)
+ for (int64_t j = 0; j < RN; ++j)
+ for (int64_t i = 0; i < RM; ++i)
Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
unhalf(B[ldb * (jj + j) + l].d)),
updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
_mm256_sign_epi8(load(B + ldb * (jj + j) + l),
load(A + lda * (ii + i) + l))),
Cv[j][i]);
- for (int j = 0; j < RN; ++j)
- for (int i = 0; i < RM; ++i)
+ for (int64_t j = 0; j < RN; ++j)
+ for (int64_t i = 0; i < RM; ++i)
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
}
}
const TA *const A;
const TB *const B;
TC *const C;
- const int k;
- const int lda;
- const int ldb;
- const int ldc;
+ const int64_t k;
+ const int64_t lda;
+ const int64_t ldb;
+ const int64_t ldc;
const int ith;
const int nth;
};
* @param Ctype is GGML data type of `C`
* @return true if this function was able to service the matmul request
*/
-bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B, int ldb, void *C,
- int ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) {
+bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
+ int64_t ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) {
assert(m >= 0);
assert(n >= 0);
assert(ldc >= m);
assert(nth > 0);
assert(ith < nth);
- assert(1ll * lda * m <= 0x7fffffff);
- assert(1ll * ldb * n <= 0x7fffffff);
- assert(1ll * ldc * n <= 0x7fffffff);
if (Ctype != GGML_TYPE_F32)
return false;