]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : remove LLAMA_MAX_DEVICES and LLAMA_SUPPORTS_GPU_OFFLOAD (#5240)
authorGeorgi Gerganov <redacted>
Wed, 31 Jan 2024 15:30:17 +0000 (17:30 +0200)
committerGitHub <redacted>
Wed, 31 Jan 2024 15:30:17 +0000 (17:30 +0200)
* llama : remove LLAMA_MAX_DEVICES from llama.h

ggml-ci

* Update llama.cpp

Co-authored-by: slaren <redacted>
* server : remove LLAMA_MAX_DEVICES

ggml-ci

* llama : remove LLAMA_SUPPORTS_GPU_OFFLOAD

ggml-ci

* train : remove LLAMA_SUPPORTS_GPU_OFFLOAD

* readme : add deprecation notice

* readme : change deprecation notice to "remove" and fix url

* llama : remove gpu includes from llama.h

ggml-ci

---------

Co-authored-by: slaren <redacted>
README.md
common/common.cpp
common/common.h
common/train.cpp
examples/batched-bench/batched-bench.cpp
examples/llama-bench/llama-bench.cpp
examples/server/server.cpp
llama.cpp
llama.h

index 7746cb5100ebb1816edec0d79c704a7a98843e46..e6ed1d4294833402bab44ed993cda7ba89bf227d 100644 (file)
--- a/README.md
+++ b/README.md
@@ -10,7 +10,8 @@ Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++
 
 ### Hot topics
 
-- ⚠️ Incoming backends: https://github.com/ggerganov/llama.cpp/discussions/5138
+- Remove LLAMA_MAX_DEVICES and LLAMA_SUPPORTS_GPU_OFFLOAD: https://github.com/ggerganov/llama.cpp/pull/5240
+- Incoming backends: https://github.com/ggerganov/llama.cpp/discussions/5138
   - [SYCL backend](README-sycl.md) is ready (1/28/2024), support Linux/Windows in Intel GPUs (iGPU, Arc/Flex/Max series)
 - New SOTA quantized models, including pure 2-bits: https://huggingface.co/ikawrakow
 - Collecting Apple Silicon performance stats:
index 9d976c7c82ff4ad7b892fbd2256a3a940e5d7301..ce739b15c1586239fe8f1ad70ff2380fa27f0345 100644 (file)
@@ -583,20 +583,20 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
                 break;
             }
             params.n_gpu_layers = std::stoi(argv[i]);
-#ifndef LLAMA_SUPPORTS_GPU_OFFLOAD
-            fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
-            fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
-#endif
+            if (!llama_supports_gpu_offload()) {
+                fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
+                fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
+            }
         } else if (arg == "--gpu-layers-draft" || arg == "-ngld" || arg == "--n-gpu-layers-draft") {
             if (++i >= argc) {
                 invalid_param = true;
                 break;
             }
             params.n_gpu_layers_draft = std::stoi(argv[i]);
-#ifndef LLAMA_SUPPORTS_GPU_OFFLOAD
-            fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers-draft option will be ignored\n");
-            fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
-#endif
+            if (!llama_supports_gpu_offload()) {
+                fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers-draft option will be ignored\n");
+                fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
+            }
         } else if (arg == "--main-gpu" || arg == "-mg") {
             if (++i >= argc) {
                 invalid_param = true;
@@ -637,11 +637,11 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
             const std::regex regex{R"([,/]+)"};
             std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
             std::vector<std::string> split_arg{it, {}};
-            if (split_arg.size() >= LLAMA_MAX_DEVICES) {
+            if (split_arg.size() >= llama_max_devices()) {
                 invalid_param = true;
                 break;
             }
-            for (size_t i = 0; i < LLAMA_MAX_DEVICES; ++i) {
+            for (size_t i = 0; i < llama_max_devices(); ++i) {
                 if (i < split_arg.size()) {
                     params.tensor_split[i] = std::stof(split_arg[i]);
                 } else {
@@ -989,30 +989,30 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     printf("  -cb, --cont-batching  enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
     printf("  --mmproj MMPROJ_FILE  path to a multimodal projector file for LLaVA. see examples/llava/README.md\n");
     printf("  --image IMAGE_FILE    path to an image file. use with multimodal models\n");
-    if (llama_mlock_supported()) {
+    if (llama_supports_mlock()) {
         printf("  --mlock               force system to keep model in RAM rather than swapping or compressing\n");
     }
-    if (llama_mmap_supported()) {
+    if (llama_supports_mmap()) {
         printf("  --no-mmap             do not memory-map model (slower load but may reduce pageouts if not using mlock)\n");
     }
     printf("  --numa                attempt optimizations that help on some NUMA systems\n");
     printf("                        if run without this previously, it is recommended to drop the system page cache before using this\n");
     printf("                        see https://github.com/ggerganov/llama.cpp/issues/1437\n");
-#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
-    printf("  -ngl N, --n-gpu-layers N\n");
-    printf("                        number of layers to store in VRAM\n");
-    printf("  -ngld N, --n-gpu-layers-draft N\n");
-    printf("                        number of layers to store in VRAM for the draft model\n");
-    printf("  -sm SPLIT_MODE, --split-mode SPLIT_MODE\n");
-    printf("                        how to split the model across multiple GPUs, one of:\n");
-    printf("                          - none: use one GPU only\n");
-    printf("                          - layer (default): split layers and KV across GPUs\n");
-    printf("                          - row: split rows across GPUs\n");
-    printf("  -ts SPLIT, --tensor-split SPLIT\n");
-    printf("                        fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1\n");
-    printf("  -mg i, --main-gpu i   the GPU to use for the model (with split-mode = none),\n");
-    printf("                        or for intermediate results and KV (with split-mode = row) (default: %d)\n", params.main_gpu);
-#endif // LLAMA_SUPPORTS_GPU_OFFLOAD
+    if (llama_supports_gpu_offload()) {
+        printf("  -ngl N, --n-gpu-layers N\n");
+        printf("                        number of layers to store in VRAM\n");
+        printf("  -ngld N, --n-gpu-layers-draft N\n");
+        printf("                        number of layers to store in VRAM for the draft model\n");
+        printf("  -sm SPLIT_MODE, --split-mode SPLIT_MODE\n");
+        printf("                        how to split the model across multiple GPUs, one of:\n");
+        printf("                          - none: use one GPU only\n");
+        printf("                          - layer (default): split layers and KV across GPUs\n");
+        printf("                          - row: split rows across GPUs\n");
+        printf("  -ts SPLIT, --tensor-split SPLIT\n");
+        printf("                        fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1\n");
+        printf("  -mg i, --main-gpu i   the GPU to use for the model (with split-mode = none),\n");
+        printf("                        or for intermediate results and KV (with split-mode = row) (default: %d)\n", params.main_gpu);
+    }
     printf("  --verbose-prompt      print a verbose prompt before generation (default: %s)\n", params.verbose_prompt ? "true" : "false");
     printf("  --no-display-prompt   don't print prompt at generation (default: %s)\n", !params.display_prompt ? "true" : "false");
     printf("  -gan N, --grp-attn-n N\n");
@@ -1651,7 +1651,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
     fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false");
     fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);
 
-    const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + LLAMA_MAX_DEVICES);
+    const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices());
     dump_vector_float_yaml(stream, "tensor_split", tensor_split_vector);
 
     fprintf(stream, "tfs: %f # default: 1.0\n", sparams.tfs_z);
index 214a379b57d1ba7f8f5cf1de0cbecfcc00c7003f..24a99d7281089e284a52c4d135da42ec60e4e688 100644 (file)
@@ -43,40 +43,40 @@ extern char const *LLAMA_BUILD_TARGET;
 int32_t get_num_physical_cores();
 
 struct gpt_params {
-    uint32_t seed                           = -1;    // RNG seed
-
-    int32_t n_threads                       = get_num_physical_cores();
-    int32_t n_threads_draft                 = -1;
-    int32_t n_threads_batch                 = -1;    // number of threads to use for batch processing (-1 = use n_threads)
-    int32_t n_threads_batch_draft           = -1;
-    int32_t n_predict                       = -1;    // new tokens to predict
-    int32_t n_ctx                           = 512;   // context size
-    int32_t n_batch                         = 512;   // batch size for prompt processing (must be >=32 to use BLAS)
-    int32_t n_keep                          = 0;     // number of tokens to keep from initial prompt
-    int32_t n_draft                         = 8;     // number of tokens to draft during speculative decoding
-    int32_t n_chunks                        = -1;    // max number of chunks to process (-1 = unlimited)
-    int32_t n_parallel                      = 1;     // number of parallel sequences to decode
-    int32_t n_sequences                     = 1;     // number of sequences to decode
-    float   p_accept                        = 0.5f;  // speculative decoding accept probability
-    float   p_split                         = 0.1f;  // speculative decoding split probability
-    int32_t n_gpu_layers                    = -1;    // number of layers to store in VRAM (-1 - use default)
-    int32_t n_gpu_layers_draft              = -1;    // number of layers to store in VRAM for the draft model (-1 - use default)
-    llama_split_mode split_mode             = LLAMA_SPLIT_LAYER; // how to split the model across GPUs
-    int32_t main_gpu                        = 0;     // the GPU that is used for scratch and small tensors
-    float   tensor_split[LLAMA_MAX_DEVICES] = {0};   // how split tensors should be distributed across GPUs
-    int32_t n_beams                         = 0;     // if non-zero then use beam search of given width.
-    int32_t grp_attn_n                      = 1;     // group-attention factor
-    int32_t grp_attn_w                      = 512;   // group-attention width
-    int32_t n_print                         = -1;    // print token count every n tokens (-1 = disabled)
-    float   rope_freq_base                  = 0.0f;  // RoPE base frequency
-    float   rope_freq_scale                 = 0.0f;  // RoPE frequency scaling factor
-    float   yarn_ext_factor                 = -1.0f; // YaRN extrapolation mix factor
-    float   yarn_attn_factor                = 1.0f;  // YaRN magnitude scaling factor
-    float   yarn_beta_fast                  = 32.0f; // YaRN low correction dim
-    float   yarn_beta_slow                  = 1.0f;  // YaRN high correction dim
-    int32_t yarn_orig_ctx                   = 0;     // YaRN original context length
-    int8_t  rope_scaling_type               = LLAMA_ROPE_SCALING_UNSPECIFIED; // TODO: better to be int32_t for alignment
-                                                                              //       pinging @cebtenzzre
+    uint32_t seed                 = -1;    // RNG seed
+
+    int32_t n_threads             = get_num_physical_cores();
+    int32_t n_threads_draft       = -1;
+    int32_t n_threads_batch       = -1;    // number of threads to use for batch processing (-1 = use n_threads)
+    int32_t n_threads_batch_draft = -1;
+    int32_t n_predict             = -1;    // new tokens to predict
+    int32_t n_ctx                 = 512;   // context size
+    int32_t n_batch               = 512;   // batch size for prompt processing (must be >=32 to use BLAS)
+    int32_t n_keep                = 0;     // number of tokens to keep from initial prompt
+    int32_t n_draft               = 8;     // number of tokens to draft during speculative decoding
+    int32_t n_chunks              = -1;    // max number of chunks to process (-1 = unlimited)
+    int32_t n_parallel            = 1;     // number of parallel sequences to decode
+    int32_t n_sequences           = 1;     // number of sequences to decode
+    float   p_accept              = 0.5f;  // speculative decoding accept probability
+    float   p_split               = 0.1f;  // speculative decoding split probability
+    int32_t n_gpu_layers          = -1;    // number of layers to store in VRAM (-1 - use default)
+    int32_t n_gpu_layers_draft    = -1;    // number of layers to store in VRAM for the draft model (-1 - use default)
+    llama_split_mode split_mode   = LLAMA_SPLIT_LAYER; // how to split the model across GPUs
+    int32_t main_gpu              = 0;     // the GPU that is used for scratch and small tensors
+    float   tensor_split[128]     = {0};   // how split tensors should be distributed across GPUs
+    int32_t n_beams               = 0;     // if non-zero then use beam search of given width.
+    int32_t grp_attn_n            = 1;     // group-attention factor
+    int32_t grp_attn_w            = 512;   // group-attention width
+    int32_t n_print               = -1;    // print token count every n tokens (-1 = disabled)
+    float   rope_freq_base        = 0.0f;  // RoPE base frequency
+    float   rope_freq_scale       = 0.0f;  // RoPE frequency scaling factor
+    float   yarn_ext_factor       = -1.0f; // YaRN extrapolation mix factor
+    float   yarn_attn_factor      = 1.0f;  // YaRN magnitude scaling factor
+    float   yarn_beta_fast        = 32.0f; // YaRN low correction dim
+    float   yarn_beta_slow        = 1.0f;  // YaRN high correction dim
+    int32_t yarn_orig_ctx         = 0;     // YaRN original context length
+    int8_t  rope_scaling_type     = LLAMA_ROPE_SCALING_UNSPECIFIED; // TODO: better to be int32_t for alignment
+                                                                    //       pinging @cebtenzzre
 
     // // sampling parameters
     struct llama_sampling_params sparams;
index e6f2f7a2fbbfd3086da59de76381b1d81aeb89e7..e4c3d5df618187228920e8e5db4f4047831932b2 100644 (file)
@@ -1363,12 +1363,12 @@ bool consume_common_train_arg(
                 *invalid_param = true;
                 return true;
             }
-#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
-            params->n_gpu_layers = std::stoi(argv[i]);
-#else
-            fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
-            fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
-#endif
+            if (llama_supports_gpu_offload()) {
+                params->n_gpu_layers = std::stoi(argv[i]);
+            } else {
+                fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
+                fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
+            }
     } else if (arg == "-h" || arg == "--help") {
         params->print_usage = true;
         return true;
index 7924db267401c453c7bf8687a506835abaaa32ad..b52d684578ceb52a0b0c0b9733b001331abb8d31 100644 (file)
@@ -88,7 +88,7 @@ int main(int argc, char ** argv) {
 
     llama_model_params model_params = llama_model_default_params();
 
-    const std::vector<float> t_split (LLAMA_MAX_DEVICES, 0.0f);
+    const std::vector<float> t_split(llama_max_devices(), 0.0f);
 
     model_params.n_gpu_layers = n_gpu_layers;
     model_params.tensor_split = t_split.data();
index 542cc7bb8ea0de5a6a438d7856669e0ac4f89b2c..c5a6f744e177f813bf280605f1db0d9220123bf2 100644 (file)
@@ -160,7 +160,7 @@ struct cmd_params {
     std::vector<int> main_gpu;
     std::vector<bool> no_kv_offload;
     std::vector<bool> mul_mat_q;
-    std::vector<std::array<float, LLAMA_MAX_DEVICES>> tensor_split;
+    std::vector<std::vector<float>> tensor_split;
     int reps;
     bool verbose;
     output_formats output_format;
@@ -179,7 +179,7 @@ static const cmd_params cmd_params_defaults = {
     /* main_gpu      */ {0},
     /* no_kv_offload */ {false},
     /* mul_mat_q     */ {true},
-    /* tensor_split  */ {{}},
+    /* tensor_split  */ {std::vector<float>(llama_max_devices(), 0.0f)},
     /* reps          */ 5,
     /* verbose       */ false,
     /* output_format */ MARKDOWN
@@ -380,10 +380,10 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
                 const std::regex regex{R"([;/]+)"};
                 std::sregex_token_iterator it{ts.begin(), ts.end(), regex, -1};
                 std::vector<std::string> split_arg{it, {}};
-                GGML_ASSERT(split_arg.size() <= LLAMA_MAX_DEVICES);
+                GGML_ASSERT(split_arg.size() <= llama_max_devices());
 
-                std::array<float, LLAMA_MAX_DEVICES> tensor_split;
-                for (size_t i = 0; i < LLAMA_MAX_DEVICES; ++i) {
+                std::vector<float> tensor_split(llama_max_devices());
+                for (size_t i = 0; i < llama_max_devices(); ++i) {
                     if (i < split_arg.size()) {
                         tensor_split[i] = std::stof(split_arg[i]);
                     } else {
@@ -459,7 +459,7 @@ struct cmd_params_instance {
     int main_gpu;
     bool no_kv_offload;
     bool mul_mat_q;
-    std::array<float, LLAMA_MAX_DEVICES> tensor_split;
+    std::vector<float> tensor_split;
 
     llama_model_params to_llama_mparams() const {
         llama_model_params mparams = llama_model_default_params();
@@ -582,7 +582,7 @@ struct test {
     int main_gpu;
     bool no_kv_offload;
     bool mul_mat_q;
-    std::array<float, LLAMA_MAX_DEVICES> tensor_split;
+    std::vector<float> tensor_split;
     int n_prompt;
     int n_gen;
     std::string test_time;
@@ -704,7 +704,7 @@ struct test {
     std::vector<std::string> get_values() const {
         std::string tensor_split_str;
         int max_nonzero = 0;
-        for (int i = 0; i < LLAMA_MAX_DEVICES; i++) {
+        for (size_t i = 0; i < llama_max_devices(); i++) {
             if (tensor_split[i] > 0) {
                 max_nonzero = i;
             }
index 21bdce8edb7807f6999d2ef456c944127ae46032..ea77125eac99d2aaa00a1bc51a3b974948084c16 100644 (file)
@@ -1789,28 +1789,28 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
     printf("  -b N, --batch-size N      batch size for prompt processing (default: %d)\n", params.n_batch);
     printf("  --memory-f32              use f32 instead of f16 for memory key+value (default: disabled)\n");
     printf("                            not recommended: doubles context memory required and no measurable increase in quality\n");
-    if (llama_mlock_supported())
+    if (llama_supports_mlock())
     {
         printf("  --mlock                   force system to keep model in RAM rather than swapping or compressing\n");
     }
-    if (llama_mmap_supported())
+    if (llama_supports_mmap())
     {
         printf("  --no-mmap                 do not memory-map model (slower load but may reduce pageouts if not using mlock)\n");
     }
     printf("  --numa                    attempt optimizations that help on some NUMA systems\n");
-#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
-    printf("  -ngl N, --n-gpu-layers N\n");
-    printf("                            number of layers to store in VRAM\n");
-    printf("  -sm SPLIT_MODE, --split-mode SPLIT_MODE\n");
-    printf("                            how to split the model across multiple GPUs, one of:\n");
-    printf("                              - none: use one GPU only\n");
-    printf("                              - layer (default): split layers and KV across GPUs\n");
-    printf("                              - row: split rows across GPUs\n");
-    printf("  -ts SPLIT --tensor-split SPLIT\n");
-    printf("                            fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1\n");
-    printf("  -mg i, --main-gpu i       the GPU to use for the model (with split-mode = none),\n");
-    printf("                            or for intermediate results and KV (with split-mode = row)\n");
-#endif
+    if (llama_supports_gpu_offload()) {
+        printf("  -ngl N, --n-gpu-layers N\n");
+        printf("                            number of layers to store in VRAM\n");
+        printf("  -sm SPLIT_MODE, --split-mode SPLIT_MODE\n");
+        printf("                            how to split the model across multiple GPUs, one of:\n");
+        printf("                              - none: use one GPU only\n");
+        printf("                              - layer (default): split layers and KV across GPUs\n");
+        printf("                              - row: split rows across GPUs\n");
+        printf("  -ts SPLIT --tensor-split SPLIT\n");
+        printf("                            fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1\n");
+        printf("  -mg i, --main-gpu i       the GPU to use for the model (with split-mode = none),\n");
+        printf("                            or for intermediate results and KV (with split-mode = row)\n");
+    }
     printf("  -m FNAME, --model FNAME\n");
     printf("                            model path (default: %s)\n", params.model.c_str());
     printf("  -a ALIAS, --alias ALIAS\n");
@@ -2066,13 +2066,13 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
                 invalid_param = true;
                 break;
             }
-#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
-            params.n_gpu_layers = std::stoi(argv[i]);
-#else
-            LOG_WARNING("Not compiled with GPU offload support, --n-gpu-layers option will be ignored. "
+            if (llama_supports_gpu_offload()) {
+                params.n_gpu_layers = std::stoi(argv[i]);
+            } else {
+                LOG_WARNING("Not compiled with GPU offload support, --n-gpu-layers option will be ignored. "
                         "See main README.md for information on enabling GPU BLAS support",
                         {{"n_gpu_layers", params.n_gpu_layers}});
-#endif
+            }
         }
         else if (arg == "--split-mode" || arg == "-sm")
         {
@@ -2115,9 +2115,9 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
             const std::regex regex{R"([,/]+)"};
             std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
             std::vector<std::string> split_arg{it, {}};
-            GGML_ASSERT(split_arg.size() <= LLAMA_MAX_DEVICES);
+            GGML_ASSERT(split_arg.size() <= llama_max_devices());
 
-            for (size_t i_device = 0; i_device < LLAMA_MAX_DEVICES; ++i_device)
+            for (size_t i_device = 0; i_device < llama_max_devices(); ++i_device)
             {
                 if (i_device < split_arg.size())
                 {
index bb23689fac0e04aa711cfd7eee010e680ab255db..9b249ba9cde162bab0d0491ba4d87d475bd0276a 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -10090,18 +10090,45 @@ struct llama_model_quantize_params llama_model_quantize_default_params() {
     return result;
 }
 
-int32_t llama_max_devices(void) {
-    return LLAMA_MAX_DEVICES;
+size_t llama_max_devices(void) {
+#if defined(GGML_USE_METAL)
+    return 1;
+#elif defined(GGML_USE_CUBLAS)
+    return GGML_CUDA_MAX_DEVICES;
+#elif defined(GGML_USE_SYCL)
+    return GGML_SYCL_MAX_DEVICES;
+#else
+    return 1;
+#endif
 }
 
-bool llama_mmap_supported(void) {
+bool llama_supports_mmap(void) {
     return llama_mmap::SUPPORTED;
 }
 
-bool llama_mlock_supported(void) {
+bool llama_supports_mlock(void) {
     return llama_mlock::SUPPORTED;
 }
 
+bool llama_supports_gpu_offload(void) {
+#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL) || defined(GGML_USE_VULKAN) || \
+    defined(GGML_USE_SYCL)   || defined(GGML_USE_KOMPUTE)
+    // Defined when llama.cpp is compiled with support for offloading model layers to GPU.
+    return true;
+#else
+    return false;
+#endif
+}
+
+// deprecated:
+bool llama_mmap_supported(void) {
+    return llama_supports_mmap();
+}
+
+bool llama_mlock_supported(void) {
+    return llama_supports_mlock();
+}
+
 void llama_backend_init(bool numa) {
     ggml_time_init();
 
@@ -10133,8 +10160,8 @@ int64_t llama_time_us(void) {
 }
 
 struct llama_model * llama_load_model_from_file(
-                             const char * path_model,
-              struct llama_model_params   params) {
+        const char * path_model,
+        struct llama_model_params   params) {
     ggml_time_init();
 
     llama_model * model = new llama_model;
diff --git a/llama.h b/llama.h
index 17d43d03901eae501c8a3bffe85ad5d9224288e9..9a60e9bfbb6d37c6d59e11eb72788014787ce12e 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -3,15 +3,7 @@
 
 #include "ggml.h"
 #include "ggml-backend.h"
-#ifdef GGML_USE_CUBLAS
-#include "ggml-cuda.h"
-#define LLAMA_MAX_DEVICES GGML_CUDA_MAX_DEVICES
-#elif defined(GGML_USE_SYCL)
-#include "ggml-sycl.h"
-#define LLAMA_MAX_DEVICES GGML_SYCL_MAX_DEVICES
-#else
-#define LLAMA_MAX_DEVICES 1
-#endif // GGML_USE_CUBLAS
+
 #include <stddef.h>
 #include <stdint.h>
 #include <stdio.h>
 #define LLAMA_SESSION_MAGIC   LLAMA_FILE_MAGIC_GGSN
 #define LLAMA_SESSION_VERSION 4
 
-#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL) || defined(GGML_USE_VULKAN) || \
-    defined(GGML_USE_SYCL) || defined(GGML_USE_KOMPUTE)
-// Defined when llama.cpp is compiled with support for offloading model layers to GPU.
-#define LLAMA_SUPPORTS_GPU_OFFLOAD
-#endif
-
 #ifdef __cplusplus
 extern "C" {
 #endif
@@ -201,7 +187,7 @@ extern "C" {
         // LLAMA_SPLIT_LAYER: ignored
         int32_t main_gpu;
 
-        // proportion of the model (layers or rows) to offload to each GPU, size: LLAMA_MAX_DEVICES
+        // proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices()
         const float * tensor_split;
 
         // Called with a progress value between 0.0 and 1.0. Pass NULL to disable.
@@ -338,9 +324,14 @@ extern "C" {
 
     LLAMA_API int64_t llama_time_us(void);
 
-    LLAMA_API int32_t  llama_max_devices(void);
-    LLAMA_API bool llama_mmap_supported (void);
-    LLAMA_API bool llama_mlock_supported(void);
+    LLAMA_API size_t llama_max_devices(void);
+
+    LLAMA_API bool llama_supports_mmap       (void);
+    LLAMA_API bool llama_supports_mlock      (void);
+    LLAMA_API bool llama_supports_gpu_offload(void);
+
+    LLAMA_API DEPRECATED(bool llama_mmap_supported (void), "use llama_supports_mmap() instead");
+    LLAMA_API DEPRECATED(bool llama_mlock_supported(void), "use llama_supports_mlock() instead");
 
     LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);