]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
faster avx512 exp implementation (llama/7551)
authorChris Elrod <redacted>
Thu, 30 May 2024 11:32:55 +0000 (07:32 -0400)
committerGeorgi Gerganov <redacted>
Sun, 16 Jun 2024 15:19:48 +0000 (18:19 +0300)
* faster avx512 exp implementation

* x->r

* improve accuracy, handle special cases

* remove `e`

ggml.c

diff --git a/ggml.c b/ggml.c
index f3a90ff2c06329ec26c04d09f327077f05051cb2..76803639c97f615a6ac22d637d2e13d96a754c5a 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -2315,32 +2315,27 @@ inline static __m512 ggml_v_expf(__m512 x) {
   const __m512 r = _mm512_set1_ps(0x1.8p23f);
   const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
   const __m512 n = _mm512_sub_ps(z, r);
-  const __m512 b = _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
-                                    _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
-  const __m512i e = _mm512_slli_epi32(_mm512_castps_si512(z), 23);
-  const __m512 k = _mm512_castsi512_ps(_mm512_add_epi32(e, _mm512_castps_si512(_mm512_set1_ps(1))));
-  const __mmask16 c = _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(126), _CMP_GT_OQ);
-  const __m512 u = _mm512_mul_ps(b, b);
-  const __m512 j = _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
-                                                                   _mm512_set1_ps(0x1.573e2ep-5f)), u,
-                                                   _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
-                                                                   _mm512_set1_ps(0x1.fffdb6p-2f))),
-                                   u, _mm512_mul_ps(_mm512_set1_ps(0x1.ffffecp-1f), b));
-  if (_mm512_kortestz(c, c))
-    return _mm512_fmadd_ps(j, k, k);
-  const __m512i g = _mm512_and_si512(
-      _mm512_movm_epi32(_mm512_cmp_ps_mask(n, _mm512_setzero_ps(), _CMP_LE_OQ)),
-      _mm512_set1_epi32(0x82000000u));
-  const __m512 s1 =
-      _mm512_castsi512_ps(_mm512_add_epi32(g, _mm512_set1_epi32(0x7f000000u)));
-  const __m512 s2 = _mm512_castsi512_ps(_mm512_sub_epi32(e, g));
+  const __m512 b =
+      _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
+                       _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
   const __mmask16 d =
       _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
-  return _mm512_mask_blend_ps(
-      d, _mm512_mask_blend_ps(
-          c, _mm512_fmadd_ps(k, j, k),
-          _mm512_mul_ps(_mm512_fmadd_ps(s2, j, s2), s1)),
-      _mm512_mul_ps(s1, s1));
+  const __m512 u = _mm512_mul_ps(b, b);
+  const __m512 j = _mm512_fmadd_ps(
+      _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
+                                      _mm512_set1_ps(0x1.573e2ep-5f)),
+                      u,
+                      _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
+                                      _mm512_set1_ps(0x1.fffdb6p-2f))),
+      u,
+      _mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F)));
+  const __m512 res = _mm512_scalef_ps(j, n);
+  if (_mm512_kortestz(d, d))
+    return res;
+  const __m512 zero = _mm512_setzero_ps();
+  const __m512 alt = _mm512_mask_blend_ps(
+      _mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero);
+  return _mm512_mask_blend_ps(d, res, alt);
 }
 
 // computes silu x/(1+exp(-x)) in single precision vector