]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : add ARM_NEON quantize_row_q4_1()
authorGeorgi Gerganov <redacted>
Wed, 29 Mar 2023 19:03:02 +0000 (22:03 +0300)
committerGeorgi Gerganov <redacted>
Wed, 29 Mar 2023 19:03:07 +0000 (22:03 +0300)
ggml.c

diff --git a/ggml.c b/ggml.c
index 0906cf90ec3787d046314cf5eed08e4facfd9e27..51cd3b91ca810b0213fe441f3af2714fae310924 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -564,10 +564,7 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
         }
     }
 #elif __ARM_NEON
-    uint8_t pp[QK/2];
     for (int i = 0; i < nb; i++) {
-        float amax = 0.0f; // absolute max
-
         float32x4_t srcv [8];
         float32x4_t asrcv[8];
         float32x4_t amaxv[8];
@@ -579,7 +576,8 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
         for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
         for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
 
-        amax = MAX(
+        // absolute max
+        const float amax = MAX(
                 MAX(vgetq_lane_f32(amaxv[0], 0), vgetq_lane_f32(amaxv[0], 1)),
                 MAX(vgetq_lane_f32(amaxv[0], 2), vgetq_lane_f32(amaxv[0], 3)));
 
@@ -593,11 +591,9 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
             const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(8.5f));
             const int32x4_t   vi = vcvtq_s32_f32(vf);
 
-            pp[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
-            pp[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
+            y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
+            y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
         }
-
-        memcpy(y[i].qs, pp, sizeof(pp));
     }
 #elif defined(__AVX2__)
     for (int i = 0; i < nb; i++) {
@@ -665,7 +661,6 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
         _mm_storeu_si128( ( __m128i* )y[i].qs, res );
     }
 #elif defined(__wasm_simd128__)
-    uint8_t pp[QK/2];
     for (int i = 0; i < nb; i++) {
         float amax = 0.0f; // absolute max
 
@@ -694,11 +689,9 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
             const v128_t vf = wasm_f32x4_add(v, wasm_f32x4_splat(8.5f));
             const v128_t vi = wasm_i32x4_trunc_sat_f32x4(vf);
 
-            pp[2*l + 0] = wasm_i32x4_extract_lane(vi, 0) | (wasm_i32x4_extract_lane(vi, 1) << 4);
-            pp[2*l + 1] = wasm_i32x4_extract_lane(vi, 2) | (wasm_i32x4_extract_lane(vi, 3) << 4);
+            y[i].qs[2*l + 0] = wasm_i32x4_extract_lane(vi, 0) | (wasm_i32x4_extract_lane(vi, 1) << 4);
+            y[i].qs[2*l + 1] = wasm_i32x4_extract_lane(vi, 2) | (wasm_i32x4_extract_lane(vi, 3) << 4);
         }
-
-        memcpy(y[i].qs, pp, sizeof(pp));
     }
 #else
     // scalar
@@ -750,11 +743,11 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric
 static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int k) {
     assert(k % QK == 0);
 
-#if defined(__AVX2__)
     const int nb = k / QK;
 
     block_q4_1 * restrict y = vy;
 
+#if defined(__AVX2__)
     for (int i = 0; i < nb; i++) {
         // Load elements into 4 AVX vectors
         __m256 v0 = _mm256_loadu_ps( x );
@@ -828,6 +821,41 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
         __m128i res = packNibbles( i0 );
         _mm_storeu_si128( ( __m128i* )y[i].qs, res );
     }
+#elif __ARM_NEON
+    for (int i = 0; i < nb; i++) {
+        float32x4_t srcv[8];
+        float32x4_t minv[8];
+        float32x4_t maxv[8];
+
+        for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
+
+        for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l + 1]);
+        for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l + 2]);
+        for (int l = 0; l < 1; l++) minv[8*l] = vminq_f32(minv[8*l], minv[8*l + 4]);
+
+        for (int l = 0; l < 4; l++) maxv[2*l] = vmaxq_f32(srcv[2*l], srcv[2*l + 1]);
+        for (int l = 0; l < 2; l++) maxv[4*l] = vmaxq_f32(maxv[4*l], maxv[4*l + 2]);
+        for (int l = 0; l < 1; l++) maxv[8*l] = vmaxq_f32(maxv[8*l], maxv[8*l + 4]);
+
+        const float min = vminvq_f32(minv[0]);
+        const float max = vmaxvq_f32(maxv[0]);
+
+        const float d = (max - min) / ((1 << 4) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = d;
+        y[i].m = min;
+
+        const float32x4_t minv0 = vdupq_n_f32(min);
+
+        for (int l = 0; l < 8; l++) {
+            const float32x4_t v  = vmulq_n_f32(vsubq_f32(srcv[l], minv0), id);
+            const int32x4_t   vi = vcvtq_s32_f32(v);
+
+            y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
+            y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
+        }
+    }
 #else
     // scalar
     quantize_row_q4_1_reference(x, vy, k);