} else if constexpr(RM == 8 && RN == 4) {
KERNEL_8x4(ii,jj);
} else {
- static_assert(false, "RN/RM values not supported");
+ assert(false && "RN/RM values not supported");
}
}
const int nth;
};
-template <typename TA, typename TB, typename TC>
+template <typename TA>
class tinyBLAS_Q0_PPC {
public:
tinyBLAS_Q0_PPC(int64_t k,
const TA *A, int64_t lda,
- const TB *B, int64_t ldb,
- TC *C, int64_t ldc,
+ const block_q8_0 *B, int64_t ldb,
+ float *C, int64_t ldc,
int ith, int nth)
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
}
private:
- template<int RM, int RN>
- inline void save_res(int ii, int jj, int idx, vector float* fin_res) {
+ inline void save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) {
for (int I = 0; I < RM; I++) {
for (int J = 0; J < RN; J++) {
*((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J);
fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]);
}
}
-
- template<typename VA, typename VB, int size>
- void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, VA* vec, std::array<int, size>& comparray) {
- int64_t i, j;
- TA *aoffset = NULL;
- VA *vecOffset = NULL;
- TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
- TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
- VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
- VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
- VB t1, t2, t3, t4, t5, t6, t7, t8;
+ /* This function processes quantized data from block_q4_0 elements.
+ * First the we try to extract the two int4 values stored in single int8_t into two signed int8.
+ * And then we subtract each of the resultant element with 8, to convert signed int8 to unsigned int8.
+ * Also compute the rowsum which is required to compensate the above conversion. */
+ inline void process_q4_elements(vector signed char (&c)[2], int* ca) {
const vector signed char lowMask = vec_splats((signed char)0xF);
const vector unsigned char v4 = vec_splats((unsigned char)0x4);
const vector signed char v8 = vec_splats((signed char)0x8);
- aoffset = const_cast<TA*>(a);
- vecOffset = vec;
+ vector signed int vsum = {0};
+ vector signed int vsum2 = {0};
+ c[0] = vec_and(c[1], lowMask);
+ c[1] = vec_sr(c[1], v4);
+ c[0] = vec_sub(c[0], v8);
+ c[1] = vec_sub(c[1], v8);
+ vsum = vec_sum4s(c[0], vsum);
+ vsum2 = vec_sum4s(c[1], vsum2);
+ vsum = vec_add(vsum, vsum2);
+ *(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3];
+ }
+
+ template <typename V1, typename V2>
+ inline void vector_permute_store(V2 &s1, V2 &s2, V2 &s3, V2 &s4, V1 *vecOffset, bool flip) {
vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
- vector signed int vsum = {0};
- vector signed int vsum2 = {0};
+ V2 t1, t2, t3, t4, t5, t6, t7, t8;
+ vector unsigned char xor_vector;
+ uint8_t flip_vec = 0x80;
+ xor_vector = vec_splats(flip_vec);
+ t1 = vec_perm(s1, s2, swiz1);
+ t2 = vec_perm(s1, s2, swiz2);
+ t3 = vec_perm(s3, s4, swiz1);
+ t4 = vec_perm(s3, s4, swiz2);
+ t5 = vec_perm(t1, t3, swiz3);
+ t6 = vec_perm(t1, t3, swiz4);
+ t7 = vec_perm(t2, t4, swiz3);
+ t8 = vec_perm(t2, t4, swiz4);
+ if (flip == true) {
+ t5 = vec_xor(t5, xor_vector);
+ t6 = vec_xor(t6, xor_vector);
+ t7 = vec_xor(t7, xor_vector);
+ t8 = vec_xor(t8, xor_vector);
+ }
+ vec_xst(t5, 0, vecOffset);
+ vec_xst(t6, 0, vecOffset+16);
+ vec_xst(t7, 0, vecOffset+32);
+ vec_xst(t8, 0, vecOffset+48);
+ }
+ template<int size>
+ void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array<int, size>& comparray) {
+ int64_t i, j;
+ TA *aoffset = NULL;
+ int8_t *vecOffset = NULL;
+ TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
+ TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
+ vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
+ vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
+ aoffset = const_cast<TA*>(a);
+ vecOffset = vec;
j = (rows >> 3);
if (j > 0) {
do {
aoffset7 = aoffset6 + lda;
aoffset8 = aoffset7 + lda;
aoffset += 8 * lda;
-
i = (cols >> 2);
if (i > 0) {
do {
- c1[1] = reinterpret_cast<VB>(vec_xl(0, aoffset1->qs));
- c2[1] = reinterpret_cast<VB>(vec_xl(0, aoffset2->qs));
- c3[1] = reinterpret_cast<VB>(vec_xl(0, aoffset3->qs));
- c4[1] = reinterpret_cast<VB>(vec_xl(0, aoffset4->qs));
- c5[1] = reinterpret_cast<VB>(vec_xl(0, aoffset5->qs));
- c6[1] = reinterpret_cast<VB>(vec_xl(0, aoffset6->qs));
- c7[1] = reinterpret_cast<VB>(vec_xl(0, aoffset7->qs));
- c8[1] = reinterpret_cast<VB>(vec_xl(0, aoffset8->qs));
-
- c1[0] = vec_and(c1[1], lowMask);
- c1[1] = vec_sr(c1[1], v4);
- c1[0] = vec_sub(c1[0], v8);
- c1[1] = vec_sub(c1[1], v8);
- vsum = vec_sum4s(c1[0], vsum);
- vsum2 = vec_sum4s(c1[1], vsum2);
- vsum = vec_add(vsum, vsum2);
- comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
- vsum = vec_splats(0);
- vsum2 = vec_splats(0);
-
- c2[0] = vec_and(c2[1], lowMask);
- c2[1] = vec_sr(c2[1], v4);
- c2[0] = vec_sub(c2[0], v8);
- c2[1] = vec_sub(c2[1], v8);
- vsum = vec_sum4s(c2[0], vsum);
- vsum2 = vec_sum4s(c2[1], vsum2);
- vsum = vec_add(vsum, vsum2);
- comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
- vsum = vec_splats(0);
- vsum2 = vec_splats(0);
-
- c3[0] = vec_and(c3[1], lowMask);
- c3[1] = vec_sr(c3[1], v4);
- c3[0] = vec_sub(c3[0], v8);
- c3[1] = vec_sub(c3[1], v8);
- vsum = vec_sum4s(c3[0], vsum);
- vsum2 = vec_sum4s(c3[1], vsum2);
- vsum = vec_add(vsum, vsum2);
- comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
- vsum = vec_splats(0);
- vsum2 = vec_splats(0);
-
- c4[0] = vec_and(c4[1], lowMask);
- c4[1] = vec_sr(c4[1], v4);
- c4[0] = vec_sub(c4[0], v8);
- c4[1] = vec_sub(c4[1], v8);
- vsum = vec_sum4s(c4[0], vsum);
- vsum2 = vec_sum4s(c4[1], vsum2);
- vsum = vec_add(vsum, vsum2);
- comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
- vsum = vec_splats(0);
- vsum2 = vec_splats(0);
-
- c5[0] = vec_and(c5[1], lowMask);
- c5[1] = vec_sr(c5[1], v4);
- c5[0] = vec_sub(c5[0], v8);
- c5[1] = vec_sub(c5[1], v8);
- vsum = vec_sum4s(c5[0], vsum);
- vsum2 = vec_sum4s(c5[1], vsum2);
- vsum = vec_add(vsum, vsum2);
- comparray[4] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
- vsum = vec_splats(0);
- vsum2 = vec_splats(0);
-
- c6[0] = vec_and(c6[1], lowMask);
- c6[1] = vec_sr(c6[1], v4);
- c6[0] = vec_sub(c6[0], v8);
- c6[1] = vec_sub(c6[1], v8);
- vsum = vec_sum4s(c6[0], vsum);
- vsum2 = vec_sum4s(c6[1], vsum2);
- vsum = vec_add(vsum, vsum2);
- comparray[5] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
- vsum = vec_splats(0);
- vsum2 = vec_splats(0);
-
- c7[0] = vec_and(c7[1], lowMask);
- c7[1] = vec_sr(c7[1], v4);
- c7[0] = vec_sub(c7[0], v8);
- c7[1] = vec_sub(c7[1], v8);
- vsum = vec_sum4s(c7[0], vsum);
- vsum2 = vec_sum4s(c7[1], vsum2);
- vsum = vec_add(vsum, vsum2);
- comparray[6] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
- vsum = vec_splats(0);
- vsum2 = vec_splats(0);
-
- c8[0] = vec_and(c8[1], lowMask);
- c8[1] = vec_sr(c8[1], v4);
- c8[0] = vec_sub(c8[0], v8);
- c8[1] = vec_sub(c8[1], v8);
- vsum = vec_sum4s(c8[0], vsum);
- vsum2 = vec_sum4s(c8[1], vsum2);
- vsum = vec_add(vsum, vsum2);
- comparray[7] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
- vsum = vec_splats(0);
- vsum2 = vec_splats(0);
-
- t1 = vec_perm(c1[0], c2[0], swiz1);
- t2 = vec_perm(c1[0], c2[0], swiz2);
- t3 = vec_perm(c3[0], c4[0], swiz1);
- t4 = vec_perm(c3[0], c4[0], swiz2);
- t5 = vec_perm(t1, t3, swiz3);
- t6 = vec_perm(t1, t3, swiz4);
- t7 = vec_perm(t2, t4, swiz3);
- t8 = vec_perm(t2, t4, swiz4);
- vec_xst(t5, 0, vecOffset);
- vec_xst(t6, 0, vecOffset+16);
- vec_xst(t7, 0, vecOffset+32);
- vec_xst(t8, 0, vecOffset+48);
-
- t1 = vec_perm(c1[1], c2[1], swiz1);
- t2 = vec_perm(c1[1], c2[1], swiz2);
- t3 = vec_perm(c3[1], c4[1], swiz1);
- t4 = vec_perm(c3[1], c4[1], swiz2);
- t5 = vec_perm(t1, t3, swiz3);
- t6 = vec_perm(t1, t3, swiz4);
- t7 = vec_perm(t2, t4, swiz3);
- t8 = vec_perm(t2, t4, swiz4);
- vec_xst(t5, 0, vecOffset+64);
- vec_xst(t6, 0, vecOffset+80);
- vec_xst(t7, 0, vecOffset+96);
- vec_xst(t8, 0, vecOffset+112);
-
- t1 = vec_perm(c5[0], c6[0], swiz1);
- t2 = vec_perm(c5[0], c6[0], swiz2);
- t3 = vec_perm(c7[0], c8[0], swiz1);
- t4 = vec_perm(c7[0], c8[0], swiz2);
- t5 = vec_perm(t1, t3, swiz3);
- t6 = vec_perm(t1, t3, swiz4);
- t7 = vec_perm(t2, t4, swiz3);
- t8 = vec_perm(t2, t4, swiz4);
- vec_xst(t5, 0, vecOffset+128);
- vec_xst(t6, 0, vecOffset+144);
- vec_xst(t7, 0, vecOffset+160);
- vec_xst(t8, 0, vecOffset+176);
-
- t1 = vec_perm(c5[1], c6[1], swiz1);
- t2 = vec_perm(c5[1], c6[1], swiz2);
- t3 = vec_perm(c7[1], c8[1], swiz1);
- t4 = vec_perm(c7[1], c8[1], swiz2);
- t5 = vec_perm(t1, t3, swiz3);
- t6 = vec_perm(t1, t3, swiz4);
- t7 = vec_perm(t2, t4, swiz3);
- t8 = vec_perm(t2, t4, swiz4);
- vec_xst(t5, 0, vecOffset+192);
- vec_xst(t6, 0, vecOffset+208);
- vec_xst(t7, 0, vecOffset+224);
- vec_xst(t8, 0, vecOffset+240);
-
+ c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
+ c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs));
+ c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
+ c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset4->qs));
+ c5[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset5->qs));
+ c6[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset6->qs));
+ c7[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset7->qs));
+ c8[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset8->qs));
+
+ process_q4_elements(c1, &comparray[0]);
+ process_q4_elements(c2, &comparray[1]);
+ process_q4_elements(c3, &comparray[2]);
+ process_q4_elements(c4, &comparray[3]);
+ process_q4_elements(c5, &comparray[4]);
+ process_q4_elements(c6, &comparray[5]);
+ process_q4_elements(c7, &comparray[6]);
+ process_q4_elements(c8, &comparray[7]);
+ vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
+ vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
+ vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false);
+ vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false);
aoffset1 += lda;
aoffset2 += lda;
aoffset3 += lda;
aoffset3 = aoffset2 + lda;
aoffset4 = aoffset3 + lda;
aoffset += 4 * lda;
-
i = (cols >> 2);
if (i > 0) {
do {
- c1[1] = reinterpret_cast<VB>(vec_xl(0, aoffset1->qs));
- c2[1] = reinterpret_cast<VB>(vec_xl(0, aoffset2->qs));
- c3[1] = reinterpret_cast<VB>(vec_xl(0, aoffset3->qs));
- c4[1] = reinterpret_cast<VB>(vec_xl(0, aoffset4->qs));
-
- c1[0] = vec_and(c1[1], lowMask);
- c1[1] = vec_sr(c1[1], v4);
- c1[0] = vec_sub(c1[0], v8);
- c1[1] = vec_sub(c1[1], v8);
- vsum = vec_sum4s(c1[0], vsum);
- vsum2 = vec_sum4s(c1[1], vsum2);
- vsum = vec_add(vsum, vsum2);
- comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
- vsum = vec_splats(0);
- vsum2 = vec_splats(0);
-
- c2[0] = vec_and(c2[1], lowMask);
- c2[1] = vec_sr(c2[1], v4);
- c2[0] = vec_sub(c2[0], v8);
- c2[1] = vec_sub(c2[1], v8);
- vsum = vec_sum4s(c2[0], vsum);
- vsum2 = vec_sum4s(c2[1], vsum2);
- vsum = vec_add(vsum, vsum2);
- comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
- vsum = vec_splats(0);
- vsum2 = vec_splats(0);
-
- c3[0] = vec_and(c3[1], lowMask);
- c3[1] = vec_sr(c3[1], v4);
- c3[0] = vec_sub(c3[0], v8);
- c3[1] = vec_sub(c3[1], v8);
- vsum = vec_sum4s(c3[0], vsum);
- vsum2 = vec_sum4s(c3[1], vsum2);
- vsum = vec_add(vsum, vsum2);
- comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
- vsum = vec_splats(0);
- vsum2 = vec_splats(0);
-
- c4[0] = vec_and(c4[1], lowMask);
- c4[1] = vec_sr(c4[1], v4);
- c4[0] = vec_sub(c4[0], v8);
- c4[1] = vec_sub(c4[1], v8);
- vsum = vec_sum4s(c4[0], vsum);
- vsum2 = vec_sum4s(c4[1], vsum2);
- vsum = vec_add(vsum, vsum2);
- comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
- vsum = vec_splats(0);
- vsum2 = vec_splats( 0);
-
- t1 = vec_perm(c1[0], c2[0], swiz1);
- t2 = vec_perm(c1[0], c2[0], swiz2);
- t3 = vec_perm(c3[0], c4[0], swiz1);
- t4 = vec_perm(c3[0], c4[0], swiz2);
- t5 = vec_perm(t1, t3, swiz3);
- t6 = vec_perm(t1, t3, swiz4);
- t7 = vec_perm(t2, t4, swiz3);
- t8 = vec_perm(t2, t4, swiz4);
- vec_xst(t5, 0, vecOffset);
- vec_xst(t6, 0, vecOffset+16);
- vec_xst(t7, 0, vecOffset+32);
- vec_xst(t8, 0, vecOffset+48);
-
- t1 = vec_perm(c1[1], c2[1], swiz1);
- t2 = vec_perm(c1[1], c2[1], swiz2);
- t3 = vec_perm(c3[1], c4[1], swiz1);
- t4 = vec_perm(c3[1], c4[1], swiz2);
- t5 = vec_perm(t1, t3, swiz3);
- t6 = vec_perm(t1, t3, swiz4);
- t7 = vec_perm(t2, t4, swiz3);
- t8 = vec_perm(t2, t4, swiz4);
- vec_xst(t5, 0, vecOffset+64);
- vec_xst(t6, 0, vecOffset+80);
- vec_xst(t7, 0, vecOffset+96);
- vec_xst(t8, 0, vecOffset+112);
-
+ c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
+ c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs));
+ c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
+ c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset4->qs));
+
+ process_q4_elements(c1, &comparray[0]);
+ process_q4_elements(c2, &comparray[1]);
+ process_q4_elements(c3, &comparray[2]);
+ process_q4_elements(c4, &comparray[3]);
+ vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
+ vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
aoffset1 += lda;
aoffset2 += lda;
aoffset3 += lda;
if (i > 0) {
do {
switch(rows) {
- case 3: c3[1] = reinterpret_cast<VB>(vec_xl(0, aoffset3->qs));
- case 2: c2[1] = reinterpret_cast<VB>(vec_xl(0, aoffset2->qs));
- case 1: c1[1] = reinterpret_cast<VB>(vec_xl(0, aoffset1->qs));
+ case 3: c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
+ case 2: c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs));
+ case 1: c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
break;
}
- c1[0] = vec_and(c1[1], lowMask);
- c1[1] = vec_sr(c1[1], v4);
- c1[0] = vec_sub(c1[0], v8);
- c1[1] = vec_sub(c1[1], v8);
- vsum = vec_sum4s(c1[0], vsum);
- vsum2 = vec_sum4s(c1[1], vsum2);
- vsum = vec_add(vsum, vsum2);
- comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
- vsum = vec_splats(0);
- vsum2 = vec_splats(0);
-
- c2[0] = vec_and(c2[1], lowMask);
- c2[1] = vec_sr(c2[1], v4);
- c2[0] = vec_sub(c2[0], v8);
- c2[1] = vec_sub(c2[1], v8);
- vsum = vec_sum4s(c2[0], vsum);
- vsum2 = vec_sum4s(c2[1], vsum2);
- vsum = vec_add(vsum, vsum2);
- comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
- vsum = vec_splats(0);
- vsum2 = vec_splats(0);
-
- c3[0] = vec_and(c3[1], lowMask);
- c3[1] = vec_sr(c3[1], v4);
- c3[0] = vec_sub(c3[0], v8);
- c3[1] = vec_sub(c3[1], v8);
- vsum = vec_sum4s(c3[0], vsum);
- vsum2 = vec_sum4s(c3[1], vsum2);
- vsum = vec_add(vsum, vsum2);
- comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
- vsum = vec_splats(0);
- vsum2 = vec_splats(0);
-
- c4[0] = vec_and(c4[1], lowMask);
- c4[1] = vec_sr(c4[1], v4);
- c4[0] = vec_sub(c4[0], v8);
- c4[1] = vec_sub(c4[1], v8);
- vsum = vec_sum4s(c4[0], vsum);
- vsum2 = vec_sum4s(c4[1], vsum2);
- vsum = vec_add(vsum, vsum2);
- comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
- vsum = vec_splats(0);
- vsum2 = vec_splats(0);
-
- t1 = vec_perm(c1[0], c2[0], swiz1);
- t2 = vec_perm(c1[0], c2[0], swiz2);
- t3 = vec_perm(c3[0], c4[0], swiz1);
- t4 = vec_perm(c3[0], c4[0], swiz2);
- t5 = vec_perm(t1, t3, swiz3);
- t6 = vec_perm(t1, t3, swiz4);
- t7 = vec_perm(t2, t4, swiz3);
- t8 = vec_perm(t2, t4, swiz4);
- vec_xst(t5, 0, vecOffset);
- vec_xst(t6, 0, vecOffset+16);
- vec_xst(t7, 0, vecOffset+32);
- vec_xst(t8, 0, vecOffset+48);
-
- t1 = vec_perm(c1[1], c2[1], swiz1);
- t2 = vec_perm(c1[1], c2[1], swiz2);
- t3 = vec_perm(c3[1], c4[1], swiz1);
- t4 = vec_perm(c3[1], c4[1], swiz2);
- t5 = vec_perm(t1, t3, swiz3);
- t6 = vec_perm(t1, t3, swiz4);
- t7 = vec_perm(t2, t4, swiz3);
- t8 = vec_perm(t2, t4, swiz4);
- vec_xst(t5, 0, vecOffset+64);
- vec_xst(t6, 0, vecOffset+80);
- vec_xst(t7, 0, vecOffset+96);
- vec_xst(t8, 0, vecOffset+112);
+ process_q4_elements(c1, &comparray[0]);
+ process_q4_elements(c2, &comparray[1]);
+ process_q4_elements(c3, &comparray[2]);
+ process_q4_elements(c4, &comparray[3]);
+ vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
+ vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
aoffset1 += lda;
aoffset2 += lda;
aoffset3 += lda;
}
}
}
-
template<typename VA, typename VB>
- void packNormal(const TB* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
+ void packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
int64_t i, j;
- TB *aoffset = NULL;
+ block_q8_0 *aoffset = NULL;
VA *vecOffset = NULL;
- TB *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
- TB *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
- __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
- VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2]={0};
- VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2]={0};
- VB t1, t2, t3, t4, t5, t6, t7, t8;
- vector unsigned char xor_vector;
- uint8_t flip_vec = 0x80;
- xor_vector = vec_splats(flip_vec);
- vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
- vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
- vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
- vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
-
- aoffset = const_cast<TB*>(a);
+ block_q8_0* aoffsets[8];
+ __vector_pair arr[8];
+ VB c[8][2] = {0};
+ VB c1[8] = {0}; VB c2[8] = {0};
+ aoffset = const_cast<block_q8_0*>(a);
vecOffset = vec;
j = (rows >> 3);
if (j > 0) {
do {
- aoffset1 = aoffset;
- aoffset2 = aoffset1 + lda;
- aoffset3 = aoffset2 + lda;
- aoffset4 = aoffset3 + lda;
- aoffset5 = aoffset4 + lda;
- aoffset6 = aoffset5 + lda;
- aoffset7 = aoffset6 + lda;
- aoffset8 = aoffset7 + lda;
+ aoffsets[0] = aoffset;
+ for (int it = 1; it < 8; it++)
+ aoffsets[it] = aoffsets[it-1] + lda;
aoffset += 8 * lda;
i = (cols >> 3);
if (i > 0) {
do {
- C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
- C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
- C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
- C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4->qs);
- C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5->qs);
- C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6->qs);
- C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7->qs);
- C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8->qs);
-
- __builtin_vsx_disassemble_pair(c1, &C1);
- __builtin_vsx_disassemble_pair(c2, &C2);
- __builtin_vsx_disassemble_pair(c3, &C3);
- __builtin_vsx_disassemble_pair(c4, &C4);
- __builtin_vsx_disassemble_pair(c5, &C5);
- __builtin_vsx_disassemble_pair(c6, &C6);
- __builtin_vsx_disassemble_pair(c7, &C7);
- __builtin_vsx_disassemble_pair(c8, &C8);
-
- t1 = vec_perm(c1[0], c2[0], swiz1);
- t2 = vec_perm(c1[0], c2[0], swiz2);
- t3 = vec_perm(c3[0], c4[0], swiz1);
- t4 = vec_perm(c3[0], c4[0], swiz2);
- t5 = vec_perm(t1, t3, swiz3);
- t6 = vec_perm(t1, t3, swiz4);
- t7 = vec_perm(t2, t4, swiz3);
- t8 = vec_perm(t2, t4, swiz4);
- if (flip == true) {
- t5 = vec_xor(t5, xor_vector);
- t6 = vec_xor(t6, xor_vector);
- t7 = vec_xor(t7, xor_vector);
- t8 = vec_xor(t8, xor_vector);
- }
- vec_xst(t5, 0, vecOffset);
- vec_xst(t6, 0, vecOffset+16);
- vec_xst(t7, 0, vecOffset+32);
- vec_xst(t8, 0, vecOffset+48);
-
- t1 = vec_perm(c1[1], c2[1], swiz1);
- t2 = vec_perm(c1[1], c2[1], swiz2);
- t3 = vec_perm(c3[1], c4[1], swiz1);
- t4 = vec_perm(c3[1], c4[1], swiz2);
- t5 = vec_perm(t1, t3, swiz3);
- t6 = vec_perm(t1, t3, swiz4);
- t7 = vec_perm(t2, t4, swiz3);
- t8 = vec_perm(t2, t4, swiz4);
- if (flip == true) {
- t5 = vec_xor(t5, xor_vector);
- t6 = vec_xor(t6, xor_vector);
- t7 = vec_xor(t7, xor_vector);
- t8 = vec_xor(t8, xor_vector);
- }
- vec_xst(t5, 0, vecOffset+64);
- vec_xst(t6, 0, vecOffset+80);
- vec_xst(t7, 0, vecOffset+96);
- vec_xst(t8, 0, vecOffset+112);
-
- t1 = vec_perm(c5[0], c6[0], swiz1);
- t2 = vec_perm(c5[0], c6[0], swiz2);
- t3 = vec_perm(c7[0], c8[0], swiz1);
- t4 = vec_perm(c7[0], c8[0], swiz2);
- t5 = vec_perm(t1, t3, swiz3);
- t6 = vec_perm(t1, t3, swiz4);
- t7 = vec_perm(t2, t4, swiz3);
- t8 = vec_perm(t2, t4, swiz4);
- if (flip == true) {
- t5 = vec_xor(t5, xor_vector);
- t6 = vec_xor(t6, xor_vector);
- t7 = vec_xor(t7, xor_vector);
- t8 = vec_xor(t8, xor_vector);
- }
- vec_xst(t5, 0, vecOffset+128);
- vec_xst(t6, 0, vecOffset+144);
- vec_xst(t7, 0, vecOffset+160);
- vec_xst(t8, 0, vecOffset+176);
-
- t1 = vec_perm(c5[1], c6[1], swiz1);
- t2 = vec_perm(c5[1], c6[1], swiz2);
- t3 = vec_perm(c7[1], c8[1], swiz1);
- t4 = vec_perm(c7[1], c8[1], swiz2);
- t5 = vec_perm(t1, t3, swiz3);
- t6 = vec_perm(t1, t3, swiz4);
- t7 = vec_perm(t2, t4, swiz3);
- t8 = vec_perm(t2, t4, swiz4);
- if (flip == true) {
- t5 = vec_xor(t5, xor_vector);
- t6 = vec_xor(t6, xor_vector);
- t7 = vec_xor(t7, xor_vector);
- t8 = vec_xor(t8, xor_vector);
+ for (int it = 0; it < 8; it++) {
+ arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
+ __builtin_vsx_disassemble_pair(c[it], &arr[it]);
+ c1[it] = c[it][0];
+ c2[it] = c[it][1];
}
- vec_xst(t5, 0, vecOffset+192);
- vec_xst(t6, 0, vecOffset+208);
- vec_xst(t7, 0, vecOffset+224);
- vec_xst(t8, 0, vecOffset+240);
-
- aoffset1 += lda;
- aoffset2 += lda;
- aoffset3 += lda;
- aoffset4 += lda;
- aoffset5 += lda;
- aoffset6 += lda;
- aoffset7 += lda;
- aoffset8 += lda;
+ vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
+ vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
+ vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip);
+ vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip);
+ for (int it = 0; it < 8; it++)
+ aoffsets[it] += lda;
vecOffset += 256;
i--;
} while(i > 0);
}
if (rows & 4) {
- aoffset1 = aoffset;
- aoffset2 = aoffset1 + lda;
- aoffset3 = aoffset2 + lda;
- aoffset4 = aoffset3 + lda;
- aoffset += 4 * lda;
-
+ aoffsets[0] = aoffset;
+ for (int it = 1; it < 4; it++ )
+ aoffsets[it] = aoffsets[it-1] + lda;
+ aoffset += 4 * lda;
i = (cols >> 3);
if (i > 0) {
do {
- C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
- C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
- C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
- C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4->qs);
-
- __builtin_vsx_disassemble_pair(c1, &C1);
- __builtin_vsx_disassemble_pair(c2, &C2);
- __builtin_vsx_disassemble_pair(c3, &C3);
- __builtin_vsx_disassemble_pair(c4, &C4);
-
- t1 = vec_perm(c1[0], c2[0], swiz1);
- t2 = vec_perm(c1[0], c2[0], swiz2);
- t3 = vec_perm(c3[0], c4[0], swiz1);
- t4 = vec_perm(c3[0], c4[0], swiz2);
- t5 = vec_perm(t1, t3, swiz3);
- t6 = vec_perm(t1, t3, swiz4);
- t7 = vec_perm(t2, t4, swiz3);
- t8 = vec_perm(t2, t4, swiz4);
- if (flip == true) {
- t5 = vec_xor(t5, xor_vector);
- t6 = vec_xor(t6, xor_vector);
- t7 = vec_xor(t7, xor_vector);
- t8 = vec_xor(t8, xor_vector);
+ for (int it = 0; it < 4; it++) {
+ arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
+ __builtin_vsx_disassemble_pair(c[it], &arr[it]);
+ c1[it] = c[it][0];
+ c2[it] = c[it][1];
}
- vec_xst(t5, 0, vecOffset);
- vec_xst(t6, 0, vecOffset+16);
- vec_xst(t7, 0, vecOffset+32);
- vec_xst(t8, 0, vecOffset+48);
-
- t1 = vec_perm(c1[1], c2[1], swiz1);
- t2 = vec_perm(c1[1], c2[1], swiz2);
- t3 = vec_perm(c3[1], c4[1], swiz1);
- t4 = vec_perm(c3[1], c4[1], swiz2);
- t5 = vec_perm(t1, t3, swiz3);
- t6 = vec_perm(t1, t3, swiz4);
- t7 = vec_perm(t2, t4, swiz3);
- t8 = vec_perm(t2, t4, swiz4);
- if (flip == true) {
- t5 = vec_xor(t5, xor_vector);
- t6 = vec_xor(t6, xor_vector);
- t7 = vec_xor(t7, xor_vector);
- t8 = vec_xor(t8, xor_vector);
+ vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
+ vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
+ for (int it = 0; it < 4; it++) {
+ aoffsets[it] += lda;
}
- vec_xst(t5, 0, vecOffset+64);
- vec_xst(t6, 0, vecOffset+80);
- vec_xst(t7, 0, vecOffset+96);
- vec_xst(t8, 0, vecOffset+112);
-
- aoffset1 += lda;
- aoffset2 += lda;
- aoffset3 += lda;
- aoffset4 += lda;
vecOffset += 128;
i--;
} while(i > 0);
}
}
+
if (rows & 3) {
- aoffset1 = aoffset;
- aoffset2 = aoffset1 + lda;
- aoffset3 = aoffset2 + lda;
+ aoffsets[0] = aoffset;
+ for (int it = 1; it < 3; it++ )
+ aoffsets[it] = aoffsets[it-1] + lda;
i = (cols >> 3);
if (i > 0) {
do {
switch(rows) {
- case 3: C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
- __builtin_vsx_disassemble_pair(c3, &C3);
- case 2: C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
- __builtin_vsx_disassemble_pair(c2, &C2);
- case 1: C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
- __builtin_vsx_disassemble_pair(c1, &C1);
+ case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[2]->qs);
+ __builtin_vsx_disassemble_pair(c[2], &arr[2]);
+ c1[2] = c[2][0]; c2[2] = c[2][1];
+ case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[1]->qs);
+ __builtin_vsx_disassemble_pair(c[1], &arr[1]);
+ c1[1] = c[1][0]; c2[1] = c[1][1];
+ case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[0]->qs);
+ __builtin_vsx_disassemble_pair(c[0], &arr[0]);
+ c1[0] = c[0][0]; c2[0] = c[0][1];
break;
}
- t1 = vec_perm(c1[0], c2[0], swiz1);
- t2 = vec_perm(c1[0], c2[0], swiz2);
- t3 = vec_perm(c3[0], c4[0], swiz1);
- t4 = vec_perm(c3[0], c4[0], swiz2);
- t5 = vec_perm(t1, t3, swiz3);
- t6 = vec_perm(t1, t3, swiz4);
- t7 = vec_perm(t2, t4, swiz3);
- t8 = vec_perm(t2, t4, swiz4);
- if (flip == true) {
- t5 = vec_xor(t5, xor_vector);
- t6 = vec_xor(t6, xor_vector);
- t7 = vec_xor(t7, xor_vector);
- t8 = vec_xor(t8, xor_vector);
- }
- vec_xst(t5, 0, vecOffset);
- vec_xst(t6, 0, vecOffset+16);
- vec_xst(t7, 0, vecOffset+32);
- vec_xst(t8, 0, vecOffset+48);
-
- t1 = vec_perm(c1[1], c2[1], swiz1);
- t2 = vec_perm(c1[1], c2[1], swiz2);
- t3 = vec_perm(c3[1], c4[1], swiz1);
- t4 = vec_perm(c3[1], c4[1], swiz2);
- t5 = vec_perm(t1, t3, swiz3);
- t6 = vec_perm(t1, t3, swiz4);
- t7 = vec_perm(t2, t4, swiz3);
- t8 = vec_perm(t2, t4, swiz4);
- if (flip == true) {
- t5 = vec_xor(t5, xor_vector);
- t6 = vec_xor(t6, xor_vector);
- t7 = vec_xor(t7, xor_vector);
- t8 = vec_xor(t8, xor_vector);
- }
- vec_xst(t5, 0, vecOffset+64);
- vec_xst(t6, 0, vecOffset+80);
- vec_xst(t7, 0, vecOffset+96);
- vec_xst(t8, 0, vecOffset+112);
-
- aoffset1 += lda;
- aoffset2 += lda;
- aoffset3 += lda;
+ vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
+ vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
+ for (int it = 0; it < 3; it++)
+ aoffsets[it] += lda;
vecOffset += 128;
i--;
} while(i > 0);
}
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
- int64_t mc, nc, mp, np;
- int m_rem = MIN(m - m0, 8);
- int n_rem = MIN(n - n0, 8);
- // TO-DO: KERNEL_16x8 and KERNEL_8x16 are having some performance
- // issues. After resolving them, below code will be enabled.
- /*if (m_rem >= 16 && n_rem >= 8) {
- mc = 16;
- nc = 8;
- gemm<16,8>(m0, m, n0, n);
- } else if(m_rem >= 8 && n_rem >= 16) {
- mc = 8;
- nc = 16;
- gemm<8,16>(m0, m, n0, n);
- }*/
+ int m_rem = MIN(m - m0, 16);
+ int n_rem = MIN(n - n0, 16);
+
+ int mc = 0, nc = 0;
+
if (m_rem >= 8 && n_rem >= 8) {
- mc = 8;
- nc = 8;
- gemm<8,8>(m0, m, n0, n);
+ mc = 8;
+ nc = 8;
+ gemm<8, 8>(m0, m, n0, n);
} else if (m_rem >= 4 && n_rem >= 8) {
mc = 4;
nc = 8;
- gemm<4,8>(m0, m, n0, n);
+ gemm<4, 8>(m0, m, n0, n);
} else if (m_rem >= 8 && n_rem >= 4) {
mc = 8;
nc = 4;
- gemm<8,4>(m0, m, n0, n);
+ gemm<8, 4>(m0, m, n0, n);
} else if (m_rem >= 4 && n_rem >= 4) {
mc = 4;
nc = 4;
- gemm_small<4, 4>(m0, m, n0, n);
- } else if ((m_rem < 4) && (n_rem > 4)) {
- nc = 4;
- switch(m_rem) {
- case 1:
- mc = 1;
- gemm_small<1, 4>(m0, m, n0, n);
- break;
- case 2:
- mc = 2;
- gemm_small<2, 4>(m0, m, n0, n);
- break;
- case 3:
- mc = 3;
- gemm_small<3, 4>(m0, m, n0, n);
- break;
- default:
- return;
- }
- } else if ((m_rem > 4) && (n_rem < 4)) {
- mc = 4;
- switch(n_rem) {
- case 1:
- nc = 1;
- gemm_small<4, 1>(m0, m, n0, n);
- break;
- case 2:
- nc = 2;
- gemm_small<4, 2>(m0, m, n0, n);
- break;
- case 3:
- nc = 3;
- gemm_small<4, 3>(m0, m, n0, n);
- break;
- default:
- return;
- }
+ gemm_small(m0, m, n0, n, mc, nc);
} else {
- switch((m_rem << 4) | n_rem) {
- case 0x43:
- mc = 4;
- nc = 3;
- gemm_small<4, 3>(m0, m, n0, n);
- break;
- case 0x42:
- mc = 4;
- nc = 2;
- gemm_small<4, 2>(m0, m, n0, n);
- break;
- case 0x41:
- mc = 4;
- nc = 1;
- gemm_small<4, 1>(m0, m, n0, n);
- break;
- case 0x34:
- mc = 3;
- nc = 4;
- gemm_small<3, 4>(m0, m, n0, n);
- break;
- case 0x33:
- mc = 3;
- nc = 3;
- gemm_small<3, 3>(m0, m, n0, n);
- break;
- case 0x32:
- mc = 3;
- nc = 2;
- gemm_small<3, 2>(m0, m, n0, n);
- break;
- case 0x31:
- mc = 3;
- nc = 1;
- gemm_small<3, 1>(m0, m, n0, n);
- break;
- case 0x24:
- mc = 2;
- nc = 4;
- gemm_small<2, 4>(m0, m, n0, n);
- break;
- case 0x23:
- mc = 2;
- nc = 3;
- gemm_small<2, 3>(m0, m, n0, n);
- break;
- case 0x22:
- mc = 2;
- nc = 2;
- gemm_small<2, 2>(m0, m, n0, n);
- break;
- case 0x21:
- mc = 2;
- nc = 1;
- gemm_small<2, 1>(m0, m, n0, n);
- break;
- case 0x14:
- mc = 1;
- nc = 4;
- gemm_small<1, 4>(m0, m, n0, n);
- break;
- case 0x13:
- mc = 1;
- nc = 3;
- gemm_small<1, 3>(m0, m, n0, n);
- break;
- case 0x12:
- mc = 1;
- nc = 2;
- gemm_small<1, 2>(m0, m, n0, n);
- break;
- case 0x11:
- mc = 1;
- nc = 1;
- gemm_small<1, 1>(m0, m, n0, n);
- break;
- default:
- return;
- }
+ mc = (m_rem >= 4) ? 4 : m_rem;
+ nc = (n_rem >= 4) ? 4 : n_rem;
+ if (mc == 0 || nc == 0)
+ return;
+ gemm_small(m0, m, n0, n, mc, nc);
}
- mp = m0 + (m - m0) / mc * mc;
- np = n0 + (n - n0) / nc * nc;
+
+ int64_t mp = m0 + ((m - m0) / mc) * mc;
+ int64_t np = n0 + ((n - n0) / nc) * nc;
mnpack(mp, m, n0, np);
mnpack(m0, m, np, n);
}
+
void KERNEL_4x8(int64_t ii, int64_t jj) {
vec_t vec_A[8], vec_B[16] = {0};
acc_t acc_0, acc_1;
__builtin_mma_xxsetaccz(&acc_0);
__builtin_mma_xxsetaccz(&acc_1);
if (std::is_same_v<TA, block_q4_0>) {
- packNormalInt4<int8_t, vector signed char, 4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray);
+ packNormalInt4<4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray);
} else {
- packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
+ packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
}
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
for(int x = 0; x < 8; x++) {
compute<4>(&acc_0, 0, 0, comparray, vs, fin_res);
compute<4>(&acc_1, 0, 4, comparray, vs, fin_res);
}
- save_res<4, 4>(ii, jj, 0, fin_res);
- save_res<4, 4>(ii, jj+4, 4, fin_res);
+ save_res(ii, jj, 0, fin_res);
+ save_res(ii, jj+4, 4, fin_res);
}
void KERNEL_8x4(int64_t ii, int64_t jj) {
__builtin_mma_xxsetaccz(&acc_0);
__builtin_mma_xxsetaccz(&acc_1);
if (std::is_same_v<TA, block_q4_0>) {
- packNormalInt4<int8_t, vector signed char, 8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
+ packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
} else {
- packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
+ packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
}
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true);
for(int x = 0; x < 8; x++) {
compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
}
- save_res<4, 4>(ii, jj, 0, fin_res);
- save_res<4, 4>(ii+4, jj, 4, fin_res);
+ save_res(ii, jj, 0, fin_res);
+ save_res(ii+4, jj, 4, fin_res);
}
void KERNEL_8x8(int64_t ii, int64_t jj) {
__builtin_mma_xxsetaccz(&acc_2);
__builtin_mma_xxsetaccz(&acc_3);
if (std::is_same_v<TA, block_q4_0>) {
- packNormalInt4<int8_t, vector signed char, 8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
+ packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
} else {
- packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
+ packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
}
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
for(int x = 0; x < 8; x++) {
compute<8>(&acc_2, 0, 8, comparray, vs, fin_res);
compute<8>(&acc_3, 4, 12, comparray, vs, fin_res);
}
- save_res<4, 4>(ii, jj, 0, fin_res);
- save_res<4, 4>(ii+4, jj, 4, fin_res);
- save_res<4, 4>(ii, jj+4, 8, fin_res);
- save_res<4, 4>(ii+4, jj+4, 12, fin_res);
+ save_res(ii, jj, 0, fin_res);
+ save_res(ii+4, jj, 4, fin_res);
+ save_res(ii, jj+4, 8, fin_res);
+ save_res(ii+4, jj+4, 12, fin_res);
}
- template<int RM, int RN>
- void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+ void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
int64_t ytiles = (m - m0) / RM;
int64_t xtiles = (n - n0) / RN;
int64_t tiles = xtiles * ytiles;
__builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
__builtin_mma_xxsetaccz(&acc_0);
if (isAblock_q4) {
- packNormalInt4<int8_t, vector signed char, 4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray);
+ packNormalInt4<4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray);
} else {
- packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
+ packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
}
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true);
for(int x = 0; x < 8; x+=4) {
fin_res[i] = vec_madd(res[i], vs[i], fin_res[i]);
}
}
- save_res<RM, RN>(ii, jj, 0, fin_res);
+ save_res(ii, jj, 0, fin_res, RM, RN);
}
}
} else if constexpr(RM == 8 && RN == 8) {
KERNEL_8x8(ii,jj);
} else {
- static_assert(false, "RN/RM values not supported");
+ assert(false && "RN/RM values not supported");
}
}
}
const TA *const A;
- const TB *const B;
- TC *C;
- TA *At;
- TB *Bt;
+ const block_q8_0 *const B;
+ float *C;
const int64_t k;
const int64_t lda;
const int64_t ldb;
const int nth;
};
-template <typename TA, typename TB, typename TC>
class tinyBLAS_PPC {
public:
tinyBLAS_PPC(int64_t k,
- const TA *A, int64_t lda,
- const TB *B, int64_t ldb,
- TC *C, int64_t ldc,
+ const float *A, int64_t lda,
+ const float *B, int64_t ldb,
+ float *C, int64_t ldc,
int ith, int nth)
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
}
void (tinyBLAS_PPC::*kernel)(int64_t, int64_t);
- template<typename VA>
- void packTranspose(const TA* a, int64_t lda, int rows, int cols, TA* vec) {
+ inline void vector_permute_store_4(vector float *src, float *vecOffset) {
+ vector float t1, t2, t3, t4, t5, t6, t7, t8;
+ t1 = vec_mergeh(src[0], src[1]);
+ t2 = vec_mergeh(src[2], src[3]);
+ t3 = vec_mergel(src[0], src[1]);
+ t4 = vec_mergel(src[2], src[3]);
+
+ t5 = vec_xxpermdi(t1, t2, 0);
+ t6 = vec_xxpermdi(t1, t2, 3);
+ t7 = vec_xxpermdi(t3, t4, 0);
+ t8 = vec_xxpermdi(t3, t4, 3);
+
+ vec_xst(t5, 0, vecOffset);
+ vec_xst(t6, 0, vecOffset + 4);
+ vec_xst(t7, 0, vecOffset + 8);
+ vec_xst(t8, 0, vecOffset + 12);
+ }
+
+ inline void vector_permute_store_8(vector float *src, float *vecOffset) {
+ vector float t1, t2, t3, t4, t5, t6, t7, t8;
+ t1 = vec_mergeh(src[0], src[1]);
+ t2 = vec_mergeh(src[2], src[3]);
+ t3 = vec_mergeh(src[4], src[5]);
+ t4 = vec_mergeh(src[6], src[7]);
+
+ t5 = vec_xxpermdi(t1, t2, 0);
+ t6 = vec_xxpermdi(t3, t4, 0);
+ t7 = vec_xxpermdi(t1, t2, 3);
+ t8 = vec_xxpermdi(t3, t4, 3);
+
+ vec_xst(t5, 0, vecOffset);
+ vec_xst(t6, 0, vecOffset + 4);
+ vec_xst(t7, 0, vecOffset + 8);
+ vec_xst(t8, 0, vecOffset + 12);
+
+ t1 = vec_mergel(src[0], src[1]);
+ t2 = vec_mergel(src[2], src[3]);
+ t3 = vec_mergel(src[4], src[5]);
+ t4 = vec_mergel(src[6], src[7]);
+
+ t5 = vec_xxpermdi(t1, t2, 0);
+ t6 = vec_xxpermdi(t3, t4, 0);
+ t7 = vec_xxpermdi(t1, t2, 3);
+ t8 = vec_xxpermdi(t3, t4, 3);
+
+ vec_xst(t5, 0, vecOffset + 16);
+ vec_xst(t6, 0, vecOffset + 20);
+ vec_xst(t7, 0, vecOffset + 24);
+ vec_xst(t8, 0, vecOffset + 28);
+ }
+
+ void packTranspose(const float* a, int64_t lda, int rows, int cols, float* vec) {
int64_t i, j;
- TA *aoffset = NULL, *boffset = NULL;
- TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
- TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
- __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
- VA c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
- VA c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
- VA t1, t2, t3, t4, t5, t6, t7, t8;
- aoffset = const_cast<TA*>(a);
+ float * aoffsets[8];
+ float *aoffset = NULL, *boffset = NULL;
+ __vector_pair arr[8];
+ vector float c[8][2] = {0};
+ vector float c1[8] = {0};
+ vector float c2[8] = {0};
+ aoffset = const_cast<float*>(a);
boffset = vec;
j = (rows >> 3);
if (j > 0) {
do {
- aoffset1 = aoffset;
- aoffset2 = aoffset1 + lda;
- aoffset3 = aoffset2 + lda;
- aoffset4 = aoffset3 + lda;
- aoffset5 = aoffset4 + lda;
- aoffset6 = aoffset5 + lda;
- aoffset7 = aoffset6 + lda;
- aoffset8 = aoffset7 + lda;
+ aoffsets[0] = aoffset;
+ for (int it = 1; it< 8; it++)
+ aoffsets[it] = aoffsets[it-1] + lda;
aoffset += 8 * lda;
i = (cols >> 3);
if (i > 0) {
do {
- C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
- C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
- C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3);
- C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4);
- C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5);
- C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6);
- C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7);
- C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8);
- __builtin_vsx_disassemble_pair(c1, &C1);
- __builtin_vsx_disassemble_pair(c2, &C2);
- __builtin_vsx_disassemble_pair(c3, &C3);
- __builtin_vsx_disassemble_pair(c4, &C4);
- __builtin_vsx_disassemble_pair(c5, &C5);
- __builtin_vsx_disassemble_pair(c6, &C6);
- __builtin_vsx_disassemble_pair(c7, &C7);
- __builtin_vsx_disassemble_pair(c8, &C8);
-
- t1 = vec_mergeh(c1[0], c2[0]);
- t2 = vec_mergeh(c3[0], c4[0]);
- t3 = vec_mergeh(c5[0], c6[0]);
- t4 = vec_mergeh(c7[0], c8[0]);
- t5 = vec_xxpermdi(t1, t2, 0);
- t6 = vec_xxpermdi(t3, t4, 0);
- t7 = vec_xxpermdi(t1, t2, 3);
- t8 = vec_xxpermdi(t3, t4, 3);
- vec_xst(t5, 0, boffset);
- vec_xst(t6, 0, boffset+4);
- vec_xst(t7, 0, boffset+8);
- vec_xst(t8, 0, boffset+12);
-
- t1 = vec_mergel(c1[0], c2[0]);
- t2 = vec_mergel(c3[0], c4[0]);
- t3 = vec_mergel(c5[0], c6[0]);
- t4 = vec_mergel(c7[0], c8[0]);
- t5 = vec_xxpermdi(t1, t2, 0);
- t6 = vec_xxpermdi(t3, t4, 0);
- t7 = vec_xxpermdi(t1, t2, 3);
- t8 = vec_xxpermdi(t3, t4, 3);
- vec_xst(t5, 0, boffset+16);
- vec_xst(t6, 0, boffset+20);
- vec_xst(t7, 0, boffset+24);
- vec_xst(t8, 0, boffset+28);
-
- t1 = vec_mergeh(c1[1], c2[1]);
- t2 = vec_mergeh(c3[1], c4[1]);
- t3 = vec_mergeh(c5[1], c6[1]);
- t4 = vec_mergeh(c7[1], c8[1]);
- t5 = vec_xxpermdi(t1, t2, 0);
- t6 = vec_xxpermdi(t3, t4, 0);
- t7 = vec_xxpermdi(t1, t2, 3);
- t8 = vec_xxpermdi(t3, t4, 3);
- vec_xst(t5, 0, boffset+32);
- vec_xst(t6, 0, boffset+36);
- vec_xst(t7, 0, boffset+40);
- vec_xst(t8, 0, boffset+44);
-
- t1 = vec_mergel(c1[1], c2[1]);
- t2 = vec_mergel(c3[1], c4[1]);
- t3 = vec_mergel(c5[1], c6[1]);
- t4 = vec_mergel(c7[1], c8[1]);
- t5 = vec_xxpermdi(t1, t2, 0);
- t6 = vec_xxpermdi(t3, t4, 0);
- t7 = vec_xxpermdi(t1, t2, 3);
- t8 = vec_xxpermdi(t3, t4, 3);
- vec_xst(t5, 0, boffset+48);
- vec_xst(t6, 0, boffset+52);
- vec_xst(t7, 0, boffset+56);
- vec_xst(t8, 0, boffset+60);
-
- aoffset1 += 8*lda;
- aoffset2 += 8*lda;
- aoffset3 += 8*lda;
- aoffset4 += 8*lda;
+ for (int it = 0; it< 8; it++) {
+ arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]);
+ __builtin_vsx_disassemble_pair(c[it], &arr[it]);
+ c1[it] = c[it][0];
+ c2[it] = c[it][1];
+ }
+
+ vector_permute_store_8(c1, boffset);
+ vector_permute_store_8(c2, boffset+32);
+ for (int it = 0; it < 4; it++)
+ aoffsets[it] = aoffsets[it] + 8*lda;
boffset += 64;
i--;
} while(i > 0);
}
if (cols & 4) {
- c1[0] = vec_xl(0, aoffset1);
- c2[0] = vec_xl(0, aoffset2);
- c3[0] = vec_xl(0, aoffset3);
- c4[0] = vec_xl(0, aoffset4);
- c5[0] = vec_xl(0, aoffset5);
- c6[0] = vec_xl(0, aoffset6);
- c7[0] = vec_xl(0, aoffset7);
- c8[0] = vec_xl(0, aoffset8);
-
- t1 = vec_mergeh(c1[0], c2[0]);
- t2 = vec_mergeh(c3[0], c4[0]);
- t3 = vec_mergeh(c5[0], c6[0]);
- t4 = vec_mergeh(c7[0], c8[0]);
- t5 = vec_xxpermdi(t1, t2, 0);
- t6 = vec_xxpermdi(t3, t4, 0);
- t7 = vec_xxpermdi(t1, t2, 3);
- t8 = vec_xxpermdi(t3, t4, 3);
- vec_xst(t5, 0, boffset);
- vec_xst(t6, 0, boffset+4);
- vec_xst(t7, 0, boffset+8);
- vec_xst(t8, 0, boffset+12);
-
- t1 = vec_mergel(c1[0], c2[0]);
- t2 = vec_mergel(c3[0], c4[0]);
- t3 = vec_mergel(c5[0], c6[0]);
- t4 = vec_mergel(c7[0], c8[0]);
- t5 = vec_xxpermdi(t1, t2, 0);
- t6 = vec_xxpermdi(t3, t4, 0);
- t7 = vec_xxpermdi(t1, t2, 3);
- t8 = vec_xxpermdi(t3, t4, 3);
- vec_xst(t5, 0, boffset+16);
- vec_xst(t6, 0, boffset+20);
- vec_xst(t7, 0, boffset+24);
- vec_xst(t8, 0, boffset+28);
+ for (int it = 0; it < 8 ; it++)
+ c1[it] = vec_xl(0, aoffsets[it]);
+ vector_permute_store_8(c1, boffset);
}
j--;
} while(j > 0);
}
if (rows & 4) {
- aoffset1 = aoffset;
- aoffset2 = aoffset1 + lda;
- aoffset3 = aoffset2 + lda;
- aoffset4 = aoffset3 + lda;
+ aoffsets[0] = aoffset;
+ for (int it = 1; it < 4; it++)
+ aoffsets[it] = aoffsets[it-1] + lda;
aoffset += 4 * lda;
i = (cols >> 3);
if (i > 0) {
do {
- C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
- C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
- C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3);
- C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4);
- __builtin_vsx_disassemble_pair(c1, &C1);
- __builtin_vsx_disassemble_pair(c2, &C2);
- __builtin_vsx_disassemble_pair(c3, &C3);
- __builtin_vsx_disassemble_pair(c4, &C4);
-
- t1 = vec_mergeh(c1[0], c2[0]);
- t2 = vec_mergeh(c3[0], c4[0]);
- t3 = vec_mergel(c1[0], c2[0]);
- t4 = vec_mergel(c3[0], c4[0]);
- t5 = vec_xxpermdi(t1, t2, 0);
- t6 = vec_xxpermdi(t1, t2, 3);
- t7 = vec_xxpermdi(t3, t4, 0);
- t8 = vec_xxpermdi(t3, t4, 3);
- vec_xst(t5, 0, boffset);
- vec_xst(t6, 0, boffset+4);
- vec_xst(t7, 0, boffset+8);
- vec_xst(t8, 0, boffset+12);
-
- t1 = vec_mergeh(c1[1], c2[1]);
- t2 = vec_mergeh(c3[1], c4[1]);
- t3 = vec_mergel(c1[1], c2[1]);
- t4 = vec_mergel(c3[1], c4[1]);
- t5 = vec_xxpermdi(t1, t2, 0);
- t6 = vec_xxpermdi(t1, t2, 3);
- t7 = vec_xxpermdi(t3, t4, 0);
- t8 = vec_xxpermdi(t3, t4, 3);
- vec_xst(t5, 0, boffset+16);
- vec_xst(t6, 0, boffset+20);
- vec_xst(t7, 0, boffset+24);
- vec_xst(t8, 0, boffset+28);
-
- aoffset1 += 8*lda;
- aoffset2 += 8*lda;
- aoffset3 += 8*lda;
- aoffset4 += 8*lda;
+ for (int it = 0; it < 4; it++) {
+ arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]);
+ __builtin_vsx_disassemble_pair(c[it], &arr[it]);
+ c1[it] = c[it][0];
+ c2[it] = c[it][1];
+ }
+ vector_permute_store_4(c1, boffset);
+ vector_permute_store_4(c2, boffset+16);
+ for (int it = 0; it < 4; it++)
+ aoffsets[it] += 8*lda;
boffset += 32;
i--;
} while(i > 0);
}
if (cols & 4) {
- c1[0] = vec_xl(0, aoffset1);
- c2[0] = vec_xl(0, aoffset2);
- c3[0] = vec_xl(0, aoffset3);
- c4[0] = vec_xl(0, aoffset4);
-
- t1 = vec_mergeh(c1[0], c2[0]);
- t2 = vec_mergeh(c3[0], c4[0]);
- t3 = vec_xxpermdi(t1, t2, 0);
- t4 = vec_xxpermdi(t1, t2, 3);
- vec_xst(t3, 0, boffset);
- vec_xst(t4, 0, boffset+4);
-
- t1 = vec_mergel(c1[0], c2[0]);
- t2 = vec_mergel(c3[0], c4[0]);
- t3 = vec_xxpermdi(t1, t2, 0);
- t4 = vec_xxpermdi(t1, t2, 3);
- vec_xst(t3, 0, boffset+8);
- vec_xst(t4, 0, boffset+12);
+ for (int it = 0; it < 4; it++)
+ c1[it] = vec_xl(0, aoffsets[it]);
+ vector_permute_store_4(c1, boffset);
}
}
if (rows & 3) {
- aoffset1 = aoffset;
- aoffset2 = aoffset1 + lda;
- aoffset3 = aoffset2 + lda;
+ aoffsets[0] = aoffset;
+ for (int it = 1; it < 3; it++)
+ aoffsets[it] = aoffsets[it-1] + lda;
if (cols & 4) {
- c1[0] = vec_xl(0, aoffset1);
- c2[0] = vec_xl(0, aoffset2);
- c3[0] = vec_xl(0, aoffset3);
-
- t1 = vec_mergeh(c1[0], c2[0]);
- t2 = vec_mergeh(c3[0], c4[0]);
- t3 = vec_xxpermdi(t1, t2, 0);
- t4 = vec_xxpermdi(t1, t2, 3);
- vec_xst(t3, 0, boffset);
- vec_xst(t4, 0, boffset+4);
-
- t1 = vec_mergel(c1[0], c2[0]);
- t2 = vec_mergel(c3[0], c4[0]);
- t3 = vec_xxpermdi(t1, t2, 0);
- t4 = vec_xxpermdi(t1, t2, 3);
- vec_xst(t3, 0, boffset+8);
- vec_xst(t4, 0, boffset+12);
+ for (int it = 0; it < 3; it++)
+ c1[it] = vec_xl(0, aoffsets[it]);
+ vector_permute_store_4(c1, boffset);
}
}
}
acc_t acc_0;
__builtin_mma_xxsetaccz(&acc_0);
for (int l = 0; l < k; l+=4) {
- packTranspose<vector float>(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A);
- packTranspose<vector float>(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
+ packTranspose(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
+ packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
__builtin_mma_xxsetaccz(&acc_0);
__builtin_mma_xxsetaccz(&acc_1);
for (int64_t l = 0; l < k; l+=4) {
- packTranspose<vector float>(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A);
- packTranspose<vector float>(B+(jj*ldb)+l, ldb, 8, 4, (TA*)vec_B);
+ packTranspose(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
+ packTranspose(B+(jj*ldb)+l, ldb, 8, 4, (float*)vec_B);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
__builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);
__builtin_mma_xxsetaccz(&acc_0);
__builtin_mma_xxsetaccz(&acc_1);
for (int64_t l = 0; l < k; l+=4) {
- packTranspose<vector float>(A+(ii*lda)+l, lda, 8, 4, (TA*)vec_A);
- packTranspose<vector float>(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
+ packTranspose(A+(ii*lda)+l, lda, 8, 4, (float*)vec_A);
+ packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);
__builtin_mma_xxsetaccz(&acc_2);
__builtin_mma_xxsetaccz(&acc_3);
for (int l = 0; l < k; l+=8) {
- packTranspose<vector float>(A+(ii*lda)+l, lda, 8, 8, (TA*)vec_A);
- packTranspose<vector float>(B+(jj*ldb)+l, ldb, 8, 8, (TA*)vec_B);
+ packTranspose(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A);
+ packTranspose(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B);
for(int x = 0; x < 16; x+=2) {
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]);
}
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
- int64_t mc, nc, mp, np;
- int m_rem = MIN(m - m0, 16);
- int n_rem = MIN(n - n0, 16);
- if (m_rem >= 16 && n_rem >= 8) {
- mc = 8;
- nc = 8;
- gemm<8,8>(m0, m, n0, n);
- } else if(m_rem >= 8 && n_rem >= 16) {
- mc = 8;
- nc = 8;
- gemm<8,8>(m0, m, n0, n);
- } else if (m_rem >= 8 && n_rem >= 8) {
- mc = 8;
- nc = 8;
- gemm<8,8>(m0, m, n0, n);
+ int m_rem = MIN(m - m0, 8);
+ int n_rem = MIN(n - n0, 8);
+ int mc = 0, nc = 0;
+ if (m_rem >= 8 && n_rem >= 8) {
+ mc = 8;
+ nc = 8;
+ gemm<8, 8>(m0, m, n0, n);
} else if (m_rem >= 4 && n_rem >= 8) {
- mc = 4;
- nc = 8;
- gemm<4,8>(m0, m, n0, n);
+ mc = 4;
+ nc = 8;
+ gemm<4, 8>(m0, m, n0, n);
} else if (m_rem >= 8 && n_rem >= 4) {
- mc = 8;
- nc = 4;
- gemm<8,4>(m0, m, n0, n);
+ mc = 8;
+ nc = 4;
+ gemm<8, 4>(m0, m, n0, n);
} else if (m_rem >= 4 && n_rem >= 4) {
- mc = 4;
- nc = 4;
- gemm<4,4>(m0, m, n0, n);
- } else if ((m_rem < 4) && (n_rem > 4)) {
- nc = 4;
- switch(m_rem) {
- case 1:
- mc = 1;
- gemm_small(m0, m, n0, n, mc, nc);
- break;
- case 2:
- mc = 2;
- gemm_small(m0, m, n0, n, mc, nc);
- break;
- case 3:
- mc = 3;
- gemm_small(m0, m, n0, n, mc, nc);
- break;
- default:
- return;
- }
- } else if ((m_rem > 4) && (n_rem < 4)) {
- mc = 4;
- switch(n_rem) {
- case 1:
- nc = 1;
- gemm_small(m0, m, n0, n, mc, nc);
- break;
- case 2:
- nc = 2;
- gemm_small(m0, m, n0, n, mc, nc);
- break;
- case 3:
- nc = 3;
- gemm_small(m0, m, n0, n, mc, nc);
- break;
- default:
- return;
- }
+ mc = 4;
+ nc = 4;
+ gemm<4, 4>(m0, m, n0, n);
} else {
- switch((m_rem << 4) | n_rem) {
- case 0x43:
- mc = 4;
- nc = 3;
- gemm_small(m0, m, n0, n, mc, nc);
- break;
- case 0x42:
- mc = 4;
- nc = 2;
- gemm_small(m0, m, n0, n, mc, nc);
- break;
- case 0x41:
- mc = 4;
- nc = 1;
- gemm_small(m0, m, n0, n, mc, nc);
- break;
- case 0x34:
- mc = 3;
- nc = 4;
- gemm_small(m0, m, n0, n, mc, nc);
- break;
- case 0x33:
- mc = 3;
- nc = 3;
- gemm_small(m0, m, n0, n, mc, nc);
- break;
- case 0x32:
- mc = 3;
- nc = 2;
- gemm_small(m0, m, n0, n, mc, nc);
- break;
- case 0x31:
- mc = 3;
- nc = 1;
- gemm_small(m0, m, n0, n, mc, nc);
- break;
- case 0x24:
- mc = 2;
- nc = 4;
- gemm_small(m0, m, n0, n, mc, nc);
- break;
- case 0x23:
- mc = 2;
- nc = 3;
- gemm_small(m0, m, n0, n, mc, nc);
- break;
- case 0x22:
- mc = 2;
- nc = 2;
- gemm_small(m0, m, n0, n, mc, nc);
- break;
- case 0x21:
- mc = 2;
- nc = 1;
- gemm_small(m0, m, n0, n, mc, nc);
- break;
- case 0x14:
- mc = 1;
- nc = 4;
- gemm_small(m0, m, n0, n, mc, nc);
- break;
- case 0x13:
- mc = 1;
- nc = 3;
- gemm_small(m0, m, n0, n, mc, nc);
- break;
- case 0x12:
- mc = 1;
- nc = 2;
- gemm_small(m0, m, n0, n, mc, nc);
- break;
- case 0x11:
- mc = 1;
- nc = 1;
- gemm_small(m0, m, n0, n, mc, nc);
- break;
- default:
- return;
- }
+ mc = (m_rem >= 4) ? 4 : m_rem;
+ nc = (n_rem >= 4) ? 4 : n_rem;
+ if (mc == 0 || nc == 0)
+ return;
+ gemm_small(m0, m, n0, n, mc, nc);
}
- mp = m0 + (m - m0) / mc * mc;
- np = n0 + (n - n0) / nc * nc;
+ int64_t mp = m0 + ((m - m0) / mc) * mc;
+ int64_t np = n0 + ((n - n0) / nc) * nc;
mnpack(mp, m, n0, np);
mnpack(m0, m, np, n);
- }
+ }
void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
int64_t ytiles = (m - m0) / RM;
* matrix elements.
*/
if (RM == 1) {
- TA* a = const_cast<TA*>(A+(ii)*lda+l);
- packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
+ float* a = const_cast<float*>(A+(ii)*lda+l);
+ packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B);
vec_A[0] = (vec_t)vec_xl(0,a);
- vec_A[1] = (vec_t)vec_splats(*((TA*)&vec_A+1));
- vec_A[2] = (vec_t)vec_splats(*((TA*)&vec_A+2));
- vec_A[3] = (vec_t)vec_splats(*((TA*)&vec_A+3));
+ vec_A[1] = (vec_t)vec_splats(*((float*)&vec_A+1));
+ vec_A[2] = (vec_t)vec_splats(*((float*)&vec_A+2));
+ vec_A[3] = (vec_t)vec_splats(*((float*)&vec_A+3));
} else if (RN == 1) {
- packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
- TB* b = const_cast<TB*>(B+(jj)*ldb+l);
+ packTranspose(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A);
+ float* b = const_cast<float*>(B+(jj)*ldb+l);
vec_B[0] = (vec_t)vec_xl(0,b);
- vec_B[1] = (vec_t)vec_splats(*((TB*)&vec_B+1));
- vec_B[2] = (vec_t)vec_splats(*((TB*)&vec_B+2));
- vec_B[3] = (vec_t)vec_splats(*((TB*)&vec_B+3));
+ vec_B[1] = (vec_t)vec_splats(*((float*)&vec_B+1));
+ vec_B[2] = (vec_t)vec_splats(*((float*)&vec_B+2));
+ vec_B[3] = (vec_t)vec_splats(*((float*)&vec_B+3));
} else {
- packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
- packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
+ packTranspose(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A);
+ packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B);
}
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
__builtin_mma_disassemble_acc(vec_C, &acc_0);
for (int I = 0; I < RM; I++) {
for (int J = 0; J < RN; J++) {
- *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
+ *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J);
}
}
}
}
}
- const TA *const A;
- const TB *const B;
- TC *C;
- TA *At;
- TB *Bt;
+ const float *const A;
+ const float *const B;
+ float *C;
const int64_t k;
const int64_t lda;
const int64_t ldb;
#elif defined(__MMA__)
if (k % 8)
return false;
- tinyBLAS_PPC<float, float, float> tb{
+ tinyBLAS_PPC tb{
k, (const float *)A, lda,
(const float *)B, ldb,
(float *)C, ldc,
return false;
if (m < 8 && m != 4)
return false;
- tinyBLAS_Q0_PPC<block_q8_0, block_q8_0, float> tb{
+ tinyBLAS_Q0_PPC<block_q8_0> tb{
k, (const block_q8_0 *)A, lda,
(const block_q8_0 *)B, ldb,
(float *)C, ldc,
return false;
if (m < 8 && m != 4)
return false;
- tinyBLAS_Q0_PPC<block_q4_0, block_q8_0, float> tb{
+ tinyBLAS_Q0_PPC<block_q4_0> tb{
k, (const block_q4_0 *)A, lda,
(const block_q8_0 *)B, ldb,
(float *)C, ldc,