]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
cmake : MSVC instruction detection (fixed up #809) (#3923)
authorEve <redacted>
Sun, 5 Nov 2023 08:03:09 +0000 (08:03 +0000)
committerGitHub <redacted>
Sun, 5 Nov 2023 08:03:09 +0000 (10:03 +0200)
* Add detection code for avx

* Only check hardware when option is ON

* Modify per code review sugguestions

* Build locally will detect CPU

* Fixes CMake style to use lowercase like everywhere else

* cleanup

* fix merge

* linux/gcc version for testing

* msvc combines avx2 and fma into /arch:AVX2 so check for both

* cleanup

* msvc only version

* style

* Update FindSIMD.cmake

---------

Co-authored-by: Howard Su <redacted>
Co-authored-by: Jeremy Dunn <redacted>
CMakeLists.txt
cmake/FindSIMD.cmake [new file with mode: 0644]

index 3c49d645c3196f0c0c6a41c2bb466d08c7c9c8ca..7b4eb18403c0bfa1e64718cb7361429d73491270 100644 (file)
@@ -10,7 +10,7 @@ endif()
 
 set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
 
-if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
+if (CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
     set(LLAMA_STANDALONE ON)
 
     # configure project version
@@ -44,7 +44,7 @@ endif()
 
 # general
 option(LLAMA_STATIC                     "llama: static link libraries"                          OFF)
-option(LLAMA_NATIVE                     "llama: enable -march=native flag"                      OFF)
+option(LLAMA_NATIVE                     "llama: enable -march=native flag"                      ON)
 option(LLAMA_LTO                        "llama: enable link time optimization"                  OFF)
 
 # debug
@@ -510,6 +510,10 @@ if ((${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm") OR (${CMAKE_SYSTEM_PROCESSOR} MATC
 elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$" OR "${CMAKE_GENERATOR_PLATFORM_LWR}" MATCHES "^(x86_64|i686|amd64|x64)$" )
     message(STATUS "x86 detected")
     if (MSVC)
+        # instruction set detection for MSVC only
+        if (LLAMA_NATIVE)
+            include(cmake/FindSIMD.cmake)
+        endif ()
         if (LLAMA_AVX512)
             add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX512>)
             add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX512>)
diff --git a/cmake/FindSIMD.cmake b/cmake/FindSIMD.cmake
new file mode 100644 (file)
index 0000000..33377ec
--- /dev/null
@@ -0,0 +1,100 @@
+include(CheckCSourceRuns)
+
+set(AVX_CODE "
+    #include <immintrin.h>
+    int main()
+    {
+        __m256 a;
+        a = _mm256_set1_ps(0);
+        return 0;
+    }
+")
+
+set(AVX512_CODE "
+    #include <immintrin.h>
+    int main()
+    {
+        __m512i a = _mm512_set_epi8(0, 0, 0, 0, 0, 0, 0, 0,
+                                    0, 0, 0, 0, 0, 0, 0, 0,
+                                    0, 0, 0, 0, 0, 0, 0, 0,
+                                    0, 0, 0, 0, 0, 0, 0, 0,
+                                    0, 0, 0, 0, 0, 0, 0, 0,
+                                    0, 0, 0, 0, 0, 0, 0, 0,
+                                    0, 0, 0, 0, 0, 0, 0, 0,
+                                    0, 0, 0, 0, 0, 0, 0, 0);
+        __m512i b = a;
+        __mmask64 equality_mask = _mm512_cmp_epi8_mask(a, b, _MM_CMPINT_EQ);
+        return 0;
+    }
+")
+
+set(AVX2_CODE "
+    #include <immintrin.h>
+    int main()
+    {
+        __m256i a = {0};
+        a = _mm256_abs_epi16(a);
+        __m256i x;
+        _mm256_extract_epi64(x, 0); // we rely on this in our AVX2 code
+        return 0;
+    }
+")
+
+set(FMA_CODE "
+    #include <immintrin.h>
+    int main()
+    {
+        __m256 acc = _mm256_setzero_ps();
+        const __m256 d = _mm256_setzero_ps();
+        const __m256 p = _mm256_setzero_ps();
+        acc = _mm256_fmadd_ps( d, p, acc );
+        return 0;
+    }
+")
+
+macro(check_sse type flags)
+    set(__FLAG_I 1)
+    set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
+    foreach (__FLAG ${flags})
+        if (NOT ${type}_FOUND)
+            set(CMAKE_REQUIRED_FLAGS ${__FLAG})
+            check_c_source_runs("${${type}_CODE}" HAS_${type}_${__FLAG_I})
+            if (HAS_${type}_${__FLAG_I})
+                set(${type}_FOUND TRUE CACHE BOOL "${type} support")
+                set(${type}_FLAGS "${__FLAG}" CACHE STRING "${type} flags")
+            endif()
+            math(EXPR __FLAG_I "${__FLAG_I}+1")
+        endif()
+    endforeach()
+    set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
+
+    if (NOT ${type}_FOUND)
+        set(${type}_FOUND FALSE CACHE BOOL "${type} support")
+        set(${type}_FLAGS "" CACHE STRING "${type} flags")
+    endif()
+
+    mark_as_advanced(${type}_FOUND ${type}_FLAGS)
+endmacro()
+
+# flags are for MSVC only!
+check_sse("AVX" " ;/arch:AVX")
+if (NOT ${AVX_FOUND})
+    set(LLAMA_AVX OFF)
+else()
+    set(LLAMA_AVX ON)
+endif()
+
+check_sse("AVX2" " ;/arch:AVX2")
+check_sse("FMA" " ;/arch:AVX2")
+if ((NOT ${AVX2_FOUND}) OR (NOT ${FMA_FOUND}))
+    set(LLAMA_AVX2 OFF)
+else()
+    set(LLAMA_AVX2 ON)
+endif()
+
+check_sse("AVX512" " ;/arch:AVX512")
+if (NOT ${AVX512_FOUND})
+    set(LLAMA_AVX512 OFF)
+else()
+    set(LLAMA_AVX512 ON)
+endif()