#set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=native")
#set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mcpu=native")
+# warning flags
+
+if (NOT MSVC)
+ add_compile_options(-Werror=vla)
+endif()
+
# dependencies
set(CMAKE_C_STANDARD 11)
+#define _USE_MATH_DEFINES // for M_PI
+
#include "common.h"
// third-party utilities
#include <codecvt>
#include <sstream>
-#ifndef M_PI
-#define M_PI 3.14159265358979323846
-#endif
-
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
+#define _USE_MATH_DEFINES // for M_PI
+
#include "ggml.h"
#include "common.h"
# compiler flags
if (NOT MSVC)
- set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Werror=vla")
#set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fno-math-errno -ffinite-math-only -funsafe-math-optimizations")
endif()
if (F16C_M MATCHES "f16c")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mf16c")
endif()
+ elseif (MSVC)
+ if (GGML_AVX512)
+ set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /arch:AVX512")
+ # MSVC has no compile-time flags enabling specific
+ # AVX512 extensions, neither it defines the
+ # macros corresponding to the extensions.
+ # Do it manually.
+ if (GGML_AVX512_VBMI)
+ add_compile_definitions(__AVX512VBMI__)
+ endif()
+ if (GGML_AVX512_VNNI)
+ add_compile_definitions(__AVX512VNNI__)
+ endif()
+ elseif (GGML_AVX2)
+ set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /arch:AVX2")
+ elseif (GGML_AVX)
+ set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /arch:AVX")
+ endif()
else()
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mfma -mf16c -mavx -mavx2")
endif()
#if defined(__AVX2__)
{
assert(QK == 64);
- const int QK8 = QK/8;
+ enum { QK8 = QK/8 };
__m256 srcv[QK8];
__m256 minv[QK8];
#if defined(__AVX2__)
{
assert(QK == 64);
- const int QK8 = QK/8;
+ enum { QK8 = QK/8 };
__m256 srcv [QK8];
__m256 asrcv[QK8];
}
for (int l = 0; l < QK8; l++) {
- asrcv[l] = _mm256_and_ps(srcv[l], (__m256) _mm256_set1_epi32(0x7fffffff));
+ asrcv[l] = _mm256_and_ps(srcv[l], _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffff)));
}
#if defined(__AVX2__)
{
- const int QK8 = 4;
+ enum { QK8 = 4 };
__m256 srcv [QK8];
__m256 asrcv[QK8];
}
for (int l = 0; l < QK8; l++) {
- asrcv[l] = _mm256_and_ps(srcv[l], (__m256) _mm256_set1_epi32(0x7fffffff));
+ asrcv[l] = _mm256_and_ps(srcv[l], _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffff)));
}
for (int l = 0; l < QK8/2; l++) {