]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Enable build with CUDA 11.0 (make) (#3132)
authorVlad <redacted>
Sat, 16 Sep 2023 14:55:43 +0000 (17:55 +0300)
committerGitHub <redacted>
Sat, 16 Sep 2023 14:55:43 +0000 (16:55 +0200)
* CUDA 11.0 fixes

* Cleaner CUDA/host flags separation

Also renamed GGML_ASSUME into GGML_CUDA_ASSUME

Makefile
ggml-cuda.cu

index a1438b80d842ad377d11ad5bf075db369f9c969a..73a9fb17a3ebedf3c7d759b112870cdccef67a47 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -95,16 +95,19 @@ CXXV := $(shell $(CXX) --version | head -n 1)
 #
 
 # keep standard at C11 and C++11
+MK_CPPFLAGS = -I. -Icommon
+MK_CFLAGS   = -std=c11   -fPIC
+MK_CXXFLAGS = -std=c++11 -fPIC
+
 # -Ofast tends to produce faster code, but may not be available for some compilers.
 ifdef LLAMA_FAST
-OPT = -Ofast
+MK_CFLAGS        += -Ofast
+MK_HOST_CXXFLAGS += -Ofast
+MK_CUDA_CXXFLAGS += -O3
 else
-OPT = -O3
+MK_CFLAGS        += -O3
+MK_CXXFLAGS      += -O3
 endif
-MK_CPPFLAGS = -I. -Icommon
-MK_CFLAGS   = $(OPT) -std=c11   -fPIC
-MK_CXXFLAGS = $(OPT) -std=c++11 -fPIC
-MK_LDFLAGS  =
 
 # clock_gettime came in POSIX.1b (1993)
 # CLOCK_MONOTONIC came in POSIX.1-2001 / SUSv3 as optional
@@ -232,7 +235,7 @@ ifndef RISCV
 ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686 amd64))
        # Use all CPU extensions that are available:
        MK_CFLAGS   += -march=native -mtune=native
-       MK_CXXFLAGS += -march=native -mtune=native
+       MK_HOST_CXXFLAGS += -march=native -mtune=native
 
        # Usage AVX-only
        #MK_CFLAGS   += -mfma -mf16c -mavx
@@ -372,7 +375,7 @@ ifdef LLAMA_CUDA_CCBIN
        NVCCFLAGS += -ccbin $(LLAMA_CUDA_CCBIN)
 endif
 ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
-       $(NVCC) $(NVCCFLAGS) $(subst -Ofast,-O3,$(CXXFLAGS)) -Wno-pedantic -c $< -o $@
+       $(NVCC) $(NVCCFLAGS) -Wno-pedantic -c $< -o $@
 endif # LLAMA_CUBLAS
 
 ifdef LLAMA_CLBLAST
@@ -440,23 +443,30 @@ k_quants.o: k_quants.c k_quants.h
 endif # LLAMA_NO_K_QUANTS
 
 # combine build flags with cmdline overrides
-override CFLAGS   := $(MK_CPPFLAGS) $(CPPFLAGS) $(MK_CFLAGS) $(CFLAGS)
-override CXXFLAGS := $(MK_CPPFLAGS) $(CPPFLAGS) $(MK_CXXFLAGS) $(CXXFLAGS)
-override LDFLAGS  := $(MK_LDFLAGS) $(LDFLAGS)
+override CFLAGS        := $(MK_CPPFLAGS) $(CPPFLAGS) $(MK_CFLAGS) $(CFLAGS)
+override CXXFLAGS      := $(MK_CPPFLAGS) $(CPPFLAGS) $(MK_CXXFLAGS) $(CXXFLAGS)
+override CUDA_CXXFLAGS := $(MK_CUDA_CXXFLAGS) $(CUDA_CXXFLAGS)
+override HOST_CXXFLAGS := $(MK_HOST_CXXFLAGS) $(HOST_CXXFLAGS)
+override LDFLAGS       := $(MK_LDFLAGS) $(LDFLAGS)
+
+# save CXXFLAGS before we add host-only options
+NVCCFLAGS := $(NVCCFLAGS) $(CXXFLAGS) $(CUDA_CXXFLAGS) -Wno-pedantic -Xcompiler "$(HOST_CXXFLAGS)"
+override CXXFLAGS += $(HOST_CXXFLAGS)
 
 #
 # Print build information
 #
 
 $(info I llama.cpp build info: )
-$(info I UNAME_S:  $(UNAME_S))
-$(info I UNAME_P:  $(UNAME_P))
-$(info I UNAME_M:  $(UNAME_M))
-$(info I CFLAGS:   $(CFLAGS))
-$(info I CXXFLAGS: $(CXXFLAGS))
-$(info I LDFLAGS:  $(LDFLAGS))
-$(info I CC:       $(CCV))
-$(info I CXX:      $(CXXV))
+$(info I UNAME_S:   $(UNAME_S))
+$(info I UNAME_P:   $(UNAME_P))
+$(info I UNAME_M:   $(UNAME_M))
+$(info I CFLAGS:    $(CFLAGS))
+$(info I CXXFLAGS:  $(CXXFLAGS))
+$(info I NVCCFLAGS: $(NVCCFLAGS))
+$(info I LDFLAGS:   $(LDFLAGS))
+$(info I CC:        $(CCV))
+$(info I CXX:       $(CXXV))
 $(info )
 
 #
index fe7332b2a3580cb6a6b8c5f1ff04e3c8cc8e0979..dbe53ceece38ab0e05dbc4b121ab26853eb7d67a 100644 (file)
@@ -61,7 +61,7 @@
 #define cudaStreamCreateWithFlags hipStreamCreateWithFlags
 #define cudaStreamNonBlocking hipStreamNonBlocking
 #define cudaStreamSynchronize hipStreamSynchronize
-#define cudaStreamWaitEvent(stream, event) hipStreamWaitEvent(stream, event, 0)
+#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
 #define cudaStream_t hipStream_t
 #define cudaSuccess hipSuccess
 #else
@@ -190,6 +190,12 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
     } while (0)
 #endif // CUDART_VERSION >= 11
 
+#if CUDART_VERSION >= 11100
+#define GGML_CUDA_ASSUME(x) __builtin_assume(x)
+#else
+#define GGML_CUDA_ASSUME(x)
+#endif // CUDART_VERSION >= 11100
+
 #ifdef GGML_CUDA_F16
 typedef half dfloat; // dequantize float
 typedef half2 dfloat2;
@@ -2145,10 +2151,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
 
-    __builtin_assume(i_offset >= 0);
-    __builtin_assume(i_offset <  nwarps);
-    __builtin_assume(k >= 0);
-    __builtin_assume(k <  WARP_SIZE);
+    GGML_CUDA_ASSUME(i_offset >= 0);
+    GGML_CUDA_ASSUME(i_offset <  nwarps);
+    GGML_CUDA_ASSUME(k >= 0);
+    GGML_CUDA_ASSUME(k <  WARP_SIZE);
 
     const int kbx  = k / QI4_0;
     const int kqsx = k % QI4_0;
@@ -2239,10 +2245,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
 
-    __builtin_assume(i_offset >= 0);
-    __builtin_assume(i_offset <  nwarps);
-    __builtin_assume(k >= 0);
-    __builtin_assume(k <  WARP_SIZE);
+    GGML_CUDA_ASSUME(i_offset >= 0);
+    GGML_CUDA_ASSUME(i_offset <  nwarps);
+    GGML_CUDA_ASSUME(k >= 0);
+    GGML_CUDA_ASSUME(k <  WARP_SIZE);
 
     const int kbx  = k / QI4_1;
     const int kqsx = k % QI4_1;
@@ -2331,10 +2337,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
 
-    __builtin_assume(i_offset >= 0);
-    __builtin_assume(i_offset <  nwarps);
-    __builtin_assume(k >= 0);
-    __builtin_assume(k <  WARP_SIZE);
+    GGML_CUDA_ASSUME(i_offset >= 0);
+    GGML_CUDA_ASSUME(i_offset <  nwarps);
+    GGML_CUDA_ASSUME(k >= 0);
+    GGML_CUDA_ASSUME(k <  WARP_SIZE);
 
     const int kbx  = k / QI5_0;
     const int kqsx = k % QI5_0;
@@ -2445,10 +2451,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
 
-    __builtin_assume(i_offset >= 0);
-    __builtin_assume(i_offset < nwarps);
-    __builtin_assume(k >= 0);
-    __builtin_assume(k <  WARP_SIZE);
+    GGML_CUDA_ASSUME(i_offset >= 0);
+    GGML_CUDA_ASSUME(i_offset < nwarps);
+    GGML_CUDA_ASSUME(k >= 0);
+    GGML_CUDA_ASSUME(k <  WARP_SIZE);
 
     const int kbx  = k / QI5_1;
     const int kqsx = k % QI5_1;
@@ -2551,10 +2557,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
 
-    __builtin_assume(i_offset >= 0);
-    __builtin_assume(i_offset <  nwarps);
-    __builtin_assume(k >= 0);
-    __builtin_assume(k <  WARP_SIZE);
+    GGML_CUDA_ASSUME(i_offset >= 0);
+    GGML_CUDA_ASSUME(i_offset <  nwarps);
+    GGML_CUDA_ASSUME(k >= 0);
+    GGML_CUDA_ASSUME(k <  WARP_SIZE);
 
     const int kbx  = k / QI8_0;
     const int kqsx = k % QI8_0;
@@ -2642,10 +2648,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
 
-    __builtin_assume(i_offset >= 0);
-    __builtin_assume(i_offset <  nwarps);
-    __builtin_assume(k >= 0);
-    __builtin_assume(k <  WARP_SIZE);
+    GGML_CUDA_ASSUME(i_offset >= 0);
+    GGML_CUDA_ASSUME(i_offset <  nwarps);
+    GGML_CUDA_ASSUME(k >= 0);
+    GGML_CUDA_ASSUME(k <  WARP_SIZE);
 
     const int kbx  = k / QI2_K;
     const int kqsx = k % QI2_K;
@@ -2763,10 +2769,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
 
-    __builtin_assume(i_offset >= 0);
-    __builtin_assume(i_offset <  nwarps);
-    __builtin_assume(k >= 0);
-    __builtin_assume(k <  WARP_SIZE);
+    GGML_CUDA_ASSUME(i_offset >= 0);
+    GGML_CUDA_ASSUME(i_offset <  nwarps);
+    GGML_CUDA_ASSUME(k >= 0);
+    GGML_CUDA_ASSUME(k <  WARP_SIZE);
 
     const int kbx  = k / QI3_K;
     const int kqsx = k % QI3_K;
@@ -2981,10 +2987,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
 
-    __builtin_assume(i_offset >= 0);
-    __builtin_assume(i_offset <  nwarps);
-    __builtin_assume(k >= 0);
-    __builtin_assume(k <  WARP_SIZE);
+    GGML_CUDA_ASSUME(i_offset >= 0);
+    GGML_CUDA_ASSUME(i_offset <  nwarps);
+    GGML_CUDA_ASSUME(k >= 0);
+    GGML_CUDA_ASSUME(k <  WARP_SIZE);
 
     const int kbx  = k / QI4_K; // == 0 if QK_K == 256
     const int kqsx = k % QI4_K; // == k if QK_K == 256
@@ -3162,10 +3168,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
 
-    __builtin_assume(i_offset >= 0);
-    __builtin_assume(i_offset <  nwarps);
-    __builtin_assume(k >= 0);
-    __builtin_assume(k <  WARP_SIZE);
+    GGML_CUDA_ASSUME(i_offset >= 0);
+    GGML_CUDA_ASSUME(i_offset <  nwarps);
+    GGML_CUDA_ASSUME(k >= 0);
+    GGML_CUDA_ASSUME(k <  WARP_SIZE);
 
     const int kbx  = k / QI5_K; // == 0 if QK_K == 256
     const int kqsx = k % QI5_K; // == k if QK_K == 256
@@ -3291,10 +3297,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
 
-    __builtin_assume(i_offset >= 0);
-    __builtin_assume(i_offset <  nwarps);
-    __builtin_assume(k >= 0);
-    __builtin_assume(k <  WARP_SIZE);
+    GGML_CUDA_ASSUME(i_offset >= 0);
+    GGML_CUDA_ASSUME(i_offset <  nwarps);
+    GGML_CUDA_ASSUME(k >= 0);
+    GGML_CUDA_ASSUME(k <  WARP_SIZE);
 
     const int kbx  = k / QI6_K; // == 0 if QK_K == 256
     const int kqsx = k % QI6_K; // == k if QK_K == 256
@@ -6408,7 +6414,7 @@ static void ggml_cuda_op_mul_mat(
 
             // wait for main GPU data if necessary
             if (split && (id != g_main_device || is != 0)) {
-                CUDA_CHECK(cudaStreamWaitEvent(stream, src0_extra->events[g_main_device][0]));
+                CUDA_CHECK(cudaStreamWaitEvent(stream, src0_extra->events[g_main_device][0], 0));
             }
 
             for (int64_t i0 = 0; i0 < ne13*ne12; ++i0) {
@@ -6530,7 +6536,7 @@ static void ggml_cuda_op_mul_mat(
         CUDA_CHECK(ggml_cuda_set_device(g_main_device));
         for (int64_t id = 0; id < g_device_count; ++id) {
             for (int64_t is = 0; is < is_max; ++is) {
-                CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams[g_main_device][0], src0_extra->events[id][is]));
+                CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams[g_main_device][0], src0_extra->events[id][is], 0));
             }
         }
     }