if (params.n_threads <= 0) {
params.n_threads = std::thread::hardware_concurrency();
}
+ } else if (arg == "-tb" || arg == "--threads-batch") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.n_threads_batch = std::stoi(argv[i]);
+ if (params.n_threads_batch <= 0) {
+ params.n_threads_batch = std::thread::hardware_concurrency();
+ }
} else if (arg == "-p" || arg == "--prompt") {
if (++i >= argc) {
invalid_param = true;
params.mul_mat_q = false;
#else
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. Disabling mul_mat_q kernels has no effect.\n");
-#endif // GGML_USE_CUBLAS
- } else if (arg == "--low-vram" || arg == "-lv") {
-#ifdef GGML_USE_CUBLAS
- params.low_vram = true;
-#else
- fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set lower vram usage.\n");
#endif // GGML_USE_CUBLAS
} else if (arg == "--no-mmap") {
params.use_mmap = false;
printf(" (can be specified more than once for multiple prompts).\n");
printf(" --color colorise output to distinguish prompt and user input from generations\n");
printf(" -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n");
- printf(" -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
+ printf(" -t N, --threads N number of threads to use during generation (default: %d)\n", params.n_threads);
+ printf(" -tb N, --threads-batch N\n");
+ printf(" number of threads to use during batch and prompt processing (default: same as --threads)\n");
printf(" -p PROMPT, --prompt PROMPT\n");
printf(" prompt to start generation with (default: empty)\n");
printf(" -e, --escape process prompt escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n");
printf(" -f FNAME, --file FNAME\n");
printf(" prompt file to start generation.\n");
printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict);
- printf(" -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
+ printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx);
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k);
printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p);
printf(" -ts SPLIT --tensor-split SPLIT\n");
printf(" how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
printf(" -mg i, --main-gpu i the GPU to use for scratch and small tensors\n");
- printf(" -lv, --low-vram don't allocate VRAM scratch buffer\n");
#ifdef GGML_USE_CUBLAS
printf(" -nommq, --no-mul-mat-q\n");
printf(" use " GGML_CUBLAS_NAME " instead of custom mul_mat_q " GGML_CUDA_NAME " kernels.\n");
printf("\n");
}
+std::string get_system_info(const gpt_params & params) {
+ std::ostringstream os;
+
+ os << "system_info: n_threads = " << params.n_threads;
+ if (params.n_threads_batch != -1) {
+ os << " (n_threads_batch = " << params.n_threads_batch << ")";
+ }
+ os << " / " << std::thread::hardware_concurrency() << " | " << llama_print_system_info();
+
+ return os.str();
+}
+
std::string gpt_random_prompt(std::mt19937 & rng) {
const int r = rng() % 10;
switch (r) {
// Model utils
//
-struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
- auto lparams = llama_context_default_params();
+struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params) {
+ auto mparams = llama_model_default_params();
- lparams.n_ctx = params.n_ctx;
- lparams.n_batch = params.n_batch;
if (params.n_gpu_layers != -1) {
- lparams.n_gpu_layers = params.n_gpu_layers;
+ mparams.n_gpu_layers = params.n_gpu_layers;
}
- lparams.main_gpu = params.main_gpu;
- lparams.tensor_split = params.tensor_split;
- lparams.low_vram = params.low_vram;
- lparams.mul_mat_q = params.mul_mat_q;
- lparams.seed = params.seed;
- lparams.f16_kv = params.memory_f16;
- lparams.use_mmap = params.use_mmap;
- lparams.use_mlock = params.use_mlock;
- lparams.logits_all = params.logits_all;
- lparams.embedding = params.embedding;
- lparams.rope_freq_base = params.rope_freq_base;
- lparams.rope_freq_scale = params.rope_freq_scale;
-
- return lparams;
+ mparams.main_gpu = params.main_gpu;
+ mparams.tensor_split = params.tensor_split;
+ mparams.use_mmap = params.use_mmap;
+ mparams.use_mlock = params.use_mlock;
+
+ return mparams;
+}
+
+struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
+ auto cparams = llama_context_default_params();
+
+ cparams.n_ctx = params.n_ctx;
+ cparams.n_batch = params.n_batch;
+ cparams.n_threads = params.n_threads;
+ cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
+ cparams.mul_mat_q = params.mul_mat_q;
+ cparams.seed = params.seed;
+ cparams.f16_kv = params.memory_f16;
+ cparams.logits_all = params.logits_all;
+ cparams.embedding = params.embedding;
+ cparams.rope_freq_base = params.rope_freq_base;
+ cparams.rope_freq_scale = params.rope_freq_scale;
+
+ return cparams;
}
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params) {
- auto lparams = llama_context_params_from_gpt_params(params);
+ auto mparams = llama_model_params_from_gpt_params(params);
- llama_model * model = llama_load_model_from_file(params.model.c_str(), lparams);
+ llama_model * model = llama_load_model_from_file(params.model.c_str(), mparams);
if (model == NULL) {
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
return std::make_tuple(nullptr, nullptr);
}
- llama_context * lctx = llama_new_context_with_model(model, lparams);
+ auto cparams = llama_context_params_from_gpt_params(params);
+
+ llama_context * lctx = llama_new_context_with_model(model, cparams);
if (lctx == NULL) {
fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str());
llama_free_model(model);
LOG("warming up the model with an empty run\n");
std::vector<llama_token> tmp = { llama_token_bos(lctx), llama_token_eos(lctx), };
- llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0), params.n_threads);
+ llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
llama_kv_cache_tokens_rm(lctx, -1, -1);
llama_reset_timings(lctx);
}
//
std::vector<llama_token> llama_tokenize(
- struct llama_context * ctx,
+ const struct llama_context * ctx,
+ const std::string & text,
+ bool add_bos) {
+ return llama_tokenize(llama_get_model(ctx), text, add_bos);
+}
+
+std::vector<llama_token> llama_tokenize(
+ const struct llama_model * model,
const std::string & text,
bool add_bos) {
// upper limit for the number of tokens
int n_tokens = text.length() + add_bos;
std::vector<llama_token> result(n_tokens);
- n_tokens = llama_tokenize(ctx, text.data(), text.length(), result.data(), result.size(), add_bos);
+ n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos);
if (n_tokens < 0) {
result.resize(-n_tokens);
- int check = llama_tokenize(ctx, text.data(), text.length(), result.data(), result.size(), add_bos);
+ int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos);
GGML_ASSERT(check == -n_tokens);
} else {
result.resize(n_tokens);
std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) {
std::vector<char> result(8, 0);
- const int n_tokens = llama_token_to_piece(ctx, token, result.data(), result.size());
+ const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
if (n_tokens < 0) {
result.resize(-n_tokens);
- int check = llama_token_to_piece(ctx, token, result.data(), result.size());
+ int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
GGML_ASSERT(check == -n_tokens);
} else {
result.resize(n_tokens);
std::vector<llama_token_data> & candidates,
int idx) {
const int n_ctx = llama_n_ctx(ctx);
- const int n_vocab = llama_n_vocab(ctx);
+ const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const float temp = params.temp;
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
#endif // NDEBUG
fprintf(stream, "model_desc: %s\n", model_desc);
- fprintf(stream, "n_vocab: %d # output size of the final layer, 32001 for some models\n", llama_n_vocab(lctx));
+ fprintf(stream, "n_vocab: %d # output size of the final layer, 32001 for some models\n", llama_n_vocab(llama_get_model(lctx)));
#ifdef __OPTIMIZE__
fprintf(stream, "optimize: true\n");
fprintf(stream, " - %s: %f\n", std::get<0>(la).c_str(), std::get<1>(la));
}
fprintf(stream, "lora_base: %s\n", params.lora_base.c_str());
- fprintf(stream, "low_vram: %s # default: false\n", params.low_vram ? "true" : "false");
fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu);
fprintf(stream, "memory_f32: %s # default: false\n", !params.memory_f16 ? "true" : "false");
fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", params.mirostat);
struct gpt_params {
uint32_t seed = -1; // RNG seed
int32_t n_threads = get_num_physical_cores();
+ int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads)
int32_t n_predict = -1; // new tokens to predict
int32_t n_ctx = 512; // context size
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt
size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score
- bool low_vram = false; // if true, reduce VRAM usage at the cost of performance
bool mul_mat_q = true; // if true, use mul_mat_q kernels instead of cuBLAS
bool memory_f16 = true; // use f16 instead of f32 for memory kv
bool random_prompt = false; // do not randomize prompt if none provided
void gpt_print_usage(int argc, char ** argv, const gpt_params & params);
+std::string get_system_info(const gpt_params & params);
+
std::string gpt_random_prompt(std::mt19937 & rng);
void process_escapes(std::string& input);
//
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params);
+struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params);
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);
//
// tokenizes a string into a vector of tokens
// should work similar to Python's `tokenizer.encode`
std::vector<llama_token> llama_tokenize(
- struct llama_context * ctx,
+ const struct llama_context * ctx,
+ const std::string & text,
+ bool add_bos);
+
+std::vector<llama_token> llama_tokenize(
+ const struct llama_model * model,
const std::string & text,
bool add_bos);
out_tokens.resize(buf.size() + n_max_tokens_overhead);
int n_tokens = llama_tokenize(
- lctx,
+ llama_get_model(lctx),
buf.data(),
(int) buf.size(),
out_tokens.data(),
if (n_tokens < 0) {
out_tokens.resize(-n_tokens);
n_tokens = llama_tokenize(
- lctx,
+ llama_get_model(lctx),
buf.data(),
(int) buf.size(),
out_tokens.data(),
size_t found_max_sample_size = 0;
size_t max_token_text_size = 0;
- int n_vocab = llama_n_vocab(lctx);
+ int n_vocab = llama_n_vocab(llama_get_model(lctx));
for (llama_token token=0; token < n_vocab; ++token) {
max_token_text_size = std::max(
max_token_text_size,
// tokenize the sample
tok_sample.resize(buf_sample.size() + n_max_tokens_overhead);
- int n_tokens = llama_tokenize(lctx,
+ int n_tokens = llama_tokenize(llama_get_model(lctx),
buf_sample.data(),
(int) buf_sample.size(),
tok_sample.data(),
false);
if (n_tokens < 0) {
tok_sample.resize(-n_tokens);
- n_tokens = llama_tokenize(lctx,
+ n_tokens = llama_tokenize(llama_get_model(lctx),
buf_sample.data(),
(int) buf_sample.size(),
tok_sample.data(),
llama_backend_init(params.numa);
- llama_context_params ctx_params = llama_context_default_params();
+ // initialize the model
- ctx_params.seed = 1234;
- ctx_params.n_ctx = n_len*n_parallel; // FIXME: use n_kv_req instead (tokenize with model after #3301)
- ctx_params.n_batch = std::max(n_len, n_parallel);
- // ctx_params.n_gpu_layers = 99; // offload all layers to the GPU
+ llama_model_params model_params = llama_model_default_params();
- llama_model * model = llama_load_model_from_file(params.model.c_str(), ctx_params);
+ // model_params.n_gpu_layers = 99; // offload all layers to the GPU
+
+ llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);
if (model == NULL) {
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
return 1;
}
+ // tokenize the prompt
+
+ std::vector<llama_token> tokens_list;
+ tokens_list = ::llama_tokenize(model, params.prompt, true);
+ const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size())*n_parallel;
+
+ // initialize the context
+
+ llama_context_params ctx_params = llama_context_default_params();
+
+ ctx_params.seed = 1234;
+ ctx_params.n_ctx = n_kv_req;
+ ctx_params.n_batch = std::max(n_len, n_parallel);
+ ctx_params.n_threads = params.n_threads;
+ ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
+
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
if (ctx == NULL) {
return 1;
}
- // tokenize the prompt
-
- std::vector<llama_token> tokens_list;
- tokens_list = ::llama_tokenize(ctx, params.prompt, true);
-
const int n_ctx = llama_n_ctx(ctx);
- const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size())*n_parallel;
LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_batch = %d, n_parallel = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, ctx_params.n_batch, n_parallel, n_kv_req);
// llama_decode will output logits only for the last token of the prompt
batch.logits[batch.n_tokens - 1] = true;
- if (llama_decode(ctx, batch, params.n_threads) != 0) {
+ if (llama_decode(ctx, batch) != 0) {
LOG_TEE("%s: llama_decode() failed\n", __func__);
return 1;
}
continue;
}
- auto n_vocab = llama_n_vocab(ctx);
+ auto n_vocab = llama_n_vocab(model);
auto * logits = llama_get_logits_ith(ctx, i_batch[i]);
std::vector<llama_token_data> candidates;
n_cur += 1;
// evaluate the current batch with the transformer model
- if (llama_decode(ctx, batch, params.n_threads)) {
+ if (llama_decode(ctx, batch)) {
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
return 1;
}
int n_past = 0;
- if (llama_decode(ctx, llama_batch_get_one(tokens_list.data(), tokens_list.size(), n_past, 0), params.n_threads))
+ if (llama_decode(ctx, llama_batch_get_one(tokens_list.data(), tokens_list.size(), n_past, 0)))
{
fprintf(stderr, "%s : failed to eval prompt.\n" , __func__ );
return 1;
beam_search_callback_data callback_data{ctx, {}};
size_t const beam_width = static_cast<size_t>(params.n_beams);
int const n_predict = 256;
- llama_beam_search(ctx, beam_search_callback, &callback_data, beam_width, n_past, n_predict, params.n_threads);
+ llama_beam_search(ctx, beam_search_callback, &callback_data, beam_width, n_past, n_predict);
std::cout << "\n\n";
for (llama_token const token_id : callback_data.response) {
// print system information
{
fprintf(stderr, "\n");
- fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
- params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
+ fprintf(stderr, "%s\n", get_system_info(params).c_str());
}
struct MyModel * ret = new MyModel();
ret->ctx = ctx;
MyModel * mymodel = (MyModel*)model;
llama_context * ctx = mymodel->ctx;
gpt_params params = mymodel->params;
- int n_emb = llama_n_embd(ctx);
+ int n_emb = llama_n_embd(llama_get_model(ctx));
int n_past = mymodel->n_past;
int n_batch = N; // params.n_batch;
n_eval = n_batch;
}
llama_batch batch = { int32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, nullptr, n_past, 1, 0, };
- if (llama_decode(ctx, batch, params.n_threads)) {
+ if (llama_decode(ctx, batch)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
}
if (n_eval > params.n_batch) {
n_eval = params.n_batch;
}
- if (llama_decode(ctx, llama_batch_get_one(&tokens[i], n_eval, n_past, 0), params.n_threads)) {
+ if (llama_decode(ctx, llama_batch_get_one(&tokens[i], n_eval, n_past, 0))) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
}
// out of user input, sample next token
const float temp = params.temp;
- const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
+ const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(llama_get_model(ctx)) : params.top_k;
const float top_p = params.top_p;
const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p;
llama_token id = 0;
{
auto logits = llama_get_logits(ctx);
- auto n_vocab = llama_n_vocab(ctx);
+ auto n_vocab = llama_n_vocab(llama_get_model(ctx));
// Apply params.logit_bias map
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
auto mymodel = create_mymodel(argc, argv);
int N = 10;
int max_tgt_len = 500;
- int n_embd = llama_n_embd(mymodel->ctx);
+ int n_embd = llama_n_embd(llama_get_model(mymodel->ctx));
// add random float embd to test evaluation
float * data = new float[N*n_embd];
return 1;
}
- const int n_ctx_train = llama_n_ctx_train(ctx);
- if (params.n_ctx > n_ctx_train) {
+ const int n_ctx_train = llama_n_ctx_train(model);
+ const int n_ctx = llama_n_ctx(ctx);
+
+ if (n_ctx > n_ctx_train) {
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
- __func__, n_ctx_train, params.n_ctx);
+ __func__, n_ctx_train, n_ctx);
}
// print system information
{
fprintf(stderr, "\n");
- fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
- params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
+ fprintf(stderr, "%s\n", get_system_info(params).c_str());
}
int n_past = 0;
fprintf(stderr, "\n");
}
- if (embd_inp.size() > (size_t)params.n_ctx) {
+ if (embd_inp.size() > (size_t)n_ctx) {
fprintf(stderr, "%s: error: prompt is longer than the context window (%zu tokens, n_ctx = %d)\n",
- __func__, embd_inp.size(), params.n_ctx);
+ __func__, embd_inp.size(), n_ctx);
return 1;
}
while (!embd_inp.empty()) {
int n_tokens = std::min(params.n_batch, (int) embd_inp.size());
- if (llama_decode(ctx, llama_batch_get_one(embd_inp.data(), n_tokens, n_past, 0), params.n_threads)) {
+ if (llama_decode(ctx, llama_batch_get_one(embd_inp.data(), n_tokens, n_past, 0))) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return 1;
}
embd_inp.erase(embd_inp.begin(), embd_inp.begin() + n_tokens);
}
- const int n_embd = llama_n_embd(ctx);
- const auto embeddings = llama_get_embeddings(ctx);
+ const int n_embd = llama_n_embd(model);
+ const auto * embeddings = llama_get_embeddings(ctx);
for (int i = 0; i < n_embd; i++) {
printf("%f ", embeddings[i]);
gguf_free(mctx);
}
- hparams.n_vocab = llama_model_n_vocab(input);
+ hparams.n_vocab = llama_n_vocab(input);
hparams.n_ctx = n_ctx;
// get tensors from llama_model (possibly mmapped)
printf("%s: seed: %u\n", __func__, params.common.seed);
srand(params.common.seed);
- struct llama_context_params llama_params = llama_context_default_params();
- llama_params.vocab_only = false;
+ struct llama_model_params llama_mparams = llama_model_default_params();
+ llama_mparams.vocab_only = false;
printf("%s: model base = '%s'\n", __func__, params.fn_model_base);
- struct llama_model * lmodel = llama_load_model_from_file(params.fn_model_base, llama_params);
- struct llama_context * lctx = llama_new_context_with_model(lmodel, llama_params);
+ struct llama_model * lmodel = llama_load_model_from_file(params.fn_model_base, llama_mparams);
+
+ struct llama_context_params llama_cparams = llama_context_default_params();
+ struct llama_context * lctx = llama_new_context_with_model(lmodel, llama_cparams);
struct my_llama_model model;
init_model(lmodel, &model, params.fn_model_base, params.common.n_ctx);
std::vector<int> n_gpu_layers;
std::vector<int> main_gpu;
std::vector<bool> mul_mat_q;
- std::vector<bool> low_vram;
std::vector<std::array<float, LLAMA_MAX_DEVICES>> tensor_split;
int reps;
bool verbose;
/* n_gpu_layers */ {99},
/* main_gpu */ {0},
/* mul_mat_q */ {true},
- /* low_vram */ {false},
/* tensor_split */ {{}},
/* reps */ 5,
/* verbose */ false,
printf(" -b, --batch-size <n> (default: %s)\n", join(cmd_params_defaults.n_batch, ",").c_str());
printf(" --memory-f32 <0|1> (default: %s)\n", join(cmd_params_defaults.f32_kv, ",").c_str());
printf(" -t, --threads <n> (default: %s)\n", join(cmd_params_defaults.n_threads, ",").c_str());
- printf(" -ngl N, --n-gpu-layers <n> (default: %s)\n", join(cmd_params_defaults.n_gpu_layers, ",").c_str());
- printf(" -mg i, --main-gpu <n> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str());
- printf(" -lv, --low-vram <0|1> (default: %s)\n", join(cmd_params_defaults.low_vram, ",").c_str());
+ printf(" -ngl, --n-gpu-layers <n> (default: %s)\n", join(cmd_params_defaults.n_gpu_layers, ",").c_str());
+ printf(" -mg, --main-gpu <i> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str());
printf(" -mmq, --mul-mat-q <0|1> (default: %s)\n", join(cmd_params_defaults.mul_mat_q, ",").c_str());
printf(" -ts, --tensor_split <ts0/ts1/..> \n");
printf(" -r, --repetitions <n> (default: %d)\n", cmd_params_defaults.reps);
break;
}
params.main_gpu = split<int>(argv[i], split_delim);
- } else if (arg == "-lv" || arg == "--low-vram") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- auto p = split<bool>(argv[i], split_delim);
- params.low_vram.insert(params.low_vram.end(), p.begin(), p.end());
} else if (arg == "-mmq" || arg == "--mul-mat-q") {
if (++i >= argc) {
invalid_param = true;
if (params.n_gpu_layers.empty()) { params.n_gpu_layers = cmd_params_defaults.n_gpu_layers; }
if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; }
if (params.mul_mat_q.empty()) { params.mul_mat_q = cmd_params_defaults.mul_mat_q; }
- if (params.low_vram.empty()) { params.low_vram = cmd_params_defaults.low_vram; }
if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; }
if (params.n_threads.empty()) { params.n_threads = cmd_params_defaults.n_threads; }
int n_gpu_layers;
int main_gpu;
bool mul_mat_q;
- bool low_vram;
std::array<float, LLAMA_MAX_DEVICES> tensor_split;
- llama_context_params to_llama_params() const {
- llama_context_params lparams = llama_context_default_params();
- lparams.n_ctx = n_prompt + n_gen;
- lparams.n_batch = n_batch;
- lparams.f16_kv = !f32_kv;
- lparams.n_gpu_layers = n_gpu_layers;
- lparams.main_gpu = main_gpu;
- lparams.mul_mat_q = mul_mat_q;
- lparams.low_vram = low_vram;
- lparams.tensor_split = tensor_split.data();
+ llama_model_params to_llama_mparams() const {
+ llama_model_params mparams = llama_model_default_params();
+
+ mparams.n_gpu_layers = n_gpu_layers;
+ mparams.main_gpu = main_gpu;
+ mparams.tensor_split = tensor_split.data();
+
+ return mparams;
+ }
+
+ bool equal_mparams(const cmd_params_instance & other) const {
+ return model == other.model &&
+ n_gpu_layers == other.n_gpu_layers &&
+ main_gpu == other.main_gpu &&
+ tensor_split == other.tensor_split;
+ }
+
+ llama_context_params to_llama_cparams() const {
+ llama_context_params cparams = llama_context_default_params();
- return lparams;
+ cparams.n_ctx = n_prompt + n_gen;
+ cparams.n_batch = n_batch;
+ cparams.f16_kv = !f32_kv;
+ cparams.mul_mat_q = mul_mat_q;
+
+ return cparams;
}
};
std::vector<cmd_params_instance> instances;
for (const auto & m : params.model)
- for (const auto & nb : params.n_batch)
- for (const auto & fk : params.f32_kv)
for (const auto & nl : params.n_gpu_layers)
for (const auto & mg : params.main_gpu)
- for (const auto & mmq : params.mul_mat_q)
- for (const auto & lv : params.low_vram)
for (const auto & ts : params.tensor_split)
+ for (const auto & nb : params.n_batch)
+ for (const auto & fk : params.f32_kv)
+ for (const auto & mmq : params.mul_mat_q)
for (const auto & nt : params.n_threads) {
cmd_params_instance instance = {
/* .model = */ m,
/* .n_gpu_layers = */ nl,
/* .main_gpu = */ mg,
/* .mul_mat_q = */ mmq,
- /* .low_vram = */ lv,
/* .tensor_split = */ ts,
};
instances.push_back(instance);
static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_params & params) {
std::vector<cmd_params_instance> instances;
+#if 1
+ // this ordering minimizes the number of times that each model needs to be reloaded
+ for (const auto & m : params.model)
+ for (const auto & nl : params.n_gpu_layers)
+ for (const auto & mg : params.main_gpu)
+ for (const auto & ts : params.tensor_split)
+ for (const auto & nb : params.n_batch)
+ for (const auto & fk : params.f32_kv)
+ for (const auto & mmq : params.mul_mat_q)
+ for (const auto & nt : params.n_threads) {
+ for (const auto & n_prompt : params.n_prompt) {
+ if (n_prompt == 0) {
+ continue;
+ }
+ cmd_params_instance instance = {
+ /* .model = */ m,
+ /* .n_prompt = */ n_prompt,
+ /* .n_gen = */ 0,
+ /* .n_batch = */ nb,
+ /* .f32_kv = */ fk,
+ /* .n_threads = */ nt,
+ /* .n_gpu_layers = */ nl,
+ /* .main_gpu = */ mg,
+ /* .mul_mat_q = */ mmq,
+ /* .tensor_split = */ ts,
+ };
+ instances.push_back(instance);
+ }
+
+ for (const auto & n_gen : params.n_gen) {
+ if (n_gen == 0) {
+ continue;
+ }
+ cmd_params_instance instance = {
+ /* .model = */ m,
+ /* .n_prompt = */ 0,
+ /* .n_gen = */ n_gen,
+ /* .n_batch = */ nb,
+ /* .f32_kv = */ fk,
+ /* .n_threads = */ nt,
+ /* .n_gpu_layers = */ nl,
+ /* .main_gpu = */ mg,
+ /* .mul_mat_q = */ mmq,
+ /* .tensor_split = */ ts,
+ };
+ instances.push_back(instance);
+ }
+ }
+#else
+ // this ordering separates the prompt and generation tests
for (const auto & n_prompt : params.n_prompt) {
if (n_prompt == 0) {
continue;
auto instances_gen = get_cmd_params_instances_int(params, n_gen, 0);
instances.insert(instances.end(), instances_gen.begin(), instances_gen.end());
}
+#endif
return instances;
}
int n_gpu_layers;
int main_gpu;
bool mul_mat_q;
- bool low_vram;
std::array<float, LLAMA_MAX_DEVICES> tensor_split;
int n_prompt;
int n_gen;
n_gpu_layers = inst.n_gpu_layers;
main_gpu = inst.main_gpu;
mul_mat_q = inst.mul_mat_q;
- low_vram = inst.low_vram;
tensor_split = inst.tensor_split;
n_prompt = inst.n_prompt;
n_gen = inst.n_gen;
"cpu_info", "gpu_info",
"model_filename", "model_type", "model_size", "model_n_params",
"n_batch", "n_threads", "f16_kv",
- "n_gpu_layers", "main_gpu", "mul_mat_q", "low_vram", "tensor_split",
+ "n_gpu_layers", "main_gpu", "mul_mat_q", "tensor_split",
"n_prompt", "n_gen", "test_time",
"avg_ns", "stddev_ns",
"avg_ts", "stddev_ts"
return INT;
}
if (field == "cuda" || field == "opencl" || field == "metal" || field == "gpu_blas" || field == "blas" ||
- field == "f16_kv" || field == "mul_mat_q" || field == "low_vram") {
+ field == "f16_kv" || field == "mul_mat_q") {
return BOOL;
}
if (field == "avg_ts" || field == "stddev_ts") {
cpu_info, gpu_info,
model_filename, model_type, std::to_string(model_size), std::to_string(model_n_params),
std::to_string(n_batch), std::to_string(n_threads), std::to_string(!f32_kv),
- std::to_string(n_gpu_layers), std::to_string(main_gpu), std::to_string(mul_mat_q), std::to_string(low_vram), tensor_split_str,
+ std::to_string(n_gpu_layers), std::to_string(main_gpu), std::to_string(mul_mat_q), tensor_split_str,
std::to_string(n_prompt), std::to_string(n_gen), test_time,
std::to_string(avg_ns()), std::to_string(stdev_ns()),
std::to_string(avg_ts()), std::to_string(stdev_ts())
if (params.mul_mat_q.size() > 1 || params.mul_mat_q != cmd_params_defaults.mul_mat_q) {
fields.push_back("mul_mat_q");
}
- if (params.low_vram.size() > 1 || params.low_vram != cmd_params_defaults.low_vram) {
- fields.push_back("low_vram");
- }
if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) {
fields.push_back("tensor_split");
}
static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) {
std::vector<llama_token> tokens(n_batch, llama_token_bos(ctx));
int n_processed = 0;
+
+ llama_set_n_threads(ctx, n_threads, n_threads);
+
while (n_processed < n_prompt) {
int n_tokens = std::min(n_prompt - n_processed, n_batch);
- llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens, n_past + n_processed, 0), n_threads);
+ llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens, n_past + n_processed, 0));
n_processed += n_tokens;
}
}
static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) {
llama_token token = llama_token_bos(ctx);
+
+ llama_set_n_threads(ctx, n_threads, n_threads);
+
for (int i = 0; i < n_gen; i++) {
- llama_decode(ctx, llama_batch_get_one(&token, 1, n_past + i, 0), n_threads);
+ llama_decode(ctx, llama_batch_get_one(&token, 1, n_past + i, 0));
}
}
std::vector<cmd_params_instance> params_instances = get_cmd_params_instances(params);
+ llama_model * lmodel = nullptr;
+ const cmd_params_instance * prev_inst = nullptr;
+
for (const auto & inst : params_instances) {
- // TODO: keep the model between tests when possible
- llama_context_params lparams = inst.to_llama_params();
+ // keep the same model between tests when possible
+ if (!lmodel || !prev_inst || !inst.equal_mparams(*prev_inst)) {
+ if (lmodel) {
+ llama_free_model(lmodel);
+ }
- llama_model * lmodel = llama_load_model_from_file(inst.model.c_str(), lparams);
- if (lmodel == NULL) {
- fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, inst.model.c_str());
- return 1;
+ lmodel = llama_load_model_from_file(inst.model.c_str(), inst.to_llama_mparams());
+ if (lmodel == NULL) {
+ fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, inst.model.c_str());
+ return 1;
+ }
+ prev_inst = &inst;
}
- llama_context * ctx = llama_new_context_with_model(lmodel, lparams);
+ llama_context * ctx = llama_new_context_with_model(lmodel, inst.to_llama_cparams());
if (ctx == NULL) {
fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, inst.model.c_str());
llama_free_model(lmodel);
llama_print_timings(ctx);
llama_free(ctx);
- llama_free_model(lmodel);
}
+ llama_free_model(lmodel);
+
p->print_footer();
llama_backend_free();
### Number of Threads
-- `-t N, --threads N`: Set the number of threads to use during computation. For optimal performance, it is recommended to set this value to the number of physical CPU cores your system has (as opposed to the logical number of cores). Using the correct number of threads can greatly improve performance.
+- `-t N, --threads N`: Set the number of threads to use during generation. For optimal performance, it is recommended to set this value to the number of physical CPU cores your system has (as opposed to the logical number of cores). Using the correct number of threads can greatly improve performance.
+- `-tb N, --threads-batch N`: Set the number of threads to use during batch and prompt processing. In some systems, it is beneficial to use a higher number of threads during batch processing than during generation. If not specified, the number of threads used for batch processing will be the same as the number of threads used for generation.
### Mlock
- `-ngl N, --n-gpu-layers N`: When compiled with appropriate support (currently CLBlast or cuBLAS), this option allows offloading some layers to the GPU for computation. Generally results in increased performance.
- `-mg i, --main-gpu i`: When using multiple GPUs this option controls which GPU is used for small tensors for which the overhead of splitting the computation across all GPUs is not worthwhile. The GPU in question will use slightly more VRAM to store a scratch buffer for temporary results. By default GPU 0 is used. Requires cuBLAS.
- `-ts SPLIT, --tensor-split SPLIT`: When using multiple GPUs this option controls how large tensors should be split across all GPUs. `SPLIT` is a comma-separated list of non-negative values that assigns the proportion of data that each GPU should get in order. For example, "3,2" will assign 60% of the data to GPU 0 and 40% to GPU 1. By default the data is split in proportion to VRAM but this may not be optimal for performance. Requires cuBLAS.
-- `-lv, --low-vram`: Do not allocate a VRAM scratch buffer for holding temporary results. Reduces VRAM usage at the cost of performance, particularly prompt processing speed. Requires cuBLAS.
- `--lora FNAME`: Apply a LoRA (Low-Rank Adaptation) adapter to the model (implies --no-mmap). This allows you to adapt the pretrained model to specific tasks or domains.
- `--lora-base FNAME`: Optional model to use as a base for the layers modified by the LoRA adapter. This flag is used in conjunction with the `--lora` flag, and specifies the base model for the adaptation.
return 0;
}
- if (params.rope_freq_base != 10000.0) {
- LOG_TEE("%s: warning: changing RoPE frequency base to %g (default 10000.0)\n", __func__, params.rope_freq_base);
+ if (params.n_ctx != 0 && params.n_ctx < 8) {
+ LOG_TEE("%s: warning: minimum context size is 8, using minimum size.\n", __func__);
+ params.n_ctx = 8;
+ }
+
+ if (params.rope_freq_base != 0.0) {
+ LOG_TEE("%s: warning: changing RoPE frequency base to %g.\n", __func__, params.rope_freq_base);
}
- if (params.rope_freq_scale != 1.0) {
- LOG_TEE("%s: warning: scaling RoPE frequency by %g (default 1.0)\n", __func__, params.rope_freq_scale);
+ if (params.rope_freq_scale != 0.0) {
+ LOG_TEE("%s: warning: scaling RoPE frequency by %g.\n", __func__, params.rope_freq_scale);
}
LOG_TEE("%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT);
return 1;
}
- const int n_ctx_train = llama_n_ctx_train(ctx);
- if (params.n_ctx > n_ctx_train) {
+ const int n_ctx_train = llama_n_ctx_train(model);
+ const int n_ctx = llama_n_ctx(ctx);
+ LOG("n_ctx: %d\n", n_ctx);
+
+ if (n_ctx > n_ctx_train) {
LOG_TEE("%s: warning: model was trained on only %d context tokens (%d specified)\n",
- __func__, n_ctx_train, params.n_ctx);
- } else if (params.n_ctx < 8) {
- LOG_TEE("%s: warning: minimum context size is 8, using minimum size.\n", __func__);
- params.n_ctx = 8;
+ __func__, n_ctx_train, n_ctx);
}
// print system information
{
LOG_TEE("\n");
- LOG_TEE("system_info: n_threads = %d / %d | %s\n",
- params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
+ LOG_TEE("%s\n", get_system_info(params).c_str());
}
std::string path_session = params.path_prompt_cache;
if (fp != NULL) {
std::fclose(fp);
- session_tokens.resize(params.n_ctx);
+ session_tokens.resize(n_ctx);
size_t n_token_count_out = 0;
if (!llama_load_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) {
LOG_TEE("%s: error: failed to load session file '%s'\n", __func__, path_session.c_str());
}
}
- const bool add_bos = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
+ const bool add_bos = llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM;
LOG("add_bos: %d\n", add_bos);
std::vector<llama_token> embd_inp;
LOG("guidance_offset: %s", log_tostr(guidance_offset));
}
- const int n_ctx = llama_n_ctx(ctx);
- LOG("n_ctx: %d\n", n_ctx);
-
if ((int) embd_inp.size() > n_ctx - 4) {
LOG_TEE("%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4);
return 1;
std::vector<llama_token> embd;
std::vector<llama_token> embd_guidance;
- const int n_vocab = llama_n_vocab(ctx);
+ const int n_vocab = llama_n_vocab(model);
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
for (int i = 0; i < input_size; i += params.n_batch) {
int n_eval = std::min(input_size - i, params.n_batch);
- if (llama_decode(ctx_guidance, llama_batch_get_one(input_buf + i, n_eval, n_past_guidance, 0), params.n_threads)) {
+ if (llama_decode(ctx_guidance, llama_batch_get_one(input_buf + i, n_eval, n_past_guidance, 0))) {
LOG_TEE("%s : failed to eval\n", __func__);
return 1;
}
LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
- if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0), params.n_threads)) {
+ if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) {
LOG_TEE("%s : failed to eval\n", __func__);
return 1;
}
fflush(stderr);
const int n_ctx = llama_n_ctx(ctx);
- const int n_vocab = llama_n_vocab(ctx);
+ const int n_vocab = llama_n_vocab(model);
std::vector<client> clients(n_clients);
for (size_t i = 0; i < clients.size(); ++i) {
batch.logits[i] = false;
}
- if (llama_decode(ctx, batch, params.n_threads) != 0) {
+ if (llama_decode(ctx, batch) != 0) {
LOG_TEE("%s: llama_decode() failed\n", __func__);
return 1;
}
0, 0, 0, // unused
};
- const int ret = llama_decode(ctx, batch_view, params.n_threads);
+ const int ret = llama_decode(ctx, batch_view);
if (ret != 0) {
if (n_batch == 1 || ret < 0) {
// if you get here, it means the KV cache is full - try increasing it via the context size
// Output: `perplexity: 13.5106 [114/114]`
// BOS tokens will be added for each chunk before eval
- const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
+ const bool is_spm = llama_vocab_type(llama_get_model(ctx)) == LLAMA_VOCAB_TYPE_SPM;
const bool add_bos = is_spm;
fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, add_bos);
- if (int(tokens.size()) < 2*params.n_ctx) {
- fprintf(stderr, "%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*params.n_ctx,
- params.n_ctx);
+ const int n_ctx = llama_n_ctx(ctx);
+
+ if (int(tokens.size()) < 2*n_ctx) {
+ fprintf(stderr, "%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*n_ctx,
+ n_ctx);
fprintf(stderr, "%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size());
return {std::move(tokens), 0., {}, {}};
}
return {tokens, -1, logit_history, prob_history};
}
- const int calc_chunk = params.n_ctx;
+ const int calc_chunk = n_ctx;
fprintf(stderr, "%s: have %zu tokens. Calculation chunk = %d\n", __func__, tokens.size(), calc_chunk);
if (int(tokens.size()) <= calc_chunk) {
fprintf(stderr, "%s: there are only %zu tokens, this is not enough for a context size of %d and stride %d\n",__func__,
- tokens.size(), params.n_ctx, params.ppl_stride);
+ tokens.size(), n_ctx, params.ppl_stride);
return {tokens, -1, logit_history, prob_history};
}
const int n_chunk_max = (tokens.size() - calc_chunk + params.ppl_stride - 1) / params.ppl_stride;
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
- const int n_vocab = llama_n_vocab(ctx);
+ const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const int n_batch = params.n_batch;
int count = 0;
const int batch_size = std::min(end - batch_start, n_batch);
//fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
- if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0), params.n_threads)) {
+ if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
//fprintf(stderr, "%s : failed to eval\n", __func__);
return {tokens, -1, logit_history, prob_history};
}
}
//fprintf(stderr, "%s: using tokens %d...%d\n",__func__,params.n_ctx - params.ppl_stride + start, params.n_ctx + start);
- for (int j = params.n_ctx - params.ppl_stride - 1; j < params.n_ctx - 1; ++j) {
+ for (int j = n_ctx - params.ppl_stride - 1; j < n_ctx - 1; ++j) {
// Calculate probability of next token, given the previous ones.
const std::vector<float> tok_logits(
// Output: `perplexity: 13.5106 [114/114]`
// BOS tokens will be added for each chunk before eval
- const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
+ const bool is_spm = llama_vocab_type(llama_get_model(ctx)) == LLAMA_VOCAB_TYPE_SPM;
const bool add_bos = is_spm;
+ const int n_ctx = llama_n_ctx(ctx);
auto tim1 = std::chrono::high_resolution_clock::now();
fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
auto tim2 = std::chrono::high_resolution_clock::now();
fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count());
- if (int(tokens.size()) < 2*params.n_ctx) {
- fprintf(stderr, "%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*params.n_ctx,
- params.n_ctx);
+ if (int(tokens.size()) < 2*n_ctx) {
+ fprintf(stderr, "%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*n_ctx,
+ n_ctx);
fprintf(stderr, "%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size());
return {std::move(tokens), 0., {}, {}};
}
std::vector<float> prob_history;
prob_history.resize(tokens.size());
- const int n_chunk_max = tokens.size() / params.n_ctx;
+ const int n_chunk_max = tokens.size() / n_ctx;
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
- const int n_vocab = llama_n_vocab(ctx);
+ const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const int n_batch = params.n_batch;
int count = 0;
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
for (int i = 0; i < n_chunk; ++i) {
- const int start = i * params.n_ctx;
- const int end = start + params.n_ctx;
+ const int start = i * n_ctx;
+ const int end = start + n_ctx;
- const int num_batches = (params.n_ctx + n_batch - 1) / n_batch;
+ const int num_batches = (n_ctx + n_batch - 1) / n_batch;
std::vector<float> logits;
tokens[batch_start] = llama_token_bos(ctx);
}
- if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0), params.n_threads)) {
+ if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return {tokens, -1, logit_history, prob_history};
}
// restore the original token in case it was set to BOS
tokens[batch_start] = token_org;
- const auto batch_logits = llama_get_logits(ctx);
+ const auto * batch_logits = llama_get_logits(ctx);
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
}
// Example, we have a context window of 512, we will compute perplexity for each of the
// last 256 tokens. Then, we split the input up into context window size chunks to
// process the entire prompt.
- const int first = params.n_ctx/2;
- process_logits(n_vocab, logits.data() + first*n_vocab, tokens.data() + start + first, params.n_ctx - 1 - first,
+ const int first = n_ctx/2;
+ process_logits(n_vocab, logits.data() + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
- count += params.n_ctx - first - 1;
+ count += n_ctx - first - 1;
// perplexity is e^(average negative log-likelihood)
if (params.ppl_output_type == 0) {
double av = nll/count;
double av2 = nll2/count - av*av;
if (av2 > 0) av2 = sqrt(av2/(count-1));
- printf("%8d %.4lf %4lf %4lf\n", i*params.n_ctx, std::exp(nll / count), av, av2);
+ printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
}
fflush(stdout);
}
}
static std::vector<float> hellaswag_evaluate_tokens(
- llama_context * ctx, std::vector<int> & tokens, int n_past, int n_batch, int n_vocab, int n_thread
+ llama_context * ctx, std::vector<int> & tokens, int n_past, int n_batch, int n_vocab
) {
std::vector<float> result;
result.reserve(tokens.size() * n_vocab);
for (size_t i_chunk = 0; i_chunk < n_chunk; ++i_chunk) {
size_t n_tokens = tokens.size() - i_chunk * n_batch;
n_tokens = std::min(n_tokens, size_t(n_batch));
- if (llama_decode(ctx, llama_batch_get_one(tokens.data() + i_chunk * n_batch, n_tokens, n_past, 0), n_thread)) {
+ if (llama_decode(ctx, llama_batch_get_one(tokens.data() + i_chunk * n_batch, n_tokens, n_past, 0))) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return {};
}
size_t hs_task_count = prompt_lines.size()/6;
fprintf(stderr, "%s : loaded %zu tasks from prompt.\n", __func__, hs_task_count);
- const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
+ const bool is_spm = llama_vocab_type(llama_get_model(ctx)) == LLAMA_VOCAB_TYPE_SPM;
fprintf(stderr, "================================= is_spm = %d\n", is_spm);
// This is needed as usual for LLaMA models
printf("\ntask\tacc_norm\n");
double acc = 0.0f;
- const int n_vocab = llama_n_vocab(ctx);
+ const int n_vocab = llama_n_vocab(llama_get_model(ctx));
+ const int n_ctx = llama_n_ctx(ctx);
std::vector<std::vector<int>> ending_tokens(4);
auto query_size = query_embd.size();
// Stop if query wont fit the ctx window
- if (query_size > (size_t)params.n_ctx) {
+ if (query_size > (size_t)n_ctx) {
fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size);
return;
}
// clear the KV cache
llama_kv_cache_tokens_rm(ctx, -1, -1);
- auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab, params.n_threads);
+ auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab);
if (logits.empty()) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return;
query_size = query_embd.size();
// Stop if query wont fit the ctx window
- if (context_size + query_size > (size_t)params.n_ctx) {
+ if (context_size + query_size > (size_t)n_ctx) {
fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size);
return;
}
//}
// Evaluate the query
- logits = hellaswag_evaluate_tokens(ctx, query_embd, context_size, params.n_batch, n_vocab, params.n_threads);
+ logits = hellaswag_evaluate_tokens(ctx, query_embd, context_size, params.n_batch, n_vocab);
if (logits.empty()) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return;
return 1;
}
- const int n_ctx_train = llama_n_ctx_train(ctx);
+ const int n_ctx_train = llama_n_ctx_train(model);
if (params.n_ctx > n_ctx_train) {
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
__func__, n_ctx_train, params.n_ctx);
// print system information
{
fprintf(stderr, "\n");
- fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
- params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
+ fprintf(stderr, "%s\n", get_system_info(params).c_str());
}
struct results_perplexity results;
llama_context * ctx;
{
- auto lparams = llama_context_default_params();
+ auto mparams = llama_model_default_params();
+ mparams.use_mlock = false;
- lparams.n_ctx = 256;
- lparams.seed = 1;
- lparams.f16_kv = false;
- lparams.use_mlock = false;
-
- model = llama_load_model_from_file(params.model.c_str(), lparams);
+ model = llama_load_model_from_file(params.model.c_str(), mparams);
if (model == NULL) {
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
return 1;
}
- ctx = llama_new_context_with_model(model, lparams);
+ auto cparams = llama_context_default_params();
+ cparams.n_ctx = 256;
+ cparams.seed = 1;
+ cparams.f16_kv = false;
+
+ ctx = llama_new_context_with_model(model, cparams);
if (ctx == NULL) {
fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str());
params.n_predict = 16;
}
- auto lparams = llama_context_default_params();
-
- lparams.n_ctx = params.n_ctx;
- lparams.seed = params.seed;
- lparams.f16_kv = params.memory_f16;
- lparams.use_mmap = params.use_mmap;
- lparams.use_mlock = params.use_mlock;
-
auto n_past = 0;
auto last_n_tokens_data = std::vector<llama_token>(params.repeat_last_n, 0);
// init
- auto * model = llama_load_model_from_file(params.model.c_str(), lparams);
+ llama_model * model;
+ llama_context * ctx;
+
+ std::tie(model, ctx) = llama_init_from_gpt_params( params );
if (model == nullptr) {
return 1;
}
- auto * ctx = llama_new_context_with_model(model, lparams);
if (ctx == nullptr) {
llama_free_model(model);
return 1;
}
// evaluate prompt
- llama_decode(ctx, llama_batch_get_one(tokens.data(), n_prompt_tokens, n_past, 0), params.n_threads);
+ llama_decode(ctx, llama_batch_get_one(tokens.data(), n_prompt_tokens, n_past, 0));
last_n_tokens_data.insert(last_n_tokens_data.end(), tokens.data(), tokens.data() + n_prompt_tokens);
n_past += n_prompt_tokens;
for (auto i = 0; i < params.n_predict; i++) {
auto * logits = llama_get_logits(ctx);
- auto n_vocab = llama_n_vocab(ctx);
+ auto n_vocab = llama_n_vocab(model);
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
last_n_tokens_data.push_back(next_token);
printf("%s", next_token_str.c_str());
- if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0), params.n_threads)) {
+ if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0))) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_free(ctx);
llama_free_model(model);
llama_free(ctx);
// make new context
- auto * ctx2 = llama_new_context_with_model(model, lparams);
+ auto * ctx2 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
// Load state (rng, logits, embedding and kv_cache) from file
{
// second run
for (auto i = 0; i < params.n_predict; i++) {
auto * logits = llama_get_logits(ctx2);
- auto n_vocab = llama_n_vocab(ctx2);
+ auto n_vocab = llama_n_vocab(model);
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
last_n_tokens_data.push_back(next_token);
printf("%s", next_token_str.c_str());
- if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0), params.n_threads)) {
+ if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0))) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_free(ctx2);
llama_free_model(model);
Command line options:
-- `--threads N`, `-t N`: Set the number of threads to use during computation.
+- `--threads N`, `-t N`: Set the number of threads to use during generation.
+- `-tb N, --threads-batch N`: Set the number of threads to use during batch and prompt processing. If not specified, the number of threads will be set to the number of threads used for generation.
- `-m FNAME`, `--model FNAME`: Specify the path to the LLaMA model file (e.g., `models/7B/ggml-model.gguf`).
- `-m ALIAS`, `--alias ALIAS`: Set an alias for the model. The alias will be returned in API responses.
- `-c N`, `--ctx-size N`: Set the size of the prompt context. The default is 512, but LLaMA models were built with a context of 2048, which will provide better results for longer input/inference. The size may differ in other models, for example, baichuan models were build with a context of 4096.
- `-ngl N`, `--n-gpu-layers N`: When compiled with appropriate support (currently CLBlast or cuBLAS), this option allows offloading some layers to the GPU for computation. Generally results in increased performance.
- `-mg i, --main-gpu i`: When using multiple GPUs this option controls which GPU is used for small tensors for which the overhead of splitting the computation across all GPUs is not worthwhile. The GPU in question will use slightly more VRAM to store a scratch buffer for temporary results. By default GPU 0 is used. Requires cuBLAS.
- `-ts SPLIT, --tensor-split SPLIT`: When using multiple GPUs this option controls how large tensors should be split across all GPUs. `SPLIT` is a comma-separated list of non-negative values that assigns the proportion of data that each GPU should get in order. For example, "3,2" will assign 60% of the data to GPU 0 and 40% to GPU 1. By default the data is split in proportion to VRAM but this may not be optimal for performance. Requires cuBLAS.
-- `-lv, --low-vram`: Do not allocate a VRAM scratch buffer for holding temporary results. Reduces VRAM usage at the cost of performance, particularly prompt processing speed. Requires cuBLAS.
- `-b N`, `--batch-size N`: Set the batch size for prompt processing. Default: `512`.
- `--memory-f32`: Use 32-bit floats instead of 16-bit floats for memory key+value. Not recommended.
- `--mlock`: Lock the model in memory, preventing it from being swapped out when memory-mapped.
llama_model *model = nullptr;
llama_context *ctx = nullptr;
gpt_params params;
+ int n_ctx;
grammar_parser::parse_state parsed_grammar;
llama_grammar *grammar = nullptr;
num_prompt_tokens = 0;
num_tokens_predicted = 0;
generated_text = "";
- generated_text.reserve(params.n_ctx);
+ generated_text.reserve(n_ctx);
generated_token_probs.clear();
truncated = false;
stopped_eos = false;
LOG_ERROR("unable to load model", {{"model", params_.model}});
return false;
}
-
- last_n_tokens.resize(params.n_ctx);
+ n_ctx = llama_n_ctx(ctx);
+ last_n_tokens.resize(n_ctx);
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
return true;
}
{
params.n_keep = (int)num_prompt_tokens;
}
- params.n_keep = std::min(params.n_ctx - 4, params.n_keep);
+ params.n_keep = std::min(n_ctx - 4, params.n_keep);
// if input prompt is too big, truncate like normal
- if (num_prompt_tokens >= (size_t)params.n_ctx)
+ if (num_prompt_tokens >= (size_t)n_ctx)
{
- const int n_left = (params.n_ctx - params.n_keep) / 2;
+ const int n_left = (n_ctx - params.n_keep) / 2;
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
- std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin());
+ std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), last_n_tokens.begin());
LOG_VERBOSE("input truncated", {
- {"n_ctx", params.n_ctx},
+ {"n_ctx", n_ctx},
{"n_keep", params.n_keep},
{"n_left", n_left},
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
completion_token_output result;
result.tok = -1;
- if (embd.size() >= (size_t)params.n_ctx)
+ if (embd.size() >= (size_t)n_ctx)
{
// Shift context
truncated = true;
LOG_VERBOSE("input truncated", {
- {"n_ctx", params.n_ctx},
+ {"n_ctx", n_ctx},
{"n_keep", params.n_keep},
{"n_left", n_left},
});
n_eval = params.n_batch;
}
- if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval, n_past, 0), params.n_threads))
+ if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval, n_past, 0)))
{
LOG_ERROR("failed to eval", {
{"n_eval", n_eval},
{"n_past", n_past},
- {"n_threads", params.n_threads},
{"embd", tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend())},
});
has_next_token = false;
// out of user input, sample next token
const float temp = params.temp;
- const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
+ const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(model) : params.top_k;
const float top_p = params.top_p;
const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p;
- const int32_t repeat_last_n = params.repeat_last_n < 0 ? params.n_ctx : params.repeat_last_n;
+ const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
const float repeat_penalty = params.repeat_penalty;
const float alpha_presence = params.presence_penalty;
const float alpha_frequency = params.frequency_penalty;
{
auto *logits = llama_get_logits(ctx);
- auto n_vocab = llama_n_vocab(ctx);
+ auto n_vocab = llama_n_vocab(model);
// Apply params.logit_bias map
for (const auto &it : params.logit_bias)
// Apply penalties
float nl_logit = logits[llama_token_nl(ctx)];
- auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), params.n_ctx);
+ auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
llama_sample_repetition_penalty(ctx, &candidates_p,
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
last_n_repeat, repeat_penalty);
std::vector<float> getEmbedding()
{
- static const int n_embd = llama_n_embd(ctx);
+ static const int n_embd = llama_n_embd(model);
if (!params.embedding)
{
LOG_WARNING("embedding disabled", {
printf(" -ts SPLIT --tensor-split SPLIT\n");
printf(" how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
printf(" -mg i, --main-gpu i the GPU to use for scratch and small tensors\n");
- printf(" -lv, --low-vram don't allocate VRAM scratch buffer\n");
printf(" -nommq, --no-mul-mat-q\n");
printf(" use cuBLAS instead of custom mul_mat_q CUDA kernels.\n");
printf(" Not recommended since this is both slower and uses more VRAM.\n");
}
#else
LOG_WARNING("llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n", {});
-#endif // GGML_USE_CUBLAS
- }
- else if (arg == "--low-vram" || arg == "-lv")
- {
-#ifdef GGML_USE_CUBLAS
- params.low_vram = true;
-#else
- LOG_WARNING("warning: llama.cpp was compiled without cuBLAS. It is not possible to set lower vram usage.\n", {});
#endif // GGML_USE_CUBLAS
}
else if (arg == "--no-mul-mat-q" || arg == "-nommq")
eos_bias->second < 0.0f && std::isinf(eos_bias->second);
return json{
- {"n_ctx", llama.params.n_ctx},
+ {"n_ctx", llama.n_ctx},
{"model", llama.params.model_alias},
{"seed", llama.params.seed},
{"temp", llama.params.temp},
const auto &logit_bias = body.find("logit_bias");
if (logit_bias != body.end() && logit_bias->is_array())
{
- const int n_vocab = llama_n_vocab(llama.ctx);
+ const int n_vocab = llama_n_vocab(llama.model);
for (const auto &el : *logit_bias)
{
if (el.is_array() && el.size() == 2 && el[0].is_number_integer())
{"commit", BUILD_COMMIT}});
LOG_INFO("system info", {
{"n_threads", params.n_threads},
+ {"n_threads_batch", params.n_threads_batch},
{"total_threads", std::thread::hardware_concurrency()},
{"system_info", llama_print_system_info()},
});
if (llama.params.n_beams) {
// Fill llama.generated_token_probs vector with final beam.
llama_beam_search(llama.ctx, beam_search_callback, &llama, llama.params.n_beams,
- llama.n_past, llama.n_remain, llama.params.n_threads);
+ llama.n_past, llama.n_remain);
// Translate llama.generated_token_probs to llama.generated_text.
append_to_generated_text_from_generated_token_probs(llama);
} else {
llama_backend_init(params.numa);
- llama_context_params ctx_params = llama_context_default_params();
+ // initialize the model
- ctx_params.seed = 1234;
- ctx_params.n_ctx = 2048;
+ llama_model_params model_params = llama_model_default_params();
- llama_model * model = llama_load_model_from_file(params.model.c_str(), ctx_params);
+ // model_params.n_gpu_layers = 99; // offload all layers to the GPU
+
+ llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);
if (model == NULL) {
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
return 1;
}
+ // initialize the context
+
+ llama_context_params ctx_params = llama_context_default_params();
+
+ ctx_params.seed = 1234;
+ ctx_params.n_ctx = 2048;
+ ctx_params.n_threads = params.n_threads;
+ ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
+
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
if (ctx == NULL) {
// llama_decode will output logits only for the last token of the prompt
batch.logits[batch.n_tokens - 1] = true;
- if (llama_decode(ctx, batch, params.n_threads) != 0) {
+ if (llama_decode(ctx, batch) != 0) {
LOG_TEE("%s: llama_decode() failed\n", __func__);
return 1;
}
while (n_cur <= n_len) {
// sample the next token
{
- auto n_vocab = llama_n_vocab(ctx);
+ auto n_vocab = llama_n_vocab(model);
auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
std::vector<llama_token_data> candidates;
n_cur += 1;
// evaluate the current batch with the transformer model
- if (llama_decode(ctx, batch, params.n_threads)) {
+ if (llama_decode(ctx, batch)) {
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
return 1;
}
const auto t_enc_start = ggml_time_us();
// eval the prompt with both models
- llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1, 0, 0), params.n_threads);
- llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0), params.n_threads);
- llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input, 0, 0), params.n_threads);
+ llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1, 0, 0));
+ llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0));
+ llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input, 0, 0));
const auto t_enc_end = ggml_time_us();
// the 2 models should have the same vocab
const int n_ctx = llama_n_ctx(ctx_tgt);
- const int n_vocab = llama_n_vocab(ctx_tgt);
- //GGML_ASSERT(n_vocab == llama_n_vocab(ctx_dft));
+ const int n_vocab = llama_n_vocab(model_tgt);
+ //GGML_ASSERT(n_vocab == llama_n_vocab(model_dft));
// how many tokens to draft each time
int n_draft = params.n_draft;
}
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, n_ctx);
- llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0), params.n_threads);
+ llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0));
++n_past_dft;
// heuristic for n_draft
// evaluate the drafted token on the draft model
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_cur, n_ctx);
- llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0), params.n_threads);
+ llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0));
++n_past_cur;
if (grammar_dft != NULL) {
// evaluate the target model on the drafted tokens
llama_kv_cache_seq_rm(ctx_tgt, 0, n_past_tgt, n_ctx);
- llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0), params.n_threads);
+ llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0));
++n_past_tgt;
// the first token is always proposed by the traget model before the speculation loop
printf("%s: seed: %u\n", __func__, params.common.seed);
srand(params.common.seed);
- struct llama_context_params llama_params = llama_context_default_params();
- llama_params.vocab_only = true;
+ struct llama_model_params mparams = llama_model_default_params();
+ mparams.vocab_only = true;
- struct llama_model * lmodel = llama_load_model_from_file(params.fn_vocab_model, llama_params);
- struct llama_context * lctx = llama_new_context_with_model(lmodel, llama_params);
+ struct llama_context_params cparams = llama_context_default_params();
+
+ struct llama_model * lmodel = llama_load_model_from_file(params.fn_vocab_model, mparams);
+ struct llama_context * lctx = llama_new_context_with_model(lmodel, cparams);
struct my_llama_model model;
- model.hparams.n_vocab = llama_n_vocab(lctx);
+ model.hparams.n_vocab = llama_n_vocab(lmodel);
model.hparams.n_ctx = params.common.n_ctx;
model.hparams.n_embd = params.n_embd;
model.hparams.n_head = params.n_head;
+#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <limits>
static bool g_mul_mat_q = true;
static void * g_scratch_buffer = nullptr;
-static size_t g_scratch_size = 1024*1024*1024; // 1 GB by default
+static size_t g_scratch_size = 0; // disabled by default
static size_t g_scratch_offset = 0;
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
const int64_t ne1 = dst->ne[1];
// TODO: find the optimal values for these
- if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
- src1->type == GGML_TYPE_F32 &&
- dst->type == GGML_TYPE_F32 &&
- (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
- return true;
- }
-
- return false;
+ return (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
+ src1->type == GGML_TYPE_F32 &&
+ dst->type == GGML_TYPE_F32 &&
+ (ne0 >= 32 && ne1 >= 32 && ne10 >= 32);
}
static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
ne10, ne11, nb10, nb11, nb12, main_stream);
} else {
+ fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
+ ggml_type_name(src0->type), ggml_type_name(src1->type));
GGML_ASSERT(false);
}
}
void ggml_cuda_set_scratch_size(const size_t scratch_size) {
- g_scratch_size = scratch_size;
+ // this is a hack to not completely break llama.cpp when using multiple models or contexts simultaneously
+ // it still won't always work as expected, but it's better than nothing
+ if (scratch_size > g_scratch_size) {
+ ggml_cuda_free_scratch();
+ }
+ g_scratch_size = std::max(g_scratch_size, scratch_size);
}
void ggml_cuda_free_scratch() {
static std::string llama_token_to_str(const struct llama_context * ctx, llama_token token) {
std::vector<char> result(8, 0);
- const int n_tokens = llama_token_to_piece(ctx, token, result.data(), result.size());
+ const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
if (n_tokens < 0) {
result.resize(-n_tokens);
- int check = llama_token_to_piece(ctx, token, result.data(), result.size());
+ int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
GGML_ASSERT(check == -n_tokens);
} else {
result.resize(n_tokens);
static const size_t GB = kB*kB*kB;
struct llama_hparams {
+ bool vocab_only;
uint32_t n_vocab;
uint32_t n_ctx_train; // context size the model was trained on
- uint32_t n_ctx; // context size used during inference
uint32_t n_embd;
uint32_t n_head;
uint32_t n_head_kv;
float f_norm_eps;
float f_norm_rms_eps;
- float rope_freq_base;
- float rope_freq_scale;
+ float rope_freq_base_train;
+ float rope_freq_scale_train;
bool operator!=(const llama_hparams & other) const {
return static_cast<bool>(memcmp(this, &other, sizeof(llama_hparams))); // NOLINT
uint32_t n_embd_gqa() const {
return n_embd/n_gqa();
}
+};
- size_t kv_size() const {
- size_t result = 2ull;
- result *= (size_t) n_embd_gqa();
- result *= (size_t) n_ctx;
- result *= (size_t) n_layer;
- result *= sizeof(ggml_fp16_t);
- return result;
- }
+struct llama_cparams {
+ uint32_t n_ctx; // context size used during inference
+ uint32_t n_batch;
+ uint32_t n_threads; // number of threads to use for generation
+ uint32_t n_threads_batch; // number of threads to use for batch processing
+
+ float rope_freq_base;
+ float rope_freq_scale;
+
+ bool mul_mat_q;
};
struct llama_layer {
};
struct llama_context {
- llama_context(const llama_model & model) : model(model), t_load_us(model.t_load_us), t_start_us(model.t_start_us) {}
+ llama_context(const llama_model & model) : model(model), t_start_us(model.t_start_us), t_load_us(model.t_load_us) {}
~llama_context() {
- if (model_owner) {
- delete &model;
- }
#ifdef GGML_USE_METAL
if (ctx_metal) {
ggml_metal_free(ctx_metal);
}
}
+ llama_cparams cparams;
+
+ const llama_model & model;
+
+ // key + value cache for the self attention
+ struct llama_kv_cache kv_self;
+
std::mt19937 rng;
bool has_evaluated_once = false;
+ int64_t t_start_us;
+ int64_t t_load_us;
int64_t t_sample_us = 0;
- int64_t t_eval_us = 0;
int64_t t_p_eval_us = 0;
+ int64_t t_eval_us = 0;
int32_t n_sample = 0; // number of tokens sampled
- int32_t n_eval = 0; // number of eval calls
int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
-
- const llama_model & model;
-
- bool model_owner = false;
-
- int64_t t_load_us;
- int64_t t_start_us;
-
- // key + value cache for the self attention
- struct llama_kv_cache kv_self;
+ int32_t n_eval = 0; // number of eval calls
// decode output (2-dimensional array: [n_tokens][n_vocab])
std::vector<float> logits;
const struct llama_hparams & hparams,
struct llama_kv_cache & cache,
ggml_type wtype,
+ uint32_t n_ctx,
int n_gpu_layers) {
const uint32_t n_embd = hparams.n_embd_gqa();
const uint32_t n_layer = hparams.n_layer;
- const uint32_t n_ctx = hparams.n_ctx;
const int64_t n_mem = n_layer*n_ctx;
const int64_t n_elements = n_embd*n_mem;
(void) n_gpu_layers;
#ifdef GGML_USE_CUBLAS
+ size_t vram_kv_cache = 0;
+
if (n_gpu_layers > (int)n_layer + 1) {
ggml_cuda_assign_buffers_no_scratch(cache.v);
+ LLAMA_LOG_INFO("%s: offloading v cache to GPU\n", __func__);
+ vram_kv_cache += ggml_nbytes(cache.v);
}
if (n_gpu_layers > (int)n_layer + 2) {
ggml_cuda_assign_buffers_no_scratch(cache.k);
+ LLAMA_LOG_INFO("%s: offloading k cache to GPU\n", __func__);
+ vram_kv_cache += ggml_nbytes(cache.k);
+ }
+ if (vram_kv_cache > 0) {
+ LLAMA_LOG_INFO("%s: VRAM kv self = %.2f MB\n", __func__, vram_kv_cache / 1024.0 / 1024.0);
}
#endif // GGML_USE_CUBLAS
lmlock->grow_to(size_lock);
}
break;
-#if defined(GGML_USE_CUBLAS)
+#ifdef GGML_USE_CUBLAS
case GGML_BACKEND_GPU:
case GGML_BACKEND_GPU_SPLIT:
// old code:
// load LLaMA models
//
-static std::string llama_model_ftype_name(enum llama_ftype ftype) {
+static std::string llama_model_arch_name(llm_arch arch) {
+ auto it = LLM_ARCH_NAMES.find(arch);
+ if (it == LLM_ARCH_NAMES.end()) {
+ return "unknown";
+ }
+ return it->second;
+}
+
+static std::string llama_model_ftype_name(llama_ftype ftype) {
if (ftype & LLAMA_FTYPE_GUESSED) {
return llama_model_ftype_name((enum llama_ftype) (ftype & ~LLAMA_FTYPE_GUESSED)) + " (guessed)";
}
static void llm_load_hparams(
llama_model_loader & ml,
- llama_model & model,
- int n_ctx,
- float rope_freq_base,
- float rope_freq_scale) {
+ llama_model & model) {
struct gguf_context * ctx = ml.ctx_gguf;
const auto kv = LLM_KV(model.arch);
GGUF_GET_KEY(ctx, model.name, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_GENERAL_NAME));
// get hparams kv
- GGUF_GET_KEY(ctx, hparams.n_vocab, gguf_get_arr_n, GGUF_TYPE_ARRAY, true, kv(LLM_KV_TOKENIZER_LIST));
- GGUF_GET_KEY(ctx, hparams.n_ctx_train, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_CONTEXT_LENGTH));
- GGUF_GET_KEY(ctx, hparams.n_embd, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_EMBEDDING_LENGTH));
- GGUF_GET_KEY(ctx, hparams.n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_FEED_FORWARD_LENGTH));
- GGUF_GET_KEY(ctx, hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ATTENTION_HEAD_COUNT));
- GGUF_GET_KEY(ctx, hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_BLOCK_COUNT));
+ GGUF_GET_KEY(ctx, hparams.n_vocab, gguf_get_arr_n, GGUF_TYPE_ARRAY, true, kv(LLM_KV_TOKENIZER_LIST));
+ GGUF_GET_KEY(ctx, hparams.n_ctx_train, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_CONTEXT_LENGTH));
+ GGUF_GET_KEY(ctx, hparams.n_embd, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_EMBEDDING_LENGTH));
+ GGUF_GET_KEY(ctx, hparams.n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_FEED_FORWARD_LENGTH));
+ GGUF_GET_KEY(ctx, hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ATTENTION_HEAD_COUNT));
+ GGUF_GET_KEY(ctx, hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_BLOCK_COUNT));
// n_head_kv is optional, default to n_head
hparams.n_head_kv = hparams.n_head;
GGUF_GET_KEY(ctx, hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV));
// rope_freq_base (optional)
- if (rope_freq_base == 0.0f) {
- rope_freq_base = 10000.0f;
- GGUF_GET_KEY(ctx, rope_freq_base, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE));
- }
+ hparams.rope_freq_base_train = 10000.0f;
+ GGUF_GET_KEY(ctx, hparams.rope_freq_base_train, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE));
// rope_freq_scale (inverse of the kv) is optional
- if (rope_freq_scale == 0.0f) {
- float ropescale = 1.0f;
- GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
- rope_freq_scale = 1.0f/ropescale;
- }
+ float ropescale = 1.0f;
+ GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
+ hparams.rope_freq_scale_train = 1.0f/ropescale;
// sanity check for n_rot (optional)
{
};
model.ftype = ml.ftype;
-
- hparams.n_ctx = n_ctx;
- hparams.rope_freq_base = rope_freq_base;
- hparams.rope_freq_scale = rope_freq_scale;
}
// TODO: This should probably be in llama.h
const auto & vocab = model.vocab;
// hparams
- LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml.fver));
- LLAMA_LOG_INFO("%s: arch = %s\n", __func__, LLM_ARCH_NAMES.at(model.arch).c_str());
- LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : "BPE"); // TODO: fix
- LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab);
- LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (int) vocab.bpe_ranks.size());
- LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train);
- LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, hparams.n_ctx);
- LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd);
- LLAMA_LOG_INFO("%s: n_head = %u\n", __func__, hparams.n_head);
- LLAMA_LOG_INFO("%s: n_head_kv = %u\n", __func__, hparams.n_head_kv);
- LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer);
- LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim
- LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa());
- LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps);
- LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps);
- LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff);
- LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base);
- LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale);
- LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type));
- LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str());
- LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, ml.n_elements*1e-9);
+ LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml.fver));
+ LLAMA_LOG_INFO("%s: arch = %s\n", __func__, LLM_ARCH_NAMES.at(model.arch).c_str());
+ LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : "BPE"); // TODO: fix
+ LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab);
+ LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (int) vocab.bpe_ranks.size());
+ LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train);
+ LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd);
+ LLAMA_LOG_INFO("%s: n_head = %u\n", __func__, hparams.n_head);
+ LLAMA_LOG_INFO("%s: n_head_kv = %u\n", __func__, hparams.n_head_kv);
+ LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer);
+ LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim
+ LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa());
+ LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps);
+ LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps);
+ LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff);
+ LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train);
+ LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
+ LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type));
+ LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str());
+ LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, ml.n_elements*1e-9);
if (ml.n_bytes < GB) {
- LLAMA_LOG_INFO("%s: model size = %.2f MiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements);
+ LLAMA_LOG_INFO("%s: model size = %.2f MiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements);
} else {
- LLAMA_LOG_INFO("%s: model size = %.2f GiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements);
+ LLAMA_LOG_INFO("%s: model size = %.2f GiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements);
}
// general kv
static void llm_load_tensors(
llama_model_loader & ml,
llama_model & model,
- int n_batch,
int n_gpu_layers,
int main_gpu,
const float * tensor_split,
- const bool mul_mat_q,
- bool low_vram,
- ggml_type memory_type,
bool use_mlock,
llama_progress_callback progress_callback,
void * progress_callback_user_data) {
}
(void) main_gpu;
- (void) mul_mat_q;
-#if defined(GGML_USE_CUBLAS)
+#ifdef GGML_USE_CUBLAS
LLAMA_LOG_INFO("%s: using " GGML_CUDA_NAME " for GPU acceleration\n", __func__);
ggml_cuda_set_main_device(main_gpu);
- ggml_cuda_set_mul_mat_q(mul_mat_q);
#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_GPU
#define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_GPU_SPLIT
#elif defined(GGML_USE_CLBLAST)
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
// on Windows however this is detrimental unless everything is on the GPU
#ifndef _WIN32
- backend_norm = low_vram ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
+ backend_norm = LLAMA_BACKEND_OFFLOAD;
#else
- backend_norm = low_vram || n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
+ backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
#endif // _WIN32
backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
// on Windows however this is detrimental unless everything is on the GPU
#ifndef _WIN32
- backend_norm = low_vram ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
+ backend_norm = LLAMA_BACKEND_OFFLOAD;
#else
- backend_norm = low_vram || n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
+ backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
#endif // _WIN32
backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
// on Windows however this is detrimental unless everything is on the GPU
#ifndef _WIN32
- backend_norm = low_vram ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
+ backend_norm = LLAMA_BACKEND_OFFLOAD;
#else
- backend_norm = low_vram || n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
+ backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
#endif // _WIN32
backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
// on Windows however this is detrimental unless everything is on the GPU
#ifndef _WIN32
- backend_norm = low_vram ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
+ backend_norm = LLAMA_BACKEND_OFFLOAD;
#else
- backend_norm = low_vram || n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
+ backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
#endif // _WIN32
backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
// print memory requirements
{
- const size_t scale = memory_type == GGML_TYPE_F32 ? 2 : 1;
-
// this is the total memory required to run the inference
size_t mem_required =
ctx_size +
mmapped_size - vram_weights; // weights in VRAM not in memory
- // this is the memory required by one llama_state
- const size_t mem_required_state = scale*hparams.kv_size();
-
- LLAMA_LOG_INFO("%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__,
- mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0);
-
- (void) n_batch;
+ LLAMA_LOG_INFO("%s: mem required = %7.2f MB\n", __func__, mem_required / 1024.0 / 1024.0);
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer));
if (n_gpu_layers > (int) hparams.n_layer) {
LLAMA_LOG_INFO("%s: offloading non-repeating layers to GPU\n", __func__);
}
- size_t vram_kv_cache = 0;
#ifdef GGML_USE_CUBLAS
const int max_backend_supported_layers = hparams.n_layer + 3;
- const int max_offloadable_layers = low_vram ? hparams.n_layer + 1 : hparams.n_layer + 3;
- if (n_gpu_layers > (int) hparams.n_layer + 1) {
- if (low_vram) {
- LLAMA_LOG_INFO("%s: cannot offload v cache to GPU due to low VRAM option\n", __func__);
- } else {
- LLAMA_LOG_INFO("%s: offloading v cache to GPU\n", __func__);
- vram_kv_cache += hparams.kv_size() / 2;
- }
- }
- if (n_gpu_layers > (int) hparams.n_layer + 2) {
- if (low_vram) {
- LLAMA_LOG_WARN("%s: cannot offload k cache to GPU due to low VRAM option\n", __func__);
- } else {
- LLAMA_LOG_INFO("%s: offloading k cache to GPU\n", __func__);
- vram_kv_cache += hparams.kv_size() / 2;
- }
- }
+ const int max_offloadable_layers = hparams.n_layer + 3;
#elif defined(GGML_USE_CLBLAST)
const int max_backend_supported_layers = hparams.n_layer + 1;
const int max_offloadable_layers = hparams.n_layer + 1;
#endif // GGML_USE_CUBLAS
- LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n",
- __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers);
- LLAMA_LOG_INFO("%s: VRAM used: %zu MB\n",
- __func__, (vram_weights + vram_kv_cache + MB - 1) / MB); // round up
+ LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers);
+ LLAMA_LOG_INFO("%s: VRAM used: %.2f MB\n", __func__, vram_weights / 1024.0 / 1024.0);
#else
(void) n_gpu_layers;
#endif // defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
}
(void) tensor_split;
-#if defined(GGML_USE_CUBLAS)
+#ifdef GGML_USE_CUBLAS
{
ggml_cuda_set_tensor_split(tensor_split);
}
static bool llama_model_load(
const std::string & fname,
llama_model & model,
- int n_ctx,
- int n_batch,
int n_gpu_layers,
int main_gpu,
const float * tensor_split,
- const bool mul_mat_q,
- float rope_freq_base,
- float rope_freq_scale,
- bool low_vram,
- ggml_type memory_type,
bool use_mmap,
bool use_mlock,
bool vocab_only,
llama_progress_callback progress_callback,
void *progress_callback_user_data) {
try {
- std::unique_ptr<llama_model_loader> ml(new llama_model_loader(fname, use_mmap));
+ llama_model_loader ml(fname, use_mmap);
- llm_load_arch (*ml, model);
- llm_load_hparams(*ml, model, n_ctx, rope_freq_base, rope_freq_scale);
- llm_load_vocab (*ml, model);
+ model.hparams.vocab_only = vocab_only;
- llm_load_print_meta(*ml, model);
+ llm_load_arch (ml, model);
+ llm_load_hparams(ml, model);
+ llm_load_vocab (ml, model);
+
+ llm_load_print_meta(ml, model);
if (model.hparams.n_vocab != model.vocab.id_to_token.size()) {
throw std::runtime_error("vocab size mismatch");
}
llm_load_tensors(
- *ml, model, n_batch, n_gpu_layers,
- main_gpu, tensor_split, mul_mat_q, low_vram, memory_type,
+ ml, model, n_gpu_layers,
+ main_gpu, tensor_split,
use_mlock, progress_callback, progress_callback_user_data);
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("error loading model: %s\n", err.what());
const llama_batch & batch) {
const auto & model = lctx.model;
const auto & hparams = model.hparams;
+ const auto & cparams = lctx.cparams;
const auto & kv_self = lctx.kv_self;
const int64_t n_embd = hparams.n_embd;
const int64_t n_layer = hparams.n_layer;
- const int64_t n_ctx = hparams.n_ctx;
+ const int64_t n_ctx = cparams.n_ctx;
const int64_t n_head = hparams.n_head;
const int64_t n_head_kv = hparams.n_head_kv;
const int64_t n_embd_head = hparams.n_embd_head();
GGML_ASSERT(n_embd_head == hparams.n_rot);
- const float freq_base = hparams.rope_freq_base;
- const float freq_scale = hparams.rope_freq_scale;
+ const float freq_base = cparams.rope_freq_base;
+ const float freq_scale = cparams.rope_freq_scale;
const float norm_rms_eps = hparams.f_norm_rms_eps;
const int n_gpu_layers = model.n_gpu_layers;
// offload functions set the tensor output backend to GPU
// tensors are GPU-accelerated if any input or the output has been offloaded
- //
- // with the low VRAM option VRAM scratch is disabled in llama_load_model_internal
- // in that case ggml_cuda_assign_buffers has no effect
offload_func_t offload_func_nr = llama_nop; // nr = non-repeating
offload_func_t offload_func_kq = llama_nop;
offload_func_t offload_func_v = llama_nop;
const llama_batch & batch) {
const auto & model = lctx.model;
const auto & hparams = model.hparams;
+ const auto & cparams = lctx.cparams;
const auto & kv_self = lctx.kv_self;
const int64_t n_embd = hparams.n_embd;
const int64_t n_layer = hparams.n_layer;
- const int64_t n_ctx = hparams.n_ctx;
+ const int64_t n_ctx = cparams.n_ctx;
const int64_t n_head = hparams.n_head;
const int64_t n_head_kv = hparams.n_head_kv;
const int64_t n_embd_head = hparams.n_embd_head();
GGML_ASSERT(n_embd_head == hparams.n_rot);
- const float freq_base = hparams.rope_freq_base;
- const float freq_scale = hparams.rope_freq_scale;
+ const float freq_base = cparams.rope_freq_base;
+ const float freq_scale = cparams.rope_freq_scale;
const float norm_rms_eps = hparams.f_norm_rms_eps;
const int n_gpu_layers = model.n_gpu_layers;
// offload functions set the tensor output backend to GPU
// tensors are GPU-accelerated if any input or the output has been offloaded
- //
- // with the low VRAM option VRAM scratch is disabled in llama_load_model_internal
- // in that case ggml_cuda_assign_buffers has no effect
offload_func_t offload_func_nr = llama_nop; // nr = non-repeating
offload_func_t offload_func_kq = llama_nop;
offload_func_t offload_func_v = llama_nop;
const llama_batch & batch) {
const auto & model = lctx.model;
const auto & hparams = model.hparams;
+ const auto & cparams = lctx.cparams;
const auto & kv_self = lctx.kv_self;
const int64_t n_embd = hparams.n_embd;
const int64_t n_layer = hparams.n_layer;
- const int64_t n_ctx = hparams.n_ctx;
+ const int64_t n_ctx = cparams.n_ctx;
const int64_t n_head = hparams.n_head;
const int64_t n_head_kv = hparams.n_head_kv;
const int64_t n_embd_head = hparams.n_embd_head();
GGML_ASSERT(n_embd_head == hparams.n_rot);
- const float freq_base = hparams.rope_freq_base;
- const float freq_scale = hparams.rope_freq_scale;
+ const float freq_base = cparams.rope_freq_base;
+ const float freq_scale = cparams.rope_freq_scale;
const float norm_eps = hparams.f_norm_eps;
const int n_gpu_layers = model.n_gpu_layers;
// offload functions set the tensor output backend to GPU
// tensors are GPU-accelerated if any input or the output has been offloaded
- //
- // with the low VRAM option VRAM scratch is disabled in llama_load_model_internal
- // in that case ggml_cuda_assign_buffers has no effect
offload_func_t offload_func_nr = llama_nop; // nr = non-repeating
offload_func_t offload_func_kq = llama_nop;
offload_func_t offload_func_v = llama_nop;
const llama_batch & batch) {
const auto & model = lctx.model;
const auto & hparams = model.hparams;
+ const auto & cparams = lctx.cparams;
const auto & kv_self = lctx.kv_self;
const int64_t n_embd = hparams.n_embd;
const int64_t n_layer = hparams.n_layer;
- const int64_t n_ctx = hparams.n_ctx;
+ const int64_t n_ctx = cparams.n_ctx;
const int64_t n_head = hparams.n_head;
const int64_t n_head_kv = hparams.n_head_kv;
const int64_t n_embd_head = hparams.n_embd_head();
//
static int llama_decode_internal(
llama_context & lctx,
- llama_batch batch,
- int n_threads) {
+ llama_batch batch) {
const uint32_t n_tokens = batch.n_tokens;
if (n_tokens == 0) {
return -1;
}
+ const auto & model = lctx.model;
+ const auto & hparams = model.hparams;
+ const auto & cparams = lctx.cparams;
+
+ const auto n_batch = cparams.n_batch;
+
+ GGML_ASSERT(n_tokens <= n_batch);
+
+ int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
const int64_t t_start_us = ggml_time_us();
GGML_ASSERT(n_threads > 0);
- const auto & model = lctx.model;
- const auto & hparams = model.hparams;
-
auto & kv_self = lctx.kv_self;
GGML_ASSERT(!!kv_self.ctx);
// after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important
//kv_self.n = std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)); // TODO: this might be better for CUDA?
- kv_self.n = std::min((int32_t) hparams.n_ctx, std::max(32, llama_kv_cache_cell_max(kv_self)));
+ kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, llama_kv_cache_cell_max(kv_self)));
//printf("kv_self.n = %d\n", kv_self.n);
ggml_cuda_assign_scratch_offset(node, (char*)node->data - (char *) lctx.buf_alloc.data);
}
}
+
+ ggml_cuda_set_mul_mat_q(cparams.mul_mat_q);
#endif
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
GGML_ASSERT(ctx);
- auto n_vocab = llama_n_vocab(ctx);
+ auto n_vocab = llama_n_vocab(llama_get_model(ctx));
GGML_ASSERT(n_vocab == (int)candidates->size);
GGML_ASSERT(!candidates->sorted);
llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu) {
GGML_ASSERT(ctx);
- auto N = float(llama_n_vocab(ctx));
+ auto N = float(llama_n_vocab(llama_get_model(ctx)));
int64_t t_start_sample_us;
t_start_sample_us = ggml_time_us();
};
llama_logit_info(llama_context * ctx)
: logits(llama_get_logits(ctx))
- , n_vocab(llama_n_vocab(ctx))
+ , n_vocab(llama_n_vocab(llama_get_model(ctx)))
, max_l(*std::max_element(logits, logits + n_vocab))
, normalizer(1.0f / std::accumulate(logits, logits + n_vocab, 0.0f, sum_exp{max_l}))
{ }
size_t n_beams;
int n_past;
int n_predict;
- int n_threads;
std::vector<llama_beam> beams;
std::vector<llama_beam> next_beams;
// Used to communicate to/from callback on beams state.
std::vector<llama_beam_view> beam_views;
- llama_beam_search_data(llama_context * ctx, size_t n_beams, int n_past, int n_predict, int n_threads)
+ llama_beam_search_data(llama_context * ctx, size_t n_beams, int n_past, int n_predict)
: ctx(ctx)
, n_beams(n_beams)
, n_past(n_past)
, n_predict(n_predict)
- , n_threads(n_threads)
, beam_views(n_beams) {
beams.reserve(n_beams);
next_beams.reserve(n_beams);
} else {
// beam is not at end-of-sentence, so branch with next top_k tokens.
if (!beam.tokens.empty()) {
- llama_decode(ctx, llama_batch_get_one(beam.tokens.data(), beam.tokens.size(), n_past, 0), n_threads);
+ llama_decode(ctx, llama_batch_get_one(beam.tokens.data(), beam.tokens.size(), n_past, 0));
}
llama_logit_info logit_info(ctx);
std::vector<llama_token_data> next_tokens = logit_info.top_k(n_beams);
callback(callback_data, get_beams_state(false)); // Sets common_prefix_length
update_beams_from_beam_views(); // Update values (p,eob) that callback may have changed.
if (common_prefix_length) {
- llama_decode(ctx, llama_batch_get_one(beams[0].tokens.data(), common_prefix_length, n_past, 0), n_threads);
+ llama_decode(ctx, llama_batch_get_one(beams[0].tokens.data(), common_prefix_length, n_past, 0));
n_past += common_prefix_length;
}
// Zero-out next_beam probabilities to place them last in following min-heap.
void llama_beam_search(llama_context * ctx,
llama_beam_search_callback_fn_t callback, void * callback_data,
- size_t n_beams, int n_past, int n_predict, int n_threads) {
+ size_t n_beams, int n_past, int n_predict) {
assert(ctx);
const int64_t t_start_sample_us = ggml_time_us();
- llama_beam_search_data beam_search_data(ctx, n_beams, n_past, n_predict, n_threads);
+ llama_beam_search_data beam_search_data(ctx, n_beams, n_past, n_predict);
beam_search_data.loop(callback, callback_data);
nthread = std::thread::hardware_concurrency();
}
- std::unique_ptr<llama_model_loader> ml(new llama_model_loader(fname_inp, /*use_mmap*/ false));
+ llama_model_loader ml(fname_inp, /*use_mmap*/ false);
llama_model model;
- llm_load_arch(*ml, model);
- llm_load_hparams(*ml, model, 0, 0, 0);
+ llm_load_arch(ml, model);
+ llm_load_hparams(ml, model);
if (params->only_copy) {
ftype = model.ftype;
struct gguf_context * ctx_out = gguf_init_empty();
// copy the KV pairs from the input file
- gguf_set_kv (ctx_out, ml->ctx_gguf);
+ gguf_set_kv (ctx_out, ml.ctx_gguf);
gguf_set_val_u32(ctx_out, "general.quantization_version", GGML_QNT_VERSION);
gguf_set_val_u32(ctx_out, "general.file_type", ftype);
int n_attention_wv = 0;
int n_feed_forward_w2 = 0;
- for (int i = 0; i < ml->n_tensors; ++i) {
- struct ggml_tensor * meta = ml->get_tensor_meta(i);
+ for (int i = 0; i < ml.n_tensors; ++i) {
+ struct ggml_tensor * meta = ml.get_tensor_meta(i);
const std::string name = ggml_get_name(meta);
std::vector<no_init<float>> f32_conv_buf;
// populate the original tensors so we get an initial meta data
- for (int i = 0; i < ml->n_tensors; ++i) {
- struct ggml_tensor * meta = ml->get_tensor_meta(i);
+ for (int i = 0; i < ml.n_tensors; ++i) {
+ struct ggml_tensor * meta = ml.get_tensor_meta(i);
gguf_add_tensor(ctx_out, meta);
}
// placeholder for the meta data
::zeros(fout, meta_size);
- for (int i = 0; i < ml->n_tensors; ++i) {
- struct ggml_tensor * tensor = ml->get_tensor_meta(i);
+ for (int i = 0; i < ml.n_tensors; ++i) {
+ struct ggml_tensor * tensor = ml.get_tensor_meta(i);
const std::string name = ggml_get_name(tensor);
read_data.resize(ggml_nbytes(tensor));
}
tensor->data = read_data.data();
- ml->load_data_for(tensor);
+ ml.load_data_for(tensor);
LLAMA_LOG_INFO("[%4d/%4d] %36s - [%s], type = %6s, ",
- ++idx, ml->n_tensors,
+ ++idx, ml.n_tensors,
ggml_get_name(tensor),
llama_format_tensor_shape(tensor).c_str(),
ggml_type_name(tensor->type));
}
}
-// TODO: after the GGUF PR, this likely won't work and needs to be updated
static int llama_apply_lora_from_file_internal(
const struct llama_model & model, const char * path_lora, float scale, const char * path_base_model, int n_threads
) {
//
// interface implementation
//
-
-struct llama_context_params llama_context_default_params() {
- struct llama_context_params result = {
- /*.seed =*/ LLAMA_DEFAULT_SEED,
- /*.n_ctx =*/ 512,
- /*.n_batch =*/ 512,
+struct llama_model_params llama_model_default_params() {
+ struct llama_model_params result = {
/*.n_gpu_layers =*/ 0,
/*.main_gpu =*/ 0,
/*.tensor_split =*/ nullptr,
- /*.rope_freq_base =*/ 0.0f,
- /*.rope_freq_scale =*/ 0.0f,
/*.progress_callback =*/ nullptr,
/*.progress_callback_user_data =*/ nullptr,
- /*.low_vram =*/ false,
- /*.mul_mat_q =*/ true,
- /*.f16_kv =*/ true,
- /*.logits_all =*/ false,
/*.vocab_only =*/ false,
/*.use_mmap =*/ true,
/*.use_mlock =*/ false,
- /*.embedding =*/ false,
};
#ifdef GGML_USE_METAL
return result;
}
+struct llama_context_params llama_context_default_params() {
+ struct llama_context_params result = {
+ /*.seed =*/ LLAMA_DEFAULT_SEED,
+ /*.n_ctx =*/ 512,
+ /*.n_batch =*/ 512,
+ /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
+ /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
+ /*.rope_freq_base =*/ 0.0f,
+ /*.rope_freq_scale =*/ 0.0f,
+ /*.mul_mat_q =*/ true,
+ /*.f16_kv =*/ true,
+ /*.logits_all =*/ false,
+ /*.embedding =*/ false,
+ };
+
+ return result;
+}
+
struct llama_model_quantize_params llama_model_quantize_default_params() {
struct llama_model_quantize_params result = {
/*.nthread =*/ 0,
struct llama_model * llama_load_model_from_file(
const char * path_model,
- struct llama_context_params params) {
+ struct llama_model_params params) {
ggml_time_init();
llama_model * model = new llama_model;
- ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
-
unsigned cur_percentage = 0;
if (params.progress_callback == NULL) {
params.progress_callback_user_data = &cur_percentage;
};
}
- if (!llama_model_load(path_model, *model, params.n_ctx, params.n_batch, params.n_gpu_layers,
- params.main_gpu, params.tensor_split, params.mul_mat_q, params.rope_freq_base, params.rope_freq_scale,
- params.low_vram, memory_type, params.use_mmap, params.use_mlock, params.vocab_only,
+ if (!llama_model_load(path_model, *model, params.n_gpu_layers,
+ params.main_gpu, params.tensor_split,
+ params.use_mmap, params.use_mlock, params.vocab_only,
params.progress_callback, params.progress_callback_user_data)) {
LLAMA_LOG_ERROR("%s: failed to load model\n", __func__);
delete model;
llama_context * ctx = new llama_context(*model);
+ const auto & hparams = model->hparams;
+ auto & cparams = ctx->cparams;
+
+ cparams.n_batch = params.n_batch;
+ cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
+ cparams.rope_freq_base = params.rope_freq_base == 0 ? hparams.rope_freq_base_train : params.rope_freq_base;
+ cparams.rope_freq_scale = params.rope_freq_scale == 0 ? hparams.rope_freq_scale_train : params.rope_freq_scale;
+ cparams.n_threads = params.n_threads;
+ cparams.n_threads_batch = params.n_threads_batch;
+ cparams.mul_mat_q = params.mul_mat_q;
+
if (params.seed == LLAMA_DEFAULT_SEED) {
params.seed = time(NULL);
}
+ LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
+ LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
+ LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
+
ctx->rng = std::mt19937(params.seed);
ctx->logits_all = params.logits_all;
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
// reserve memory for context buffers
- if (!params.vocab_only) {
- if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, memory_type, params.n_gpu_layers)) {
+ if (!hparams.vocab_only) {
+ if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, memory_type, cparams.n_ctx, model->n_gpu_layers)) {
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
llama_free(ctx);
return nullptr;
LLAMA_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
}
- const auto & hparams = ctx->model.hparams;
-
// resized during inference
if (params.logits_all) {
- ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab);
+ ctx->logits.reserve(cparams.n_ctx*hparams.n_vocab);
} else {
ctx->logits.reserve(hparams.n_vocab);
}
ctx->alloc = ggml_allocr_new_measure(tensor_alignment);
// build worst-case graph
- const uint32_t n_tokens = std::min((int) hparams.n_ctx, params.n_batch);
+ int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_batch);
+ int n_past = cparams.n_ctx - n_tokens;
llama_token token = llama_token_bos(ctx); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
- ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, hparams.n_ctx - n_tokens, 0));
+ ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, n_past, 0));
#ifdef GGML_USE_METAL
- if (params.n_gpu_layers > 0) {
+ if (model->n_gpu_layers > 0) {
ctx->ctx_metal = ggml_metal_init(1);
if (!ctx->ctx_metal) {
LLAMA_LOG_ERROR("%s: ggml_metal_init() failed\n", __func__);
// measure memory requirements for the graph
size_t alloc_size = ggml_allocr_alloc_graph(ctx->alloc, gf) + tensor_alignment;
- LLAMA_LOG_INFO("%s: compute buffer total size = %7.2f MB\n", __func__, (ctx->buf_compute.size + alloc_size) / 1024.0 / 1024.0);
+ LLAMA_LOG_INFO("%s: compute buffer total size = %.2f MB\n", __func__, (ctx->buf_compute.size + alloc_size) / 1024.0 / 1024.0);
// recreate allocator with exact memory requirements
ggml_allocr_free(ctx->alloc);
}
#endif
#ifdef GGML_USE_CUBLAS
- if (params.low_vram) {
- LLAMA_LOG_INFO("%s: not allocating a VRAM scratch buffer due to low VRAM option\n", __func__);
- ggml_cuda_set_scratch_size(0); // disable scratch
- } else {
- ggml_cuda_set_scratch_size(alloc_size);
- LLAMA_LOG_INFO("%s: VRAM scratch buffer: %.2f MB\n", __func__, alloc_size / 1024.0 / 1024.0);
+ ggml_cuda_set_scratch_size(alloc_size);
+ LLAMA_LOG_INFO("%s: VRAM scratch buffer: %.2f MB\n", __func__, alloc_size / 1024.0 / 1024.0);
+
+ // calculate total VRAM usage
+ auto add_tensor = [](const ggml_tensor * t, size_t & size) {
+ if (t->backend == GGML_BACKEND_GPU || t->backend == GGML_BACKEND_GPU_SPLIT) {
+ size += ggml_nbytes(t);
+ }
+ };
+ size_t model_vram_size = 0;
+ for (const auto & kv : model->tensors_by_name) {
+ add_tensor(kv.second, model_vram_size);
}
+
+ size_t kv_vram_size = 0;
+ add_tensor(ctx->kv_self.k, kv_vram_size);
+ add_tensor(ctx->kv_self.v, kv_vram_size);
+
+ size_t ctx_vram_size = alloc_size + kv_vram_size;
+ size_t total_vram_size = model_vram_size + ctx_vram_size;
+
+ LLAMA_LOG_INFO("%s: total VRAM used: %.2f MB (model: %.2f MB, context: %.2f MB)\n", __func__,
+ total_vram_size / 1024.0 / 1024.0,
+ model_vram_size / 1024.0 / 1024.0,
+ ctx_vram_size / 1024.0 / 1024.0);
#endif
}
#ifdef GGML_USE_METAL
- if (params.n_gpu_layers > 0) {
+ if (model->n_gpu_layers > 0) {
// this allocates all Metal resources and memory buffers
void * data_ptr = NULL;
size_t data_size = 0;
- if (params.use_mmap) {
+ if (ctx->model.mapping) {
data_ptr = ctx->model.mapping->addr;
data_size = ctx->model.mapping->size;
} else {
return NULL; \
}
- LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "data", data_ptr, data_size, max_size));
-
- LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "eval", ctx->buf_compute.data, ctx->buf_compute.size, 0));
- LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "kv", ctx->kv_self.buf.data, ctx->kv_self.buf.size, 0));
-
+ LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "data", data_ptr, data_size, max_size));
+ LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "kv", ctx->kv_self.buf.data, ctx->kv_self.buf.size, 0));
LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "alloc", ctx->buf_alloc.data, ctx->buf_alloc.size, 0));
#undef LLAMA_METAL_CHECK_BUF
}
return ctx;
}
-static struct llama_context * llama_init_from_file(
- const char * path_model,
- struct llama_context_params params) {
- struct llama_model * model = llama_load_model_from_file(path_model, params);
- if (!model) {
- return nullptr;
- }
-
- struct llama_context * ctx = llama_new_context_with_model(model, params);
- ctx->model_owner = true;
-
- return ctx;
-}
-
void llama_free(struct llama_context * ctx) {
delete ctx;
}
-int llama_n_vocab(const struct llama_context * ctx) {
- return llama_model_n_vocab(&ctx->model);
+const llama_model * llama_get_model(const struct llama_context * ctx) {
+ return &ctx->model;
}
int llama_n_ctx(const struct llama_context * ctx) {
- return llama_model_n_ctx(&ctx->model);
-}
-
-int llama_n_ctx_train(const struct llama_context * ctx) {
- return llama_model_n_ctx_train(&ctx->model);
+ return ctx->cparams.n_ctx;
}
-int llama_n_embd(const struct llama_context * ctx) {
- return llama_model_n_embd(&ctx->model);
+enum llama_vocab_type llama_vocab_type(const struct llama_model * model) {
+ return model->vocab.type;
}
-enum llama_vocab_type llama_vocab_type(const struct llama_context * ctx) {
- return ctx->model.vocab.type;
-}
-
-int llama_model_n_vocab(const struct llama_model * model) {
+int llama_n_vocab(const struct llama_model * model) {
return model->vocab.id_to_token.size();
}
-int llama_model_n_ctx(const struct llama_model * model) {
- return model->hparams.n_ctx;
-}
-
-int llama_model_n_ctx_train(const struct llama_model * model) {
+int llama_n_ctx_train(const struct llama_model * model) {
return model->hparams.n_ctx_train;
}
-int llama_model_n_embd(const struct llama_model * model) {
+int llama_n_embd(const struct llama_model * model) {
return model->hparams.n_embd;
}
int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size) {
return snprintf(buf, buf_size, "%s %s %s",
- model->name.c_str(),
+ llama_model_arch_name(model->arch).c_str(),
llama_model_type_name(model->type),
llama_model_ftype_name(model->ftype).c_str());
}
{
const auto & kv_self = ctx->kv_self;
const auto & hparams = ctx->model.hparams;
+ const auto & cparams = ctx->cparams;
+
const int n_layer = hparams.n_layer;
const int n_embd = hparams.n_embd_gqa();
- const int n_ctx = hparams.n_ctx;
+ const int n_ctx = cparams.n_ctx;
const size_t kv_size = kv_self.buf.size;
const int kv_ntok = kv_self.head;
{
const auto & kv_self = ctx->kv_self;
const auto & hparams = ctx->model.hparams;
+ const auto & cparams = ctx->cparams;
+
const int n_layer = hparams.n_layer;
const int n_embd = hparams.n_embd_gqa();
- const int n_ctx = hparams.n_ctx;
+ const int n_ctx = cparams.n_ctx;
size_t kv_size;
int kv_ntok;
struct llama_context * ctx,
llama_token * tokens,
int32_t n_tokens,
- int n_past,
- int n_threads) {
+ int n_past) {
llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1);
- const int ret = llama_decode_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0), n_threads);
+ const int ret = llama_decode_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0));
if (ret < 0) {
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
}
struct llama_context * ctx,
float * embd,
int32_t n_tokens,
- int n_past,
- int n_threads) {
+ int n_past) {
llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1);
llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, n_past, 1, 0, };
- const int ret = llama_decode_internal(*ctx, batch, n_threads);
+ const int ret = llama_decode_internal(*ctx, batch);
if (ret < 0) {
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
}
return ret;
}
+void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch) {
+ ctx->cparams.n_threads = n_threads;
+ ctx->cparams.n_threads_batch = n_threads_batch;
+}
+
struct llama_batch llama_batch_get_one(
llama_token * tokens,
int32_t n_tokens,
int llama_decode(
struct llama_context * ctx,
- struct llama_batch batch,
- int n_threads) {
- const int ret = llama_decode_internal(*ctx, batch, n_threads);
+ struct llama_batch batch) {
+ const int ret = llama_decode_internal(*ctx, batch);
if (ret < 0) {
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
}
}
int llama_tokenize(
- struct llama_context * ctx,
- const char * text,
- int text_len,
- llama_token * tokens,
- int n_max_tokens,
- bool add_bos) {
- return llama_tokenize_with_model(&ctx->model, text, text_len, tokens, n_max_tokens, add_bos);
-}
-
-int llama_tokenize_with_model(
const struct llama_model * model,
const char * text,
int text_len,
return res.size();
}
-int llama_token_to_piece(const struct llama_context * ctx, llama_token token, char * buf, int length) {
- return llama_token_to_piece_with_model(&ctx->model, token, buf, length);
-}
-
// does not write null-terminator to buf
-int llama_token_to_piece_with_model(const struct llama_model * model, llama_token token, char * buf, int length) {
- if (0 <= token && token < llama_model_n_vocab(model)) {
+int llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int length) {
+ if (0 <= token && token < llama_n_vocab(model)) {
if (llama_is_normal_token(model->vocab, token)) {
std::string result = model->vocab.id_to_token[token].text;
if (llama_vocab_get_type(model->vocab) == LLAMA_VOCAB_TYPE_SPM) {
llama_seq_id all_seq_id; // used if seq_id == NULL
} llama_batch;
- struct llama_context_params {
- uint32_t seed; // RNG seed, -1 for random
- int32_t n_ctx; // text context
- int32_t n_batch; // prompt processing batch size
- int32_t n_gpu_layers; // number of layers to store in VRAM
- int32_t main_gpu; // the GPU that is used for scratch and small tensors
-
+ struct llama_model_params {
+ int32_t n_gpu_layers; // number of layers to store in VRAM
+ int32_t main_gpu; // the GPU that is used for scratch and small tensors
const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)
- // ref: https://github.com/ggerganov/llama.cpp/pull/2054
- float rope_freq_base; // RoPE base frequency
- float rope_freq_scale; // RoPE frequency scaling factor
-
// called with a progress value between 0 and 1, pass NULL to disable
llama_progress_callback progress_callback;
// context pointer passed to the progress callback
void * progress_callback_user_data;
// Keep the booleans together to avoid misalignment during copy-by-value.
- bool low_vram; // if true, reduce VRAM usage at the cost of performance
- bool mul_mat_q; // if true, use experimental mul_mat_q kernels
- bool f16_kv; // use fp16 for KV cache
- bool logits_all; // the llama_eval() call computes all logits, not just the last one
bool vocab_only; // only load the vocabulary, no weights
bool use_mmap; // use mmap if possible
bool use_mlock; // force system to keep model in RAM
+ };
+
+ struct llama_context_params {
+ uint32_t seed; // RNG seed, -1 for random
+ uint32_t n_ctx; // text context
+ uint32_t n_batch; // prompt processing batch size
+ uint32_t n_threads; // number of threads to use for generation
+ uint32_t n_threads_batch; // number of threads to use for batch processing
+
+ // ref: https://github.com/ggerganov/llama.cpp/pull/2054
+ float rope_freq_base; // RoPE base frequency
+ float rope_freq_scale; // RoPE frequency scaling factor
+
+ // Keep the booleans together to avoid misalignment during copy-by-value.
+ bool mul_mat_q; // if true, use experimental mul_mat_q kernels
+ bool f16_kv; // use fp16 for KV cache
+ bool logits_all; // the llama_eval() call computes all logits, not just the last one
bool embedding; // embedding mode only
};
};
// Helpers for getting default parameters
+ LLAMA_API struct llama_model_params llama_model_default_params(void);
LLAMA_API struct llama_context_params llama_context_default_params(void);
LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void);
LLAMA_API struct llama_model * llama_load_model_from_file(
const char * path_model,
- struct llama_context_params params);
+ struct llama_model_params params);
LLAMA_API void llama_free_model(struct llama_model * model);
LLAMA_API bool llama_mmap_supported (void);
LLAMA_API bool llama_mlock_supported(void);
- LLAMA_API int llama_n_vocab (const struct llama_context * ctx);
+ LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
+
LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
- LLAMA_API int llama_n_ctx_train(const struct llama_context * ctx);
- LLAMA_API int llama_n_embd (const struct llama_context * ctx);
- LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_context * ctx);
+ LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
- LLAMA_API int llama_model_n_vocab (const struct llama_model * model);
- LLAMA_API int llama_model_n_ctx (const struct llama_model * model);
- LLAMA_API int llama_model_n_ctx_train(const struct llama_model * model);
- LLAMA_API int llama_model_n_embd (const struct llama_model * model);
+ LLAMA_API int llama_n_vocab (const struct llama_model * model);
+ LLAMA_API int llama_n_ctx_train(const struct llama_model * model);
+ LLAMA_API int llama_n_embd (const struct llama_model * model);
// Get a string describing the model type
LLAMA_API int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size);
struct llama_context * ctx,
llama_token * tokens,
int32_t n_tokens,
- int n_past,
- int n_threads),
+ int n_past),
"use llama_decode() instead");
// Same as llama_eval, but use float matrix input directly.
struct llama_context * ctx,
float * embd,
int32_t n_tokens,
- int n_past,
- int n_threads),
+ int n_past),
"use llama_decode() instead");
// Return batch for single sequence of tokens starting at pos_0
// < 0 - error
LLAMA_API int llama_decode(
struct llama_context * ctx,
- struct llama_batch batch,
- int n_threads);
+ struct llama_batch batch);
+
+ // Set the number of threads used for decoding
+ // n_threads is the number of threads used for generation (single token)
+ // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
+ LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch);
// Token logits obtained from the last call to llama_eval()
// The logits for the last token are stored in the last row
// Returns the number of tokens on success, no more than n_max_tokens
// Returns a negative number on failure - the number of tokens that would have been returned
LLAMA_API int llama_tokenize(
- struct llama_context * ctx,
- const char * text,
- int text_len,
- llama_token * tokens,
- int n_max_tokens,
- bool add_bos);
-
- LLAMA_API int llama_tokenize_with_model(
const struct llama_model * model,
const char * text,
int text_len,
// Does not write null terminator to the buffer.
// User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens.
LLAMA_API int llama_token_to_piece(
- const struct llama_context * ctx,
- llama_token token,
- char * buf,
- int length);
-
- LLAMA_API int llama_token_to_piece_with_model(
const struct llama_model * model,
llama_token token,
char * buf,
/// @param n_beams Number of beams to use.
/// @param n_past Number of tokens already evaluated.
/// @param n_predict Maximum number of tokens to predict. EOS may occur earlier.
- /// @param n_threads Number of threads as passed to llama_eval().
LLAMA_API void llama_beam_search(
struct llama_context * ctx,
llama_beam_search_callback_fn_t callback,
void * callback_data,
size_t n_beams,
int n_past,
- int n_predict,
- int n_threads);
+ int n_predict);
// Performance information
LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);
// load the vocab
{
- auto lparams = llama_context_default_params();
+ auto mparams = llama_model_default_params();
- lparams.vocab_only = true;
+ mparams.vocab_only = true;
- model = llama_load_model_from_file(fname.c_str(), lparams);
+ model = llama_load_model_from_file(fname.c_str(), mparams);
if (model == NULL) {
fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
return 1;
}
- ctx = llama_new_context_with_model(model, lparams);
+ auto cparams = llama_context_default_params();
+
+ ctx = llama_new_context_with_model(model, cparams);
if (ctx == NULL) {
fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
}
}
- if (llama_vocab_type(ctx) != LLAMA_VOCAB_TYPE_BPE) {
+ if (llama_vocab_type(model) != LLAMA_VOCAB_TYPE_BPE) {
fprintf(stderr, "%s : error: vocab type is not SPM\n", __func__);
llama_free_model(model);
llama_free(ctx);
// load the vocab
{
- auto lparams = llama_context_default_params();
+ auto mparams = llama_model_default_params();
- lparams.vocab_only = true;
+ mparams.vocab_only = true;
- model = llama_load_model_from_file(fname.c_str(), lparams);
+ model = llama_load_model_from_file(fname.c_str(), mparams);
if (model == NULL) {
fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
return 1;
}
- ctx = llama_new_context_with_model(model, lparams);
+ auto cparams = llama_context_default_params();
+
+ ctx = llama_new_context_with_model(model, cparams);
if (ctx == NULL) {
fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
}
}
- if (llama_vocab_type(ctx) != LLAMA_VOCAB_TYPE_SPM) {
+ if (llama_vocab_type(model) != LLAMA_VOCAB_TYPE_SPM) {
fprintf(stderr, "%s : error: vocab type is not SPM\n", __func__);
llama_free_model(model);
llama_free(ctx);
// load the vocab
{
- auto lparams = llama_context_default_params();
+ auto mparams = llama_model_default_params();
- lparams.vocab_only = true;
+ mparams.vocab_only = true;
- model = llama_load_model_from_file(fname.c_str(), lparams);
+ model = llama_load_model_from_file(fname.c_str(), mparams);
if (model == NULL) {
fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
return 1;
}
- ctx = llama_new_context_with_model(model, lparams);
+ auto cparams = llama_context_default_params();
+
+ ctx = llama_new_context_with_model(model, cparams);
if (ctx == NULL) {
fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
}
}
- GGML_ASSERT(llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM);
+ GGML_ASSERT(llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM);
#ifdef _WIN32
// We need this for unicode console support
atexit([]() { console::cleanup(); });
#endif
- const int n_vocab = llama_n_vocab(ctx);
+ const int n_vocab = llama_n_vocab(model);
for (int i = 0; i < n_vocab; ++i) {
std::string str = llama_detokenize_spm(ctx, std::vector<int>(1, i));