]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : optimize llamafile cpu matrix multiplication for ppc64le (#10156)
authoramritahs-ibm <redacted>
Sat, 9 Nov 2024 07:17:50 +0000 (12:47 +0530)
committerGitHub <redacted>
Sat, 9 Nov 2024 07:17:50 +0000 (09:17 +0200)
This change upstreams llamafile's cpu matrix
multiplication kernels for ppc64le using MMA
builtins for FP32 datatype.

This change results in a consistent 90%
improvement in input processing time, and 20%
to 80% improvement in output processing time,
across various batch sizes.

The patch is tested with Meta-Lllama-3-8B,
Mistral-7B, Llama-2-7B-chat-hf models on a
IBM POWER10 machine.

Signed-off-by: Amrita H S <redacted>
ggml/src/CMakeLists.txt
ggml/src/llamafile/sgemm.cpp

index 6c5b816d2f5e7d18fb4c972a7fbe5f3ef70aa925..a05f8c505c49211d61509470ce4b86144438e9ef 100644 (file)
@@ -1265,8 +1265,13 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
     endif()
 elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
     message(STATUS "PowerPC detected")
-    if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le")
-        list(APPEND ARCH_FLAGS -mcpu=powerpc64le)
+    execute_process(COMMAND bash -c "grep POWER10 /proc/cpuinfo | head -n 1"
+                   OUTPUT_VARIABLE POWER10_M)
+    string(FIND ${POWER10_M} "POWER10" substring_index)
+    if(${substring_index} GREATER_EQUAL 0)
+       list(APPEND ARCH_FLAGS -mcpu=power10)
+    elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le")
+       list(APPEND ARCH_FLAGS -mcpu=powerpc64le)
     else()
         list(APPEND ARCH_FLAGS -mcpu=native -mtune=native)
         #TODO: Add  targets for Power8/Power9 (Altivec/VSX) and Power10(MMA) and query for big endian systems (ppc64/le/be)
index 9eead3f61e090dd2f1c0e4840f94d0279165d68e..da4146ec4f6886773c74c0bb1d31ae4f72d0462f 100644 (file)
@@ -106,6 +106,10 @@ inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
 inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
 
+#if defined(__MMA__)
+typedef vector unsigned char vec_t;
+typedef __vector_quad acc_t;
+#endif
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 // VECTORIZED FUSED MULTIPLY ADD
 
@@ -1026,6 +1030,600 @@ class tinyBLAS_Q0_AVX {
 };
 #endif // __AVX__
 
+//PPC Implementation
+#if defined(__MMA__)
+
+#define SAVE_ACC(ACC, ii, jj) \
+   __builtin_mma_disassemble_acc(vec_C, ACC); \
+   for (int I = 0; I < 4; I++) { \
+      for (int J = 0; J < 4; J++) { \
+         *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J); \
+      } \
+   } \
+
+template <typename TA, typename TB, typename TC>
+class tinyBLAS_PPC {
+  public:
+    tinyBLAS_PPC(int64_t k,
+                const TA *A, int64_t lda,
+                const TB *B, int64_t ldb,
+                TC *C, int64_t ldc,
+                int ith, int nth)
+        : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
+    }
+
+    void matmul(int64_t m, int64_t n) {
+       mnpack(0, m, 0, n);
+    }
+
+  private:
+
+    void (tinyBLAS_PPC::*kernel)(int64_t, int64_t);
+
+    void READ_BLOCK(const float* a, int64_t lda, int rows, int cols, float* vec) {
+        int64_t i, j;
+        float *aoffset = NULL, *boffset = NULL;
+        float *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
+        float *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
+
+        aoffset = const_cast<float*>(a);
+        boffset = vec;
+        j = (rows >> 3);
+        if (j > 0) {
+            do {
+                aoffset1 = aoffset;
+                aoffset2 = aoffset1 + lda;
+                aoffset3 = aoffset2 + lda;
+                aoffset4 = aoffset3 + lda;
+                aoffset5 = aoffset4 + lda;
+                aoffset6 = aoffset5 + lda;
+                aoffset7 = aoffset6 + lda;
+                aoffset8 = aoffset7 + lda;
+                aoffset += 8 * lda;
+                i = (cols >> 3);
+                if (i > 0) {
+                    __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
+                    vector float c1[2], c2[2], c3[2], c4[2], c5[2], c6[2], c7[2], c8[2];
+                    vector float t1, t2, t3, t4, t5, t6, t7, t8;
+                    do {
+                        C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
+                        C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
+                        C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3);
+                        C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4);
+                        C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5);
+                        C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6);
+                        C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7);
+                        C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8);
+                        __builtin_vsx_disassemble_pair(c1, &C1);
+                        __builtin_vsx_disassemble_pair(c2, &C2);
+                        __builtin_vsx_disassemble_pair(c3, &C3);
+                        __builtin_vsx_disassemble_pair(c4, &C4);
+                        __builtin_vsx_disassemble_pair(c5, &C5);
+                        __builtin_vsx_disassemble_pair(c6, &C6);
+                        __builtin_vsx_disassemble_pair(c7, &C7);
+                        __builtin_vsx_disassemble_pair(c8, &C8);
+
+                        t1 = vec_mergeh(c1[0], c2[0]);
+                        t2 = vec_mergeh(c3[0], c4[0]);
+                        t3 = vec_mergeh(c5[0], c6[0]);
+                        t4 = vec_mergeh(c7[0], c8[0]);
+                        t5 = vec_xxpermdi(t1, t2, 0);
+                        t6 = vec_xxpermdi(t3, t4, 0);
+                        t7 = vec_xxpermdi(t1, t2, 3);
+                        t8 = vec_xxpermdi(t3, t4, 3);
+                        vec_xst(t5, 0, boffset);
+                        vec_xst(t6, 0, boffset+4);
+                        vec_xst(t7, 0, boffset+8);
+                        vec_xst(t8, 0, boffset+12);
+
+                        t1 = vec_mergel(c1[0], c2[0]);
+                        t2 = vec_mergel(c3[0], c4[0]);
+                        t3 = vec_mergel(c5[0], c6[0]);
+                        t4 = vec_mergel(c7[0], c8[0]);
+                        t5 = vec_xxpermdi(t1, t2, 0);
+                        t6 = vec_xxpermdi(t3, t4, 0);
+                        t7 = vec_xxpermdi(t1, t2, 3);
+                        t8 = vec_xxpermdi(t3, t4, 3);
+                        vec_xst(t5, 0, boffset+16);
+                        vec_xst(t6, 0, boffset+20);
+                        vec_xst(t7, 0, boffset+24);
+                        vec_xst(t8, 0, boffset+28);
+
+                        t1 = vec_mergeh(c1[1], c2[1]);
+                        t2 = vec_mergeh(c3[1], c4[1]);
+                        t3 = vec_mergeh(c5[1], c6[1]);
+                        t4 = vec_mergeh(c7[1], c8[1]);
+                        t5 = vec_xxpermdi(t1, t2, 0);
+                        t6 = vec_xxpermdi(t3, t4, 0);
+                        t7 = vec_xxpermdi(t1, t2, 3);
+                        t8 = vec_xxpermdi(t3, t4, 3);
+                        vec_xst(t5, 0, boffset+32);
+                        vec_xst(t6, 0, boffset+36);
+                        vec_xst(t7, 0, boffset+40);
+                        vec_xst(t8, 0, boffset+44);
+
+                        t1 = vec_mergel(c1[1], c2[1]);
+                        t2 = vec_mergel(c3[1], c4[1]);
+                        t3 = vec_mergel(c5[1], c6[1]);
+                        t4 = vec_mergel(c7[1], c8[1]);
+                        t5 = vec_xxpermdi(t1, t2, 0);
+                        t6 = vec_xxpermdi(t3, t4, 0);
+                        t7 = vec_xxpermdi(t1, t2, 3);
+                        t8 = vec_xxpermdi(t3, t4, 3);
+                        vec_xst(t5, 0, boffset+48);
+                        vec_xst(t6, 0, boffset+52);
+                        vec_xst(t7, 0, boffset+56);
+                        vec_xst(t8, 0, boffset+60);
+
+                        aoffset1 += 8*lda;
+                        aoffset2 += 8*lda;
+                        aoffset3 += 8*lda;
+                        aoffset4 += 8*lda;
+                        boffset += 64;
+                        i--;
+                    } while(i > 0);
+                }
+                if (cols & 4) {
+                    vector float c1, c2, c3, c4, c5, c6, c7, c8;
+                    vector float t1, t2, t3, t4, t5, t6, t7, t8;
+                    c1 = vec_xl(0, aoffset1);
+                    c2 = vec_xl(0, aoffset2);
+                    c3 = vec_xl(0, aoffset3);
+                    c4 = vec_xl(0, aoffset4);
+                    c5 = vec_xl(0, aoffset5);
+                    c6 = vec_xl(0, aoffset6);
+                    c7 = vec_xl(0, aoffset7);
+                    c8 = vec_xl(0, aoffset8);
+
+                    t1 = vec_mergeh(c1, c2);
+                    t2 = vec_mergeh(c3, c4);
+                    t3 = vec_mergeh(c5, c6);
+                    t4 = vec_mergeh(c7, c8);
+                    t5 = vec_xxpermdi(t1, t2, 0);
+                    t6 = vec_xxpermdi(t3, t4, 0);
+                    t7 = vec_xxpermdi(t1, t2, 3);
+                    t8 = vec_xxpermdi(t3, t4, 3);
+                    vec_xst(t5, 0, boffset);
+                    vec_xst(t6, 0, boffset+4);
+                    vec_xst(t7, 0, boffset+8);
+                    vec_xst(t8, 0, boffset+12);
+
+                    t1 = vec_mergel(c1, c2);
+                    t2 = vec_mergel(c3, c4);
+                    t3 = vec_mergel(c5, c6);
+                    t4 = vec_mergel(c7, c8);
+                    t5 = vec_xxpermdi(t1, t2, 0);
+                    t6 = vec_xxpermdi(t3, t4, 0);
+                    t7 = vec_xxpermdi(t1, t2, 3);
+                    t8 = vec_xxpermdi(t3, t4, 3);
+                    vec_xst(t5, 0, boffset+16);
+                    vec_xst(t6, 0, boffset+20);
+                    vec_xst(t7, 0, boffset+24);
+                    vec_xst(t8, 0, boffset+28);
+                }
+            j--;
+            } while(j > 0);
+        }
+
+        if (rows & 4) {
+            aoffset1 = aoffset;
+            aoffset2 = aoffset1 + lda;
+            aoffset3 = aoffset2 + lda;
+            aoffset4 = aoffset3 + lda;
+            aoffset += 4 * lda;
+            i = (cols >> 3);
+            if (i > 0) {
+                __vector_pair C1, C2, C3, C4;
+                vector float c1[2], c2[2], c3[2], c4[2];
+                vector float t1, t2, t3, t4, t5, t6, t7, t8;
+                do {
+                    C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
+                    C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
+                    C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3);
+                    C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4);
+                    __builtin_vsx_disassemble_pair(c1, &C1);
+                    __builtin_vsx_disassemble_pair(c2, &C2);
+                    __builtin_vsx_disassemble_pair(c3, &C3);
+                    __builtin_vsx_disassemble_pair(c4, &C4);
+
+                    t1 = vec_mergeh(c1[0], c2[0]);
+                    t2 = vec_mergeh(c3[0], c4[0]);
+                    t3 = vec_mergel(c1[0], c2[0]);
+                    t4 = vec_mergel(c3[0], c4[0]);
+                    t5 = vec_xxpermdi(t1, t2, 0);
+                    t6 = vec_xxpermdi(t1, t2, 3);
+                    t7 = vec_xxpermdi(t3, t4, 0);
+                    t8 = vec_xxpermdi(t3, t4, 3);
+                    vec_xst(t5, 0, boffset);
+                    vec_xst(t6, 0, boffset+4);
+                    vec_xst(t7, 0, boffset+8);
+                    vec_xst(t8, 0, boffset+12);
+
+                    t1 = vec_mergeh(c1[1], c2[1]);
+                    t2 = vec_mergeh(c3[1], c4[1]);
+                    t3 = vec_mergel(c1[1], c2[1]);
+                    t4 = vec_mergel(c3[1], c4[1]);
+                    t5 = vec_xxpermdi(t1, t2, 0);
+                    t6 = vec_xxpermdi(t1, t2, 3);
+                    t7 = vec_xxpermdi(t3, t4, 0);
+                    t8 = vec_xxpermdi(t3, t4, 3);
+                    vec_xst(t5, 0, boffset+16);
+                    vec_xst(t6, 0, boffset+20);
+                    vec_xst(t7, 0, boffset+24);
+                    vec_xst(t8, 0, boffset+28);
+
+                    aoffset1 += 8*lda;
+                    aoffset2 += 8*lda;
+                    aoffset3 += 8*lda;
+                    aoffset4 += 8*lda;
+                    boffset += 32;
+                    i--;
+                } while(i > 0);
+            }
+
+            if (cols & 4) {
+                vector float c1, c2, c3, c4;
+                vector float t1, t2, t3, t4;
+                c1 = vec_xl(0, aoffset1);
+                c2 = vec_xl(0, aoffset2);
+                c3 = vec_xl(0, aoffset3);
+                c4 = vec_xl(0, aoffset4);
+
+                t1 = vec_mergeh(c1, c2);
+                t2 = vec_mergeh(c3, c4);
+                t3 = vec_xxpermdi(t1, t2, 0);
+                t4 = vec_xxpermdi(t1, t2, 3);
+                vec_xst(t3, 0, boffset);
+                vec_xst(t4, 0, boffset+4);
+
+                t1 = vec_mergel(c1, c2);
+                t2 = vec_mergel(c3, c4);
+                t3 = vec_xxpermdi(t1, t2, 0);
+                t4 = vec_xxpermdi(t1, t2, 3);
+                vec_xst(t3, 0, boffset+8);
+                vec_xst(t4, 0, boffset+12);
+            }
+        }
+        if (rows & 3) {
+            aoffset1 = aoffset;
+            aoffset2 = aoffset1 + lda;
+            aoffset3 = aoffset2 + lda;
+            if (cols & 4) {
+                vector float c1, c2, c3, c4 = {0};
+                vector float t1, t2, t3, t4;
+                c1 = vec_xl(0, aoffset1);
+                c2 = vec_xl(0, aoffset2);
+                c3 = vec_xl(0, aoffset3);
+
+                t1 = vec_mergeh(c1, c2);
+                t2 = vec_mergeh(c3, c4);
+                t3 = vec_xxpermdi(t1, t2, 0);
+                t4 = vec_xxpermdi(t1, t2, 3);
+                vec_xst(t3, 0, boffset);
+                vec_xst(t4, 0, boffset+4);
+
+                t1 = vec_mergel(c1, c2);
+                t2 = vec_mergel(c3, c4);
+                t3 = vec_xxpermdi(t1, t2, 0);
+                t4 = vec_xxpermdi(t1, t2, 3);
+                vec_xst(t3, 0, boffset+8);
+                vec_xst(t4, 0, boffset+12);
+            }
+        }
+    }
+
+    void KERNEL_4x4(int64_t ii, int64_t jj) {
+        vec_t vec_A[4], vec_B[4], vec_C[4];
+        acc_t acc_0;
+        __builtin_mma_xxsetaccz(&acc_0);
+        for (int l = 0; l < k; l+=4) {
+            READ_BLOCK(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
+            READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
+            __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
+            __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
+            __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
+            __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
+        }
+        SAVE_ACC(&acc_0, ii, jj);
+    }
+
+    void KERNEL_4x8(int64_t ii, int64_t jj) {
+        vec_t vec_A[4], vec_B[8], vec_C[4];
+        acc_t acc_0, acc_1;
+        __builtin_mma_xxsetaccz(&acc_0);
+        __builtin_mma_xxsetaccz(&acc_1);
+        for (int64_t l = 0; l < k; l+=4) {
+            READ_BLOCK(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
+            READ_BLOCK(B+(jj*ldb)+l, ldb, 8, 4, (float*)vec_B);
+            __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
+            __builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
+            __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);
+            __builtin_mma_xvf32gerpp(&acc_1, vec_A[1], (vec_t)vec_B[3]);
+            __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], (vec_t)vec_B[4]);
+            __builtin_mma_xvf32gerpp(&acc_1, vec_A[2], (vec_t)vec_B[5]);
+            __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], (vec_t)vec_B[6]);
+            __builtin_mma_xvf32gerpp(&acc_1, vec_A[3], (vec_t)vec_B[7]);
+        }
+        SAVE_ACC(&acc_0, ii, jj);
+        SAVE_ACC(&acc_1, ii, jj+4);
+    }
+
+    void KERNEL_8x4(int64_t ii, int64_t jj) {
+        vec_t vec_A[8], vec_B[4], vec_C[4];
+        acc_t acc_0, acc_1;
+        __builtin_mma_xxsetaccz(&acc_0);
+        __builtin_mma_xxsetaccz(&acc_1);
+        for (int64_t l = 0; l < k; l+=4) {
+            READ_BLOCK(A+(ii*lda)+l, lda, 8, 4, (float*)vec_A);
+            READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
+            __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
+            __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
+            __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);
+            __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[3], vec_B[1]);
+            __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[4], vec_B[2]);
+            __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[5], vec_B[2]);
+            __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[6], vec_B[3]);
+            __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[7], vec_B[3]);
+        }
+        SAVE_ACC(&acc_0, ii, jj);
+        SAVE_ACC(&acc_1, ii+4, jj);
+    }
+
+    void KERNEL_8x8(int64_t ii, int64_t jj) {
+        vec_t vec_A[16], vec_B[16], vec_C[4];
+        acc_t acc_0, acc_1, acc_2, acc_3;
+        __builtin_mma_xxsetaccz(&acc_0);
+        __builtin_mma_xxsetaccz(&acc_1);
+        __builtin_mma_xxsetaccz(&acc_2);
+        __builtin_mma_xxsetaccz(&acc_3);
+        for (int l = 0; l < k; l+=8) {
+            READ_BLOCK(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A);
+            READ_BLOCK(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B);
+            for(int x = 0; x < 16; x+=2) {
+                __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
+                __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]);
+                __builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x+1], vec_B[x]);
+                __builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x+1], vec_B[x+1]);
+            }
+        }
+        SAVE_ACC(&acc_0, ii, jj);
+        SAVE_ACC(&acc_1, ii, jj+4);
+        SAVE_ACC(&acc_2, ii+4, jj);
+        SAVE_ACC(&acc_3, ii+4, jj+4);
+    }
+
+    void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+        int64_t mc, nc, mp, np;
+        int m_rem = MIN(m - m0, 16);
+        int n_rem = MIN(n - n0, 16);
+        if (m_rem >= 16 && n_rem >= 8) {
+            mc = 8;
+            nc = 8;
+            gemm<8,8>(m0, m, n0, n);
+        } else if(m_rem >= 8 && n_rem >= 16) {
+            mc = 8;
+            nc = 8;
+            gemm<8,8>(m0, m, n0, n);
+        } else if (m_rem >= 8 && n_rem >= 8) {
+            mc = 8;
+            nc = 8;
+            gemm<8,8>(m0, m, n0, n);
+        } else if (m_rem >= 4 && n_rem >= 8) {
+            mc = 4;
+            nc = 8;
+            gemm<4,8>(m0, m, n0, n);
+        } else if (m_rem >= 8 && n_rem >= 4) {
+            mc = 8;
+            nc = 4;
+            gemm<8,4>(m0, m, n0, n);
+        } else if (m_rem >= 4 && n_rem >= 4) {
+            mc = 4;
+            nc = 4;
+            gemm<4,4>(m0, m, n0, n);
+        } else if ((m_rem < 4) && (n_rem > 4)) {
+            nc = 4;
+            switch(m_rem) {
+                case 1:
+                    mc = 1;
+                    gemm_small(m0, m, n0, n, mc, nc);
+                    break;
+                case 2:
+                    mc = 2;
+                    gemm_small(m0, m, n0, n, mc, nc);
+                    break;
+                case 3:
+                    mc = 3;
+                    gemm_small(m0, m, n0, n, mc, nc);
+                    break;
+                default:
+                    return;
+            }
+        } else if ((m_rem > 4) && (n_rem < 4)) {
+            mc = 4;
+            switch(n_rem) {
+                case 1:
+                    nc = 1;
+                    gemm_small(m0, m, n0, n, mc, nc);
+                    break;
+                case 2:
+                    nc = 2;
+                    gemm_small(m0, m, n0, n, mc, nc);
+                    break;
+                case 3:
+                    nc = 3;
+                    gemm_small(m0, m, n0, n, mc, nc);
+                    break;
+                default:
+                    return;
+            }
+        } else {
+            switch((m_rem << 4) | n_rem) {
+                case 0x43:
+                    mc = 4;
+                    nc = 3;
+                    gemm_small(m0, m, n0, n, mc, nc);
+                    break;
+                case 0x42:
+                    mc = 4;
+                    nc = 2;
+                    gemm_small(m0, m, n0, n, mc, nc);
+                    break;
+                case 0x41:
+                    mc = 4;
+                    nc = 1;
+                    gemm_small(m0, m, n0, n, mc, nc);
+                    break;
+                case 0x34:
+                    mc = 3;
+                    nc = 4;
+                    gemm_small(m0, m, n0, n, mc, nc);
+                    break;
+                case 0x33:
+                    mc = 3;
+                    nc = 3;
+                    gemm_small(m0, m, n0, n, mc, nc);
+                    break;
+                case 0x32:
+                    mc = 3;
+                    nc = 2;
+                    gemm_small(m0, m, n0, n, mc, nc);
+                    break;
+                case 0x31:
+                    mc = 3;
+                    nc = 1;
+                    gemm_small(m0, m, n0, n, mc, nc);
+                    break;
+                case 0x24:
+                    mc = 2;
+                    nc = 4;
+                    gemm_small(m0, m, n0, n, mc, nc);
+                    break;
+                case 0x23:
+                    mc = 2;
+                    nc = 3;
+                    gemm_small(m0, m, n0, n, mc, nc);
+                    break;
+                case 0x22:
+                    mc = 2;
+                    nc = 2;
+                    gemm_small(m0, m, n0, n, mc, nc);
+                    break;
+                case 0x21:
+                    mc = 2;
+                    nc = 1;
+                    gemm_small(m0, m, n0, n, mc, nc);
+                    break;
+                case 0x14:
+                    mc = 1;
+                    nc = 4;
+                    gemm_small(m0, m, n0, n, mc, nc);
+                    break;
+                case 0x13:
+                    mc = 1;
+                    nc = 3;
+                    gemm_small(m0, m, n0, n, mc, nc);
+                    break;
+                case 0x12:
+                    mc = 1;
+                    nc = 2;
+                    gemm_small(m0, m, n0, n, mc, nc);
+                    break;
+                case 0x11:
+                    mc = 1;
+                    nc = 1;
+                    gemm_small(m0, m, n0, n, mc, nc);
+                    break;
+                default:
+                    return;
+            }
+        }
+        mp = m0 + (m - m0) / mc * mc;
+        np = n0 + (n - n0) / nc * nc;
+        mnpack(mp, m, n0, np);
+        mnpack(m0, m, np, n);
+    }
+
+     void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
+        int64_t ytiles = (m - m0) / RM;
+        int64_t xtiles = (n - n0) / RN;
+        int64_t tiles = xtiles * ytiles;
+        int64_t duty = (tiles + nth - 1) / nth;
+        int64_t start = duty * ith;
+        int64_t end = start + duty;
+        if (end > tiles)
+            end = tiles;
+        for (int64_t job = start; job < end; ++job) {
+            int64_t ii = m0 + job / xtiles * RM;
+            int64_t jj = n0 + job % xtiles * RN;
+            vec_t vec_C[4];
+            acc_t acc_0;
+            __builtin_mma_xxsetaccz(&acc_0);
+            vec_t vec_A[4], vec_B[4];
+            for (int l=0; l<k; l+=4) {
+                if (RN >= 4 && RM == 1) {
+                    float* a = const_cast<float*>(A+(ii)*lda+l);
+                    READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
+                    vec_A[0] = (vec_t)vec_xl(0,a);
+                    vec_A[1] = (vec_t)vec_splats(*((float*)&vec_A+1));
+                    vec_A[2] = (vec_t)vec_splats(*((float*)&vec_A+2));
+                    vec_A[3] = (vec_t)vec_splats(*((float*)&vec_A+3));
+                } else {
+                    READ_BLOCK(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A);
+                    READ_BLOCK(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B);
+                }
+                __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
+                __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
+                __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
+                __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
+            }
+            __builtin_mma_disassemble_acc(vec_C, &acc_0);
+            for (int I = 0; I < RM; I++) {
+                for (int J = 0; J < RN; J++) {
+                    *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J);
+                }
+            }
+       }
+    }
+
+    template <int RM, int RN>
+    NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+        int64_t ytiles = (m - m0) / RM;
+        int64_t xtiles = (n - n0) / RN;
+        int64_t tiles = xtiles * ytiles;
+        int64_t duty = (tiles + nth - 1) / nth;
+        int64_t start = duty * ith;
+        int64_t end = start + duty;
+        if (RM == 4 && RN == 4) {
+            kernel = &tinyBLAS_PPC::KERNEL_4x4;
+        } else if (RM == 4 && RN == 8) {
+            kernel = &tinyBLAS_PPC::KERNEL_4x8;
+        } else if (RM == 8 && RN == 4) {
+            kernel = &tinyBLAS_PPC::KERNEL_8x4;
+        } else if (RM == 8 && RN == 8) {
+            kernel = &tinyBLAS_PPC::KERNEL_8x8;
+        }
+        if (end > tiles)
+            end = tiles;
+        for (int64_t job = start; job < end; ++job) {
+            int64_t ii = m0 + job / xtiles * RM;
+            int64_t jj = n0 + job % xtiles * RN;
+            (this->*kernel)(ii, jj);
+        }
+    }
+
+    const TA *const A;
+    const TB *const B;
+    TC *C;
+    TA *At;
+    TB *Bt;
+    const int64_t k;
+    const int64_t lda;
+    const int64_t ldb;
+    const int64_t ldc;
+    const int ith;
+    const int nth;
+};
+#endif
 } // namespace
 
 /**
@@ -1114,6 +1712,16 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
             ith, nth};
         tb.matmul(m, n);
         return true;
+#elif defined(__MMA__)
+        if (k % 8)
+            return false;
+        tinyBLAS_PPC<float, float, float> tb{
+            k, (const float *)A, lda,
+            (const float *)B, ldb,
+            (float *)C, ldc,
+            ith, nth};
+        tb.matmul(m, n);
+        return true;
 #else
         return false;
 #endif