]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
cpu : optimize the ggml NORM operation (llama/15953)
authorduduta <redacted>
Thu, 9 Oct 2025 19:11:15 +0000 (22:11 +0300)
committerGeorgi Gerganov <redacted>
Sun, 12 Oct 2025 08:16:23 +0000 (11:16 +0300)
* ggml-cpu: optimize norm operation to use intrinsics or Accelerate

          rename function

          add endif macro comment

Co-authored-by: Georgi Gerganov <redacted>
Co-authored-by: Aaron Teo <redacted>
* implement s390x SIMD suggested by @taronaeo

* add TODO comment

* tidy up spaces

---------

Co-authored-by: Georgi Gerganov <redacted>
Co-authored-by: Aaron Teo <redacted>
ggml/src/ggml-cpu/ops.cpp
ggml/src/ggml-cpu/vec.cpp
ggml/src/ggml-cpu/vec.h

index 8e1a2de14f9833163489cfab24eb39b8d3195e95..1c43865ff65fcff2677c50b9981b24672f1ea396 100644 (file)
@@ -3467,31 +3467,27 @@ static void ggml_compute_forward_norm_f32(
 
     GGML_ASSERT(eps >= 0.0f);
 
-    // TODO: optimize
     for (int64_t i03 = 0; i03 < ne03; i03++) {
         for (int64_t i02 = 0; i02 < ne02; i02++) {
             for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
                 const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
 
-                ggml_float sum = 0.0;
-                for (int64_t i00 = 0; i00 < ne00; i00++) {
-                    sum += (ggml_float)x[i00];
-                }
-
+                float sum = 0.0;
+                ggml_vec_sum_f32(ne00, &sum, x);
                 float mean = sum/ne00;
 
                 float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
+                float variance = 0;
 
-                ggml_float sum2 = 0.0;
-                for (int64_t i00 = 0; i00 < ne00; i00++) {
-                    float v = x[i00] - mean;
-                    y[i00] = v;
-                    sum2 += (ggml_float)(v*v);
-                }
+#ifdef GGML_USE_ACCELERATE
+                mean = -mean;
+                vDSP_vsadd(x, 1, &mean, y, 1, ne00);
+                vDSP_measqv(y, 1, &variance, ne00);
+#else
+                variance = ggml_vec_cvar_f32(ne00, y, x, mean);
+#endif //GGML_USE_ACCELERATE
 
-                float variance = sum2/ne00;
                 const float scale = 1.0f/sqrtf(variance + eps);
-
                 ggml_vec_scale_f32(ne00, y, scale);
             }
         }
index 437192d525a34fd072ec058cc94e7085cfeaecd8..b8e37052d35e15c2f0f28acf8389e7e6337314e4 100644 (file)
@@ -404,6 +404,72 @@ void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float *
     }
 }
 
+ggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const float mean) {
+    int i = 0;
+    ggml_float sum = 0;
+// TODO: optimize to process the remaining elements in groups using the smaller vector sizes from AVX2 and SSE
+// ref: https://github.com/ggml-org/llama.cpp/pull/15953#pullrequestreview-3310928344
+#if defined(__AVX512F__) && defined(__AVX512DQ__)
+    for (; i + 15 < n; i += 16) {
+        __m512 val = _mm512_sub_ps(_mm512_loadu_ps(x + i),
+                                   _mm512_set1_ps(mean));
+        _mm512_storeu_ps(y + i, val);
+        sum += (ggml_float)_mm512_reduce_add_ps(_mm512_mul_ps(val, val));
+    }
+#elif defined(__AVX2__) && defined(__FMA__)
+    for (; i + 7 < n; i += 8) {
+        __m256 val = _mm256_sub_ps(_mm256_loadu_ps(x + i),
+                                   _mm256_set1_ps(mean));
+        _mm256_storeu_ps(y + i, val);
+        val = _mm256_mul_ps(val,val);
+        __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1),
+                                 _mm256_castps256_ps128(val));
+        val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
+        val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
+        sum += (ggml_float)_mm_cvtss_f32(val2);
+    }
+#elif defined(__SSE2__)
+    for (; i + 3 < n; i += 4) {
+        __m128 val = _mm_sub_ps(_mm_loadu_ps(x + i),
+                                _mm_set1_ps(mean));
+        _mm_storeu_ps(y + i, val);
+        val = _mm_mul_ps(val, val);
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
+        val = _mm_add_ps(val, _mm_movehl_ps(val, val));
+        val = _mm_add_ss(val, _mm_movehdup_ps(val));
+#else
+        __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1));
+        val = _mm_add_ps(val, tmp);
+        tmp = _mm_movehl_ps(tmp, val);
+        val = _mm_add_ss(val, tmp);
+#endif  // __AVX__ || __AVX2__ || __AVX512F__
+        sum += (ggml_float)_mm_cvtss_f32(val);
+    }
+#elif defined(__ARM_NEON) && defined(__aarch64__)
+    for (; i + 3 < n; i += 4) {
+        float32x4_t val = vsubq_f32(vld1q_f32(x + i),
+                                    vdupq_n_f32(mean));
+        vst1q_f32(y + i, val);
+        val = vmulq_f32(val, val);
+        sum += (ggml_float)vaddvq_f32(val);
+    }
+#elif defined(__VXE__) || defined(__VXE2__)
+    for (; i + 3 < n; i += 4) {
+        float32x4_t val = vec_sub(vec_xl(0, x + i), vec_splats(mean));
+        vec_xst(val, 0, y + i);
+        val = vec_mul(val, val);
+        sum += (ggml_float)vec_hsum_f32x4(val);
+    }
+#endif
+    for (; i < n; ++i) {
+        float val = x[i] - mean;
+        val *= val;
+        sum += (ggml_float)val;
+        y[i] = val;
+    }
+    return sum/n;
+}
+
 ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
     int i = 0;
     ggml_float sum = 0;
index f95ca94e54b16cbf27b64f1d0ba023cfd78da913..2751359ce49f456082239dfeed8cfe1c8f6e922b 100644 (file)
@@ -44,6 +44,7 @@ void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t *
 void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc);
 
 void ggml_vec_silu_f32(const int n, float * y, const float * x);
+ggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const float mean); //it will also center y ( y = y - mean )
 ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max);
 ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max);