]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
add AVX support
authorkatsu560 <redacted>
Wed, 23 Nov 2022 11:23:24 +0000 (20:23 +0900)
committerGeorgi Gerganov <redacted>
Wed, 23 Nov 2022 20:16:33 +0000 (22:16 +0200)
ggml.c
ggml.h
whisper.cpp

diff --git a/ggml.c b/ggml.c
index d58d9889624b8a5c79b3cecfd499d100a921a570..ab49c23bdcec424f6ff6de1b14efe028330d3fa2 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -372,6 +372,49 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
 
     sumf = _mm_cvtss_f32(r1);
 
+    // leftovers
+    for (int i = n32; i < n; ++i) {
+        sumf += x[i]*y[i];
+    }
+#elif defined(__AVX__)
+    // AVX 256-bit
+    const int n32 = (n & ~31);
+
+    __m256 sum0 = _mm256_setzero_ps();
+    __m256 sum1 = _mm256_setzero_ps();
+    __m256 sum2 = _mm256_setzero_ps();
+    __m256 sum3 = _mm256_setzero_ps();
+
+    __m256 x0, x1, x2, x3;
+    __m256 y0, y1, y2, y3;
+
+    for (int i = 0; i < n32; i += 32) {
+        x0 = _mm256_loadu_ps(x + i + 0);
+        x1 = _mm256_loadu_ps(x + i + 8);
+        x2 = _mm256_loadu_ps(x + i + 16);
+        x3 = _mm256_loadu_ps(x + i + 24);
+
+        y0 = _mm256_loadu_ps(y + i + 0);
+        y1 = _mm256_loadu_ps(y + i + 8);
+        y2 = _mm256_loadu_ps(y + i + 16);
+        y3 = _mm256_loadu_ps(y + i + 24);
+
+       sum0 = _mm256_add_ps(_mm256_mul_ps(x0, y0), sum0);
+       sum1 = _mm256_add_ps(_mm256_mul_ps(x1, y1), sum1);
+       sum2 = _mm256_add_ps(_mm256_mul_ps(x2, y2), sum2);
+       sum3 = _mm256_add_ps(_mm256_mul_ps(x3, y3), sum3);
+    }
+
+    sum0 = _mm256_add_ps(sum0, sum1);
+    sum2 = _mm256_add_ps(sum2, sum3);
+    sum0 = _mm256_add_ps(sum0, sum2);
+
+    const __m128 r4 = _mm_add_ps(_mm256_castps256_ps128(sum0), _mm256_extractf128_ps(sum0, 1));
+    const __m128 r2 = _mm_add_ps(r4, _mm_movehl_ps(r4, r4));
+    const __m128 r1 = _mm_add_ss(r2, _mm_movehdup_ps(r2));
+
+    sumf = _mm_cvtss_f32(r1);
+
     // leftovers
     for (int i = n32; i < n; ++i) {
         sumf += x[i]*y[i];
@@ -569,6 +612,50 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
 
     sumf = _mm_cvtss_f32(r1);
 
+    // leftovers
+    for (int i = n32; i < n; ++i) {
+        //GGML_ASSERT(false);
+        sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]);
+    }
+#elif defined(__AVX__)
+    // AVX 256-bit
+    const int n32 = (n & ~31);
+
+    __m256 sum0 = _mm256_setzero_ps();
+    __m256 sum1 = _mm256_setzero_ps();
+    __m256 sum2 = _mm256_setzero_ps();
+    __m256 sum3 = _mm256_setzero_ps();
+
+    __m256 x0, x1, x2, x3;
+    __m256 y0, y1, y2, y3;
+
+    for (int i = 0; i < n32; i += 32) {
+        x0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 0 )));
+        x1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 8 )));
+        x2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 16)));
+        x3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 24)));
+
+        y0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 0 )));
+        y1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 8 )));
+        y2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 16)));
+        y3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 24)));
+
+       sum0 = _mm256_add_ps(_mm256_mul_ps(x0, y0), sum0);
+       sum1 = _mm256_add_ps(_mm256_mul_ps(x1, y1), sum1);
+       sum2 = _mm256_add_ps(_mm256_mul_ps(x2, y2), sum2);
+       sum3 = _mm256_add_ps(_mm256_mul_ps(x3, y3), sum3);
+    }
+
+    const __m256 sum01 = _mm256_add_ps(sum0, sum1);
+    const __m256 sum23 = _mm256_add_ps(sum2, sum3);
+    const __m256 sum0123 = _mm256_add_ps(sum01, sum23);
+
+    const __m128 r4 = _mm_add_ps(_mm256_castps256_ps128(sum0123), _mm256_extractf128_ps(sum0123, 1));
+    const __m128 r2 = _mm_add_ps(r4, _mm_movehl_ps(r4, r4));
+    const __m128 r1 = _mm_add_ss(r2, _mm_movehdup_ps(r2));
+
+    sumf = _mm_cvtss_f32(r1);
+
     // leftovers
     for (int i = n32; i < n; ++i) {
         //GGML_ASSERT(false);
@@ -698,6 +785,41 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
         _mm256_storeu_ps(y + i + 24, y3);
     }
 
+    // leftovers
+    for (int i = n32; i < n; ++i) {
+        y[i] += x[i]*v;
+    }
+#elif defined(__AVX__)
+    // AVX 256-bit
+    const int n32 = (n & ~31);
+
+    const __m256 v4 = _mm256_set1_ps(v);
+
+    __m256 x0, x1, x2, x3;
+    __m256 y0, y1, y2, y3;
+
+    for (int i = 0; i < n32; i += 32) {
+        x0 = _mm256_loadu_ps(x + i + 0);
+        x1 = _mm256_loadu_ps(x + i + 8);
+        x2 = _mm256_loadu_ps(x + i + 16);
+        x3 = _mm256_loadu_ps(x + i + 24);
+
+        y0 = _mm256_loadu_ps(y + i + 0);
+        y1 = _mm256_loadu_ps(y + i + 8);
+        y2 = _mm256_loadu_ps(y + i + 16);
+        y3 = _mm256_loadu_ps(y + i + 24);
+
+       y0 = _mm256_add_ps(_mm256_mul_ps(x0, v4), y0);
+       y1 = _mm256_add_ps(_mm256_mul_ps(x1, v4), y1);
+       y2 = _mm256_add_ps(_mm256_mul_ps(x2, v4), y2);
+       y3 = _mm256_add_ps(_mm256_mul_ps(x3, v4), y3);
+
+        _mm256_storeu_ps(y + i + 0, y0);
+        _mm256_storeu_ps(y + i + 8, y1);
+        _mm256_storeu_ps(y + i + 16, y2);
+        _mm256_storeu_ps(y + i + 24, y3);
+    }
+
     // leftovers
     for (int i = n32; i < n; ++i) {
         y[i] += x[i]*v;
@@ -859,6 +981,42 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_
         _mm_storeu_si128((__m128i*)(y + i + 24), _mm256_cvtps_ph(y3, 0));
     }
 
+    // leftovers
+    for (int i = n32; i < n; ++i) {
+        GGML_ASSERT(false);
+        y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v);
+    }
+#elif defined(__AVX__)
+    // AVX 256-bit
+    const int n32 = (n & ~31);
+
+    const __m256 v8 = _mm256_set1_ps(v);
+
+    __m256 x0, x1, x2, x3;
+    __m256 y0, y1, y2, y3;
+
+    for (int i = 0; i < n32; i += 32) {
+        y0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 0 )));
+        y1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 8 )));
+        y2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 16)));
+        y3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 24)));
+
+        x0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 0 )));
+        x1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 8 )));
+        x2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 16)));
+        x3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 24)));
+
+       y0 = _mm256_add_ps(_mm256_mul_ps(x0, v8), y0);
+       y1 = _mm256_add_ps(_mm256_mul_ps(x1, v8), y1);
+       y2 = _mm256_add_ps(_mm256_mul_ps(x2, v8), y2);
+       y3 = _mm256_add_ps(_mm256_mul_ps(x3, v8), y3);
+
+        _mm_storeu_si128((__m128i*)(y + i + 0 ), _mm256_cvtps_ph(y0, 0));
+        _mm_storeu_si128((__m128i*)(y + i + 8 ), _mm256_cvtps_ph(y1, 0));
+        _mm_storeu_si128((__m128i*)(y + i + 16), _mm256_cvtps_ph(y2, 0));
+        _mm_storeu_si128((__m128i*)(y + i + 24), _mm256_cvtps_ph(y3, 0));
+    }
+
     // leftovers
     for (int i = n32; i < n; ++i) {
         GGML_ASSERT(false);
@@ -8081,6 +8239,14 @@ enum ggml_opt_result ggml_opt(
 
 ////////////////////////////////////////////////////////////////////////////////
 
+int ggml_cpu_has_avx(void) {
+#if defined(__AVX__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
 int ggml_cpu_has_avx2(void) {
 #if defined(__AVX2__)
     return 1;
diff --git a/ggml.h b/ggml.h
index f352e716350c5581e8877fc2b7efea4edcf4acf4..3e4e962a69ec166538a817876f2cbe0c391a89f3 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -723,6 +723,7 @@ enum ggml_opt_result ggml_opt(
 // system info
 //
 
+int ggml_cpu_has_avx(void);
 int ggml_cpu_has_avx2(void);
 int ggml_cpu_has_avx512(void);
 int ggml_cpu_has_neon(void);
index 4f23cde40b20c2905f61f4a9e6bf18863b9a5291..d729dba52181ae4e68acbf98276480ed4561ff0f 100644 (file)
@@ -3041,6 +3041,7 @@ const char * whisper_print_system_info() {
     static std::string s;
 
     s  = "";
+    s += "AVX = "       + std::to_string(ggml_cpu_has_avx())       + " | ";
     s += "AVX2 = "      + std::to_string(ggml_cpu_has_avx2())      + " | ";
     s += "AVX512 = "    + std::to_string(ggml_cpu_has_avx512())    + " | ";
     s += "NEON = "      + std::to_string(ggml_cpu_has_neon())      + " | ";