]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml : fixes for AVXVNNI instruction set with MSVC and Clang (llama/11027)
authorSrihari-mcw <redacted>
Tue, 31 Dec 2024 14:23:33 +0000 (19:53 +0530)
committerGeorgi Gerganov <redacted>
Sat, 4 Jan 2025 08:45:01 +0000 (10:45 +0200)
* Fixes for clang AVX VNNI

* enable AVX VNNI and alder lake build for MSVC

* Apply suggestions from code review

---------

Co-authored-by: slaren <redacted>
ggml/src/CMakeLists.txt
ggml/src/ggml-cpu/CMakeLists.txt
ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp
ggml/src/ggml-cpu/ggml-cpu-quants.c
ggml/src/ggml-cpu/llamafile/sgemm.cpp

index a5f7f7b5b85a69f2cdbfcb1cb025733cf42c5392..84101c32c2b50171a166ae652a6b8fe78557c2e9 100644 (file)
@@ -290,9 +290,9 @@ if (GGML_CPU_ALL_VARIANTS)
     ggml_add_cpu_backend_variant(haswell        AVX F16C AVX2 FMA)
     ggml_add_cpu_backend_variant(skylakex       AVX F16C AVX2 FMA AVX512)
     ggml_add_cpu_backend_variant(icelake        AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI)
+    ggml_add_cpu_backend_variant(alderlake      AVX F16C AVX2 FMA AVX_VNNI)
     if (NOT MSVC)
-        # MSVC doesn't support AVX-VNNI or AMX
-        ggml_add_cpu_backend_variant(alderlake      AVX F16C AVX2 FMA AVX_VNNI)
+        # MSVC doesn't support AMX
         ggml_add_cpu_backend_variant(sapphirerapids AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)
     endif()
 else ()
index f0aecac1bd1c614be10513877a253956150e2767..6b3641c4263c711d176aecf678075a44ae25f9b0 100644 (file)
@@ -215,8 +215,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
                 list(APPEND ARCH_DEFINITIONS GGML_SSE42)
             endif()
             if (GGML_AVX_VNNI)
-                # MSVC generates AVX512 with AVX-VNNI intrinsics even with /arch:AVX2
-                #list(APPEND ARCH_DEFINITIONS __AVXVNNI__ GGML_AVX_VNNI)
+                list(APPEND ARCH_DEFINITIONS __AVXVNNI__ GGML_AVX_VNNI)
             endif()
         else ()
             if (GGML_NATIVE)
index 2d79b8b611de33ec2de632a27788f37f532fec27..622c63f1f8ebfbdb87ad229e499bcc560d4a44f0 100644 (file)
@@ -194,9 +194,12 @@ static inline __m256i sum_i16_pairs_int32x8(const __m256i x) {
 }
 
 static inline __m256i mul_sum_us8_pairs_int32x8(const __m256i ax, const __m256i sy) {
-#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
+#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
     const __m256i zero = _mm256_setzero_si256();
     return _mm256_dpbusd_epi32(zero, ax, sy);
+#elif defined(__AVXVNNI__)
+    const __m256i zero = _mm256_setzero_si256();
+    return _mm256_dpbusd_avx_epi32(zero, ax, sy);
 #else
     // Perform multiplication and create 16-bit values
     const __m256i dot = _mm256_maddubs_epi16(ax, sy);
index 634c5fa1162c3c4b6e030a713a953679b4f0b273..8e14722667abbc03f906cb96d801e8076592aa3a 100644 (file)
@@ -103,10 +103,14 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) {
 }
 
 static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
-#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
+#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
     const __m256i zero = _mm256_setzero_si256();
     const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
     return _mm256_cvtepi32_ps(summed_pairs);
+#elif defined(__AVXVNNI__)
+    const __m256i zero = _mm256_setzero_si256();
+    const __m256i summed_pairs = _mm256_dpbusd_avx_epi32(zero, ax, sy);
+    return _mm256_cvtepi32_ps(summed_pairs);
 #else
     // Perform multiplication and create 16-bit values
     const __m256i dot = _mm256_maddubs_epi16(ax, sy);
index 00f7f11704ead7ddcc6489042323d0484b6a505b..8fce576c3e4739fa03ac329cb62ea3854b97c8de 100644 (file)
@@ -1000,8 +1000,10 @@ class tinyBLAS_Q0_AVX {
 
     inline __m256 updot(__m256i u, __m256i s) {
         __m256i res;
-#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
+#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
         res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);
+#elif defined(__AVXVNNI__)
+        res = _mm256_dpbusd_avx_epi32(_mm256_setzero_si256(), u, s);
 #else
         res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));
 #endif