class tinyBLAS_PPC {
public:
tinyBLAS_PPC(int64_t k,
- const float *A, int64_t lda,
- const float *B, int64_t ldb,
- float *C, int64_t ldc,
+ const float * A, int64_t lda,
+ const float * B, int64_t ldb,
+ float * C, int64_t ldc,
int ith, int nth)
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
}
void matmul(int64_t m, int64_t n) {
- mnpack(0, m, 0, n);
+ int64_t mc = 256; int64_t nc = 256; int64_t kc = 256;
+ if (m % mc == 0 && n % nc == 0 && k % kc == 0) {
+ matmul_tiled(m, n, mc, nc, kc);
+ } else {
+ mnpack(0, m, 0, n);
+ }
}
private:
- void (tinyBLAS_PPC::*kernel)(int64_t, int64_t);
+ inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
+ vec_t vec_C[4];
+ __builtin_mma_disassemble_acc(vec_C, ACC);
+ for (int I = 0; I < 4; I++) {
+ for (int J = 0; J < 4; J++) {
+ *((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&vec_C[I]+J);
+ }
+ }
+ }
- inline void vector_permute_store_4(vector float *src, float *vecOffset) {
- vector float t1, t2, t3, t4, t5, t6, t7, t8;
- t1 = vec_mergeh(src[0], src[1]);
- t2 = vec_mergeh(src[2], src[3]);
- t3 = vec_mergel(src[0], src[1]);
- t4 = vec_mergel(src[2], src[3]);
+ inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
+ vec_t vec_C[4];
+ __builtin_mma_disassemble_acc(vec_C, ACC);
+ for (int I = 0; I < 4; I++) {
+ for (int J = 0; J < 4; J++) {
+ float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I);
+ *c_ptr += *((float *)&vec_C[I]+J);
+ }
+ }
+ }
- t5 = vec_xxpermdi(t1, t2, 0);
- t6 = vec_xxpermdi(t1, t2, 3);
- t7 = vec_xxpermdi(t3, t4, 0);
- t8 = vec_xxpermdi(t3, t4, 3);
+ inline void vector_permute_store_4(vector float * src, float * vecOffset) {
+ vector float t1, t2, t3, t4, t5, t6, t7, t8;
+ t1 = vec_mergeh(src[0], src[1]);
+ t2 = vec_mergeh(src[2], src[3]);
+ t3 = vec_mergel(src[0], src[1]);
+ t4 = vec_mergel(src[2], src[3]);
- vec_xst(t5, 0, vecOffset);
- vec_xst(t6, 0, vecOffset + 4);
- vec_xst(t7, 0, vecOffset + 8);
- vec_xst(t8, 0, vecOffset + 12);
- }
+ t5 = vec_xxpermdi(t1, t2, 0);
+ t6 = vec_xxpermdi(t1, t2, 3);
+ t7 = vec_xxpermdi(t3, t4, 0);
+ t8 = vec_xxpermdi(t3, t4, 3);
- inline void vector_permute_store_8(vector float *src, float *vecOffset) {
- vector float t1, t2, t3, t4, t5, t6, t7, t8;
- t1 = vec_mergeh(src[0], src[1]);
- t2 = vec_mergeh(src[2], src[3]);
- t3 = vec_mergeh(src[4], src[5]);
- t4 = vec_mergeh(src[6], src[7]);
+ vec_xst(t5, 0, vecOffset);
+ vec_xst(t6, 0, vecOffset + 4);
+ vec_xst(t7, 0, vecOffset + 8);
+ vec_xst(t8, 0, vecOffset + 12);
+ }
- t5 = vec_xxpermdi(t1, t2, 0);
- t6 = vec_xxpermdi(t3, t4, 0);
- t7 = vec_xxpermdi(t1, t2, 3);
- t8 = vec_xxpermdi(t3, t4, 3);
+ inline void vector_permute_store_8(vector float * src, float * vecOffset) {
+ vector float t1, t2, t3, t4, t5, t6, t7, t8;
+ t1 = vec_mergeh(src[0], src[1]);
+ t2 = vec_mergeh(src[2], src[3]);
+ t3 = vec_mergeh(src[4], src[5]);
+ t4 = vec_mergeh(src[6], src[7]);
- vec_xst(t5, 0, vecOffset);
- vec_xst(t6, 0, vecOffset + 4);
- vec_xst(t7, 0, vecOffset + 8);
- vec_xst(t8, 0, vecOffset + 12);
+ t5 = vec_xxpermdi(t1, t2, 0);
+ t6 = vec_xxpermdi(t3, t4, 0);
+ t7 = vec_xxpermdi(t1, t2, 3);
+ t8 = vec_xxpermdi(t3, t4, 3);
- t1 = vec_mergel(src[0], src[1]);
- t2 = vec_mergel(src[2], src[3]);
- t3 = vec_mergel(src[4], src[5]);
- t4 = vec_mergel(src[6], src[7]);
+ vec_xst(t5, 0, vecOffset);
+ vec_xst(t6, 0, vecOffset + 4);
+ vec_xst(t7, 0, vecOffset + 8);
+ vec_xst(t8, 0, vecOffset + 12);
- t5 = vec_xxpermdi(t1, t2, 0);
- t6 = vec_xxpermdi(t3, t4, 0);
- t7 = vec_xxpermdi(t1, t2, 3);
- t8 = vec_xxpermdi(t3, t4, 3);
+ t1 = vec_mergel(src[0], src[1]);
+ t2 = vec_mergel(src[2], src[3]);
+ t3 = vec_mergel(src[4], src[5]);
+ t4 = vec_mergel(src[6], src[7]);
- vec_xst(t5, 0, vecOffset + 16);
- vec_xst(t6, 0, vecOffset + 20);
- vec_xst(t7, 0, vecOffset + 24);
- vec_xst(t8, 0, vecOffset + 28);
+ t5 = vec_xxpermdi(t1, t2, 0);
+ t6 = vec_xxpermdi(t3, t4, 0);
+ t7 = vec_xxpermdi(t1, t2, 3);
+ t8 = vec_xxpermdi(t3, t4, 3);
+
+ vec_xst(t5, 0, vecOffset + 16);
+ vec_xst(t6, 0, vecOffset + 20);
+ vec_xst(t7, 0, vecOffset + 24);
+ vec_xst(t8, 0, vecOffset + 28);
}
- void packTranspose(const float* a, int64_t lda, int rows, int cols, float* vec) {
+ void packTranspose(const float * a, int64_t lda, int rows, int cols, float * vec) {
int64_t i, j;
float * aoffsets[8];
- float *aoffset = NULL, *boffset = NULL;
+ float * aoffset = NULL, * boffset = NULL;
__vector_pair arr[8];
vector float c[8][2] = {0};
vector float c1[8] = {0};
vector float c2[8] = {0};
- aoffset = const_cast<float*>(a);
+ aoffset = const_cast<float *>(a);
boffset = vec;
j = (rows >> 3);
if (j > 0) {
-
do {
aoffsets[0] = aoffset;
- for (int it = 1; it< 8; it++)
+ for (int it = 1; it < 8; it++)
aoffsets[it] = aoffsets[it-1] + lda;
aoffset += 8 * lda;
i = (cols >> 3);
if (i > 0) {
do {
- for (int it = 0; it< 8; it++) {
+ for (int it = 0; it < 8; it++) {
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]);
__builtin_vsx_disassemble_pair(c[it], &arr[it]);
c1[it] = c[it][0];
}
vector_permute_store_8(c1, boffset);
- vector_permute_store_8(c2, boffset+32);
- for (int it = 0; it < 4; it++)
- aoffsets[it] = aoffsets[it] + 8*lda;
+ vector_permute_store_8(c2, boffset + 32);
boffset += 64;
i--;
+ if (i > 0) {
+ for (int it = 0; it < 8; it++) {
+ aoffsets[it] = aoffsets[it] + 8;
+ }
+ }
} while(i > 0);
}
if (cols & 4) {
c2[it] = c[it][1];
}
vector_permute_store_4(c1, boffset);
- vector_permute_store_4(c2, boffset+16);
+ vector_permute_store_4(c2, boffset + 16);
for (int it = 0; it < 4; it++)
- aoffsets[it] += 8*lda;
+ aoffsets[it] += 8 * lda;
boffset += 32;
i--;
} while(i > 0);
vec_t vec_A[4], vec_B[4], vec_C[4];
acc_t acc_0;
__builtin_mma_xxsetaccz(&acc_0);
- for (int l = 0; l < k; l+=4) {
- packTranspose(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
- packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
+ for (int l = 0; l < k; l += 4) {
+ packTranspose(A + (ii * lda) + l, lda, 4, 4, (float *)vec_A);
+ packTranspose(B + (jj * ldb) + l, ldb, 4, 4, (float *)vec_B);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
}
- SAVE_ACC(&acc_0, ii, jj);
+ save_acc(&acc_0, ii, jj);
}
void KERNEL_4x8(int64_t ii, int64_t jj) {
acc_t acc_0, acc_1;
__builtin_mma_xxsetaccz(&acc_0);
__builtin_mma_xxsetaccz(&acc_1);
- for (int64_t l = 0; l < k; l+=4) {
- packTranspose(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
- packTranspose(B+(jj*ldb)+l, ldb, 8, 4, (float*)vec_B);
+ for (int64_t l = 0; l < k; l += 4) {
+ packTranspose(A + (ii * lda) + l, lda, 4, 4, (float *)vec_A);
+ packTranspose(B + (jj * ldb) + l, ldb, 8, 4, (float *)vec_B);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
__builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[3], (vec_t)vec_B[6]);
__builtin_mma_xvf32gerpp(&acc_1, vec_A[3], (vec_t)vec_B[7]);
}
- SAVE_ACC(&acc_0, ii, jj);
- SAVE_ACC(&acc_1, ii, jj+4);
+ save_acc(&acc_0, ii, jj);
+ save_acc(&acc_1, ii, jj + 4);
}
void KERNEL_8x4(int64_t ii, int64_t jj) {
acc_t acc_0, acc_1;
__builtin_mma_xxsetaccz(&acc_0);
__builtin_mma_xxsetaccz(&acc_1);
- for (int64_t l = 0; l < k; l+=4) {
- packTranspose(A+(ii*lda)+l, lda, 8, 4, (float*)vec_A);
- packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
+ for (int64_t l = 0; l < k; l += 4) {
+ packTranspose(A + (ii * lda) + l, lda, 8, 4, (float *)vec_A);
+ packTranspose(B + (jj * ldb) + l, ldb, 4, 4, (float *)vec_B);
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[6], vec_B[3]);
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[7], vec_B[3]);
}
- SAVE_ACC(&acc_0, ii, jj);
- SAVE_ACC(&acc_1, ii+4, jj);
+ save_acc(&acc_0, ii, jj);
+ save_acc(&acc_1, ii + 4, jj);
}
void KERNEL_8x8(int64_t ii, int64_t jj) {
__builtin_mma_xxsetaccz(&acc_2);
__builtin_mma_xxsetaccz(&acc_3);
for (int l = 0; l < k; l+=8) {
- packTranspose(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A);
- packTranspose(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B);
+ packTranspose(A + (ii * lda) + l, lda, 8, 8, (float *)vec_A);
+ packTranspose(B + (jj * ldb) + l, ldb, 8, 8, (float *)vec_B);
for(int x = 0; x < 16; x+=2) {
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
- __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]);
- __builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x+1], vec_B[x]);
- __builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x+1], vec_B[x+1]);
+ __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x + 1]);
+ __builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x + 1], vec_B[x]);
+ __builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x + 1], vec_B[x + 1]);
+ }
+ }
+ save_acc(&acc_0, ii, jj);
+ save_acc(&acc_1, ii, jj + 4);
+ save_acc(&acc_2, ii + 4, jj);
+ save_acc(&acc_3, ii + 4, jj + 4);
+ }
+
+ inline void MMA_16x8(vec_t * vec_A0, vec_t * vec_A1, vec_t * vec_B, acc_t * acc) {
+ for (int x = 0; x < 16; x += 2) {
+ __builtin_mma_xvf32gerpp(&acc[0], vec_A0[x + 0], vec_B[x]);
+ __builtin_mma_xvf32gerpp(&acc[1], vec_A0[x + 0], vec_B[x + 1]);
+ __builtin_mma_xvf32gerpp(&acc[2], vec_A0[x + 1], vec_B[x]);
+ __builtin_mma_xvf32gerpp(&acc[3], vec_A0[x + 1], vec_B[x + 1]);
+ __builtin_mma_xvf32gerpp(&acc[4], vec_A1[x + 0], vec_B[x]);
+ __builtin_mma_xvf32gerpp(&acc[5], vec_A1[x + 0], vec_B[x + 1]);
+ __builtin_mma_xvf32gerpp(&acc[6], vec_A1[x + 1], vec_B[x]);
+ __builtin_mma_xvf32gerpp(&acc[7], vec_A1[x + 1], vec_B[x + 1]);
+ }
+ }
+
+ void KERNEL(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, vec_t * vec_A, vec_t * vec_B, int64_t kk) {
+ for (int64_t i = 0; i < mc; i += 16) {
+ int A_base_addr = (mc / 8) * (i / 8) * 16;
+ for (int64_t j = 0; j < nc; j += 8) {
+ int B_base_addr = (nc / 8) * (j / 8) * 16;
+ acc_t acc[8];
+ vec_t A0_block[16]; vec_t A1_block[16];
+ for (int x = 0; x < 8; x++)
+ __builtin_mma_xxsetaccz(&acc[x]);
+ for (int64_t l = 0; l < kc; l += 8) {
+ int A0_block_idx = A_base_addr + (l / 8) * 16;
+ int A1_block_idx = A0_block_idx + (mc / 8) * 16;
+ int B_block_idx = B_base_addr + (l / 8) * 16;
+ vec_t* A0_block = &vec_A[A0_block_idx];
+ vec_t* A1_block = &vec_A[A1_block_idx];
+ vec_t* B_block = &vec_B[B_block_idx];
+ MMA_16x8(A0_block, A1_block, B_block, acc);
+ }
+ if (kk == 0) {
+ save_acc(&acc[0], ii + i, jj + j);
+ save_acc(&acc[1], ii + i, jj + j + 4);
+ save_acc(&acc[2], ii + i + 4, jj + j);
+ save_acc(&acc[3], ii + i + 4, jj + j + 4);
+ save_acc(&acc[4], ii + i + 8, jj + j);
+ save_acc(&acc[5], ii + i + 8, jj + j + 4);
+ save_acc(&acc[6], ii + i + 12, jj + j);
+ save_acc(&acc[7], ii + i + 12, jj + j + 4);
+ } else {
+ add_save_acc(&acc[0], ii + i, jj + j);
+ add_save_acc(&acc[1], ii + i, jj + j + 4);
+ add_save_acc(&acc[2], ii + i + 4, jj + j);
+ add_save_acc(&acc[3], ii + i + 4, jj + j + 4);
+ add_save_acc(&acc[4], ii + i + 8, jj + j);
+ add_save_acc(&acc[5], ii + i + 8, jj + j + 4);
+ add_save_acc(&acc[6], ii + i + 12, jj + j);
+ add_save_acc(&acc[7], ii + i + 12, jj + j + 4);
+ }
+ }
+ }
+ }
+
+ void matmul_tiled(int64_t m , int64_t n, int64_t mc, int64_t nc, int64_t kc) {
+ int64_t ytiles = m / mc;
+ int64_t xtiles = n / nc;
+ int64_t tiles = xtiles * ytiles;
+ int64_t duty = (tiles + nth - 1) / nth;
+ int64_t start = duty * ith;
+ int64_t end = start + duty;
+ if (end > tiles) {
+ end = tiles;
+ }
+ for (int64_t job = start; job < end; ++job) {
+ int64_t ii = (job / xtiles) * mc;
+ int64_t jj = (job % xtiles) * nc;
+ for (int64_t kk = 0; kk < k; kk += kc) {
+ vec_t A_pack[kc * mc / 4];
+ vec_t B_pack[kc * nc / 4];
+ packTranspose(A + (ii * lda) + kk, lda, kc, mc, (float *)A_pack);
+ packTranspose(B + (jj * ldb) + kk, ldb, kc, nc, (float *)B_pack);
+ KERNEL(ii, jj, mc, nc, kc, A_pack, B_pack, kk);
}
}
- SAVE_ACC(&acc_0, ii, jj);
- SAVE_ACC(&acc_1, ii, jj+4);
- SAVE_ACC(&acc_2, ii+4, jj);
- SAVE_ACC(&acc_3, ii+4, jj+4);
}
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
int n_rem = MIN(n - n0, 8);
int mc = 0, nc = 0;
if (m_rem >= 8 && n_rem >= 8) {
- mc = 8;
- nc = 8;
- gemm<8, 8>(m0, m, n0, n);
+ mc = 8;
+ nc = 8;
+ gemm<8, 8>(m0, m, n0, n);
} else if (m_rem >= 4 && n_rem >= 8) {
- mc = 4;
- nc = 8;
- gemm<4, 8>(m0, m, n0, n);
+ mc = 4;
+ nc = 8;
+ gemm<4, 8>(m0, m, n0, n);
} else if (m_rem >= 8 && n_rem >= 4) {
- mc = 8;
- nc = 4;
- gemm<8, 4>(m0, m, n0, n);
+ mc = 8;
+ nc = 4;
+ gemm<8, 4>(m0, m, n0, n);
} else if (m_rem >= 4 && n_rem >= 4) {
- mc = 4;
- nc = 4;
- gemm<4, 4>(m0, m, n0, n);
+ mc = 4;
+ nc = 4;
+ gemm<4, 4>(m0, m, n0, n);
} else {
mc = (m_rem >= 4) ? 4 : m_rem;
nc = (n_rem >= 4) ? 4 : n_rem;
if (mc == 0 || nc == 0)
- return;
+ return;
gemm_small(m0, m, n0, n, mc, nc);
}
int64_t mp = m0 + ((m - m0) / mc) * mc;
int64_t np = n0 + ((n - n0) / nc) * nc;
mnpack(mp, m, n0, np);
mnpack(m0, m, np, n);
- }
+ }
- void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
+ void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
int64_t ytiles = (m - m0) / RM;
int64_t xtiles = (n - n0) / RN;
int64_t tiles = xtiles * ytiles;
vec_t vec_C[4];
acc_t acc_0;
__builtin_mma_xxsetaccz(&acc_0);
- vec_t vec_A[4] {0}, vec_B[4] = {0};
- for (int l=0; l<k; l+=4) {
+ vec_t vec_A[4] = {0}, vec_B[4] = {0};
+ for (int l = 0; l < k; l += 4) {
/* 'GEMV Forwarding' concept is used in first two conditional loops.
* when one of the matrix has a single row/column, the elements are
* broadcasted, instead of using packing routine to prepack the
* matrix elements.
*/
if (RM == 1) {
- float* a = const_cast<float*>(A+(ii)*lda+l);
- packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B);
+ float * a = const_cast<float *>(A + (ii) * lda + l);
+ packTranspose(B + (jj * ldb) + l, ldb, RN, 4, (float *)vec_B);
vec_A[0] = (vec_t)vec_xl(0,a);
- vec_A[1] = (vec_t)vec_splats(*((float*)&vec_A+1));
- vec_A[2] = (vec_t)vec_splats(*((float*)&vec_A+2));
- vec_A[3] = (vec_t)vec_splats(*((float*)&vec_A+3));
+ vec_A[1] = (vec_t)vec_splats(*((float *)&vec_A+1));
+ vec_A[2] = (vec_t)vec_splats(*((float *)&vec_A+2));
+ vec_A[3] = (vec_t)vec_splats(*((float *)&vec_A+3));
} else if (RN == 1) {
- packTranspose(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A);
- float* b = const_cast<float*>(B+(jj)*ldb+l);
+ packTranspose(A + (ii * lda) + l, lda, RM, 4, (float *)vec_A);
+ float * b = const_cast<float *>(B + (jj) * ldb + l);
vec_B[0] = (vec_t)vec_xl(0,b);
- vec_B[1] = (vec_t)vec_splats(*((float*)&vec_B+1));
- vec_B[2] = (vec_t)vec_splats(*((float*)&vec_B+2));
- vec_B[3] = (vec_t)vec_splats(*((float*)&vec_B+3));
+ vec_B[1] = (vec_t)vec_splats(*((float *)&vec_B+1));
+ vec_B[2] = (vec_t)vec_splats(*((float *)&vec_B+2));
+ vec_B[3] = (vec_t)vec_splats(*((float *)&vec_B+3));
} else {
- packTranspose(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A);
- packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B);
+ packTranspose(A + (ii * lda) + l, lda, RM, 4, (float *)vec_A);
+ packTranspose(B + (jj * ldb) + l, ldb, RN, 4, (float *)vec_B);
}
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
__builtin_mma_disassemble_acc(vec_C, &acc_0);
for (int I = 0; I < RM; I++) {
for (int J = 0; J < RN; J++) {
- *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J);
+ *((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&vec_C[I]+J);
}
}
}
}
+ template<int RM, int RN>
+ inline void kernel(int64_t ii, int64_t jj) {
+ if constexpr(RM == 4 && RN == 4) {
+ KERNEL_4x4(ii, jj);
+ } else if constexpr(RM == 4 && RN == 8) {
+ KERNEL_4x8(ii, jj);
+ } else if constexpr(RM == 8 && RN == 4) {
+ KERNEL_8x4(ii, jj);
+ } else if constexpr(RM == 8 && RN == 8) {
+ KERNEL_8x8(ii, jj);
+ } else {
+ static_assert(false, "RN/RM values not supported");
+ }
+ }
+
template <int RM, int RN>
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
int64_t ytiles = (m - m0) / RM;
int64_t duty = (tiles + nth - 1) / nth;
int64_t start = duty * ith;
int64_t end = start + duty;
- if (RM == 4 && RN == 4) {
- kernel = &tinyBLAS_PPC::KERNEL_4x4;
- } else if (RM == 4 && RN == 8) {
- kernel = &tinyBLAS_PPC::KERNEL_4x8;
- } else if (RM == 8 && RN == 4) {
- kernel = &tinyBLAS_PPC::KERNEL_8x4;
- } else if (RM == 8 && RN == 8) {
- kernel = &tinyBLAS_PPC::KERNEL_8x8;
- }
if (end > tiles)
end = tiles;
for (int64_t job = start; job < end; ++job) {
int64_t ii = m0 + job / xtiles * RM;
int64_t jj = n0 + job % xtiles * RN;
- (this->*kernel)(ii, jj);
+ kernel<RM, RN>(ii, jj);
}
}
- const float *const A;
- const float *const B;
- float *C;
+ const float * const A;
+ const float * const B;
+ float * C;
const int64_t k;
const int64_t lda;
const int64_t ldb;