inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+#if defined(__MMA__)
+typedef vector unsigned char vec_t;
+typedef __vector_quad acc_t;
+#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
// VECTORIZED FUSED MULTIPLY ADD
};
#endif // __AVX__
+//PPC Implementation
+#if defined(__MMA__)
+
+#define SAVE_ACC(ACC, ii, jj) \
+ __builtin_mma_disassemble_acc(vec_C, ACC); \
+ for (int I = 0; I < 4; I++) { \
+ for (int J = 0; J < 4; J++) { \
+ *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J); \
+ } \
+ } \
+
+template <typename TA, typename TB, typename TC>
+class tinyBLAS_PPC {
+ public:
+ tinyBLAS_PPC(int64_t k,
+ const TA *A, int64_t lda,
+ const TB *B, int64_t ldb,
+ TC *C, int64_t ldc,
+ int ith, int nth)
+ : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
+ }
+
+ void matmul(int64_t m, int64_t n) {
+ mnpack(0, m, 0, n);
+ }
+
+ private:
+
+ void (tinyBLAS_PPC::*kernel)(int64_t, int64_t);
+
+ void READ_BLOCK(const float* a, int64_t lda, int rows, int cols, float* vec) {
+ int64_t i, j;
+ float *aoffset = NULL, *boffset = NULL;
+ float *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
+ float *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
+
+ aoffset = const_cast<float*>(a);
+ boffset = vec;
+ j = (rows >> 3);
+ if (j > 0) {
+ do {
+ aoffset1 = aoffset;
+ aoffset2 = aoffset1 + lda;
+ aoffset3 = aoffset2 + lda;
+ aoffset4 = aoffset3 + lda;
+ aoffset5 = aoffset4 + lda;
+ aoffset6 = aoffset5 + lda;
+ aoffset7 = aoffset6 + lda;
+ aoffset8 = aoffset7 + lda;
+ aoffset += 8 * lda;
+ i = (cols >> 3);
+ if (i > 0) {
+ __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
+ vector float c1[2], c2[2], c3[2], c4[2], c5[2], c6[2], c7[2], c8[2];
+ vector float t1, t2, t3, t4, t5, t6, t7, t8;
+ do {
+ C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
+ C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
+ C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3);
+ C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4);
+ C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5);
+ C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6);
+ C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7);
+ C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8);
+ __builtin_vsx_disassemble_pair(c1, &C1);
+ __builtin_vsx_disassemble_pair(c2, &C2);
+ __builtin_vsx_disassemble_pair(c3, &C3);
+ __builtin_vsx_disassemble_pair(c4, &C4);
+ __builtin_vsx_disassemble_pair(c5, &C5);
+ __builtin_vsx_disassemble_pair(c6, &C6);
+ __builtin_vsx_disassemble_pair(c7, &C7);
+ __builtin_vsx_disassemble_pair(c8, &C8);
+
+ t1 = vec_mergeh(c1[0], c2[0]);
+ t2 = vec_mergeh(c3[0], c4[0]);
+ t3 = vec_mergeh(c5[0], c6[0]);
+ t4 = vec_mergeh(c7[0], c8[0]);
+ t5 = vec_xxpermdi(t1, t2, 0);
+ t6 = vec_xxpermdi(t3, t4, 0);
+ t7 = vec_xxpermdi(t1, t2, 3);
+ t8 = vec_xxpermdi(t3, t4, 3);
+ vec_xst(t5, 0, boffset);
+ vec_xst(t6, 0, boffset+4);
+ vec_xst(t7, 0, boffset+8);
+ vec_xst(t8, 0, boffset+12);
+
+ t1 = vec_mergel(c1[0], c2[0]);
+ t2 = vec_mergel(c3[0], c4[0]);
+ t3 = vec_mergel(c5[0], c6[0]);
+ t4 = vec_mergel(c7[0], c8[0]);
+ t5 = vec_xxpermdi(t1, t2, 0);
+ t6 = vec_xxpermdi(t3, t4, 0);
+ t7 = vec_xxpermdi(t1, t2, 3);
+ t8 = vec_xxpermdi(t3, t4, 3);
+ vec_xst(t5, 0, boffset+16);
+ vec_xst(t6, 0, boffset+20);
+ vec_xst(t7, 0, boffset+24);
+ vec_xst(t8, 0, boffset+28);
+
+ t1 = vec_mergeh(c1[1], c2[1]);
+ t2 = vec_mergeh(c3[1], c4[1]);
+ t3 = vec_mergeh(c5[1], c6[1]);
+ t4 = vec_mergeh(c7[1], c8[1]);
+ t5 = vec_xxpermdi(t1, t2, 0);
+ t6 = vec_xxpermdi(t3, t4, 0);
+ t7 = vec_xxpermdi(t1, t2, 3);
+ t8 = vec_xxpermdi(t3, t4, 3);
+ vec_xst(t5, 0, boffset+32);
+ vec_xst(t6, 0, boffset+36);
+ vec_xst(t7, 0, boffset+40);
+ vec_xst(t8, 0, boffset+44);
+
+ t1 = vec_mergel(c1[1], c2[1]);
+ t2 = vec_mergel(c3[1], c4[1]);
+ t3 = vec_mergel(c5[1], c6[1]);
+ t4 = vec_mergel(c7[1], c8[1]);
+ t5 = vec_xxpermdi(t1, t2, 0);
+ t6 = vec_xxpermdi(t3, t4, 0);
+ t7 = vec_xxpermdi(t1, t2, 3);
+ t8 = vec_xxpermdi(t3, t4, 3);
+ vec_xst(t5, 0, boffset+48);
+ vec_xst(t6, 0, boffset+52);
+ vec_xst(t7, 0, boffset+56);
+ vec_xst(t8, 0, boffset+60);
+
+ aoffset1 += 8*lda;
+ aoffset2 += 8*lda;
+ aoffset3 += 8*lda;
+ aoffset4 += 8*lda;
+ boffset += 64;
+ i--;
+ } while(i > 0);
+ }
+ if (cols & 4) {
+ vector float c1, c2, c3, c4, c5, c6, c7, c8;
+ vector float t1, t2, t3, t4, t5, t6, t7, t8;
+ c1 = vec_xl(0, aoffset1);
+ c2 = vec_xl(0, aoffset2);
+ c3 = vec_xl(0, aoffset3);
+ c4 = vec_xl(0, aoffset4);
+ c5 = vec_xl(0, aoffset5);
+ c6 = vec_xl(0, aoffset6);
+ c7 = vec_xl(0, aoffset7);
+ c8 = vec_xl(0, aoffset8);
+
+ t1 = vec_mergeh(c1, c2);
+ t2 = vec_mergeh(c3, c4);
+ t3 = vec_mergeh(c5, c6);
+ t4 = vec_mergeh(c7, c8);
+ t5 = vec_xxpermdi(t1, t2, 0);
+ t6 = vec_xxpermdi(t3, t4, 0);
+ t7 = vec_xxpermdi(t1, t2, 3);
+ t8 = vec_xxpermdi(t3, t4, 3);
+ vec_xst(t5, 0, boffset);
+ vec_xst(t6, 0, boffset+4);
+ vec_xst(t7, 0, boffset+8);
+ vec_xst(t8, 0, boffset+12);
+
+ t1 = vec_mergel(c1, c2);
+ t2 = vec_mergel(c3, c4);
+ t3 = vec_mergel(c5, c6);
+ t4 = vec_mergel(c7, c8);
+ t5 = vec_xxpermdi(t1, t2, 0);
+ t6 = vec_xxpermdi(t3, t4, 0);
+ t7 = vec_xxpermdi(t1, t2, 3);
+ t8 = vec_xxpermdi(t3, t4, 3);
+ vec_xst(t5, 0, boffset+16);
+ vec_xst(t6, 0, boffset+20);
+ vec_xst(t7, 0, boffset+24);
+ vec_xst(t8, 0, boffset+28);
+ }
+ j--;
+ } while(j > 0);
+ }
+
+ if (rows & 4) {
+ aoffset1 = aoffset;
+ aoffset2 = aoffset1 + lda;
+ aoffset3 = aoffset2 + lda;
+ aoffset4 = aoffset3 + lda;
+ aoffset += 4 * lda;
+ i = (cols >> 3);
+ if (i > 0) {
+ __vector_pair C1, C2, C3, C4;
+ vector float c1[2], c2[2], c3[2], c4[2];
+ vector float t1, t2, t3, t4, t5, t6, t7, t8;
+ do {
+ C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
+ C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
+ C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3);
+ C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4);
+ __builtin_vsx_disassemble_pair(c1, &C1);
+ __builtin_vsx_disassemble_pair(c2, &C2);
+ __builtin_vsx_disassemble_pair(c3, &C3);
+ __builtin_vsx_disassemble_pair(c4, &C4);
+
+ t1 = vec_mergeh(c1[0], c2[0]);
+ t2 = vec_mergeh(c3[0], c4[0]);
+ t3 = vec_mergel(c1[0], c2[0]);
+ t4 = vec_mergel(c3[0], c4[0]);
+ t5 = vec_xxpermdi(t1, t2, 0);
+ t6 = vec_xxpermdi(t1, t2, 3);
+ t7 = vec_xxpermdi(t3, t4, 0);
+ t8 = vec_xxpermdi(t3, t4, 3);
+ vec_xst(t5, 0, boffset);
+ vec_xst(t6, 0, boffset+4);
+ vec_xst(t7, 0, boffset+8);
+ vec_xst(t8, 0, boffset+12);
+
+ t1 = vec_mergeh(c1[1], c2[1]);
+ t2 = vec_mergeh(c3[1], c4[1]);
+ t3 = vec_mergel(c1[1], c2[1]);
+ t4 = vec_mergel(c3[1], c4[1]);
+ t5 = vec_xxpermdi(t1, t2, 0);
+ t6 = vec_xxpermdi(t1, t2, 3);
+ t7 = vec_xxpermdi(t3, t4, 0);
+ t8 = vec_xxpermdi(t3, t4, 3);
+ vec_xst(t5, 0, boffset+16);
+ vec_xst(t6, 0, boffset+20);
+ vec_xst(t7, 0, boffset+24);
+ vec_xst(t8, 0, boffset+28);
+
+ aoffset1 += 8*lda;
+ aoffset2 += 8*lda;
+ aoffset3 += 8*lda;
+ aoffset4 += 8*lda;
+ boffset += 32;
+ i--;
+ } while(i > 0);
+ }
+
+ if (cols & 4) {
+ vector float c1, c2, c3, c4;
+ vector float t1, t2, t3, t4;
+ c1 = vec_xl(0, aoffset1);
+ c2 = vec_xl(0, aoffset2);
+ c3 = vec_xl(0, aoffset3);
+ c4 = vec_xl(0, aoffset4);
+
+ t1 = vec_mergeh(c1, c2);
+ t2 = vec_mergeh(c3, c4);
+ t3 = vec_xxpermdi(t1, t2, 0);
+ t4 = vec_xxpermdi(t1, t2, 3);
+ vec_xst(t3, 0, boffset);
+ vec_xst(t4, 0, boffset+4);
+
+ t1 = vec_mergel(c1, c2);
+ t2 = vec_mergel(c3, c4);
+ t3 = vec_xxpermdi(t1, t2, 0);
+ t4 = vec_xxpermdi(t1, t2, 3);
+ vec_xst(t3, 0, boffset+8);
+ vec_xst(t4, 0, boffset+12);
+ }
+ }
+ if (rows & 3) {
+ aoffset1 = aoffset;
+ aoffset2 = aoffset1 + lda;
+ aoffset3 = aoffset2 + lda;
+ if (cols & 4) {
+ vector float c1, c2, c3, c4 = {0};
+ vector float t1, t2, t3, t4;
+ c1 = vec_xl(0, aoffset1);
+ c2 = vec_xl(0, aoffset2);
+ c3 = vec_xl(0, aoffset3);
+
+ t1 = vec_mergeh(c1, c2);
+ t2 = vec_mergeh(c3, c4);
+ t3 = vec_xxpermdi(t1, t2, 0);
+ t4 = vec_xxpermdi(t1, t2, 3);
+ vec_xst(t3, 0, boffset);
+ vec_xst(t4, 0, boffset+4);
+
+ t1 = vec_mergel(c1, c2);
+ t2 = vec_mergel(c3, c4);
+ t3 = vec_xxpermdi(t1, t2, 0);
+ t4 = vec_xxpermdi(t1, t2, 3);
+ vec_xst(t3, 0, boffset+8);
+ vec_xst(t4, 0, boffset+12);
+ }
+ }
+ }
+
+ void KERNEL_4x4(int64_t ii, int64_t jj) {
+ vec_t vec_A[4], vec_B[4], vec_C[4];
+ acc_t acc_0;
+ __builtin_mma_xxsetaccz(&acc_0);
+ for (int l = 0; l < k; l+=4) {
+ READ_BLOCK(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
+ READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
+ }
+ SAVE_ACC(&acc_0, ii, jj);
+ }
+
+ void KERNEL_4x8(int64_t ii, int64_t jj) {
+ vec_t vec_A[4], vec_B[8], vec_C[4];
+ acc_t acc_0, acc_1;
+ __builtin_mma_xxsetaccz(&acc_0);
+ __builtin_mma_xxsetaccz(&acc_1);
+ for (int64_t l = 0; l < k; l+=4) {
+ READ_BLOCK(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
+ READ_BLOCK(B+(jj*ldb)+l, ldb, 8, 4, (float*)vec_B);
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
+ __builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);
+ __builtin_mma_xvf32gerpp(&acc_1, vec_A[1], (vec_t)vec_B[3]);
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], (vec_t)vec_B[4]);
+ __builtin_mma_xvf32gerpp(&acc_1, vec_A[2], (vec_t)vec_B[5]);
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], (vec_t)vec_B[6]);
+ __builtin_mma_xvf32gerpp(&acc_1, vec_A[3], (vec_t)vec_B[7]);
+ }
+ SAVE_ACC(&acc_0, ii, jj);
+ SAVE_ACC(&acc_1, ii, jj+4);
+ }
+
+ void KERNEL_8x4(int64_t ii, int64_t jj) {
+ vec_t vec_A[8], vec_B[4], vec_C[4];
+ acc_t acc_0, acc_1;
+ __builtin_mma_xxsetaccz(&acc_0);
+ __builtin_mma_xxsetaccz(&acc_1);
+ for (int64_t l = 0; l < k; l+=4) {
+ READ_BLOCK(A+(ii*lda)+l, lda, 8, 4, (float*)vec_A);
+ READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
+ __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
+ __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
+ __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);
+ __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[3], vec_B[1]);
+ __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[4], vec_B[2]);
+ __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[5], vec_B[2]);
+ __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[6], vec_B[3]);
+ __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[7], vec_B[3]);
+ }
+ SAVE_ACC(&acc_0, ii, jj);
+ SAVE_ACC(&acc_1, ii+4, jj);
+ }
+
+ void KERNEL_8x8(int64_t ii, int64_t jj) {
+ vec_t vec_A[16], vec_B[16], vec_C[4];
+ acc_t acc_0, acc_1, acc_2, acc_3;
+ __builtin_mma_xxsetaccz(&acc_0);
+ __builtin_mma_xxsetaccz(&acc_1);
+ __builtin_mma_xxsetaccz(&acc_2);
+ __builtin_mma_xxsetaccz(&acc_3);
+ for (int l = 0; l < k; l+=8) {
+ READ_BLOCK(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A);
+ READ_BLOCK(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B);
+ for(int x = 0; x < 16; x+=2) {
+ __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
+ __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]);
+ __builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x+1], vec_B[x]);
+ __builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x+1], vec_B[x+1]);
+ }
+ }
+ SAVE_ACC(&acc_0, ii, jj);
+ SAVE_ACC(&acc_1, ii, jj+4);
+ SAVE_ACC(&acc_2, ii+4, jj);
+ SAVE_ACC(&acc_3, ii+4, jj+4);
+ }
+
+ void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+ int64_t mc, nc, mp, np;
+ int m_rem = MIN(m - m0, 16);
+ int n_rem = MIN(n - n0, 16);
+ if (m_rem >= 16 && n_rem >= 8) {
+ mc = 8;
+ nc = 8;
+ gemm<8,8>(m0, m, n0, n);
+ } else if(m_rem >= 8 && n_rem >= 16) {
+ mc = 8;
+ nc = 8;
+ gemm<8,8>(m0, m, n0, n);
+ } else if (m_rem >= 8 && n_rem >= 8) {
+ mc = 8;
+ nc = 8;
+ gemm<8,8>(m0, m, n0, n);
+ } else if (m_rem >= 4 && n_rem >= 8) {
+ mc = 4;
+ nc = 8;
+ gemm<4,8>(m0, m, n0, n);
+ } else if (m_rem >= 8 && n_rem >= 4) {
+ mc = 8;
+ nc = 4;
+ gemm<8,4>(m0, m, n0, n);
+ } else if (m_rem >= 4 && n_rem >= 4) {
+ mc = 4;
+ nc = 4;
+ gemm<4,4>(m0, m, n0, n);
+ } else if ((m_rem < 4) && (n_rem > 4)) {
+ nc = 4;
+ switch(m_rem) {
+ case 1:
+ mc = 1;
+ gemm_small(m0, m, n0, n, mc, nc);
+ break;
+ case 2:
+ mc = 2;
+ gemm_small(m0, m, n0, n, mc, nc);
+ break;
+ case 3:
+ mc = 3;
+ gemm_small(m0, m, n0, n, mc, nc);
+ break;
+ default:
+ return;
+ }
+ } else if ((m_rem > 4) && (n_rem < 4)) {
+ mc = 4;
+ switch(n_rem) {
+ case 1:
+ nc = 1;
+ gemm_small(m0, m, n0, n, mc, nc);
+ break;
+ case 2:
+ nc = 2;
+ gemm_small(m0, m, n0, n, mc, nc);
+ break;
+ case 3:
+ nc = 3;
+ gemm_small(m0, m, n0, n, mc, nc);
+ break;
+ default:
+ return;
+ }
+ } else {
+ switch((m_rem << 4) | n_rem) {
+ case 0x43:
+ mc = 4;
+ nc = 3;
+ gemm_small(m0, m, n0, n, mc, nc);
+ break;
+ case 0x42:
+ mc = 4;
+ nc = 2;
+ gemm_small(m0, m, n0, n, mc, nc);
+ break;
+ case 0x41:
+ mc = 4;
+ nc = 1;
+ gemm_small(m0, m, n0, n, mc, nc);
+ break;
+ case 0x34:
+ mc = 3;
+ nc = 4;
+ gemm_small(m0, m, n0, n, mc, nc);
+ break;
+ case 0x33:
+ mc = 3;
+ nc = 3;
+ gemm_small(m0, m, n0, n, mc, nc);
+ break;
+ case 0x32:
+ mc = 3;
+ nc = 2;
+ gemm_small(m0, m, n0, n, mc, nc);
+ break;
+ case 0x31:
+ mc = 3;
+ nc = 1;
+ gemm_small(m0, m, n0, n, mc, nc);
+ break;
+ case 0x24:
+ mc = 2;
+ nc = 4;
+ gemm_small(m0, m, n0, n, mc, nc);
+ break;
+ case 0x23:
+ mc = 2;
+ nc = 3;
+ gemm_small(m0, m, n0, n, mc, nc);
+ break;
+ case 0x22:
+ mc = 2;
+ nc = 2;
+ gemm_small(m0, m, n0, n, mc, nc);
+ break;
+ case 0x21:
+ mc = 2;
+ nc = 1;
+ gemm_small(m0, m, n0, n, mc, nc);
+ break;
+ case 0x14:
+ mc = 1;
+ nc = 4;
+ gemm_small(m0, m, n0, n, mc, nc);
+ break;
+ case 0x13:
+ mc = 1;
+ nc = 3;
+ gemm_small(m0, m, n0, n, mc, nc);
+ break;
+ case 0x12:
+ mc = 1;
+ nc = 2;
+ gemm_small(m0, m, n0, n, mc, nc);
+ break;
+ case 0x11:
+ mc = 1;
+ nc = 1;
+ gemm_small(m0, m, n0, n, mc, nc);
+ break;
+ default:
+ return;
+ }
+ }
+ mp = m0 + (m - m0) / mc * mc;
+ np = n0 + (n - n0) / nc * nc;
+ mnpack(mp, m, n0, np);
+ mnpack(m0, m, np, n);
+ }
+
+ void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
+ int64_t ytiles = (m - m0) / RM;
+ int64_t xtiles = (n - n0) / RN;
+ int64_t tiles = xtiles * ytiles;
+ int64_t duty = (tiles + nth - 1) / nth;
+ int64_t start = duty * ith;
+ int64_t end = start + duty;
+ if (end > tiles)
+ end = tiles;
+ for (int64_t job = start; job < end; ++job) {
+ int64_t ii = m0 + job / xtiles * RM;
+ int64_t jj = n0 + job % xtiles * RN;
+ vec_t vec_C[4];
+ acc_t acc_0;
+ __builtin_mma_xxsetaccz(&acc_0);
+ vec_t vec_A[4], vec_B[4];
+ for (int l=0; l<k; l+=4) {
+ if (RN >= 4 && RM == 1) {
+ float* a = const_cast<float*>(A+(ii)*lda+l);
+ READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
+ vec_A[0] = (vec_t)vec_xl(0,a);
+ vec_A[1] = (vec_t)vec_splats(*((float*)&vec_A+1));
+ vec_A[2] = (vec_t)vec_splats(*((float*)&vec_A+2));
+ vec_A[3] = (vec_t)vec_splats(*((float*)&vec_A+3));
+ } else {
+ READ_BLOCK(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A);
+ READ_BLOCK(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B);
+ }
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
+ }
+ __builtin_mma_disassemble_acc(vec_C, &acc_0);
+ for (int I = 0; I < RM; I++) {
+ for (int J = 0; J < RN; J++) {
+ *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J);
+ }
+ }
+ }
+ }
+
+ template <int RM, int RN>
+ NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+ int64_t ytiles = (m - m0) / RM;
+ int64_t xtiles = (n - n0) / RN;
+ int64_t tiles = xtiles * ytiles;
+ int64_t duty = (tiles + nth - 1) / nth;
+ int64_t start = duty * ith;
+ int64_t end = start + duty;
+ if (RM == 4 && RN == 4) {
+ kernel = &tinyBLAS_PPC::KERNEL_4x4;
+ } else if (RM == 4 && RN == 8) {
+ kernel = &tinyBLAS_PPC::KERNEL_4x8;
+ } else if (RM == 8 && RN == 4) {
+ kernel = &tinyBLAS_PPC::KERNEL_8x4;
+ } else if (RM == 8 && RN == 8) {
+ kernel = &tinyBLAS_PPC::KERNEL_8x8;
+ }
+ if (end > tiles)
+ end = tiles;
+ for (int64_t job = start; job < end; ++job) {
+ int64_t ii = m0 + job / xtiles * RM;
+ int64_t jj = n0 + job % xtiles * RN;
+ (this->*kernel)(ii, jj);
+ }
+ }
+
+ const TA *const A;
+ const TB *const B;
+ TC *C;
+ TA *At;
+ TB *Bt;
+ const int64_t k;
+ const int64_t lda;
+ const int64_t ldb;
+ const int64_t ldc;
+ const int ith;
+ const int nth;
+};
+#endif
} // namespace
/**
ith, nth};
tb.matmul(m, n);
return true;
+#elif defined(__MMA__)
+ if (k % 8)
+ return false;
+ tinyBLAS_PPC<float, float, float> tb{
+ k, (const float *)A, lda,
+ (const float *)B, ldb,
+ (float *)C, ldc,
+ ith, nth};
+ tb.matmul(m, n);
+ return true;
#else
return false;
#endif