### Hot topics
-- ⚠️ Incoming backends: https://github.com/ggerganov/llama.cpp/discussions/5138
+- Remove LLAMA_MAX_DEVICES and LLAMA_SUPPORTS_GPU_OFFLOAD: https://github.com/ggerganov/llama.cpp/pull/5240
+- Incoming backends: https://github.com/ggerganov/llama.cpp/discussions/5138
- [SYCL backend](README-sycl.md) is ready (1/28/2024), support Linux/Windows in Intel GPUs (iGPU, Arc/Flex/Max series)
- New SOTA quantized models, including pure 2-bits: https://huggingface.co/ikawrakow
- Collecting Apple Silicon performance stats:
break;
}
params.n_gpu_layers = std::stoi(argv[i]);
-#ifndef LLAMA_SUPPORTS_GPU_OFFLOAD
- fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
- fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
-#endif
+ if (!llama_supports_gpu_offload()) {
+ fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
+ fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
+ }
} else if (arg == "--gpu-layers-draft" || arg == "-ngld" || arg == "--n-gpu-layers-draft") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.n_gpu_layers_draft = std::stoi(argv[i]);
-#ifndef LLAMA_SUPPORTS_GPU_OFFLOAD
- fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers-draft option will be ignored\n");
- fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
-#endif
+ if (!llama_supports_gpu_offload()) {
+ fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers-draft option will be ignored\n");
+ fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
+ }
} else if (arg == "--main-gpu" || arg == "-mg") {
if (++i >= argc) {
invalid_param = true;
const std::regex regex{R"([,/]+)"};
std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
std::vector<std::string> split_arg{it, {}};
- if (split_arg.size() >= LLAMA_MAX_DEVICES) {
+ if (split_arg.size() >= llama_max_devices()) {
invalid_param = true;
break;
}
- for (size_t i = 0; i < LLAMA_MAX_DEVICES; ++i) {
+ for (size_t i = 0; i < llama_max_devices(); ++i) {
if (i < split_arg.size()) {
params.tensor_split[i] = std::stof(split_arg[i]);
} else {
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md\n");
printf(" --image IMAGE_FILE path to an image file. use with multimodal models\n");
- if (llama_mlock_supported()) {
+ if (llama_supports_mlock()) {
printf(" --mlock force system to keep model in RAM rather than swapping or compressing\n");
}
- if (llama_mmap_supported()) {
+ if (llama_supports_mmap()) {
printf(" --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n");
}
printf(" --numa attempt optimizations that help on some NUMA systems\n");
printf(" if run without this previously, it is recommended to drop the system page cache before using this\n");
printf(" see https://github.com/ggerganov/llama.cpp/issues/1437\n");
-#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
- printf(" -ngl N, --n-gpu-layers N\n");
- printf(" number of layers to store in VRAM\n");
- printf(" -ngld N, --n-gpu-layers-draft N\n");
- printf(" number of layers to store in VRAM for the draft model\n");
- printf(" -sm SPLIT_MODE, --split-mode SPLIT_MODE\n");
- printf(" how to split the model across multiple GPUs, one of:\n");
- printf(" - none: use one GPU only\n");
- printf(" - layer (default): split layers and KV across GPUs\n");
- printf(" - row: split rows across GPUs\n");
- printf(" -ts SPLIT, --tensor-split SPLIT\n");
- printf(" fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1\n");
- printf(" -mg i, --main-gpu i the GPU to use for the model (with split-mode = none),\n");
- printf(" or for intermediate results and KV (with split-mode = row) (default: %d)\n", params.main_gpu);
-#endif // LLAMA_SUPPORTS_GPU_OFFLOAD
+ if (llama_supports_gpu_offload()) {
+ printf(" -ngl N, --n-gpu-layers N\n");
+ printf(" number of layers to store in VRAM\n");
+ printf(" -ngld N, --n-gpu-layers-draft N\n");
+ printf(" number of layers to store in VRAM for the draft model\n");
+ printf(" -sm SPLIT_MODE, --split-mode SPLIT_MODE\n");
+ printf(" how to split the model across multiple GPUs, one of:\n");
+ printf(" - none: use one GPU only\n");
+ printf(" - layer (default): split layers and KV across GPUs\n");
+ printf(" - row: split rows across GPUs\n");
+ printf(" -ts SPLIT, --tensor-split SPLIT\n");
+ printf(" fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1\n");
+ printf(" -mg i, --main-gpu i the GPU to use for the model (with split-mode = none),\n");
+ printf(" or for intermediate results and KV (with split-mode = row) (default: %d)\n", params.main_gpu);
+ }
printf(" --verbose-prompt print a verbose prompt before generation (default: %s)\n", params.verbose_prompt ? "true" : "false");
printf(" --no-display-prompt don't print prompt at generation (default: %s)\n", !params.display_prompt ? "true" : "false");
printf(" -gan N, --grp-attn-n N\n");
fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false");
fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);
- const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + LLAMA_MAX_DEVICES);
+ const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices());
dump_vector_float_yaml(stream, "tensor_split", tensor_split_vector);
fprintf(stream, "tfs: %f # default: 1.0\n", sparams.tfs_z);
int32_t get_num_physical_cores();
struct gpt_params {
- uint32_t seed = -1; // RNG seed
-
- int32_t n_threads = get_num_physical_cores();
- int32_t n_threads_draft = -1;
- int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads)
- int32_t n_threads_batch_draft = -1;
- int32_t n_predict = -1; // new tokens to predict
- int32_t n_ctx = 512; // context size
- int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
- int32_t n_keep = 0; // number of tokens to keep from initial prompt
- int32_t n_draft = 8; // number of tokens to draft during speculative decoding
- int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
- int32_t n_parallel = 1; // number of parallel sequences to decode
- int32_t n_sequences = 1; // number of sequences to decode
- float p_accept = 0.5f; // speculative decoding accept probability
- float p_split = 0.1f; // speculative decoding split probability
- int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
- int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
- llama_split_mode split_mode = LLAMA_SPLIT_LAYER; // how to split the model across GPUs
- int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
- float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
- int32_t n_beams = 0; // if non-zero then use beam search of given width.
- int32_t grp_attn_n = 1; // group-attention factor
- int32_t grp_attn_w = 512; // group-attention width
- int32_t n_print = -1; // print token count every n tokens (-1 = disabled)
- float rope_freq_base = 0.0f; // RoPE base frequency
- float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
- float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor
- float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor
- float yarn_beta_fast = 32.0f; // YaRN low correction dim
- float yarn_beta_slow = 1.0f; // YaRN high correction dim
- int32_t yarn_orig_ctx = 0; // YaRN original context length
- int8_t rope_scaling_type = LLAMA_ROPE_SCALING_UNSPECIFIED; // TODO: better to be int32_t for alignment
- // pinging @cebtenzzre
+ uint32_t seed = -1; // RNG seed
+
+ int32_t n_threads = get_num_physical_cores();
+ int32_t n_threads_draft = -1;
+ int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads)
+ int32_t n_threads_batch_draft = -1;
+ int32_t n_predict = -1; // new tokens to predict
+ int32_t n_ctx = 512; // context size
+ int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
+ int32_t n_keep = 0; // number of tokens to keep from initial prompt
+ int32_t n_draft = 8; // number of tokens to draft during speculative decoding
+ int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
+ int32_t n_parallel = 1; // number of parallel sequences to decode
+ int32_t n_sequences = 1; // number of sequences to decode
+ float p_accept = 0.5f; // speculative decoding accept probability
+ float p_split = 0.1f; // speculative decoding split probability
+ int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
+ int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
+ llama_split_mode split_mode = LLAMA_SPLIT_LAYER; // how to split the model across GPUs
+ int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
+ float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
+ int32_t n_beams = 0; // if non-zero then use beam search of given width.
+ int32_t grp_attn_n = 1; // group-attention factor
+ int32_t grp_attn_w = 512; // group-attention width
+ int32_t n_print = -1; // print token count every n tokens (-1 = disabled)
+ float rope_freq_base = 0.0f; // RoPE base frequency
+ float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
+ float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor
+ float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor
+ float yarn_beta_fast = 32.0f; // YaRN low correction dim
+ float yarn_beta_slow = 1.0f; // YaRN high correction dim
+ int32_t yarn_orig_ctx = 0; // YaRN original context length
+ int8_t rope_scaling_type = LLAMA_ROPE_SCALING_UNSPECIFIED; // TODO: better to be int32_t for alignment
+ // pinging @cebtenzzre
// // sampling parameters
struct llama_sampling_params sparams;
*invalid_param = true;
return true;
}
-#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
- params->n_gpu_layers = std::stoi(argv[i]);
-#else
- fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
- fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
-#endif
+ if (llama_supports_gpu_offload()) {
+ params->n_gpu_layers = std::stoi(argv[i]);
+ } else {
+ fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
+ fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
+ }
} else if (arg == "-h" || arg == "--help") {
params->print_usage = true;
return true;
llama_model_params model_params = llama_model_default_params();
- const std::vector<float> t_split (LLAMA_MAX_DEVICES, 0.0f);
+ const std::vector<float> t_split(llama_max_devices(), 0.0f);
model_params.n_gpu_layers = n_gpu_layers;
model_params.tensor_split = t_split.data();
std::vector<int> main_gpu;
std::vector<bool> no_kv_offload;
std::vector<bool> mul_mat_q;
- std::vector<std::array<float, LLAMA_MAX_DEVICES>> tensor_split;
+ std::vector<std::vector<float>> tensor_split;
int reps;
bool verbose;
output_formats output_format;
/* main_gpu */ {0},
/* no_kv_offload */ {false},
/* mul_mat_q */ {true},
- /* tensor_split */ {{}},
+ /* tensor_split */ {std::vector<float>(llama_max_devices(), 0.0f)},
/* reps */ 5,
/* verbose */ false,
/* output_format */ MARKDOWN
const std::regex regex{R"([;/]+)"};
std::sregex_token_iterator it{ts.begin(), ts.end(), regex, -1};
std::vector<std::string> split_arg{it, {}};
- GGML_ASSERT(split_arg.size() <= LLAMA_MAX_DEVICES);
+ GGML_ASSERT(split_arg.size() <= llama_max_devices());
- std::array<float, LLAMA_MAX_DEVICES> tensor_split;
- for (size_t i = 0; i < LLAMA_MAX_DEVICES; ++i) {
+ std::vector<float> tensor_split(llama_max_devices());
+ for (size_t i = 0; i < llama_max_devices(); ++i) {
if (i < split_arg.size()) {
tensor_split[i] = std::stof(split_arg[i]);
} else {
int main_gpu;
bool no_kv_offload;
bool mul_mat_q;
- std::array<float, LLAMA_MAX_DEVICES> tensor_split;
+ std::vector<float> tensor_split;
llama_model_params to_llama_mparams() const {
llama_model_params mparams = llama_model_default_params();
int main_gpu;
bool no_kv_offload;
bool mul_mat_q;
- std::array<float, LLAMA_MAX_DEVICES> tensor_split;
+ std::vector<float> tensor_split;
int n_prompt;
int n_gen;
std::string test_time;
std::vector<std::string> get_values() const {
std::string tensor_split_str;
int max_nonzero = 0;
- for (int i = 0; i < LLAMA_MAX_DEVICES; i++) {
+ for (size_t i = 0; i < llama_max_devices(); i++) {
if (tensor_split[i] > 0) {
max_nonzero = i;
}
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
printf(" not recommended: doubles context memory required and no measurable increase in quality\n");
- if (llama_mlock_supported())
+ if (llama_supports_mlock())
{
printf(" --mlock force system to keep model in RAM rather than swapping or compressing\n");
}
- if (llama_mmap_supported())
+ if (llama_supports_mmap())
{
printf(" --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n");
}
printf(" --numa attempt optimizations that help on some NUMA systems\n");
-#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
- printf(" -ngl N, --n-gpu-layers N\n");
- printf(" number of layers to store in VRAM\n");
- printf(" -sm SPLIT_MODE, --split-mode SPLIT_MODE\n");
- printf(" how to split the model across multiple GPUs, one of:\n");
- printf(" - none: use one GPU only\n");
- printf(" - layer (default): split layers and KV across GPUs\n");
- printf(" - row: split rows across GPUs\n");
- printf(" -ts SPLIT --tensor-split SPLIT\n");
- printf(" fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1\n");
- printf(" -mg i, --main-gpu i the GPU to use for the model (with split-mode = none),\n");
- printf(" or for intermediate results and KV (with split-mode = row)\n");
-#endif
+ if (llama_supports_gpu_offload()) {
+ printf(" -ngl N, --n-gpu-layers N\n");
+ printf(" number of layers to store in VRAM\n");
+ printf(" -sm SPLIT_MODE, --split-mode SPLIT_MODE\n");
+ printf(" how to split the model across multiple GPUs, one of:\n");
+ printf(" - none: use one GPU only\n");
+ printf(" - layer (default): split layers and KV across GPUs\n");
+ printf(" - row: split rows across GPUs\n");
+ printf(" -ts SPLIT --tensor-split SPLIT\n");
+ printf(" fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1\n");
+ printf(" -mg i, --main-gpu i the GPU to use for the model (with split-mode = none),\n");
+ printf(" or for intermediate results and KV (with split-mode = row)\n");
+ }
printf(" -m FNAME, --model FNAME\n");
printf(" model path (default: %s)\n", params.model.c_str());
printf(" -a ALIAS, --alias ALIAS\n");
invalid_param = true;
break;
}
-#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
- params.n_gpu_layers = std::stoi(argv[i]);
-#else
- LOG_WARNING("Not compiled with GPU offload support, --n-gpu-layers option will be ignored. "
+ if (llama_supports_gpu_offload()) {
+ params.n_gpu_layers = std::stoi(argv[i]);
+ } else {
+ LOG_WARNING("Not compiled with GPU offload support, --n-gpu-layers option will be ignored. "
"See main README.md for information on enabling GPU BLAS support",
{{"n_gpu_layers", params.n_gpu_layers}});
-#endif
+ }
}
else if (arg == "--split-mode" || arg == "-sm")
{
const std::regex regex{R"([,/]+)"};
std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
std::vector<std::string> split_arg{it, {}};
- GGML_ASSERT(split_arg.size() <= LLAMA_MAX_DEVICES);
+ GGML_ASSERT(split_arg.size() <= llama_max_devices());
- for (size_t i_device = 0; i_device < LLAMA_MAX_DEVICES; ++i_device)
+ for (size_t i_device = 0; i_device < llama_max_devices(); ++i_device)
{
if (i_device < split_arg.size())
{
return result;
}
-int32_t llama_max_devices(void) {
- return LLAMA_MAX_DEVICES;
+size_t llama_max_devices(void) {
+#if defined(GGML_USE_METAL)
+ return 1;
+#elif defined(GGML_USE_CUBLAS)
+ return GGML_CUDA_MAX_DEVICES;
+#elif defined(GGML_USE_SYCL)
+ return GGML_SYCL_MAX_DEVICES;
+#else
+ return 1;
+#endif
}
-bool llama_mmap_supported(void) {
+bool llama_supports_mmap(void) {
return llama_mmap::SUPPORTED;
}
-bool llama_mlock_supported(void) {
+bool llama_supports_mlock(void) {
return llama_mlock::SUPPORTED;
}
+bool llama_supports_gpu_offload(void) {
+#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL) || defined(GGML_USE_VULKAN) || \
+ defined(GGML_USE_SYCL) || defined(GGML_USE_KOMPUTE)
+ // Defined when llama.cpp is compiled with support for offloading model layers to GPU.
+ return true;
+#else
+ return false;
+#endif
+}
+
+// deprecated:
+bool llama_mmap_supported(void) {
+ return llama_supports_mmap();
+}
+
+bool llama_mlock_supported(void) {
+ return llama_supports_mlock();
+}
+
void llama_backend_init(bool numa) {
ggml_time_init();
}
struct llama_model * llama_load_model_from_file(
- const char * path_model,
- struct llama_model_params params) {
+ const char * path_model,
+ struct llama_model_params params) {
ggml_time_init();
llama_model * model = new llama_model;
#include "ggml.h"
#include "ggml-backend.h"
-#ifdef GGML_USE_CUBLAS
-#include "ggml-cuda.h"
-#define LLAMA_MAX_DEVICES GGML_CUDA_MAX_DEVICES
-#elif defined(GGML_USE_SYCL)
-#include "ggml-sycl.h"
-#define LLAMA_MAX_DEVICES GGML_SYCL_MAX_DEVICES
-#else
-#define LLAMA_MAX_DEVICES 1
-#endif // GGML_USE_CUBLAS
+
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
#define LLAMA_SESSION_VERSION 4
-#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL) || defined(GGML_USE_VULKAN) || \
- defined(GGML_USE_SYCL) || defined(GGML_USE_KOMPUTE)
-// Defined when llama.cpp is compiled with support for offloading model layers to GPU.
-#define LLAMA_SUPPORTS_GPU_OFFLOAD
-#endif
-
#ifdef __cplusplus
extern "C" {
#endif
// LLAMA_SPLIT_LAYER: ignored
int32_t main_gpu;
- // proportion of the model (layers or rows) to offload to each GPU, size: LLAMA_MAX_DEVICES
+ // proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices()
const float * tensor_split;
// Called with a progress value between 0.0 and 1.0. Pass NULL to disable.
LLAMA_API int64_t llama_time_us(void);
- LLAMA_API int32_t llama_max_devices(void);
- LLAMA_API bool llama_mmap_supported (void);
- LLAMA_API bool llama_mlock_supported(void);
+ LLAMA_API size_t llama_max_devices(void);
+
+ LLAMA_API bool llama_supports_mmap (void);
+ LLAMA_API bool llama_supports_mlock (void);
+ LLAMA_API bool llama_supports_gpu_offload(void);
+
+ LLAMA_API DEPRECATED(bool llama_mmap_supported (void), "use llama_supports_mmap() instead");
+ LLAMA_API DEPRECATED(bool llama_mlock_supported(void), "use llama_supports_mlock() instead");
LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);