]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : vector length agnostic SVE support (llama/9290)
authorPrashant Vithule <redacted>
Mon, 9 Sep 2024 15:37:18 +0000 (21:07 +0530)
committerGeorgi Gerganov <redacted>
Fri, 20 Sep 2024 19:03:57 +0000 (22:03 +0300)
* Implemented vector length agnostic SVE using switch case for 512-bit, 256-bit, 128-bit vector lengths

* Implemented vector length agnostic SVE using switch case for 512-bit, 256-bit, 128-bit vector lengths

* Removed WhiteSpaces

* ggml : style changes + fix 512-bit nb loop check

- fix local scope in switch cases
- consistent predicate names
- empty lines when necessary
- opening braces, spaces
- const-correctness
- add asserts

* Update ggml/src/ggml-quants.c

Co-authored-by: Georgi Gerganov <redacted>
---------

Co-authored-by: Georgi Gerganov <redacted>
src/ggml-quants.c

index 8c31e2ccabda0b11ed6064f2bf13bb8e28e02792..322c85d2a816e1af61f5bc4c071f7f9757e29533 100644 (file)
@@ -4003,42 +4003,141 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
     float sumf = 0;
 
 #if defined(__ARM_FEATURE_SVE)
-    if (ggml_sve_cnt_b == QK8_0) {
-        const svbool_t ptrueh = svptrue_pat_b8(SV_VL16);
-        const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh);
+    svfloat32_t sumv0 = svdup_n_f32(0.0f);
+    svfloat32_t sumv1 = svdup_n_f32(0.0f);
 
-        svfloat32_t sumv0 = svdup_n_f32(0.0f);
-        svfloat32_t sumv1 = svdup_n_f32(0.0f);
+    const int vector_length = ggml_sve_cnt_b*8;
 
-        for (; ib + 1 < nb; ib += 2) {
-            const block_q4_0 * restrict x0 = &x[ib + 0];
-            const block_q4_0 * restrict x1 = &x[ib + 1];
-            const block_q8_0 * restrict y0 = &y[ib + 0];
-            const block_q8_0 * restrict y1 = &y[ib + 1];
-
-            // load x
-            const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
-            const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
-
-            // 4-bit -> 8-bit
-            const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04));
-            const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04));
-
-            // sub 8
-            const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);
-            const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);
+    // VLA Implementation using switch case
+    switch (vector_length) {
+        case 128:
+            {
+                // predicate for activating higher lanes for 4 float32 elements
+                const svbool_t ph4 = svptrue_pat_b32(SV_VL4);
+
+                for (; ib + 1 < nb; ib += 2) {
+                    const block_q4_0 * restrict x0 = &x[ib + 0];
+                    const block_q4_0 * restrict x1 = &x[ib + 1];
+                    const block_q8_0 * restrict y0 = &y[ib + 0];
+                    const block_q8_0 * restrict y1 = &y[ib + 1];
+
+                    // load x
+                    const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
+                    const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
+
+                    // 4-bit -> 8-bit
+                    const svint8_t qx0l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx0r, 0x0F));
+                    const svint8_t qx0h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx0r, 0x04));
+                    const svint8_t qx1l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx1r, 0x0F));
+                    const svint8_t qx1h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx1r, 0x04));
+
+                    // sub 8
+                    const svint8_t qx0ls = svsub_n_s8_x(svptrue_b8(), qx0h, 8);
+                    const svint8_t qx0hs = svsub_n_s8_x(svptrue_b8(), qx0l, 8);
+                    const svint8_t qx1ls = svsub_n_s8_x(svptrue_b8(), qx1h, 8);
+                    const svint8_t qx1hs = svsub_n_s8_x(svptrue_b8(), qx1l, 8);
+
+                    // load y
+                    const svint8_t qy0h = svld1_s8(svptrue_b8(), y0->qs);
+                    const svint8_t qy0l = svld1_s8(svptrue_b8(), y0->qs + 16);
+                    const svint8_t qy1h = svld1_s8(svptrue_b8(), y1->qs);
+                    const svint8_t qy1l = svld1_s8(svptrue_b8(), y1->qs + 16);
+
+                    // dot product
+                    sumv0 = svmla_n_f32_x(ph4, sumv0, svcvt_f32_s32_x(ph4, svadd_x(ph4,
+                                    svdot_s32(svdup_n_s32(0), qx0ls, qy0l),
+                                    svdot_s32(svdup_n_s32(0), qx0hs, qy0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+                    sumv1 = svmla_n_f32_x(ph4, sumv1, svcvt_f32_s32_x(ph4, svadd_x(ph4,
+                                    svdot_s32(svdup_n_s32(0), qx1ls, qy1l),
+                                    svdot_s32(svdup_n_s32(0), qx1hs, qy1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+                }
 
-            // load y
-            const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
-            const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
+                sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
+            } break;
+        case 256:
+            {
+                // predicate for activating higher lanes for 16 int8 elements
+                const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
+                // predicate for activating lower lanes for  16 int8 elements
+                const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16);
+
+                for (; ib + 1 < nb; ib += 2) {
+                    const block_q4_0 * restrict x0 = &x[ib + 0];
+                    const block_q4_0 * restrict x1 = &x[ib + 1];
+                    const block_q8_0 * restrict y0 = &y[ib + 0];
+                    const block_q8_0 * restrict y1 = &y[ib + 1];
+
+                    // load x
+                    const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
+                    const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
+
+                    // 4-bit -> 8-bit
+                    const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04));
+                    const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04));
+
+                    // sub 8
+                    const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);
+                    const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);
+
+                    // load y
+                    const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
+                    const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
+
+                    // dot product
+                    sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(),
+                                svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+                    sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(),
+                                svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+                }
 
-            // dot product
-            sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
-            sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
-        }
+                sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
+            } break;
+        case 512:
+            {
+                // predicate for activating higher lanes for 32 int8 elements
+                const svbool_t ph32 = svptrue_pat_b8(SV_VL32);
+
+                // predicate for activating higher lanes for 16 int8 elements
+                const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
+                // predicate for activating lower lanes for 16 int8 elements from first 32 int8 activated lanes
+                const svbool_t pl16 = svnot_b_z(ph32, ph16);
+
+                for (; ib + 1 < nb; ib += 2) {
+                    const block_q4_0 * restrict x0 = &x[ib + 0];
+                    const block_q4_0 * restrict x1 = &x[ib + 1];
+                    const block_q8_0 * restrict y0 = &y[ib + 0];
+                    const block_q8_0 * restrict y1 = &y[ib + 1];
+
+                    // load x
+                    const svuint8_t qx0r = svld1rq_u8(ph32, x0->qs);
+                    const svuint8_t qx1r = svld1rq_u8(ph32, x1->qs);
+
+                    // 4-bit -> 8-bit
+                    const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04));
+                    const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04));
+
+                    // sub 8
+                    const svint8_t qx0s = svsub_n_s8_x(ph32, qx0, 8);
+                    const svint8_t qx1s = svsub_n_s8_x(ph32, qx1, 8);
+
+                    // load y
+                    const svint8_t qy0 = svld1_s8(ph32, y0->qs);
+                    const svint8_t qy1 = svld1_s8(ph32, y1->qs);
+
+                    // dot product
+                    sumv0 = svmla_n_f32_x(ph32, sumv0, svcvt_f32_s32_x(ph32,
+                                svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+                    sumv1 = svmla_n_f32_x(ph32, sumv1, svcvt_f32_s32_x(ph32,
+                                svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+                }
 
-        sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
+                sumf = svaddv_f32(ph32, svadd_f32_x(ph32, sumv0, sumv1));
+            } break;
+        default:
+            assert(false && "Unsupported vector length");
+            break;
     }
+
 #elif defined(__ARM_NEON)
     float32x4_t sumv0 = vdupq_n_f32(0.0f);
     float32x4_t sumv1 = vdupq_n_f32(0.0f);
@@ -5488,29 +5587,124 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
     float sumf = 0;
 
 #if defined(__ARM_FEATURE_SVE)
-    if (ggml_sve_cnt_b == QK8_0) {
-        svfloat32_t sumv0 = svdup_n_f32(0.0f);
-        svfloat32_t sumv1 = svdup_n_f32(0.0f);
+    svfloat32_t sumv0 = svdup_n_f32(0.0f);
+    svfloat32_t sumv1 = svdup_n_f32(0.0f);
 
-        for (; ib + 1 < nb; ib += 2) {
-            const block_q8_0 * restrict x0 = &x[ib + 0];
-            const block_q8_0 * restrict x1 = &x[ib + 1];
-            const block_q8_0 * restrict y0 = &y[ib + 0];
-            const block_q8_0 * restrict y1 = &y[ib + 1];
+    const int vector_length = ggml_sve_cnt_b*8;
 
-            // load x
-            const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
-            const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);
+    //VLA Implemenation for SVE
+    switch (vector_length) {
+        case 128:
+            {
+                // predicate for activating lanes for 16 Int8 elements
+                const svbool_t ph16 = svptrue_pat_b8 (SV_VL16);
+                const svbool_t pl16 = svptrue_pat_b32(SV_VL4);
+
+                for (; ib + 1 < nb; ib += 2) {
+                    const block_q8_0 * restrict x0 = &x[ib + 0];
+                    const block_q8_0 * restrict x1 = &x[ib + 1];
+                    const block_q8_0 * restrict y0 = &y[ib + 0];
+                    const block_q8_0 * restrict y1 = &y[ib + 1];
+
+                    // load x
+                    const svint8_t qx0_0 = svld1_s8(ph16, x0->qs);
+                    const svint8_t qx0_1 = svld1_s8(ph16, x0->qs+16);
+                    const svint8_t qx1_0 = svld1_s8(ph16, x1->qs);
+                    const svint8_t qx1_1 = svld1_s8(ph16, x1->qs+16);
+
+                    // load y
+                    const svint8_t qy0_0 = svld1_s8(ph16, y0->qs);
+                    const svint8_t qy0_1 = svld1_s8(ph16, y0->qs+16);
+                    const svint8_t qy1_0 = svld1_s8(ph16, y1->qs);
+                    const svint8_t qy1_1 = svld1_s8(ph16, y1->qs+16);
+
+                    sumv0 = svmla_n_f32_x(pl16, sumv0, svcvt_f32_s32_x(pl16, svadd_x(pl16,
+                                    svdot_s32(svdup_n_s32(0), qx0_0, qy0_0),
+                                    svdot_s32(svdup_n_s32(0), qx0_1, qy0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+                    sumv1 = svmla_n_f32_x(pl16, sumv1, svcvt_f32_s32_x(pl16, svadd_x(pl16,
+                                    svdot_s32(svdup_n_s32(0), qx1_0, qy1_0),
+                                    svdot_s32(svdup_n_s32(0), qx1_1, qy1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+                }
 
-            // load y
-            const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
-            const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
+                sumf = svaddv_f32(pl16, svadd_f32_x(pl16, sumv0, sumv1));
+            } break;
+        case 256:
+            {
+                //printf("sve256");
+                for (; ib + 1 < nb; ib += 2) {
+                    const block_q8_0 * restrict x0 = &x[ib + 0];
+                    const block_q8_0 * restrict x1 = &x[ib + 1];
+                    const block_q8_0 * restrict y0 = &y[ib + 0];
+                    const block_q8_0 * restrict y1 = &y[ib + 1];
+
+                    // load x
+                    const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
+                    const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);
+
+                    // load y
+                    const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
+                    const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
+
+                    sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(),
+                                svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+                    sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(),
+                                svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+                }
 
-            sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
-            sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
-        }
+                sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
+            } break;
+        case 512:
+            {
+                // predicate for activating high 256 bit
+                const svbool_t ph32 = svptrue_pat_b8(SV_VL32);
+                // predicate for activating low 256 bit
+                const svbool_t pl32 = svnot_b_z(svptrue_b8(), ph32);
+
+                // predicate for activating high lanes for 8 float32 elements
+                const svbool_t ph8 = svptrue_pat_b32(SV_VL8);
+                // predicate for activating low lanes for 8 float32 elements
+                const svbool_t pl8 = svnot_b_z(svptrue_b32(), ph8);
+
+                svfloat32_t sumv00 = svdup_n_f32(0.0f);
+
+                for (; ib + 1 < nb; ib += 2) {
+                    const block_q8_0 * restrict x0 = &x[ib + 0];
+                    const block_q8_0 * restrict x1 = &x[ib + 1];
+                    const block_q8_0 * restrict y0 = &y[ib + 0];
+                    const block_q8_0 * restrict y1 = &y[ib + 1];
+
+                    //load 32 int8_t in first half of vector and put another 32 int8_t in second vector lower bits
+                    // and add them to make one 64 element vector
+                    // load x
+                    const svint8_t qx_32 = svld1_s8(ph32, x0->qs);
+                          svint8_t qx_64 = svld1_s8(pl32, x0->qs + 2);
+
+                    qx_64 = svadd_s8_x(svptrue_b8(), qx_32, qx_64);
 
-        sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
+                    // load y
+                    const svint8_t qy_32 = svld1_s8(ph32, y0->qs);
+                          svint8_t qy_64 = svld1_s8(pl32, y0->qs + 2);
+
+                    qy_64 = svadd_s8_x(svptrue_b8(), qy_32, qy_64);
+
+                    // scale creation
+                    const float32_t deq1 = GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d);
+                    const float32_t deq2 = GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d);
+
+                    // duplicate deq1 in first half of vector and deq2 in second half of vector
+                    const svfloat32_t temp = svdup_f32_m(svdup_f32_z(ph8, deq1), pl8, deq2);
+
+                    const svfloat32_t sumvt = svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx_64, qy_64));
+
+                    sumv00 = svmla_f32_m(svptrue_b32(), sumv00, sumvt, temp);
+                }
+
+                sumf = svaddv_f32(svptrue_b32(), sumv00);
+                break;
+            }
+        default:
+            assert(false && "Unsupported vector length");
+            break;
     }
 #elif defined(__ARM_NEON)
     float32x4_t sumv0 = vdupq_n_f32(0.0f);