]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
Add AVX,AVX2 support for ggml_vec_scale_f32
authorkatsu560 <redacted>
Fri, 16 Dec 2022 23:42:30 +0000 (08:42 +0900)
committerGeorgi Gerganov <redacted>
Sat, 17 Dec 2022 17:40:10 +0000 (19:40 +0200)
ggml.c

diff --git a/ggml.c b/ggml.c
index c5780ed2579a59363314985d1ea03af4ec486dde..f1d2b2564325204f4dd3dcc7a89ad7a857d665c1 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -1118,7 +1118,45 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_
 #endif
 }
 
-inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) { for (int i = 0; i < n; ++i) y[i] *= v;          }
+//inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) { for (int i = 0; i < n; ++i) y[i] *= v;          }
+inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) {
+#if defined(__AVX__) || defined(__AVX2__)
+    // AVX 256-bit
+    const int n32 = (n & ~31);
+
+    const __m256 v4 = _mm256_set1_ps(v);
+
+    __m256 y0, y1, y2, y3;
+
+    for (int i = 0; i < n32; i += 32) {
+        y0 = _mm256_loadu_ps(y + i + 0);
+        y1 = _mm256_loadu_ps(y + i + 8);
+        y2 = _mm256_loadu_ps(y + i + 16);
+        y3 = _mm256_loadu_ps(y + i + 24);
+
+       y0 = _mm256_mul_ps(y0, v4);
+       y1 = _mm256_mul_ps(y1, v4);
+       y2 = _mm256_mul_ps(y2, v4);
+       y3 = _mm256_mul_ps(y3, v4);
+
+        _mm256_storeu_ps(y + i + 0, y0);
+        _mm256_storeu_ps(y + i + 8, y1);
+        _mm256_storeu_ps(y + i + 16, y2);
+        _mm256_storeu_ps(y + i + 24, y3);
+    }
+
+    // leftovers
+    for (int i = n32; i < n; ++i) {
+        y[i] *= v;
+    }
+#else
+    // scalar
+    for (int i = 0; i < n; ++i) {
+        y[i] *= v;
+    }
+#endif
+}
+
 inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrt(*s);   }
 inline static void ggml_vec_sqr_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i];   }
 inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrt(x[i]); }