]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
cuda : supports running on CPU for GGML_USE_CUBLAS=ON build (#3946)
authorMeng Zhang <redacted>
Tue, 7 Nov 2023 06:49:08 +0000 (22:49 -0800)
committerGitHub <redacted>
Tue, 7 Nov 2023 06:49:08 +0000 (08:49 +0200)
* protyping the idea that supports running on CPU for a GGML_USE_CUBLAS=on build

* doc: add comments to ggml_cublas_loaded()

* fix defined(...)

ggml-cuda.cu
ggml-cuda.h
llama.cpp

index 2d9ffffbf7496677f75a5ad1a8e449803a4fdb77..f87f18802c8f870f5570780dd412df5ea7f7a1aa 100644 (file)
@@ -5790,6 +5790,11 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) {
     CUDA_CHECK(cudaFree(ptr));
 }
 
+static bool g_cublas_loaded = false;
+
+bool ggml_cublas_loaded(void) {
+    return g_cublas_loaded;
+}
 
 void ggml_init_cublas() {
     static bool initialized = false;
@@ -5803,7 +5808,12 @@ void ggml_init_cublas() {
         CUDA_CHECK(cudaDeviceSynchronize());
 #endif
 
-        CUDA_CHECK(cudaGetDeviceCount(&g_device_count));
+        if (cudaGetDeviceCount(&g_device_count) != cudaSuccess) {
+            initialized = true;
+            g_cublas_loaded = false;
+            return;
+        }
+
         GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES);
         int64_t total_vram = 0;
 #if defined(GGML_CUDA_FORCE_MMQ)
@@ -5851,6 +5861,7 @@ void ggml_init_cublas() {
         // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
 
         initialized = true;
+        g_cublas_loaded = true;
     }
 }
 
@@ -7158,6 +7169,8 @@ static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src
 }
 
 bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
+    if (!g_cublas_loaded) return false;
+
     const int64_t ne10 = src1->ne[0];
 
     const int64_t ne0 = dst->ne[0];
@@ -7843,6 +7856,8 @@ void ggml_cuda_free_scratch() {
 }
 
 bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
+    if (!g_cublas_loaded) return false;
+
     ggml_cuda_func_t func;
     const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
         || (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT))
index 57adc9cf34bc5bc4fae4d576b4c2b1572364b8ff..528e66c33a20738ce185744ff5780203c869e4ad 100644 (file)
@@ -17,7 +17,12 @@ extern "C" {
 
 #define GGML_CUDA_MAX_DEVICES       16
 
+// Always success. To check if CUDA is actually loaded, use `ggml_cublas_loaded`.
 GGML_API void   ggml_init_cublas(void);
+
+// Returns `true` if there are available CUDA devices and cublas loads successfully; otherwise, it returns `false`.
+GGML_API bool   ggml_cublas_loaded(void);
+
 GGML_API void * ggml_cuda_host_malloc(size_t size);
 GGML_API void   ggml_cuda_host_free(void * ptr);
 
index e165390005c8501dd509d62f8754831b8ad0051a..d220ff3e9b130c3cf48fd49b26966fc3aa753e04 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -596,19 +596,37 @@ static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph *
 // llama helpers
 //
 
+inline void * llama_host_malloc(size_t n) {
 #ifdef GGML_USE_CUBLAS
-#   define llama_host_malloc(n)  ggml_cuda_host_malloc(n)
-#   define llama_host_free(data) ggml_cuda_host_free(data)
+    if (ggml_cublas_loaded()) {
+        return ggml_cuda_host_malloc(n);
+    } else {
+        return malloc(n);
+    }
 #elif GGML_USE_METAL
-#   define llama_host_malloc(n)  ggml_metal_host_malloc(n)
-#   define llama_host_free(data) ggml_metal_host_free(data)
+    return ggml_metal_host_malloc(n);
 #elif GGML_USE_CPU_HBM
-#   define llama_host_malloc(n)  hbw_malloc(n)
-#   define llama_host_free(data) if (data != NULL) hbw_free(data)
+    return hbw_malloc(n);
 #else
-#   define llama_host_malloc(n)  malloc(n)
-#   define llama_host_free(data) free(data)
+    return malloc(n);
 #endif
+}
+
+inline void llama_host_free(void * ptr) {
+#ifdef GGML_USE_CUBLAS
+    if (ggml_cublas_loaded()) {
+        return ggml_cuda_host_free(ptr);
+    } else {
+        return free(ptr);
+    }
+#elif GGML_USE_METAL
+    return ggml_metal_host_free(ptr);
+#elif GGML_USE_CPU_HBM
+    return hbw_free(ptr);
+#else
+    return free(ptr);
+#endif
+}
 
 #if defined(_WIN32)
 static std::string llama_format_win_err(DWORD err) {
@@ -1200,9 +1218,11 @@ struct llama_kv_cache {
         }
 
 #ifdef GGML_USE_CUBLAS
-        ggml_cuda_free_data(k);
-        ggml_cuda_free_data(v);
-#endif // GGML_USE_CUBLAS
+        if (ggml_cublas_loaded()) {
+            ggml_cuda_free_data(k);
+            ggml_cuda_free_data(v);
+        }
+#endif
     }
 };
 
@@ -1302,11 +1322,15 @@ struct llama_model {
         }
 
 #ifdef GGML_USE_CUBLAS
-        for (size_t i = 0; i < tensors_by_name.size(); ++i) {
-            ggml_cuda_free_data(tensors_by_name[i].second);
+        if (ggml_cublas_loaded()) {
+            for (size_t i = 0; i < tensors_by_name.size(); ++i) {
+                ggml_cuda_free_data(tensors_by_name[i].second);
+            }
+            ggml_cuda_free_scratch();
         }
-        ggml_cuda_free_scratch();
-#elif defined(GGML_USE_CLBLAST)
+#endif
+
+#if defined(GGML_USE_CLBLAST)
         for (size_t i = 0; i < tensors_by_name.size(); ++i) {
             ggml_cl_free_data(tensors_by_name[i].second);
         }
@@ -1418,23 +1442,26 @@ static bool llama_kv_cache_init(
     ggml_set_name(cache.v, "cache_v");
 
     (void) n_gpu_layers;
+
 #ifdef GGML_USE_CUBLAS
-    size_t vram_kv_cache = 0;
+    if (ggml_cublas_loaded()) {
+        size_t vram_kv_cache = 0;
 
-    if (n_gpu_layers > (int)n_layer + 1) {
-        ggml_cuda_assign_buffers_no_scratch(cache.v);
-        LLAMA_LOG_INFO("%s: offloading v cache to GPU\n", __func__);
-        vram_kv_cache += ggml_nbytes(cache.v);
-    }
-    if (n_gpu_layers > (int)n_layer + 2) {
-        ggml_cuda_assign_buffers_no_scratch(cache.k);
-        LLAMA_LOG_INFO("%s: offloading k cache to GPU\n", __func__);
-        vram_kv_cache += ggml_nbytes(cache.k);
-    }
-    if (vram_kv_cache > 0) {
-        LLAMA_LOG_INFO("%s: VRAM kv self = %.2f MB\n", __func__, vram_kv_cache / 1024.0 / 1024.0);
+        if (n_gpu_layers > (int)n_layer + 1) {
+            ggml_cuda_assign_buffers_no_scratch(cache.v);
+            LLAMA_LOG_INFO("%s: offloading v cache to GPU\n", __func__);
+            vram_kv_cache += ggml_nbytes(cache.v);
+        }
+        if (n_gpu_layers > (int)n_layer + 2) {
+            ggml_cuda_assign_buffers_no_scratch(cache.k);
+            LLAMA_LOG_INFO("%s: offloading k cache to GPU\n", __func__);
+            vram_kv_cache += ggml_nbytes(cache.k);
+        }
+        if (vram_kv_cache > 0) {
+            LLAMA_LOG_INFO("%s: VRAM kv self = %.2f MB\n", __func__, vram_kv_cache / 1024.0 / 1024.0);
+        }
     }
-#endif // GGML_USE_CUBLAS
+#endif
 
     return true;
 }
@@ -2521,18 +2548,22 @@ static void llm_load_tensors(
     }
 
     (void) main_gpu;
+
+    enum ggml_backend_type llama_backend_offload = GGML_BACKEND_CPU;
+    enum ggml_backend_type llama_backend_offload_split = GGML_BACKEND_CPU;
+
 #ifdef GGML_USE_CUBLAS
-    LLAMA_LOG_INFO("%s: using " GGML_CUDA_NAME " for GPU acceleration\n", __func__);
-    ggml_cuda_set_main_device(main_gpu);
-#define LLAMA_BACKEND_OFFLOAD       GGML_BACKEND_GPU
-#define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_GPU_SPLIT
+    if (ggml_cublas_loaded()) {
+        LLAMA_LOG_INFO("%s: using " GGML_CUDA_NAME " for GPU acceleration\n", __func__);
+        ggml_cuda_set_main_device(main_gpu);
+
+        llama_backend_offload = GGML_BACKEND_GPU;
+        llama_backend_offload_split = GGML_BACKEND_GPU_SPLIT;
+    }
 #elif defined(GGML_USE_CLBLAST)
-    LLAMA_LOG_INFO("%s: using OpenCL for GPU acceleration\n", __func__);
-#define LLAMA_BACKEND_OFFLOAD       GGML_BACKEND_GPU
-#define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_GPU
-#else
-#define LLAMA_BACKEND_OFFLOAD       GGML_BACKEND_CPU
-#define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_CPU
+        LLAMA_LOG_INFO("%s: using OpenCL for GPU acceleration\n", __func__);
+        llama_backend_offload = GGML_BACKEND_GPU;
+        llama_backend_offload_split = GGML_BACKEND_GPU;
 #endif
 
     // prepare memory for the weights
@@ -2559,12 +2590,12 @@ static void llm_load_tensors(
                             // norm is not performance relevant on its own but keeping it in VRAM reduces data copying
                             // on Windows however this is detrimental unless everything is on the GPU
 #ifndef _WIN32
-                            backend_norm = LLAMA_BACKEND_OFFLOAD;
+                            backend_norm = llama_backend_offload;
 #else
-                            backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
+                            backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
 #endif // _WIN32
 
-                            backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
+                            backend_output = llama_backend_offload_split;
                         } else {
                             backend_norm   = GGML_BACKEND_CPU;
                             backend_output = GGML_BACKEND_CPU;
@@ -2588,8 +2619,8 @@ static void llm_load_tensors(
                     model.layers.resize(n_layer);
 
                     for (uint32_t i = 0; i < n_layer; ++i) {
-                        const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT
-                        const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT
+                        const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT
+                        const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT
 
                         auto & layer = model.layers[i];
 
@@ -2625,12 +2656,12 @@ static void llm_load_tensors(
                             // norm is not performance relevant on its own but keeping it in VRAM reduces data copying
                             // on Windows however this is detrimental unless everything is on the GPU
 #ifndef _WIN32
-                            backend_norm = LLAMA_BACKEND_OFFLOAD;
+                            backend_norm = llama_backend_offload;
 #else
-                            backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
+                            backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
 #endif // _WIN32
 
-                            backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
+                            backend_output = llama_backend_offload_split;
                         } else {
                             backend_norm   = GGML_BACKEND_CPU;
                             backend_output = GGML_BACKEND_CPU;
@@ -2654,8 +2685,8 @@ static void llm_load_tensors(
                     model.layers.resize(n_layer);
 
                     for (uint32_t i = 0; i < n_layer; ++i) {
-                        const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT
-                        const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT
+                        const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT
+                        const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT
 
                         auto & layer = model.layers[i];
 
@@ -2695,12 +2726,12 @@ static void llm_load_tensors(
                             // norm is not performance relevant on its own but keeping it in VRAM reduces data copying
                             // on Windows however this is detrimental unless everything is on the GPU
 #ifndef _WIN32
-                            backend_norm = LLAMA_BACKEND_OFFLOAD;
+                            backend_norm = llama_backend_offload;
 #else
-                            backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
+                            backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
 #endif // _WIN32
 
-                            backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
+                            backend_output = llama_backend_offload_split;
                         } else {
                             backend_norm   = GGML_BACKEND_CPU;
                             backend_output = GGML_BACKEND_CPU;
@@ -2726,8 +2757,8 @@ static void llm_load_tensors(
                     model.layers.resize(n_layer);
 
                     for (uint32_t i = 0; i < n_layer; ++i) {
-                        const ggml_backend_type backend       = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT
-                        const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT
+                        const ggml_backend_type backend       = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT
+                        const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT
 
                         auto & layer = model.layers[i];
 
@@ -2772,12 +2803,12 @@ static void llm_load_tensors(
                             // norm is not performance relevant on its own but keeping it in VRAM reduces data copying
                             // on Windows however this is detrimental unless everything is on the GPU
 #ifndef _WIN32
-                            backend_norm = LLAMA_BACKEND_OFFLOAD;
+                            backend_norm = llama_backend_offload;
 #else
-                            backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
+                            backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
 #endif // _WIN32
 
-                            backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
+                            backend_output = llama_backend_offload_split;
                         } else {
                             backend_norm   = GGML_BACKEND_CPU;
                             backend_output = GGML_BACKEND_CPU;
@@ -2803,8 +2834,8 @@ static void llm_load_tensors(
                     model.layers.resize(n_layer);
 
                     for (uint32_t i = 0; i < n_layer; ++i) {
-                        const ggml_backend_type backend       = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT
-                        const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT
+                        const ggml_backend_type backend       = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT
+                        const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT
 
                         auto & layer = model.layers[i];
 
@@ -2849,12 +2880,12 @@ static void llm_load_tensors(
                             // norm is not performance relevant on its own but keeping it in VRAM reduces data copying
                             // on Windows however this is detrimental unless everything is on the GPU
 #ifndef _WIN32
-                            backend_norm = LLAMA_BACKEND_OFFLOAD;
+                            backend_norm = llama_backend_offload;
 #else
-                            backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
+                            backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
 #endif // _WIN32
 
-                            backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
+                            backend_output = llama_backend_offload_split;
                         } else {
                             backend_norm   = GGML_BACKEND_CPU;
                             backend_output = GGML_BACKEND_CPU;
@@ -2877,8 +2908,8 @@ static void llm_load_tensors(
                     const int i_gpu_start = n_layer - n_gpu_layers;
                     model.layers.resize(n_layer);
                     for (uint32_t i = 0; i < n_layer; ++i) {
-                        const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
-                        const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT;
+                        const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload;
+                        const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split;
                         auto & layer = model.layers[i];
                         layer.attn_norm     = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM,   "weight", i), {n_embd}, backend);
                         layer.attn_norm_b   = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM,   "bias",   i), {n_embd}, backend);
@@ -2915,12 +2946,12 @@ static void llm_load_tensors(
                             // norm is not performance relevant on its own but keeping it in VRAM reduces data copying
                             // on Windows however this is detrimental unless everything is on the GPU
 #ifndef _WIN32
-                            backend_norm = LLAMA_BACKEND_OFFLOAD;
+                            backend_norm = llama_backend_offload;
 #else
-                            backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
+                            backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
 #endif // _WIN32
 
-                            backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
+                            backend_output = llama_backend_offload_split;
                         } else {
                             backend_norm   = GGML_BACKEND_CPU;
                             backend_output = GGML_BACKEND_CPU;
@@ -2946,8 +2977,8 @@ static void llm_load_tensors(
                     model.layers.resize(n_layer);
 
                     for (uint32_t i = 0; i < n_layer; ++i) {
-                        const ggml_backend_type backend       = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT
-                        const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT
+                        const ggml_backend_type backend       = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT
+                        const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT
 
                         auto & layer = model.layers[i];
 
@@ -2993,12 +3024,12 @@ static void llm_load_tensors(
                             // norm is not performance relevant on its own but keeping it in VRAM reduces data copying
                             // on Windows however this is detrimental unless everything is on the GPU
 #ifndef _WIN32
-                            backend_norm = LLAMA_BACKEND_OFFLOAD;
+                            backend_norm = llama_backend_offload;
 #else
-                            backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
+                            backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
 #endif // _WIN32
 
-                            backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
+                            backend_output = llama_backend_offload_split;
                         } else {
                             backend_norm   = GGML_BACKEND_CPU;
                             backend_output = GGML_BACKEND_CPU;
@@ -3022,8 +3053,8 @@ static void llm_load_tensors(
                     model.layers.resize(n_layer);
 
                     for (uint32_t i = 0; i < n_layer; ++i) {
-                        const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT
-                        const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT
+                        const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT
+                        const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT
 
                         auto & layer = model.layers[i];