]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Add provisions for windows support for BF16 code including CMake provision for enabli...
authorSrihari-mcw <redacted>
Mon, 20 May 2024 02:18:39 +0000 (19:18 -0700)
committerGitHub <redacted>
Mon, 20 May 2024 02:18:39 +0000 (12:18 +1000)
CMakeLists.txt
ggml-impl.h
ggml.c
ggml.h
llama.cpp

index 616698c7fe28c8174504232ce6cdf4a7d6a3c806..92c9f09eb5d9d75e378de7beae314c45c1618a3a 100644 (file)
@@ -77,6 +77,7 @@ option(LLAMA_AVX2                            "llama: enable AVX2"
 option(LLAMA_AVX512                          "llama: enable AVX512"                             OFF)
 option(LLAMA_AVX512_VBMI                     "llama: enable AVX512-VBMI"                        OFF)
 option(LLAMA_AVX512_VNNI                     "llama: enable AVX512-VNNI"                        OFF)
+option(LLAMA_AVX512_BF16                     "llama: enable AVX512-BF16"                        OFF)
 option(LLAMA_FMA                             "llama: enable FMA"                                ${INS_ENB})
 # in MSVC F16C is implied with AVX2/AVX512
 if (NOT MSVC)
@@ -1060,6 +1061,10 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
                 add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VNNI__>)
                 add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VNNI__>)
             endif()
+            if (LLAMA_AVX512_BF16)
+                add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512BF16__>)
+                add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512BF16__>)
+            endif()
         elseif (LLAMA_AVX2)
             list(APPEND ARCH_FLAGS /arch:AVX2)
         elseif (LLAMA_AVX)
@@ -1091,6 +1096,9 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
         if (LLAMA_AVX512_VNNI)
             list(APPEND ARCH_FLAGS -mavx512vnni)
         endif()
+        if (LLAMA_AVX512_BF16)
+            list(APPEND ARCH_FLAGS -mavx512bf16)
+        endif()
     endif()
 elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
     message(STATUS "PowerPC detected")
index 59684fa81f07e8d81010147bbf71b3842111fddd..5ff014fe3f339d5d3716718d935ce5ba802c7dd0 100644 (file)
 #define MIN(a, b) ((a) < (b) ? (a) : (b))
 #define MAX(a, b) ((a) > (b) ? (a) : (b))
 
+#if defined(_WIN32)
+
+#define m512bh(p) p
+#define m512i(p) p
+
+#else
+
+#define m512bh(p) (__m512bh)(p)
+#define m512i(p) (__m512i)(p)
+
+#endif
+
 /**
  * Converts brain16 to float32.
  *
diff --git a/ggml.c b/ggml.c
index 3a104c486339e5cead080f9405cc898a2076d3f1..53da231ee061c46d503564a6f7dd66395474934f 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -406,10 +406,10 @@ void ggml_fp32_to_bf16_row(const float * x, ggml_bf16_t * y, int64_t n) {
   int i = 0;
 #if defined(__AVX512BF16__)
   for (; i + 32 <= n; i += 32) {
-        _mm512_storeu_ps(
-            (__m512 *)(y + i),
-            (__m512)_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
-                                        _mm512_loadu_ps(x + i)));
+        _mm512_storeu_si512(
+            (__m512i *)(y + i),
+            m512i(_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
+                                _mm512_loadu_ps(x + i))));
   }
 #endif
     for (; i < n; i++) {
@@ -1666,10 +1666,10 @@ static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t
     __m512 c1 = _mm512_setzero_ps();
     __m512 c2 = _mm512_setzero_ps();
     for (; i + 64 <= n; i += 64) {
-        c1 = _mm512_dpbf16_ps(c1, (__m512bh)_mm512_loadu_ps((const float *)(x + i)),
-                              (__m512bh)_mm512_loadu_ps((const float *)(y + i)));
-        c2 = _mm512_dpbf16_ps(c2, (__m512bh)_mm512_loadu_ps((const float *)(x + i + 32)),
-                              (__m512bh)_mm512_loadu_ps((const float *)(y + i + 32)));
+        c1 = _mm512_dpbf16_ps(c1, m512bh(_mm512_loadu_si512((x + i))),
+                             m512bh(_mm512_loadu_si512((y + i))));
+        c2 = _mm512_dpbf16_ps(c2, m512bh(_mm512_loadu_si512((x + i + 32))),
+                             m512bh(_mm512_loadu_si512((y + i + 32))));
     }
     sumf += (ggml_float)_mm512_reduce_add_ps(c1);
     sumf += (ggml_float)_mm512_reduce_add_ps(c2);
@@ -23137,6 +23137,14 @@ int ggml_cpu_has_avx512_vnni(void) {
 #endif
 }
 
+int ggml_cpu_has_avx512_bf16(void) {
+#if defined(__AVX512BF16__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
 int ggml_cpu_has_fma(void) {
 #if defined(__FMA__)
     return 1;
diff --git a/ggml.h b/ggml.h
index 8c13f4ba89c6e71a492a182eadbb008e8e810d77..77475710129d7986b1684e56fc45fbd28cb82d5d 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -2390,6 +2390,7 @@ extern "C" {
     GGML_API int ggml_cpu_has_avx512     (void);
     GGML_API int ggml_cpu_has_avx512_vbmi(void);
     GGML_API int ggml_cpu_has_avx512_vnni(void);
+    GGML_API int ggml_cpu_has_avx512_bf16(void);
     GGML_API int ggml_cpu_has_fma        (void);
     GGML_API int ggml_cpu_has_neon       (void);
     GGML_API int ggml_cpu_has_arm_fma    (void);
index 102bc2020aaac05e9085fe94949c3203489f12b6..ca3e9fcc0c24953502e1e2bbe2d760c296c6ef4b 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -18074,6 +18074,7 @@ const char * llama_print_system_info(void) {
     s += "AVX512 = "      + std::to_string(ggml_cpu_has_avx512())      + " | ";
     s += "AVX512_VBMI = " + std::to_string(ggml_cpu_has_avx512_vbmi()) + " | ";
     s += "AVX512_VNNI = " + std::to_string(ggml_cpu_has_avx512_vnni()) + " | ";
+    s += "AVX512_BF16 = " + std::to_string(ggml_cpu_has_avx512_bf16()) + " | ";
     s += "FMA = "         + std::to_string(ggml_cpu_has_fma())         + " | ";
     s += "NEON = "        + std::to_string(ggml_cpu_has_neon())        + " | ";
     s += "ARM_FMA = "     + std::to_string(ggml_cpu_has_arm_fma())     + " | ";