]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
llamafile : ppc64le MMA implementation for Q4_0. (llama/12489)
authoramritahs-ibm <redacted>
Thu, 27 Mar 2025 06:51:47 +0000 (12:21 +0530)
committerGeorgi Gerganov <redacted>
Thu, 27 Mar 2025 09:06:03 +0000 (11:06 +0200)
This change upstreams llamafile's cpu matrix
multiplication kernels for ppc64le ISA using MMA
builtins. This patch handles matrix multiplication
between quantised datatypes, block_q4_0 and
block_q8_0.

This change results in 5% - 50% improvement
in total speed(ie all tokens/total time), across
various batch sizes.

The patch is tested with Meta-Lllama-3-8B,
Mistral-7B, Llama-2-7B-chat-hf models on a
IBM POWER10 machine.

Signed-off-by: Amrita H S <redacted>
ggml/src/ggml-cpu/llamafile/sgemm.cpp

index e0482c59377fd22ca628bf41d3e06856f04ee801..92dfbc2d2c9f0702bc3252da4c8a18fc4dfdb3ad 100644 (file)
@@ -55,6 +55,7 @@
 
 #include <atomic>
 #include <array>
+#include <type_traits>
 
 #ifdef _MSC_VER
 #define NOINLINE __declspec(noinline)
@@ -1092,13 +1093,403 @@ class tinyBLAS_Q0_PPC {
        }
     }
 
-    template<typename VA, typename VB>
-    void packNormal(const TA* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
+    template<typename VA, typename VB, int size>
+    void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, VA* vec, std::array<int, size>& comparray) {
         int64_t i, j;
         TA *aoffset = NULL;
         VA *vecOffset = NULL;
         TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
         TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
+        VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
+        VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
+        VB t1, t2, t3, t4, t5, t6, t7, t8;
+        const vector signed char lowMask = vec_splats((signed char)0xF);
+        const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+        const vector signed char v8 = vec_splats((signed char)0x8);
+        aoffset = const_cast<TA*>(a);
+        vecOffset = vec;
+        vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
+        vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
+        vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
+        vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
+        vector signed int vsum = {0};
+        vector signed int vsum2 = {0};
+
+        j = (rows >> 3);
+        if (j > 0) {
+            do {
+                aoffset1 = aoffset;
+                aoffset2 = aoffset1 + lda;
+                aoffset3 = aoffset2 + lda;
+                aoffset4 = aoffset3 + lda;
+                aoffset5 = aoffset4 + lda;
+                aoffset6 = aoffset5 + lda;
+                aoffset7 = aoffset6 + lda;
+                aoffset8 = aoffset7 + lda;
+                aoffset += 8 * lda;
+
+                i = (cols >> 2);
+                if (i > 0) {
+                    do {
+                        c1[1] = reinterpret_cast<VB>(vec_xl(0, aoffset1->qs));
+                        c2[1] = reinterpret_cast<VB>(vec_xl(0, aoffset2->qs));
+                        c3[1] = reinterpret_cast<VB>(vec_xl(0, aoffset3->qs));
+                        c4[1] = reinterpret_cast<VB>(vec_xl(0, aoffset4->qs));
+                        c5[1] = reinterpret_cast<VB>(vec_xl(0, aoffset5->qs));
+                        c6[1] = reinterpret_cast<VB>(vec_xl(0, aoffset6->qs));
+                        c7[1] = reinterpret_cast<VB>(vec_xl(0, aoffset7->qs));
+                        c8[1] = reinterpret_cast<VB>(vec_xl(0, aoffset8->qs));
+
+                        c1[0] = vec_and(c1[1], lowMask);
+                        c1[1] = vec_sr(c1[1], v4);
+                        c1[0] = vec_sub(c1[0], v8);
+                        c1[1] = vec_sub(c1[1], v8);
+                        vsum = vec_sum4s(c1[0], vsum);
+                        vsum2 = vec_sum4s(c1[1], vsum2);
+                        vsum = vec_add(vsum, vsum2);
+                        comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
+                        vsum = vec_splats(0);
+                        vsum2 = vec_splats(0);
+
+                        c2[0] = vec_and(c2[1], lowMask);
+                        c2[1] = vec_sr(c2[1], v4);
+                        c2[0] = vec_sub(c2[0], v8);
+                        c2[1] = vec_sub(c2[1], v8);
+                        vsum = vec_sum4s(c2[0], vsum);
+                        vsum2 = vec_sum4s(c2[1], vsum2);
+                        vsum = vec_add(vsum, vsum2);
+                        comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
+                        vsum = vec_splats(0);
+                        vsum2 = vec_splats(0);
+
+                        c3[0] = vec_and(c3[1], lowMask);
+                        c3[1] = vec_sr(c3[1], v4);
+                        c3[0] = vec_sub(c3[0], v8);
+                        c3[1] = vec_sub(c3[1], v8);
+                        vsum = vec_sum4s(c3[0], vsum);
+                        vsum2 = vec_sum4s(c3[1], vsum2);
+                        vsum = vec_add(vsum, vsum2);
+                        comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
+                        vsum = vec_splats(0);
+                        vsum2 = vec_splats(0);
+
+                        c4[0] = vec_and(c4[1], lowMask);
+                        c4[1] = vec_sr(c4[1], v4);
+                        c4[0] = vec_sub(c4[0], v8);
+                        c4[1] = vec_sub(c4[1], v8);
+                        vsum = vec_sum4s(c4[0], vsum);
+                        vsum2 = vec_sum4s(c4[1], vsum2);
+                        vsum = vec_add(vsum, vsum2);
+                        comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
+                        vsum = vec_splats(0);
+                        vsum2 = vec_splats(0);
+
+                        c5[0] = vec_and(c5[1], lowMask);
+                        c5[1] = vec_sr(c5[1], v4);
+                        c5[0] = vec_sub(c5[0], v8);
+                        c5[1] = vec_sub(c5[1], v8);
+                        vsum = vec_sum4s(c5[0], vsum);
+                        vsum2 = vec_sum4s(c5[1], vsum2);
+                        vsum = vec_add(vsum, vsum2);
+                        comparray[4] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
+                        vsum = vec_splats(0);
+                        vsum2 = vec_splats(0);
+
+                        c6[0] = vec_and(c6[1], lowMask);
+                        c6[1] = vec_sr(c6[1], v4);
+                        c6[0] = vec_sub(c6[0], v8);
+                        c6[1] = vec_sub(c6[1], v8);
+                        vsum = vec_sum4s(c6[0], vsum);
+                        vsum2 = vec_sum4s(c6[1], vsum2);
+                        vsum = vec_add(vsum, vsum2);
+                        comparray[5] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
+                        vsum = vec_splats(0);
+                        vsum2 = vec_splats(0);
+
+                        c7[0] = vec_and(c7[1], lowMask);
+                        c7[1] = vec_sr(c7[1], v4);
+                        c7[0] = vec_sub(c7[0], v8);
+                        c7[1] = vec_sub(c7[1], v8);
+                        vsum = vec_sum4s(c7[0], vsum);
+                        vsum2 = vec_sum4s(c7[1], vsum2);
+                        vsum = vec_add(vsum, vsum2);
+                        comparray[6] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
+                        vsum = vec_splats(0);
+                        vsum2 = vec_splats(0);
+
+                        c8[0] = vec_and(c8[1], lowMask);
+                        c8[1] = vec_sr(c8[1], v4);
+                        c8[0] = vec_sub(c8[0], v8);
+                        c8[1] = vec_sub(c8[1], v8);
+                        vsum = vec_sum4s(c8[0], vsum);
+                        vsum2 = vec_sum4s(c8[1], vsum2);
+                        vsum = vec_add(vsum, vsum2);
+                        comparray[7] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
+                        vsum = vec_splats(0);
+                        vsum2 = vec_splats(0);
+
+                        t1 = vec_perm(c1[0], c2[0], swiz1);
+                        t2 = vec_perm(c1[0], c2[0], swiz2);
+                        t3 = vec_perm(c3[0], c4[0], swiz1);
+                        t4 = vec_perm(c3[0], c4[0], swiz2);
+                        t5 = vec_perm(t1, t3, swiz3);
+                        t6 = vec_perm(t1, t3, swiz4);
+                        t7 = vec_perm(t2, t4, swiz3);
+                        t8 = vec_perm(t2, t4, swiz4);
+                        vec_xst(t5, 0, vecOffset);
+                        vec_xst(t6, 0, vecOffset+16);
+                        vec_xst(t7, 0, vecOffset+32);
+                        vec_xst(t8, 0, vecOffset+48);
+
+                        t1 = vec_perm(c1[1], c2[1], swiz1);
+                        t2 = vec_perm(c1[1], c2[1], swiz2);
+                        t3 = vec_perm(c3[1], c4[1], swiz1);
+                        t4 = vec_perm(c3[1], c4[1], swiz2);
+                        t5 = vec_perm(t1, t3, swiz3);
+                        t6 = vec_perm(t1, t3, swiz4);
+                        t7 = vec_perm(t2, t4, swiz3);
+                        t8 = vec_perm(t2, t4, swiz4);
+                        vec_xst(t5, 0, vecOffset+64);
+                        vec_xst(t6, 0, vecOffset+80);
+                        vec_xst(t7, 0, vecOffset+96);
+                        vec_xst(t8, 0, vecOffset+112);
+
+                        t1 = vec_perm(c5[0], c6[0], swiz1);
+                        t2 = vec_perm(c5[0], c6[0], swiz2);
+                        t3 = vec_perm(c7[0], c8[0], swiz1);
+                        t4 = vec_perm(c7[0], c8[0], swiz2);
+                        t5 = vec_perm(t1, t3, swiz3);
+                        t6 = vec_perm(t1, t3, swiz4);
+                        t7 = vec_perm(t2, t4, swiz3);
+                        t8 = vec_perm(t2, t4, swiz4);
+                        vec_xst(t5, 0, vecOffset+128);
+                        vec_xst(t6, 0, vecOffset+144);
+                        vec_xst(t7, 0, vecOffset+160);
+                        vec_xst(t8, 0, vecOffset+176);
+
+                        t1 = vec_perm(c5[1], c6[1], swiz1);
+                        t2 = vec_perm(c5[1], c6[1], swiz2);
+                        t3 = vec_perm(c7[1], c8[1], swiz1);
+                        t4 = vec_perm(c7[1], c8[1], swiz2);
+                        t5 = vec_perm(t1, t3, swiz3);
+                        t6 = vec_perm(t1, t3, swiz4);
+                        t7 = vec_perm(t2, t4, swiz3);
+                        t8 = vec_perm(t2, t4, swiz4);
+                        vec_xst(t5, 0, vecOffset+192);
+                        vec_xst(t6, 0, vecOffset+208);
+                        vec_xst(t7, 0, vecOffset+224);
+                        vec_xst(t8, 0, vecOffset+240);
+
+                        aoffset1 += lda;
+                        aoffset2 += lda;
+                        aoffset3 += lda;
+                        aoffset4 += lda;
+                        aoffset5 += lda;
+                        aoffset6 += lda;
+                        aoffset7 += lda;
+                        aoffset8 += lda;
+                        vecOffset += 256;
+                        i--;
+                    } while (i > 0);
+                }
+                j--;
+            } while (j > 0);
+        }
+
+        if (rows & 4) {
+            aoffset1 = aoffset;
+            aoffset2 = aoffset1 + lda;
+            aoffset3 = aoffset2 + lda;
+            aoffset4 = aoffset3 + lda;
+            aoffset += 4 * lda;
+
+            i = (cols >> 2);
+            if (i > 0) {
+                do {
+                    c1[1] = reinterpret_cast<VB>(vec_xl(0, aoffset1->qs));
+                    c2[1] = reinterpret_cast<VB>(vec_xl(0, aoffset2->qs));
+                    c3[1] = reinterpret_cast<VB>(vec_xl(0, aoffset3->qs));
+                    c4[1] = reinterpret_cast<VB>(vec_xl(0, aoffset4->qs));
+
+                    c1[0] = vec_and(c1[1], lowMask);
+                    c1[1] = vec_sr(c1[1], v4);
+                    c1[0] = vec_sub(c1[0], v8);
+                    c1[1] = vec_sub(c1[1], v8);
+                    vsum = vec_sum4s(c1[0], vsum);
+                    vsum2 = vec_sum4s(c1[1], vsum2);
+                    vsum = vec_add(vsum, vsum2);
+                    comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
+                    vsum = vec_splats(0);
+                    vsum2 = vec_splats(0);
+
+                    c2[0] = vec_and(c2[1], lowMask);
+                    c2[1] = vec_sr(c2[1], v4);
+                    c2[0] = vec_sub(c2[0], v8);
+                    c2[1] = vec_sub(c2[1], v8);
+                    vsum = vec_sum4s(c2[0], vsum);
+                    vsum2 = vec_sum4s(c2[1], vsum2);
+                    vsum = vec_add(vsum, vsum2);
+                    comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
+                    vsum = vec_splats(0);
+                    vsum2 = vec_splats(0);
+
+                    c3[0] = vec_and(c3[1], lowMask);
+                    c3[1] = vec_sr(c3[1], v4);
+                    c3[0] = vec_sub(c3[0], v8);
+                    c3[1] = vec_sub(c3[1], v8);
+                    vsum = vec_sum4s(c3[0], vsum);
+                    vsum2 = vec_sum4s(c3[1], vsum2);
+                    vsum = vec_add(vsum, vsum2);
+                    comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
+                    vsum = vec_splats(0);
+                    vsum2 = vec_splats(0);
+
+                    c4[0] = vec_and(c4[1], lowMask);
+                    c4[1] = vec_sr(c4[1], v4);
+                    c4[0] = vec_sub(c4[0], v8);
+                    c4[1] = vec_sub(c4[1], v8);
+                    vsum = vec_sum4s(c4[0], vsum);
+                    vsum2 = vec_sum4s(c4[1], vsum2);
+                    vsum = vec_add(vsum, vsum2);
+                    comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
+                    vsum = vec_splats(0);
+                    vsum2 = vec_splats( 0);
+
+                    t1 = vec_perm(c1[0], c2[0], swiz1);
+                    t2 = vec_perm(c1[0], c2[0], swiz2);
+                    t3 = vec_perm(c3[0], c4[0], swiz1);
+                    t4 = vec_perm(c3[0], c4[0], swiz2);
+                    t5 = vec_perm(t1, t3, swiz3);
+                    t6 = vec_perm(t1, t3, swiz4);
+                    t7 = vec_perm(t2, t4, swiz3);
+                    t8 = vec_perm(t2, t4, swiz4);
+                    vec_xst(t5, 0, vecOffset);
+                    vec_xst(t6, 0, vecOffset+16);
+                    vec_xst(t7, 0, vecOffset+32);
+                    vec_xst(t8, 0, vecOffset+48);
+
+                    t1 = vec_perm(c1[1], c2[1], swiz1);
+                    t2 = vec_perm(c1[1], c2[1], swiz2);
+                    t3 = vec_perm(c3[1], c4[1], swiz1);
+                    t4 = vec_perm(c3[1], c4[1], swiz2);
+                    t5 = vec_perm(t1, t3, swiz3);
+                    t6 = vec_perm(t1, t3, swiz4);
+                    t7 = vec_perm(t2, t4, swiz3);
+                    t8 = vec_perm(t2, t4, swiz4);
+                    vec_xst(t5, 0, vecOffset+64);
+                    vec_xst(t6, 0, vecOffset+80);
+                    vec_xst(t7, 0, vecOffset+96);
+                    vec_xst(t8, 0, vecOffset+112);
+
+                    aoffset1 += lda;
+                    aoffset2 += lda;
+                    aoffset3 += lda;
+                    aoffset4 += lda;
+                    vecOffset += 128;
+                    i--;
+                } while (i > 0);
+            }
+        }
+
+        if (rows & 3) {
+            aoffset1 = aoffset;
+            aoffset2 = aoffset1 + lda;
+            aoffset3 = aoffset2 + lda;
+            i = (cols >> 2);
+            if (i > 0) {
+                do {
+                    switch(rows) {
+                        case 3: c3[1] = reinterpret_cast<VB>(vec_xl(0, aoffset3->qs));
+                        case 2: c2[1] = reinterpret_cast<VB>(vec_xl(0, aoffset2->qs));
+                        case 1: c1[1] = reinterpret_cast<VB>(vec_xl(0, aoffset1->qs));
+                            break;
+                    }
+                    c1[0] = vec_and(c1[1], lowMask);
+                    c1[1] = vec_sr(c1[1], v4);
+                    c1[0] = vec_sub(c1[0], v8);
+                    c1[1] = vec_sub(c1[1], v8);
+                    vsum = vec_sum4s(c1[0], vsum);
+                    vsum2 = vec_sum4s(c1[1], vsum2);
+                    vsum = vec_add(vsum, vsum2);
+                    comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
+                    vsum = vec_splats(0);
+                    vsum2 = vec_splats(0);
+
+                    c2[0] = vec_and(c2[1], lowMask);
+                    c2[1] = vec_sr(c2[1], v4);
+                    c2[0] = vec_sub(c2[0], v8);
+                    c2[1] = vec_sub(c2[1], v8);
+                    vsum = vec_sum4s(c2[0], vsum);
+                    vsum2 = vec_sum4s(c2[1], vsum2);
+                    vsum = vec_add(vsum, vsum2);
+                    comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
+                    vsum = vec_splats(0);
+                    vsum2 = vec_splats(0);
+
+                    c3[0] = vec_and(c3[1], lowMask);
+                    c3[1] = vec_sr(c3[1], v4);
+                    c3[0] = vec_sub(c3[0], v8);
+                    c3[1] = vec_sub(c3[1], v8);
+                    vsum = vec_sum4s(c3[0], vsum);
+                    vsum2 = vec_sum4s(c3[1], vsum2);
+                    vsum = vec_add(vsum, vsum2);
+                    comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
+                    vsum = vec_splats(0);
+                    vsum2 = vec_splats(0);
+
+                    c4[0] = vec_and(c4[1], lowMask);
+                    c4[1] = vec_sr(c4[1], v4);
+                    c4[0] = vec_sub(c4[0], v8);
+                    c4[1] = vec_sub(c4[1], v8);
+                    vsum = vec_sum4s(c4[0], vsum);
+                    vsum2 = vec_sum4s(c4[1], vsum2);
+                    vsum = vec_add(vsum, vsum2);
+                    comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
+                    vsum = vec_splats(0);
+                    vsum2 = vec_splats(0);
+
+                    t1 = vec_perm(c1[0], c2[0], swiz1);
+                    t2 = vec_perm(c1[0], c2[0], swiz2);
+                    t3 = vec_perm(c3[0], c4[0], swiz1);
+                    t4 = vec_perm(c3[0], c4[0], swiz2);
+                    t5 = vec_perm(t1, t3, swiz3);
+                    t6 = vec_perm(t1, t3, swiz4);
+                    t7 = vec_perm(t2, t4, swiz3);
+                    t8 = vec_perm(t2, t4, swiz4);
+                    vec_xst(t5, 0, vecOffset);
+                    vec_xst(t6, 0, vecOffset+16);
+                    vec_xst(t7, 0, vecOffset+32);
+                    vec_xst(t8, 0, vecOffset+48);
+
+                    t1 = vec_perm(c1[1], c2[1], swiz1);
+                    t2 = vec_perm(c1[1], c2[1], swiz2);
+                    t3 = vec_perm(c3[1], c4[1], swiz1);
+                    t4 = vec_perm(c3[1], c4[1], swiz2);
+                    t5 = vec_perm(t1, t3, swiz3);
+                    t6 = vec_perm(t1, t3, swiz4);
+                    t7 = vec_perm(t2, t4, swiz3);
+                    t8 = vec_perm(t2, t4, swiz4);
+                    vec_xst(t5, 0, vecOffset+64);
+                    vec_xst(t6, 0, vecOffset+80);
+                    vec_xst(t7, 0, vecOffset+96);
+                    vec_xst(t8, 0, vecOffset+112);
+                    aoffset1 += lda;
+                    aoffset2 += lda;
+                    aoffset3 += lda;
+                    vecOffset += 128;
+                    i--;
+                } while(i > 0);
+            }
+        }
+    }
+
+    template<typename VA, typename VB>
+    void packNormal(const TB* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
+        int64_t i, j;
+        TB *aoffset = NULL;
+        VA *vecOffset = NULL;
+        TB *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
+        TB *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
         __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
         VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2]={0};
         VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2]={0};
@@ -1111,24 +1502,24 @@ class tinyBLAS_Q0_PPC {
         vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
         vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
 
-        aoffset = const_cast<TA*>(a);
+        aoffset = const_cast<TB*>(a);
         vecOffset = vec;
         j = (rows >> 3);
         if (j > 0) {
             do {
-            aoffset1 = aoffset;
-            aoffset2 = aoffset1 + lda;
-            aoffset3 = aoffset2 + lda;
-            aoffset4 = aoffset3 + lda;
-            aoffset5 = aoffset4 + lda;
-            aoffset6 = aoffset5 + lda;
-            aoffset7 = aoffset6 + lda;
-            aoffset8 = aoffset7 + lda;
-            aoffset += 8 * lda;
+                aoffset1 = aoffset;
+                aoffset2 = aoffset1 + lda;
+                aoffset3 = aoffset2 + lda;
+                aoffset4 = aoffset3 + lda;
+                aoffset5 = aoffset4 + lda;
+                aoffset6 = aoffset5 + lda;
+                aoffset7 = aoffset6 + lda;
+                aoffset8 = aoffset7 + lda;
+                aoffset += 8 * lda;
 
-            i = (cols >> 3);
-            if (i > 0) {
-               do {
+                i = (cols >> 3);
+                if (i > 0) {
+                do {
                     C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
                     C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
                     C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
@@ -1156,10 +1547,10 @@ class tinyBLAS_Q0_PPC {
                     t7 = vec_perm(t2, t4, swiz3);
                     t8 = vec_perm(t2, t4, swiz4);
                     if (flip == true) {
-                       t5 = vec_xor(t5, xor_vector);
-                       t6 = vec_xor(t6, xor_vector);
-                       t7 = vec_xor(t7, xor_vector);
-                       t8 = vec_xor(t8, xor_vector);
+                        t5 = vec_xor(t5, xor_vector);
+                        t6 = vec_xor(t6, xor_vector);
+                        t7 = vec_xor(t7, xor_vector);
+                        t8 = vec_xor(t8, xor_vector);
                     }
                     vec_xst(t5, 0, vecOffset);
                     vec_xst(t6, 0, vecOffset+16);
@@ -1175,10 +1566,10 @@ class tinyBLAS_Q0_PPC {
                     t7 = vec_perm(t2, t4, swiz3);
                     t8 = vec_perm(t2, t4, swiz4);
                     if (flip == true) {
-                       t5 = vec_xor(t5, xor_vector);
-                       t6 = vec_xor(t6, xor_vector);
-                       t7 = vec_xor(t7, xor_vector);
-                       t8 = vec_xor(t8, xor_vector);
+                        t5 = vec_xor(t5, xor_vector);
+                        t6 = vec_xor(t6, xor_vector);
+                        t7 = vec_xor(t7, xor_vector);
+                        t8 = vec_xor(t8, xor_vector);
                     }
                     vec_xst(t5, 0, vecOffset+64);
                     vec_xst(t6, 0, vecOffset+80);
@@ -1194,10 +1585,10 @@ class tinyBLAS_Q0_PPC {
                     t7 = vec_perm(t2, t4, swiz3);
                     t8 = vec_perm(t2, t4, swiz4);
                     if (flip == true) {
-                       t5 = vec_xor(t5, xor_vector);
-                       t6 = vec_xor(t6, xor_vector);
-                       t7 = vec_xor(t7, xor_vector);
-                       t8 = vec_xor(t8, xor_vector);
+                        t5 = vec_xor(t5, xor_vector);
+                        t6 = vec_xor(t6, xor_vector);
+                        t7 = vec_xor(t7, xor_vector);
+                        t8 = vec_xor(t8, xor_vector);
                     }
                     vec_xst(t5, 0, vecOffset+128);
                     vec_xst(t6, 0, vecOffset+144);
@@ -1213,10 +1604,10 @@ class tinyBLAS_Q0_PPC {
                     t7 = vec_perm(t2, t4, swiz3);
                     t8 = vec_perm(t2, t4, swiz4);
                     if (flip == true) {
-                       t5 = vec_xor(t5, xor_vector);
-                       t6 = vec_xor(t6, xor_vector);
-                       t7 = vec_xor(t7, xor_vector);
-                       t8 = vec_xor(t8, xor_vector);
+                        t5 = vec_xor(t5, xor_vector);
+                        t6 = vec_xor(t6, xor_vector);
+                        t7 = vec_xor(t7, xor_vector);
+                        t8 = vec_xor(t8, xor_vector);
                     }
                     vec_xst(t5, 0, vecOffset+192);
                     vec_xst(t6, 0, vecOffset+208);
@@ -1240,11 +1631,11 @@ class tinyBLAS_Q0_PPC {
     }
 
     if (rows & 4) {
-            aoffset1 = aoffset;
-            aoffset2 = aoffset1 + lda;
-            aoffset3 = aoffset2 + lda;
-            aoffset4 = aoffset3 + lda;
-            aoffset += 4 * lda;
+        aoffset1 = aoffset;
+        aoffset2 = aoffset1 + lda;
+        aoffset3 = aoffset2 + lda;
+        aoffset4 = aoffset3 + lda;
+        aoffset += 4 * lda;
 
         i = (cols >> 3);
             if (i > 0) {
@@ -1311,7 +1702,7 @@ class tinyBLAS_Q0_PPC {
             aoffset2 = aoffset1 + lda;
             aoffset3 = aoffset2 + lda;
             i = (cols >> 3);
-        if (i > 0) {
+            if (i > 0) {
                 do {
                     switch(rows) {
                         case 3: C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
@@ -1527,13 +1918,18 @@ class tinyBLAS_Q0_PPC {
     void KERNEL_4x8(int64_t ii, int64_t jj) {
         vec_t vec_A[8], vec_B[16] = {0};
         acc_t acc_0, acc_1;
-        std::array<int, 4> comparray;
+        std::array<int, 4> comparray {};
         vector float fin_res[8] = {0};
         vector float vs[8] = {0};
+        bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
         for (int l = 0; l < k; l++) {
             __builtin_mma_xxsetaccz(&acc_0);
             __builtin_mma_xxsetaccz(&acc_1);
-            packNormal<int8_t, vector signed char>((A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
+            if (std::is_same_v<TA, block_q4_0>) {
+               packNormalInt4<int8_t, vector signed char, 4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray);
+            } else {
+               packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
+            }
             packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
             for(int x = 0; x < 8; x++) {
                 __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
@@ -1545,15 +1941,17 @@ class tinyBLAS_Q0_PPC {
                     *((float*)&vs[I+4]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
                 }
             }
-            auto aoffset = A+(ii*lda)+l;
-            for (int i = 0; i < 4; i++) {
-                comparray[i] = 0;
-                int ca = 0;
-                const int8_t *at = aoffset->qs;
-                for (int j = 0; j < 32; j++)
-                    ca += (int)*at++;
-                comparray[i] = ca;
-                aoffset += lda;
+            if (!isAblock_q4) {
+                auto aoffset = A+(ii*lda)+l;
+                for (int i = 0; i < 4; i++) {
+                    comparray[i] = 0;
+                    int ca = 0;
+                    auto *at = aoffset->qs;
+                    for (int j = 0; j < 32; j++)
+                        ca += (int)*at++;
+                    comparray[i] = ca;
+                    aoffset += lda;
+                }
             }
             compute<4>(&acc_0, 0, 0, comparray, vs, fin_res);
             compute<4>(&acc_1, 0, 4, comparray, vs, fin_res);
@@ -1565,13 +1963,18 @@ class tinyBLAS_Q0_PPC {
     void KERNEL_8x4(int64_t ii, int64_t jj) {
         vec_t vec_A[16], vec_B[8] = {0};
         acc_t acc_0, acc_1;
-        std::array<int, 8> comparray;
+        std::array<int, 8> comparray {};
         vector float fin_res[8] = {0};
         vector float vs[8] = {0};
+        bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
         for (int l = 0; l < k; l++) {
             __builtin_mma_xxsetaccz(&acc_0);
             __builtin_mma_xxsetaccz(&acc_1);
-            packNormal<int8_t, vector signed char>((A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
+            if (std::is_same_v<TA, block_q4_0>) {
+               packNormalInt4<int8_t, vector signed char, 8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
+            } else {
+               packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
+            }
             packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true);
             for(int x = 0; x < 8; x++) {
                 __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
@@ -1582,15 +1985,17 @@ class tinyBLAS_Q0_PPC {
                     *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
                 }
             }
-            auto aoffset = A+(ii*lda)+l;
-            for (int i = 0; i < 8; i++) {
-                comparray[i] = 0;
-                int ca = 0;
-                const int8_t *at = aoffset->qs;
-                for (int j = 0; j < 32; j++)
-                    ca += (int)*at++;
-                comparray[i] = ca;
-                aoffset += lda;
+            if (!isAblock_q4) {
+                auto aoffset = A+(ii*lda)+l;
+                for (int i = 0; i < 8; i++) {
+                    comparray[i] = 0;
+                    int ca = 0;
+                    auto *at = aoffset->qs;
+                    for (int j = 0; j < 32; j++)
+                        ca += (int)*at++;
+                    comparray[i] = ca;
+                    aoffset += lda;
+                }
             }
             compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
             compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
@@ -1602,15 +2007,20 @@ class tinyBLAS_Q0_PPC {
     void KERNEL_8x8(int64_t ii, int64_t jj) {
         vec_t vec_A[16], vec_B[16] = {0};
         acc_t acc_0, acc_1, acc_2, acc_3;
-        std::array<int, 8> comparray;
+        std::array<int, 8> comparray {};
         vector float fin_res[16] = {0};
         vector float vs[16] = {0};
+        bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
         for (int l = 0; l < k; l++) {
             __builtin_mma_xxsetaccz(&acc_0);
             __builtin_mma_xxsetaccz(&acc_1);
             __builtin_mma_xxsetaccz(&acc_2);
             __builtin_mma_xxsetaccz(&acc_3);
-            packNormal<int8_t, vector signed char>((A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
+            if (std::is_same_v<TA, block_q4_0>) {
+               packNormalInt4<int8_t, vector signed char, 8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
+            } else {
+               packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
+            }
             packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
             for(int x = 0; x < 8; x++) {
                 __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
@@ -1624,15 +2034,17 @@ class tinyBLAS_Q0_PPC {
                     *((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
                 }
             }
-            auto aoffset = A+(ii*lda)+l;
-            for (int i = 0; i < 8; i++) {
-                comparray[i] = 0;
-                int ca = 0;
-                const int8_t *at = aoffset->qs;
-                for (int j = 0; j < 32; j++)
-                    ca += (int)*at++;
-                comparray[i] = ca;
-                aoffset += lda;
+            if (!isAblock_q4) {
+                auto aoffset = A+(ii*lda)+l;
+                for (int i = 0; i < 8; i++) {
+                    comparray[i] = 0;
+                    int ca = 0;
+                    auto *at = aoffset->qs;
+                    for (int j = 0; j < 32; j++)
+                        ca += (int)*at++;
+                    comparray[i] = ca;
+                    aoffset += lda;
+                }
             }
             compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
             compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
@@ -1653,16 +2065,17 @@ class tinyBLAS_Q0_PPC {
         int64_t duty = (tiles + nth - 1) / nth;
         int64_t start = duty * ith;
         int64_t end = start + duty;
-        vec_t vec_A[8], vec_B[8] = {0};
+        vec_t vec_A[8] = {0}, vec_B[8] = {0};
         vector signed int vec_C[4];
         acc_t acc_0;
+        bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
 
         if (end > tiles)
             end = tiles;
         for (int64_t job = start; job < end; ++job) {
             int64_t ii = m0 + job / xtiles * RM;
             int64_t jj = n0 + job % xtiles * RN;
-            std::array<int, RM> comparray;
+            std::array<int, 4> comparray{};
             vector float res[4] = {0};
             vector float fin_res[4] = {0};
             vector float vs[4] = {0};
@@ -1673,7 +2086,11 @@ class tinyBLAS_Q0_PPC {
                 __builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead
                 __builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
                 __builtin_mma_xxsetaccz(&acc_0);
-                packNormal<int8_t, vector signed char>((A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
+                if (isAblock_q4) {
+                   packNormalInt4<int8_t, vector signed char, 4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray);
+                } else {
+                   packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
+                }
                 packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true);
                 for(int x = 0; x < 8; x+=4) {
                     __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
@@ -1687,17 +2104,18 @@ class tinyBLAS_Q0_PPC {
                     }
                 }
                 __builtin_mma_disassemble_acc(vec_C, &acc_0);
-                auto aoffset = A+(ii*lda)+l;
-                for (int i = 0; i < RM; i++) {
-                    comparray[i] = 0;
-                    int ca = 0;
-                    const int8_t *at = aoffset->qs;
-                    for (int j = 0; j < 32; j++)
-                        ca += (int)*at++;
-                    comparray[i] = ca;
-                    aoffset += lda;
+                if (!isAblock_q4) {
+                    auto aoffset = A+(ii*lda)+l;
+                    for (int i = 0; i < RM; i++) {
+                        comparray[i] = 0;
+                        int ca = 0;
+                        auto *at = aoffset->qs;
+                        for (int j = 0; j < 32; j++)
+                            ca += (int)*at++;
+                        comparray[i] = ca;
+                        aoffset += lda;
+                    }
                 }
-
                 for (int i = 0; i < RM; i++) {
                     CA[i] = vec_splats((float)(((double)comparray[i]) * -128.0));
                     res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
@@ -2013,6 +2431,7 @@ class tinyBLAS_PPC {
             }
         }
     }
+
     void KERNEL_4x4(int64_t ii, int64_t jj) {
         vec_t vec_A[4], vec_B[4], vec_C[4];
         acc_t acc_0;
@@ -2259,7 +2678,7 @@ class tinyBLAS_PPC {
             vec_t vec_C[4];
             acc_t acc_0;
             __builtin_mma_xxsetaccz(&acc_0);
-            vec_t vec_A[4], vec_B[4];
+            vec_t vec_A[4] {0}, vec_B[4] = {0};
             for (int l=0; l<k; l+=4) {
                 if (RN >= 4 && RM == 1) {
                     TA* a = const_cast<TA*>(A+(ii)*lda+l);
@@ -2503,8 +2922,8 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
             params->ith, params->nth};
         tb.matmul(m, n);
         return true;
-
 #elif defined(__MMA__)
+    //TO-DO: Remove this condition once gemv forwarding is enabled.
         if (n < 8 && n != 4)
            return false;
         if (m < 8 && m != 4)
@@ -2516,7 +2935,6 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
             params->ith, params->nth};
         tb.matmul(m, n);
         return true;
-
 #else
         return false;
 #endif
@@ -2541,6 +2959,19 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
             params->ith, params->nth};
         tb.matmul(m, n);
         return true;
+#elif defined(__MMA__)
+    //TO-DO: Remove this condition once gemv forwarding is enabled.
+        if (n < 8 && n != 4)
+           return false;
+        if (m < 8 && m != 4)
+           return false;
+        tinyBLAS_Q0_PPC<block_q4_0, block_q8_0, float> tb{
+            k, (const block_q4_0 *)A, lda,
+            (const block_q8_0 *)B, ldb,
+            (float *)C, ldc,
+            params->ith, params->nth};
+        tb.matmul(m, n);
+        return true;
 #else
         return false;
 #endif