} \
} \
+template<typename T>
+struct mma_instr;
+
+template<>
+struct mma_instr<ggml_bf16_t> {
+ static inline void outer_product(acc_t *acc, vec_t a, vec_t b) {
+ __builtin_mma_xvbf16ger2pp(acc, a, b);
+ }
+};
+
+template<>
+struct mma_instr<ggml_fp16_t> {
+ static inline void outer_product(acc_t *acc, vec_t a, vec_t b) {
+ __builtin_mma_xvf16ger2pp(acc, a, b);
+ }
+};
+
template <typename TA, typename TB, typename TC>
-class tinyBLAS_BF16_PPC {
+class tinyBLAS_HP16_PPC {
public:
- tinyBLAS_BF16_PPC(int64_t k,
+ tinyBLAS_HP16_PPC(int64_t k,
const TA *A, int64_t lda,
const TB *B, int64_t ldb,
TC *C, int64_t ldc,
packNormal((A+(ii*lda)+l), lda, 4, 8, (uint8_t*)vec_A);
packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B);
for (int x = 0; x < 4; x++) {
- __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
- __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]);
+ mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
+ mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
}
}
SAVE_ACC(&acc_0, ii, jj);
packNormal((A+(ii*lda)+l), lda, 8, 8, (uint8_t*)vec_A);
packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B);
for (int x = 0; x < 4; x++) {
- __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
- __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x+4], vec_B[x]);
+ mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
+ mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
}
}
SAVE_ACC(&acc_0, ii, jj);
packNormal(A+(ii*lda)+l, lda, 8, 8, (uint8_t*)vec_A);
packNormal(B+(jj*ldb)+l, ldb, 8, 8, (uint8_t*)vec_B);
for (int x = 0; x < 4; x++) {
- __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
- __builtin_mma_xvbf16ger2pp(&acc_1, (vec_t)vec_A[x], (vec_t)vec_B[x+4]);
- __builtin_mma_xvbf16ger2pp(&acc_2, (vec_t)vec_A[x+4], (vec_t)vec_B[x]);
- __builtin_mma_xvbf16ger2pp(&acc_3, (vec_t)vec_A[x+4], (vec_t)vec_B[x+4]);
+ mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
+ mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
+ mma_instr<TA>::outer_product(&acc_2, vec_A[x+4], vec_B[x]);
+ mma_instr<TA>::outer_product(&acc_3, vec_A[x+4], vec_B[x+4]);
}
}
packNormal(A+(ii*lda)+l, lda, RM, 4, (uint8_t*)vec_A);
packNormal(B+(jj*ldb)+l, ldb, RN, 4, (uint8_t*)vec_B);
for (int x = 0; x<2; x++) {
- __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
+ mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
}
}
__builtin_mma_disassemble_acc(vec_C, &acc_0);
packNormal(A+(ii*lda)+l, lda, RM, 8, (uint8_t*)vec_A);
packNormal(B+(jj*ldb)+l, ldb, RN, 8, (uint8_t*)vec_B);
for (int x = 0; x<4; x++) {
- __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
- __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]);
+ mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
+ mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
}
}
__builtin_mma_disassemble_acc(vec_C, &acc_0);
return tb.matmul(m, n);
}
#elif defined(__MMA__)
- if ((k % 8))
- return false;
- if(Btype == GGML_TYPE_BF16) {
- tinyBLAS_BF16_PPC<ggml_bf16_t, ggml_bf16_t, float> tb{ k,
- (const ggml_bf16_t *)A, lda,
- (const ggml_bf16_t *)B, ldb,
- (float *)C, ldc,
- params->ith, params->nth};
- tb.matmul(m, n);
- return true;
+ if (k % 8) {
+ return false;
+ }
+
+ if (Btype == GGML_TYPE_BF16) {
+ tinyBLAS_HP16_PPC<ggml_bf16_t, ggml_bf16_t, float> tb{ k,
+ (const ggml_bf16_t *)A, lda,
+ (const ggml_bf16_t *)B, ldb,
+ (float *)C, ldc,
+ params->ith, params->nth };
+
+ tb.matmul(m, n);
+ return true;
}
#elif defined(__riscv_zvfbfwma)
#if LMUL == 1
#endif
return tb.matmul(m, n);
}
+#elif defined(__MMA__)
+ if (k % 8) {
+ return false;
+ }
+
+ if (Btype == GGML_TYPE_F16) {
+ tinyBLAS_HP16_PPC<ggml_fp16_t, ggml_fp16_t, float> tb{ k,
+ (const ggml_fp16_t *)A, lda,
+ (const ggml_fp16_t *)B, ldb,
+ (float *)C, ldc,
+ params->ith, params->nth };
+
+ tb.matmul(m, n);
+ return true;
+ }
#endif
return false;
}