]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama.cpp : split llama_context_params into model and context params (#3301)
authorslaren <redacted>
Thu, 28 Sep 2023 19:42:38 +0000 (21:42 +0200)
committerGitHub <redacted>
Thu, 28 Sep 2023 19:42:38 +0000 (22:42 +0300)
* llama.cpp : split llama_context_params into model and context params

ggml-ci

* fix metal build

* fix freq_base/scale default to model value

* llama-bench : keep the same model between tests when possible

* move n_threads to llama_context_params, add n_threads_batch

* fix mpi build

* remove kv_size(), cuda scratch fixes

* remove low-vram option

* add n_threads_batch to system info, refactor to get_system_info()

* add documentation about --threads-batch to the READMEs

* llama-bench fix

* main : fix rope freq/scale warning

* llama.cpp : add llama_get_model
common : add llama_tokenize from model

* remove duplicated ctx/model functions

ggml-ci

* cuda : print total VRAM used

27 files changed:
common/common.cpp
common/common.h
common/train.cpp
examples/batched/batched.cpp
examples/beam-search/beam-search.cpp
examples/embd-input/embd-input-lib.cpp
examples/embd-input/embd-input-test.cpp
examples/embedding/embedding.cpp
examples/finetune/finetune.cpp
examples/llama-bench/llama-bench.cpp
examples/main/README.md
examples/main/main.cpp
examples/parallel/parallel.cpp
examples/perplexity/perplexity.cpp
examples/quantize-stats/quantize-stats.cpp
examples/save-load-state/save-load-state.cpp
examples/server/README.md
examples/server/server.cpp
examples/simple/simple.cpp
examples/speculative/speculative.cpp
examples/train-text-from-scratch/train-text-from-scratch.cpp
ggml-cuda.cu
llama.cpp
llama.h
tests/test-tokenizer-0-falcon.cpp
tests/test-tokenizer-0-llama.cpp
tests/test-tokenizer-1-llama.cpp

index 8764a7be3c9ce177e9bf60e4027acdd90ebe0c1b..6e8c08cb883873358ce7241db087c1d353332e34 100644 (file)
@@ -129,6 +129,15 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
             if (params.n_threads <= 0) {
                 params.n_threads = std::thread::hardware_concurrency();
             }
+        } else if (arg == "-tb" || arg == "--threads-batch") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.n_threads_batch = std::stoi(argv[i]);
+            if (params.n_threads_batch <= 0) {
+                params.n_threads_batch = std::thread::hardware_concurrency();
+            }
         } else if (arg == "-p" || arg == "--prompt") {
             if (++i >= argc) {
                 invalid_param = true;
@@ -451,12 +460,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
             params.mul_mat_q = false;
 #else
             fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. Disabling mul_mat_q kernels has no effect.\n");
-#endif // GGML_USE_CUBLAS
-        } else if (arg == "--low-vram" || arg == "-lv") {
-#ifdef GGML_USE_CUBLAS
-            params.low_vram = true;
-#else
-            fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set lower vram usage.\n");
 #endif // GGML_USE_CUBLAS
         } else if (arg == "--no-mmap") {
             params.use_mmap = false;
@@ -630,7 +633,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     printf("                        (can be specified more than once for multiple prompts).\n");
     printf("  --color               colorise output to distinguish prompt and user input from generations\n");
     printf("  -s SEED, --seed SEED  RNG seed (default: -1, use random seed for < 0)\n");
-    printf("  -t N, --threads N     number of threads to use during computation (default: %d)\n", params.n_threads);
+    printf("  -t N, --threads N     number of threads to use during generation (default: %d)\n", params.n_threads);
+    printf("  -tb N, --threads-batch N\n");
+    printf("                        number of threads to use during batch and prompt processing (default: same as --threads)\n");
     printf("  -p PROMPT, --prompt PROMPT\n");
     printf("                        prompt to start generation with (default: empty)\n");
     printf("  -e, --escape          process prompt escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n");
@@ -645,7 +650,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     printf("  -f FNAME, --file FNAME\n");
     printf("                        prompt file to start generation.\n");
     printf("  -n N, --n-predict N   number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict);
-    printf("  -c N, --ctx-size N    size of the prompt context (default: %d)\n", params.n_ctx);
+    printf("  -c N, --ctx-size N    size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx);
     printf("  -b N, --batch-size N  batch size for prompt processing (default: %d)\n", params.n_batch);
     printf("  --top-k N             top-k sampling (default: %d, 0 = disabled)\n", params.top_k);
     printf("  --top-p N             top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p);
@@ -705,7 +710,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     printf("  -ts SPLIT --tensor-split SPLIT\n");
     printf("                        how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
     printf("  -mg i, --main-gpu i   the GPU to use for scratch and small tensors\n");
-    printf("  -lv, --low-vram       don't allocate VRAM scratch buffer\n");
 #ifdef GGML_USE_CUBLAS
     printf("  -nommq, --no-mul-mat-q\n");
     printf("                        use " GGML_CUBLAS_NAME " instead of custom mul_mat_q " GGML_CUDA_NAME " kernels.\n");
@@ -726,6 +730,18 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     printf("\n");
 }
 
+std::string get_system_info(const gpt_params & params) {
+    std::ostringstream os;
+
+    os << "system_info: n_threads = " << params.n_threads;
+    if (params.n_threads_batch != -1) {
+        os << " (n_threads_batch = " << params.n_threads_batch << ")";
+    }
+    os << " / " << std::thread::hardware_concurrency() << " | " << llama_print_system_info();
+
+    return os.str();
+}
+
 std::string gpt_random_prompt(std::mt19937 & rng) {
     const int r = rng() % 10;
     switch (r) {
@@ -749,40 +765,50 @@ std::string gpt_random_prompt(std::mt19937 & rng) {
 // Model utils
 //
 
-struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
-    auto lparams = llama_context_default_params();
+struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params) {
+    auto mparams = llama_model_default_params();
 
-    lparams.n_ctx           = params.n_ctx;
-    lparams.n_batch         = params.n_batch;
     if (params.n_gpu_layers != -1) {
-        lparams.n_gpu_layers = params.n_gpu_layers;
+        mparams.n_gpu_layers = params.n_gpu_layers;
     }
-    lparams.main_gpu        = params.main_gpu;
-    lparams.tensor_split    = params.tensor_split;
-    lparams.low_vram        = params.low_vram;
-    lparams.mul_mat_q       = params.mul_mat_q;
-    lparams.seed            = params.seed;
-    lparams.f16_kv          = params.memory_f16;
-    lparams.use_mmap        = params.use_mmap;
-    lparams.use_mlock       = params.use_mlock;
-    lparams.logits_all      = params.logits_all;
-    lparams.embedding       = params.embedding;
-    lparams.rope_freq_base  = params.rope_freq_base;
-    lparams.rope_freq_scale = params.rope_freq_scale;
-
-    return lparams;
+    mparams.main_gpu        = params.main_gpu;
+    mparams.tensor_split    = params.tensor_split;
+    mparams.use_mmap        = params.use_mmap;
+    mparams.use_mlock       = params.use_mlock;
+
+    return mparams;
+}
+
+struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
+    auto cparams = llama_context_default_params();
+
+    cparams.n_ctx           = params.n_ctx;
+    cparams.n_batch         = params.n_batch;
+    cparams.n_threads       = params.n_threads;
+    cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
+    cparams.mul_mat_q       = params.mul_mat_q;
+    cparams.seed            = params.seed;
+    cparams.f16_kv          = params.memory_f16;
+    cparams.logits_all      = params.logits_all;
+    cparams.embedding       = params.embedding;
+    cparams.rope_freq_base  = params.rope_freq_base;
+    cparams.rope_freq_scale = params.rope_freq_scale;
+
+    return cparams;
 }
 
 std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params) {
-    auto lparams = llama_context_params_from_gpt_params(params);
+    auto mparams = llama_model_params_from_gpt_params(params);
 
-    llama_model * model  = llama_load_model_from_file(params.model.c_str(), lparams);
+    llama_model * model  = llama_load_model_from_file(params.model.c_str(), mparams);
     if (model == NULL) {
         fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
         return std::make_tuple(nullptr, nullptr);
     }
 
-    llama_context * lctx = llama_new_context_with_model(model, lparams);
+    auto cparams = llama_context_params_from_gpt_params(params);
+
+    llama_context * lctx = llama_new_context_with_model(model, cparams);
     if (lctx == NULL) {
         fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str());
         llama_free_model(model);
@@ -815,7 +841,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
         LOG("warming up the model with an empty run\n");
 
         std::vector<llama_token> tmp = { llama_token_bos(lctx), llama_token_eos(lctx), };
-        llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0), params.n_threads);
+        llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
         llama_kv_cache_tokens_rm(lctx, -1, -1);
         llama_reset_timings(lctx);
     }
@@ -828,16 +854,23 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
 //
 
 std::vector<llama_token> llama_tokenize(
-        struct llama_context * ctx,
+  const struct llama_context * ctx,
+           const std::string & text,
+                        bool   add_bos) {
+    return llama_tokenize(llama_get_model(ctx), text, add_bos);
+}
+
+std::vector<llama_token> llama_tokenize(
+    const struct llama_model * model,
            const std::string & text,
                         bool   add_bos) {
     // upper limit for the number of tokens
     int n_tokens = text.length() + add_bos;
     std::vector<llama_token> result(n_tokens);
-    n_tokens = llama_tokenize(ctx, text.data(), text.length(), result.data(), result.size(), add_bos);
+    n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos);
     if (n_tokens < 0) {
         result.resize(-n_tokens);
-        int check = llama_tokenize(ctx, text.data(), text.length(), result.data(), result.size(), add_bos);
+        int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos);
         GGML_ASSERT(check == -n_tokens);
     } else {
         result.resize(n_tokens);
@@ -847,10 +880,10 @@ std::vector<llama_token> llama_tokenize(
 
 std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) {
     std::vector<char> result(8, 0);
-    const int n_tokens = llama_token_to_piece(ctx, token, result.data(), result.size());
+    const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
     if (n_tokens < 0) {
         result.resize(-n_tokens);
-        int check = llama_token_to_piece(ctx, token, result.data(), result.size());
+        int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
         GGML_ASSERT(check == -n_tokens);
     } else {
         result.resize(n_tokens);
@@ -905,7 +938,7 @@ llama_token llama_sample_token(
          std::vector<llama_token_data> & candidates,
                                    int   idx) {
     const int n_ctx   = llama_n_ctx(ctx);
-    const int n_vocab = llama_n_vocab(ctx);
+    const int n_vocab = llama_n_vocab(llama_get_model(ctx));
 
     const float   temp            = params.temp;
     const int32_t top_k           = params.top_k <= 0 ? n_vocab : params.top_k;
@@ -1191,7 +1224,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
 #endif // NDEBUG
 
     fprintf(stream, "model_desc: %s\n", model_desc);
-    fprintf(stream, "n_vocab: %d  # output size of the final layer, 32001 for some models\n", llama_n_vocab(lctx));
+    fprintf(stream, "n_vocab: %d  # output size of the final layer, 32001 for some models\n", llama_n_vocab(llama_get_model(lctx)));
 
 #ifdef __OPTIMIZE__
     fprintf(stream, "optimize: true\n");
@@ -1258,7 +1291,6 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
         fprintf(stream, "  - %s: %f\n", std::get<0>(la).c_str(), std::get<1>(la));
     }
     fprintf(stream, "lora_base: %s\n", params.lora_base.c_str());
-    fprintf(stream, "low_vram: %s # default: false\n", params.low_vram ? "true" : "false");
     fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu);
     fprintf(stream, "memory_f32: %s # default: false\n", !params.memory_f16 ? "true" : "false");
     fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", params.mirostat);
index 64601f99742d9ab3d30ce187b291cbecfb417bc2..0e2d3fa6c07d9be88dab434e444c34611eff71a3 100644 (file)
@@ -36,6 +36,7 @@ int32_t get_num_physical_cores();
 struct gpt_params {
     uint32_t seed                           = -1;   // RNG seed
     int32_t n_threads                       = get_num_physical_cores();
+    int32_t n_threads_batch                 = -1;   // number of threads to use for batch processing (-1 = use n_threads)
     int32_t n_predict                       = -1;   // new tokens to predict
     int32_t n_ctx                           = 512;  // context size
     int32_t n_batch                         = 512;  // batch size for prompt processing (must be >=32 to use BLAS)
@@ -95,7 +96,6 @@ struct gpt_params {
     bool hellaswag         = false; // compute HellaSwag score over random tasks from datafile supplied in prompt
     size_t hellaswag_tasks = 400;   // number of tasks to use when computing the HellaSwag score
 
-    bool low_vram          = false; // if true, reduce VRAM usage at the cost of performance
     bool mul_mat_q         = true;  // if true, use mul_mat_q kernels instead of cuBLAS
     bool memory_f16        = true;  // use f16 instead of f32 for memory kv
     bool random_prompt     = false; // do not randomize prompt if none provided
@@ -126,6 +126,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
 
 void gpt_print_usage(int argc, char ** argv, const gpt_params & params);
 
+std::string get_system_info(const gpt_params & params);
+
 std::string gpt_random_prompt(std::mt19937 & rng);
 
 void process_escapes(std::string& input);
@@ -135,6 +137,7 @@ void process_escapes(std::string& input);
 //
 
 std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params);
+struct llama_model_params   llama_model_params_from_gpt_params(const gpt_params & params);
 struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);
 
 //
@@ -144,7 +147,12 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
 // tokenizes a string into a vector of tokens
 // should work similar to Python's `tokenizer.encode`
 std::vector<llama_token> llama_tokenize(
-        struct llama_context * ctx,
+  const struct llama_context * ctx,
+           const std::string & text,
+                        bool   add_bos);
+
+std::vector<llama_token> llama_tokenize(
+    const struct llama_model * model,
            const std::string & text,
                         bool   add_bos);
 
index 4a12809663a1a86fb7cae30ea62194e9be0b8c01..35a4cf9e6cae39b16c7c991760b9ec98812d02a8 100644 (file)
@@ -858,7 +858,7 @@ size_t tokenize_file(
         out_tokens.resize(buf.size() + n_max_tokens_overhead);
 
         int n_tokens = llama_tokenize(
-            lctx,
+            llama_get_model(lctx),
             buf.data(),
             (int) buf.size(),
             out_tokens.data(),
@@ -867,7 +867,7 @@ size_t tokenize_file(
         if (n_tokens < 0) {
             out_tokens.resize(-n_tokens);
             n_tokens = llama_tokenize(
-                lctx,
+                llama_get_model(lctx),
                 buf.data(),
                 (int) buf.size(),
                 out_tokens.data(),
@@ -920,7 +920,7 @@ size_t tokenize_file(
         size_t found_max_sample_size  = 0;
 
         size_t max_token_text_size = 0;
-        int n_vocab = llama_n_vocab(lctx);
+        int n_vocab = llama_n_vocab(llama_get_model(lctx));
         for (llama_token token=0; token < n_vocab; ++token) {
             max_token_text_size = std::max(
                 max_token_text_size,
@@ -961,7 +961,7 @@ size_t tokenize_file(
 
                 // tokenize the sample
                 tok_sample.resize(buf_sample.size() + n_max_tokens_overhead);
-                int n_tokens = llama_tokenize(lctx,
+                int n_tokens = llama_tokenize(llama_get_model(lctx),
                     buf_sample.data(),
                     (int) buf_sample.size(),
                     tok_sample.data(),
@@ -969,7 +969,7 @@ size_t tokenize_file(
                     false);
                 if (n_tokens < 0) {
                     tok_sample.resize(-n_tokens);
-                    n_tokens = llama_tokenize(lctx,
+                    n_tokens = llama_tokenize(llama_get_model(lctx),
                         buf_sample.data(),
                         (int) buf_sample.size(),
                         tok_sample.data(),
index 4dd1d553d1c1838a33a9f9c873ac165bd0084433..688ef221335a98f63117855c5844d39a4f5591cf 100644 (file)
@@ -40,20 +40,35 @@ int main(int argc, char ** argv) {
 
     llama_backend_init(params.numa);
 
-    llama_context_params ctx_params = llama_context_default_params();
+    // initialize the model
 
-    ctx_params.seed  = 1234;
-    ctx_params.n_ctx = n_len*n_parallel; // FIXME: use n_kv_req instead (tokenize with model after #3301)
-    ctx_params.n_batch = std::max(n_len, n_parallel);
-    // ctx_params.n_gpu_layers = 99; // offload all layers to the GPU
+    llama_model_params model_params = llama_model_default_params();
 
-    llama_model * model = llama_load_model_from_file(params.model.c_str(), ctx_params);
+    // model_params.n_gpu_layers = 99; // offload all layers to the GPU
+
+    llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);
 
     if (model == NULL) {
         fprintf(stderr , "%s: error: unable to load model\n" , __func__);
         return 1;
     }
 
+    // tokenize the prompt
+
+    std::vector<llama_token> tokens_list;
+    tokens_list = ::llama_tokenize(model, params.prompt, true);
+    const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size())*n_parallel;
+
+    // initialize the context
+
+    llama_context_params ctx_params = llama_context_default_params();
+
+    ctx_params.seed  = 1234;
+    ctx_params.n_ctx = n_kv_req;
+    ctx_params.n_batch = std::max(n_len, n_parallel);
+    ctx_params.n_threads = params.n_threads;
+    ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
+
     llama_context * ctx = llama_new_context_with_model(model, ctx_params);
 
     if (ctx == NULL) {
@@ -61,13 +76,7 @@ int main(int argc, char ** argv) {
         return 1;
     }
 
-    // tokenize the prompt
-
-    std::vector<llama_token> tokens_list;
-    tokens_list = ::llama_tokenize(ctx, params.prompt, true);
-
     const int n_ctx    = llama_n_ctx(ctx);
-    const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size())*n_parallel;
 
     LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_batch = %d, n_parallel = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, ctx_params.n_batch, n_parallel, n_kv_req);
 
@@ -106,7 +115,7 @@ int main(int argc, char ** argv) {
     // llama_decode will output logits only for the last token of the prompt
     batch.logits[batch.n_tokens - 1] = true;
 
-    if (llama_decode(ctx, batch, params.n_threads) != 0) {
+    if (llama_decode(ctx, batch) != 0) {
         LOG_TEE("%s: llama_decode() failed\n", __func__);
         return 1;
     }
@@ -146,7 +155,7 @@ int main(int argc, char ** argv) {
                 continue;
             }
 
-            auto   n_vocab = llama_n_vocab(ctx);
+            auto   n_vocab = llama_n_vocab(model);
             auto * logits  = llama_get_logits_ith(ctx, i_batch[i]);
 
             std::vector<llama_token_data> candidates;
@@ -210,7 +219,7 @@ int main(int argc, char ** argv) {
         n_cur += 1;
 
         // evaluate the current batch with the transformer model
-        if (llama_decode(ctx, batch, params.n_threads)) {
+        if (llama_decode(ctx, batch)) {
             fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
             return 1;
         }
index 63da7c3ec02a523a59e386ba660119b42f7d33ff..f078ab8a87fa5fba0c0e21bfc9b257a039fe64c6 100644 (file)
@@ -160,7 +160,7 @@ int main(int argc, char ** argv)
 
     int n_past = 0;
 
-    if (llama_decode(ctx, llama_batch_get_one(tokens_list.data(), tokens_list.size(), n_past, 0), params.n_threads))
+    if (llama_decode(ctx, llama_batch_get_one(tokens_list.data(), tokens_list.size(), n_past, 0)))
     {
         fprintf(stderr, "%s : failed to eval prompt.\n" , __func__ );
         return 1;
@@ -170,7 +170,7 @@ int main(int argc, char ** argv)
     beam_search_callback_data callback_data{ctx, {}};
     size_t const beam_width = static_cast<size_t>(params.n_beams);
     int const n_predict = 256;
-    llama_beam_search(ctx, beam_search_callback, &callback_data, beam_width, n_past, n_predict, params.n_threads);
+    llama_beam_search(ctx, beam_search_callback, &callback_data, beam_width, n_past, n_predict);
 
     std::cout << "\n\n";
     for (llama_token const token_id : callback_data.response) {
index 9bd4d34705cbc112b319054a7c07f126f9ef06e7..99e6bdad5ac45442a05e2663bea6d3ce2b3aeae9 100644 (file)
@@ -48,8 +48,7 @@ struct MyModel* create_mymodel(int argc, char ** argv) {
     // print system information
     {
         fprintf(stderr, "\n");
-        fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
-                params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
+        fprintf(stderr, "%s\n", get_system_info(params).c_str());
     }
     struct MyModel * ret = new MyModel();
     ret->ctx = ctx;
@@ -71,7 +70,7 @@ bool eval_float(void * model, float * input, int N){
     MyModel * mymodel = (MyModel*)model;
     llama_context * ctx = mymodel->ctx;
     gpt_params params = mymodel->params;
-    int n_emb = llama_n_embd(ctx);
+    int n_emb = llama_n_embd(llama_get_model(ctx));
     int n_past = mymodel->n_past;
     int n_batch = N; // params.n_batch;
 
@@ -81,7 +80,7 @@ bool eval_float(void * model, float * input, int N){
             n_eval = n_batch;
         }
         llama_batch batch = {  int32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, nullptr, n_past, 1, 0, };
-        if (llama_decode(ctx, batch, params.n_threads)) {
+        if (llama_decode(ctx, batch)) {
             fprintf(stderr, "%s : failed to eval\n", __func__);
             return false;
         }
@@ -102,7 +101,7 @@ bool eval_tokens(void * model, std::vector<llama_token> tokens) {
         if (n_eval > params.n_batch) {
             n_eval = params.n_batch;
         }
-        if (llama_decode(ctx, llama_batch_get_one(&tokens[i], n_eval, n_past, 0), params.n_threads)) {
+        if (llama_decode(ctx, llama_batch_get_one(&tokens[i], n_eval, n_past, 0))) {
             fprintf(stderr, "%s : failed to eval\n", __func__);
             return false;
         }
@@ -133,7 +132,7 @@ llama_token sampling_id(struct MyModel* mymodel) {
 
     // out of user input, sample next token
     const float   temp            = params.temp;
-    const int32_t top_k           = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
+    const int32_t top_k           = params.top_k <= 0 ? llama_n_vocab(llama_get_model(ctx)) : params.top_k;
     const float   top_p           = params.top_p;
     const float   tfs_z           = params.tfs_z;
     const float   typical_p       = params.typical_p;
@@ -149,7 +148,7 @@ llama_token sampling_id(struct MyModel* mymodel) {
     llama_token id = 0;
     {
         auto logits  = llama_get_logits(ctx);
-        auto n_vocab = llama_n_vocab(ctx);
+        auto n_vocab = llama_n_vocab(llama_get_model(ctx));
 
         // Apply params.logit_bias map
         for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
index e5e040f62a60a54bf16e1aeb8e4e687b8f933a9e..dc4a0e48854adce86d44a3ba8ce7d2a67996b125 100644 (file)
@@ -8,7 +8,7 @@ int main(int argc, char** argv) {
     auto mymodel = create_mymodel(argc, argv);
     int N = 10;
     int max_tgt_len = 500;
-    int n_embd = llama_n_embd(mymodel->ctx);
+    int n_embd = llama_n_embd(llama_get_model(mymodel->ctx));
 
     // add random float embd to test evaluation
     float * data = new float[N*n_embd];
index 18cefa237bbc1c4400e51d73c212e26e1ed6c490..14075609ebfd919b534bb2f41b1bd14b2d867673 100644 (file)
@@ -42,17 +42,18 @@ int main(int argc, char ** argv) {
         return 1;
     }
 
-    const int n_ctx_train = llama_n_ctx_train(ctx);
-    if (params.n_ctx > n_ctx_train) {
+    const int n_ctx_train = llama_n_ctx_train(model);
+    const int n_ctx = llama_n_ctx(ctx);
+
+    if (n_ctx > n_ctx_train) {
         fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
-                __func__, n_ctx_train, params.n_ctx);
+                __func__, n_ctx_train, n_ctx);
     }
 
     // print system information
     {
         fprintf(stderr, "\n");
-        fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
-                params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
+        fprintf(stderr, "%s\n", get_system_info(params).c_str());
     }
 
     int n_past = 0;
@@ -70,15 +71,15 @@ int main(int argc, char ** argv) {
         fprintf(stderr, "\n");
     }
 
-    if (embd_inp.size() > (size_t)params.n_ctx) {
+    if (embd_inp.size() > (size_t)n_ctx) {
         fprintf(stderr, "%s: error: prompt is longer than the context window (%zu tokens, n_ctx = %d)\n",
-                __func__, embd_inp.size(), params.n_ctx);
+                __func__, embd_inp.size(), n_ctx);
         return 1;
     }
 
     while (!embd_inp.empty()) {
         int n_tokens = std::min(params.n_batch, (int) embd_inp.size());
-        if (llama_decode(ctx, llama_batch_get_one(embd_inp.data(), n_tokens, n_past, 0), params.n_threads)) {
+        if (llama_decode(ctx, llama_batch_get_one(embd_inp.data(), n_tokens, n_past, 0))) {
             fprintf(stderr, "%s : failed to eval\n", __func__);
             return 1;
         }
@@ -86,8 +87,8 @@ int main(int argc, char ** argv) {
         embd_inp.erase(embd_inp.begin(), embd_inp.begin() + n_tokens);
     }
 
-    const int n_embd = llama_n_embd(ctx);
-    const auto embeddings = llama_get_embeddings(ctx);
+    const int n_embd = llama_n_embd(model);
+    const auto embeddings = llama_get_embeddings(ctx);
 
     for (int i = 0; i < n_embd; i++) {
         printf("%f ", embeddings[i]);
index 6e29e1c15e3f727a2851de2aef540b36f442233f..b61165fb7c6c93bd0e08b338e3a25fd694faaa8b 100644 (file)
@@ -304,7 +304,7 @@ static void init_model(struct llama_model * input, struct my_llama_model * model
 
         gguf_free(mctx);
     }
-    hparams.n_vocab = llama_model_n_vocab(input);
+    hparams.n_vocab = llama_n_vocab(input);
     hparams.n_ctx = n_ctx;
 
     // get tensors from llama_model (possibly mmapped)
@@ -1540,12 +1540,14 @@ int main(int argc, char ** argv) {
     printf("%s: seed: %u\n", __func__, params.common.seed);
     srand(params.common.seed);
 
-    struct llama_context_params llama_params = llama_context_default_params();
-    llama_params.vocab_only = false;
+    struct llama_model_params llama_mparams = llama_model_default_params();
+    llama_mparams.vocab_only = false;
 
     printf("%s: model base = '%s'\n", __func__, params.fn_model_base);
-    struct llama_model * lmodel = llama_load_model_from_file(params.fn_model_base, llama_params);
-    struct llama_context * lctx = llama_new_context_with_model(lmodel, llama_params);
+    struct llama_model * lmodel = llama_load_model_from_file(params.fn_model_base, llama_mparams);
+
+    struct llama_context_params llama_cparams = llama_context_default_params();
+    struct llama_context * lctx = llama_new_context_with_model(lmodel, llama_cparams);
 
     struct my_llama_model model;
     init_model(lmodel, &model, params.fn_model_base, params.common.n_ctx);
index 058e34d5c275c277b929925c561e5a98000a3868..93bb0c8b1916b50990cbdeb5272df0b4fdd4588b 100644 (file)
@@ -132,7 +132,6 @@ struct cmd_params {
     std::vector<int> n_gpu_layers;
     std::vector<int> main_gpu;
     std::vector<bool> mul_mat_q;
-    std::vector<bool> low_vram;
     std::vector<std::array<float, LLAMA_MAX_DEVICES>> tensor_split;
     int reps;
     bool verbose;
@@ -149,7 +148,6 @@ static const cmd_params cmd_params_defaults = {
     /* n_gpu_layers  */ {99},
     /* main_gpu      */ {0},
     /* mul_mat_q     */ {true},
-    /* low_vram      */ {false},
     /* tensor_split  */ {{}},
     /* reps          */ 5,
     /* verbose       */ false,
@@ -167,9 +165,8 @@ static void print_usage(int /* argc */, char ** argv) {
     printf("  -b, --batch-size <n>              (default: %s)\n", join(cmd_params_defaults.n_batch, ",").c_str());
     printf("  --memory-f32 <0|1>                (default: %s)\n", join(cmd_params_defaults.f32_kv, ",").c_str());
     printf("  -t, --threads <n>                 (default: %s)\n", join(cmd_params_defaults.n_threads, ",").c_str());
-    printf("  -ngl N, --n-gpu-layers <n>        (default: %s)\n", join(cmd_params_defaults.n_gpu_layers, ",").c_str());
-    printf("  -mg i, --main-gpu <n>             (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str());
-    printf("  -lv, --low-vram <0|1>             (default: %s)\n", join(cmd_params_defaults.low_vram, ",").c_str());
+    printf("  -ngl, --n-gpu-layers <n>          (default: %s)\n", join(cmd_params_defaults.n_gpu_layers, ",").c_str());
+    printf("  -mg, --main-gpu <i>               (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str());
     printf("  -mmq, --mul-mat-q <0|1>           (default: %s)\n", join(cmd_params_defaults.mul_mat_q, ",").c_str());
     printf("  -ts, --tensor_split <ts0/ts1/..>               \n");
     printf("  -r, --repetitions <n>             (default: %d)\n", cmd_params_defaults.reps);
@@ -255,13 +252,6 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
                 break;
             }
             params.main_gpu = split<int>(argv[i], split_delim);
-        } else if (arg == "-lv" || arg == "--low-vram") {
-            if (++i >= argc) {
-                invalid_param = true;
-                break;
-            }
-            auto p = split<bool>(argv[i], split_delim);
-            params.low_vram.insert(params.low_vram.end(), p.begin(), p.end());
         } else if (arg == "-mmq" || arg == "--mul-mat-q") {
             if (++i >= argc) {
                 invalid_param = true;
@@ -336,7 +326,6 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
     if (params.n_gpu_layers.empty()) { params.n_gpu_layers = cmd_params_defaults.n_gpu_layers; }
     if (params.main_gpu.empty())     { params.main_gpu = cmd_params_defaults.main_gpu; }
     if (params.mul_mat_q.empty())    { params.mul_mat_q = cmd_params_defaults.mul_mat_q; }
-    if (params.low_vram.empty())     { params.low_vram = cmd_params_defaults.low_vram; }
     if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; }
     if (params.n_threads.empty())    { params.n_threads = cmd_params_defaults.n_threads; }
 
@@ -353,21 +342,34 @@ struct cmd_params_instance {
     int n_gpu_layers;
     int main_gpu;
     bool mul_mat_q;
-    bool low_vram;
     std::array<float, LLAMA_MAX_DEVICES> tensor_split;
 
-    llama_context_params to_llama_params() const {
-        llama_context_params lparams = llama_context_default_params();
-        lparams.n_ctx = n_prompt + n_gen;
-        lparams.n_batch = n_batch;
-        lparams.f16_kv = !f32_kv;
-        lparams.n_gpu_layers = n_gpu_layers;
-        lparams.main_gpu = main_gpu;
-        lparams.mul_mat_q = mul_mat_q;
-        lparams.low_vram = low_vram;
-        lparams.tensor_split = tensor_split.data();
+    llama_model_params to_llama_mparams() const {
+        llama_model_params mparams = llama_model_default_params();
+
+        mparams.n_gpu_layers = n_gpu_layers;
+        mparams.main_gpu = main_gpu;
+        mparams.tensor_split = tensor_split.data();
+
+        return mparams;
+    }
+
+    bool equal_mparams(const cmd_params_instance & other) const {
+        return model == other.model &&
+               n_gpu_layers == other.n_gpu_layers &&
+               main_gpu == other.main_gpu &&
+               tensor_split == other.tensor_split;
+    }
+
+    llama_context_params to_llama_cparams() const {
+        llama_context_params cparams = llama_context_default_params();
 
-        return lparams;
+        cparams.n_ctx = n_prompt + n_gen;
+        cparams.n_batch = n_batch;
+        cparams.f16_kv = !f32_kv;
+        cparams.mul_mat_q = mul_mat_q;
+
+        return cparams;
     }
 };
 
@@ -375,13 +377,12 @@ static std::vector<cmd_params_instance> get_cmd_params_instances_int(const cmd_p
     std::vector<cmd_params_instance> instances;
 
     for (const auto & m : params.model)
-    for (const auto & nb : params.n_batch)
-    for (const auto & fk : params.f32_kv)
     for (const auto & nl : params.n_gpu_layers)
     for (const auto & mg : params.main_gpu)
-    for (const auto & mmq : params.mul_mat_q)
-    for (const auto & lv : params.low_vram)
     for (const auto & ts : params.tensor_split)
+    for (const auto & nb : params.n_batch)
+    for (const auto & fk : params.f32_kv)
+    for (const auto & mmq : params.mul_mat_q)
     for (const auto & nt : params.n_threads) {
         cmd_params_instance instance = {
             /* .model        = */ m,
@@ -393,7 +394,6 @@ static std::vector<cmd_params_instance> get_cmd_params_instances_int(const cmd_p
             /* .n_gpu_layers = */ nl,
             /* .main_gpu     = */ mg,
             /* .mul_mat_q    = */ mmq,
-            /* .low_vram     = */ lv,
             /* .tensor_split = */ ts,
         };
         instances.push_back(instance);
@@ -404,6 +404,56 @@ static std::vector<cmd_params_instance> get_cmd_params_instances_int(const cmd_p
 static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_params & params) {
     std::vector<cmd_params_instance> instances;
 
+#if 1
+    // this ordering minimizes the number of times that each model needs to be reloaded
+    for (const auto & m : params.model)
+    for (const auto & nl : params.n_gpu_layers)
+    for (const auto & mg : params.main_gpu)
+    for (const auto & ts : params.tensor_split)
+    for (const auto & nb : params.n_batch)
+    for (const auto & fk : params.f32_kv)
+    for (const auto & mmq : params.mul_mat_q)
+    for (const auto & nt : params.n_threads) {
+        for (const auto & n_prompt : params.n_prompt) {
+            if (n_prompt == 0) {
+                continue;
+            }
+            cmd_params_instance instance = {
+                /* .model        = */ m,
+                /* .n_prompt     = */ n_prompt,
+                /* .n_gen        = */ 0,
+                /* .n_batch      = */ nb,
+                /* .f32_kv       = */ fk,
+                /* .n_threads    = */ nt,
+                /* .n_gpu_layers = */ nl,
+                /* .main_gpu     = */ mg,
+                /* .mul_mat_q    = */ mmq,
+                /* .tensor_split = */ ts,
+            };
+            instances.push_back(instance);
+        }
+
+        for (const auto & n_gen : params.n_gen) {
+            if (n_gen == 0) {
+                continue;
+            }
+            cmd_params_instance instance = {
+                /* .model        = */ m,
+                /* .n_prompt     = */ 0,
+                /* .n_gen        = */ n_gen,
+                /* .n_batch      = */ nb,
+                /* .f32_kv       = */ fk,
+                /* .n_threads    = */ nt,
+                /* .n_gpu_layers = */ nl,
+                /* .main_gpu     = */ mg,
+                /* .mul_mat_q    = */ mmq,
+                /* .tensor_split = */ ts,
+            };
+            instances.push_back(instance);
+        }
+    }
+#else
+    // this ordering separates the prompt and generation tests
     for (const auto & n_prompt : params.n_prompt) {
         if (n_prompt == 0) {
             continue;
@@ -419,6 +469,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
         auto instances_gen = get_cmd_params_instances_int(params, n_gen, 0);
         instances.insert(instances.end(), instances_gen.begin(), instances_gen.end());
     }
+#endif
 
     return instances;
 }
@@ -443,7 +494,6 @@ struct test {
     int n_gpu_layers;
     int main_gpu;
     bool mul_mat_q;
-    bool low_vram;
     std::array<float, LLAMA_MAX_DEVICES> tensor_split;
     int n_prompt;
     int n_gen;
@@ -463,7 +513,6 @@ struct test {
         n_gpu_layers = inst.n_gpu_layers;
         main_gpu = inst.main_gpu;
         mul_mat_q = inst.mul_mat_q;
-        low_vram = inst.low_vram;
         tensor_split = inst.tensor_split;
         n_prompt = inst.n_prompt;
         n_gen = inst.n_gen;
@@ -524,7 +573,7 @@ struct test {
             "cpu_info", "gpu_info",
             "model_filename", "model_type", "model_size", "model_n_params",
             "n_batch", "n_threads", "f16_kv",
-            "n_gpu_layers", "main_gpu", "mul_mat_q", "low_vram", "tensor_split",
+            "n_gpu_layers", "main_gpu", "mul_mat_q", "tensor_split",
             "n_prompt", "n_gen", "test_time",
             "avg_ns", "stddev_ns",
             "avg_ts", "stddev_ts"
@@ -543,7 +592,7 @@ struct test {
             return INT;
         }
         if (field == "cuda" || field == "opencl" || field == "metal" || field == "gpu_blas" || field == "blas" ||
-            field == "f16_kv" || field == "mul_mat_q" || field == "low_vram") {
+            field == "f16_kv" || field == "mul_mat_q") {
             return BOOL;
         }
         if (field == "avg_ts" || field == "stddev_ts") {
@@ -574,7 +623,7 @@ struct test {
             cpu_info, gpu_info,
             model_filename, model_type, std::to_string(model_size), std::to_string(model_n_params),
             std::to_string(n_batch), std::to_string(n_threads), std::to_string(!f32_kv),
-            std::to_string(n_gpu_layers), std::to_string(main_gpu), std::to_string(mul_mat_q), std::to_string(low_vram), tensor_split_str,
+            std::to_string(n_gpu_layers), std::to_string(main_gpu), std::to_string(mul_mat_q), tensor_split_str,
             std::to_string(n_prompt), std::to_string(n_gen), test_time,
             std::to_string(avg_ns()), std::to_string(stdev_ns()),
             std::to_string(avg_ts()), std::to_string(stdev_ts())
@@ -766,9 +815,6 @@ struct markdown_printer : public printer {
         if (params.mul_mat_q.size() > 1 || params.mul_mat_q != cmd_params_defaults.mul_mat_q) {
             fields.push_back("mul_mat_q");
         }
-        if (params.low_vram.size() > 1 || params.low_vram != cmd_params_defaults.low_vram) {
-            fields.push_back("low_vram");
-        }
         if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) {
             fields.push_back("tensor_split");
         }
@@ -889,17 +935,23 @@ struct sql_printer : public printer {
 static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) {
     std::vector<llama_token> tokens(n_batch, llama_token_bos(ctx));
     int n_processed = 0;
+
+    llama_set_n_threads(ctx, n_threads, n_threads);
+
     while (n_processed < n_prompt) {
         int n_tokens = std::min(n_prompt - n_processed, n_batch);
-        llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens, n_past + n_processed, 0), n_threads);
+        llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens, n_past + n_processed, 0));
         n_processed += n_tokens;
     }
 }
 
 static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) {
     llama_token token = llama_token_bos(ctx);
+
+    llama_set_n_threads(ctx, n_threads, n_threads);
+
     for (int i = 0; i < n_gen; i++) {
-        llama_decode(ctx, llama_batch_get_one(&token, 1, n_past + i, 0), n_threads);
+        llama_decode(ctx, llama_batch_get_one(&token, 1, n_past + i, 0));
     }
 }
 
@@ -958,17 +1010,25 @@ int main(int argc, char ** argv) {
 
     std::vector<cmd_params_instance> params_instances = get_cmd_params_instances(params);
 
+    llama_model * lmodel = nullptr;
+    const cmd_params_instance * prev_inst = nullptr;
+
     for (const auto & inst : params_instances) {
-        // TODO: keep the model between tests when possible
-        llama_context_params lparams = inst.to_llama_params();
+        // keep the same model between tests when possible
+        if (!lmodel || !prev_inst || !inst.equal_mparams(*prev_inst)) {
+            if (lmodel) {
+                llama_free_model(lmodel);
+            }
 
-        llama_model * lmodel  = llama_load_model_from_file(inst.model.c_str(), lparams);
-        if (lmodel == NULL) {
-            fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, inst.model.c_str());
-            return 1;
+            lmodel = llama_load_model_from_file(inst.model.c_str(), inst.to_llama_mparams());
+            if (lmodel == NULL) {
+                fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, inst.model.c_str());
+                return 1;
+            }
+            prev_inst = &inst;
         }
 
-        llama_context * ctx = llama_new_context_with_model(lmodel, lparams);
+        llama_context * ctx = llama_new_context_with_model(lmodel, inst.to_llama_cparams());
         if (ctx == NULL) {
             fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, inst.model.c_str());
             llama_free_model(lmodel);
@@ -1006,9 +1066,10 @@ int main(int argc, char ** argv) {
         llama_print_timings(ctx);
 
         llama_free(ctx);
-        llama_free_model(lmodel);
     }
 
+    llama_free_model(lmodel);
+
     p->print_footer();
 
     llama_backend_free();
index 26e1e28dd08c179a8726ee8eff8b8d61ac56af8d..a9561c383c0cba7873808626cc4114e25dc1865d 100644 (file)
@@ -262,7 +262,8 @@ These options help improve the performance and memory usage of the LLaMA models.
 
 ### Number of Threads
 
--   `-t N, --threads N`: Set the number of threads to use during computation. For optimal performance, it is recommended to set this value to the number of physical CPU cores your system has (as opposed to the logical number of cores). Using the correct number of threads can greatly improve performance.
+-   `-t N, --threads N`: Set the number of threads to use during generation. For optimal performance, it is recommended to set this value to the number of physical CPU cores your system has (as opposed to the logical number of cores). Using the correct number of threads can greatly improve performance.
+-   `-tb N, --threads-batch N`: Set the number of threads to use during batch and prompt processing. In some systems, it is beneficial to use a higher number of threads during batch processing than during generation. If not specified, the number of threads used for batch processing will be the same as the number of threads used for generation.
 
 ### Mlock
 
@@ -305,6 +306,5 @@ These options provide extra functionality and customization when running the LLa
 -   `-ngl N, --n-gpu-layers N`: When compiled with appropriate support (currently CLBlast or cuBLAS), this option allows offloading some layers to the GPU for computation. Generally results in increased performance.
 -   `-mg i, --main-gpu i`: When using multiple GPUs this option controls which GPU is used for small tensors for which the overhead of splitting the computation across all GPUs is not worthwhile. The GPU in question will use slightly more VRAM to store a scratch buffer for temporary results. By default GPU 0 is used. Requires cuBLAS.
 -   `-ts SPLIT, --tensor-split SPLIT`: When using multiple GPUs this option controls how large tensors should be split across all GPUs. `SPLIT` is a comma-separated list of non-negative values that assigns the proportion of data that each GPU should get in order. For example, "3,2" will assign 60% of the data to GPU 0 and 40% to GPU 1. By default the data is split in proportion to VRAM but this may not be optimal for performance. Requires cuBLAS.
--   `-lv, --low-vram`: Do not allocate a VRAM scratch buffer for holding temporary results. Reduces VRAM usage at the cost of performance, particularly prompt processing speed. Requires cuBLAS.
 -   `--lora FNAME`: Apply a LoRA (Low-Rank Adaptation) adapter to the model (implies --no-mmap). This allows you to adapt the pretrained model to specific tasks or domains.
 -   `--lora-base FNAME`: Optional model to use as a base for the layers modified by the LoRA adapter. This flag is used in conjunction with the `--lora` flag, and specifies the base model for the adaptation.
index 1ed543cbc627aadefe97db588cd7423688fe35a8..fd506773f74a345c93d9b5d1dde87a46e08f5951 100644 (file)
@@ -140,12 +140,17 @@ int main(int argc, char ** argv) {
         return 0;
     }
 
-    if (params.rope_freq_base != 10000.0) {
-        LOG_TEE("%s: warning: changing RoPE frequency base to %g (default 10000.0)\n", __func__, params.rope_freq_base);
+    if (params.n_ctx != 0 && params.n_ctx < 8) {
+        LOG_TEE("%s: warning: minimum context size is 8, using minimum size.\n", __func__);
+        params.n_ctx = 8;
+    }
+
+    if (params.rope_freq_base != 0.0) {
+        LOG_TEE("%s: warning: changing RoPE frequency base to %g.\n", __func__, params.rope_freq_base);
     }
 
-    if (params.rope_freq_scale != 1.0) {
-        LOG_TEE("%s: warning: scaling RoPE frequency by %g (default 1.0)\n", __func__, params.rope_freq_scale);
+    if (params.rope_freq_scale != 0.0) {
+        LOG_TEE("%s: warning: scaling RoPE frequency by %g.\n", __func__, params.rope_freq_scale);
     }
 
     LOG_TEE("%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT);
@@ -184,20 +189,19 @@ int main(int argc, char ** argv) {
         return 1;
     }
 
-    const int n_ctx_train = llama_n_ctx_train(ctx);
-    if (params.n_ctx > n_ctx_train) {
+    const int n_ctx_train = llama_n_ctx_train(model);
+    const int n_ctx = llama_n_ctx(ctx);
+    LOG("n_ctx: %d\n", n_ctx);
+
+    if (n_ctx > n_ctx_train) {
         LOG_TEE("%s: warning: model was trained on only %d context tokens (%d specified)\n",
-                __func__, n_ctx_train, params.n_ctx);
-    } else if (params.n_ctx < 8) {
-        LOG_TEE("%s: warning: minimum context size is 8, using minimum size.\n", __func__);
-        params.n_ctx = 8;
+                __func__, n_ctx_train, n_ctx);
     }
 
     // print system information
     {
         LOG_TEE("\n");
-        LOG_TEE("system_info: n_threads = %d / %d | %s\n",
-                params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
+        LOG_TEE("%s\n", get_system_info(params).c_str());
     }
 
     std::string path_session = params.path_prompt_cache;
@@ -211,7 +215,7 @@ int main(int argc, char ** argv) {
         if (fp != NULL) {
             std::fclose(fp);
 
-            session_tokens.resize(params.n_ctx);
+            session_tokens.resize(n_ctx);
             size_t n_token_count_out = 0;
             if (!llama_load_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) {
                 LOG_TEE("%s: error: failed to load session file '%s'\n", __func__, path_session.c_str());
@@ -226,7 +230,7 @@ int main(int argc, char ** argv) {
         }
     }
 
-    const bool add_bos = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
+    const bool add_bos = llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM;
     LOG("add_bos: %d\n", add_bos);
 
     std::vector<llama_token> embd_inp;
@@ -267,9 +271,6 @@ int main(int argc, char ** argv) {
         LOG("guidance_offset:     %s", log_tostr(guidance_offset));
     }
 
-    const int n_ctx = llama_n_ctx(ctx);
-    LOG("n_ctx: %d\n", n_ctx);
-
     if ((int) embd_inp.size() > n_ctx - 4) {
         LOG_TEE("%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4);
         return 1;
@@ -466,7 +467,7 @@ int main(int argc, char ** argv) {
     std::vector<llama_token> embd;
     std::vector<llama_token> embd_guidance;
 
-    const int n_vocab = llama_n_vocab(ctx);
+    const int n_vocab = llama_n_vocab(model);
 
     std::vector<llama_token_data> candidates;
     candidates.reserve(n_vocab);
@@ -576,7 +577,7 @@ int main(int argc, char ** argv) {
 
                 for (int i = 0; i < input_size; i += params.n_batch) {
                     int n_eval = std::min(input_size - i, params.n_batch);
-                    if (llama_decode(ctx_guidance, llama_batch_get_one(input_buf + i, n_eval, n_past_guidance, 0), params.n_threads)) {
+                    if (llama_decode(ctx_guidance, llama_batch_get_one(input_buf + i, n_eval, n_past_guidance, 0))) {
                         LOG_TEE("%s : failed to eval\n", __func__);
                         return 1;
                     }
@@ -593,7 +594,7 @@ int main(int argc, char ** argv) {
 
                 LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
 
-                if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0), params.n_threads)) {
+                if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) {
                     LOG_TEE("%s : failed to eval\n", __func__);
                     return 1;
                 }
index 790189af988768abece332435f114a4e68e6048f..0434ded234b183cbd22fb23089e403664dc91518 100644 (file)
@@ -108,7 +108,7 @@ int main(int argc, char ** argv) {
     fflush(stderr);
 
     const int n_ctx   = llama_n_ctx(ctx);
-    const int n_vocab = llama_n_vocab(ctx);
+    const int n_vocab = llama_n_vocab(model);
 
     std::vector<client> clients(n_clients);
     for (size_t i = 0; i < clients.size(); ++i) {
@@ -153,7 +153,7 @@ int main(int argc, char ** argv) {
             batch.logits[i] = false;
         }
 
-        if (llama_decode(ctx, batch, params.n_threads) != 0) {
+        if (llama_decode(ctx, batch) != 0) {
             LOG_TEE("%s: llama_decode() failed\n", __func__);
             return 1;
         }
@@ -272,7 +272,7 @@ int main(int argc, char ** argv) {
                 0, 0, 0, // unused
             };
 
-            const int ret = llama_decode(ctx, batch_view, params.n_threads);
+            const int ret = llama_decode(ctx, batch_view);
             if (ret != 0) {
                 if (n_batch == 1 || ret < 0) {
                     // if you get here, it means the KV cache is full - try increasing it via the context size
index de08bd4a185b4e874a625936cc0321a2ef0d50cd..7d0038bd4075735d73bbf7b19b997fdb45dea940 100644 (file)
@@ -150,16 +150,18 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
     // Output: `perplexity: 13.5106 [114/114]`
     // BOS tokens will be added for each chunk before eval
 
-    const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
+    const bool is_spm = llama_vocab_type(llama_get_model(ctx)) == LLAMA_VOCAB_TYPE_SPM;
     const bool add_bos = is_spm;
 
     fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
 
     std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, add_bos);
 
-    if (int(tokens.size()) < 2*params.n_ctx) {
-        fprintf(stderr, "%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*params.n_ctx,
-                params.n_ctx);
+    const int n_ctx = llama_n_ctx(ctx);
+
+    if (int(tokens.size()) < 2*n_ctx) {
+        fprintf(stderr, "%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*n_ctx,
+                n_ctx);
         fprintf(stderr, "%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size());
         return {std::move(tokens), 0., {}, {}};
     }
@@ -175,20 +177,20 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
         return {tokens, -1, logit_history, prob_history};
     }
 
-    const int calc_chunk = params.n_ctx;
+    const int calc_chunk = n_ctx;
 
     fprintf(stderr, "%s: have %zu tokens. Calculation chunk = %d\n", __func__, tokens.size(), calc_chunk);
 
     if (int(tokens.size()) <= calc_chunk) {
         fprintf(stderr, "%s: there are only %zu tokens, this is not enough for a context size of %d and stride %d\n",__func__,
-                tokens.size(), params.n_ctx, params.ppl_stride);
+                tokens.size(), n_ctx, params.ppl_stride);
         return {tokens, -1, logit_history, prob_history};
     }
 
     const int n_chunk_max = (tokens.size() - calc_chunk + params.ppl_stride - 1)  / params.ppl_stride;
 
     const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
-    const int n_vocab = llama_n_vocab(ctx);
+    const int n_vocab = llama_n_vocab(llama_get_model(ctx));
     const int n_batch = params.n_batch;
 
     int count = 0;
@@ -215,7 +217,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
             const int batch_size  = std::min(end - batch_start, n_batch);
 
             //fprintf(stderr, "    Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
-            if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0), params.n_threads)) {
+            if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
                 //fprintf(stderr, "%s : failed to eval\n", __func__);
                 return {tokens, -1, logit_history, prob_history};
             }
@@ -250,7 +252,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
         }
 
         //fprintf(stderr, "%s: using tokens %d...%d\n",__func__,params.n_ctx - params.ppl_stride + start, params.n_ctx + start);
-        for (int j = params.n_ctx - params.ppl_stride - 1; j < params.n_ctx - 1; ++j) {
+        for (int j = n_ctx - params.ppl_stride - 1; j < n_ctx - 1; ++j) {
 
             // Calculate probability of next token, given the previous ones.
             const std::vector<float> tok_logits(
@@ -287,8 +289,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
     // Output: `perplexity: 13.5106 [114/114]`
     // BOS tokens will be added for each chunk before eval
 
-    const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
+    const bool is_spm = llama_vocab_type(llama_get_model(ctx)) == LLAMA_VOCAB_TYPE_SPM;
     const bool add_bos = is_spm;
+    const int n_ctx = llama_n_ctx(ctx);
 
     auto tim1 = std::chrono::high_resolution_clock::now();
     fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
@@ -298,9 +301,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
     auto tim2 = std::chrono::high_resolution_clock::now();
     fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count());
 
-    if (int(tokens.size()) < 2*params.n_ctx) {
-        fprintf(stderr, "%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*params.n_ctx,
-                params.n_ctx);
+    if (int(tokens.size()) < 2*n_ctx) {
+        fprintf(stderr, "%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*n_ctx,
+                n_ctx);
         fprintf(stderr, "%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size());
         return {std::move(tokens), 0., {}, {}};
     }
@@ -311,10 +314,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
     std::vector<float> prob_history;
     prob_history.resize(tokens.size());
 
-    const int n_chunk_max = tokens.size() / params.n_ctx;
+    const int n_chunk_max = tokens.size() / n_ctx;
 
     const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
-    const int n_vocab = llama_n_vocab(ctx);
+    const int n_vocab = llama_n_vocab(llama_get_model(ctx));
     const int n_batch = params.n_batch;
 
     int count = 0;
@@ -326,10 +329,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
     std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
 
     for (int i = 0; i < n_chunk; ++i) {
-        const int start =     i * params.n_ctx;
-        const int end   = start + params.n_ctx;
+        const int start =     i * n_ctx;
+        const int end   = start + n_ctx;
 
-        const int num_batches = (params.n_ctx + n_batch - 1) / n_batch;
+        const int num_batches = (n_ctx + n_batch - 1) / n_batch;
 
         std::vector<float> logits;
 
@@ -350,7 +353,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
                 tokens[batch_start] = llama_token_bos(ctx);
             }
 
-            if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0), params.n_threads)) {
+            if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
                 fprintf(stderr, "%s : failed to eval\n", __func__);
                 return {tokens, -1, logit_history, prob_history};
             }
@@ -358,7 +361,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
             // restore the original token in case it was set to BOS
             tokens[batch_start] = token_org;
 
-            const auto batch_logits = llama_get_logits(ctx);
+            const auto batch_logits = llama_get_logits(ctx);
             logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
         }
 
@@ -387,10 +390,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
         // Example, we have a context window of 512, we will compute perplexity for each of the
         // last 256 tokens.  Then, we split the input up into context window size chunks to
         // process the entire prompt.
-        const int first = params.n_ctx/2;
-        process_logits(n_vocab, logits.data() + first*n_vocab, tokens.data() + start + first, params.n_ctx - 1 - first,
+        const int first = n_ctx/2;
+        process_logits(n_vocab, logits.data() + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
                        workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
-        count += params.n_ctx - first - 1;
+        count += n_ctx - first - 1;
 
         // perplexity is e^(average negative log-likelihood)
         if (params.ppl_output_type == 0) {
@@ -399,7 +402,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
             double av = nll/count;
             double av2 = nll2/count - av*av;
             if (av2 > 0) av2 = sqrt(av2/(count-1));
-            printf("%8d  %.4lf  %4lf  %4lf\n", i*params.n_ctx, std::exp(nll / count), av, av2);
+            printf("%8d  %.4lf  %4lf  %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
         }
         fflush(stdout);
     }
@@ -420,7 +423,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
 }
 
 static std::vector<float> hellaswag_evaluate_tokens(
-    llama_context * ctx, std::vector<int> & tokens, int n_past, int n_batch, int n_vocab, int n_thread
+    llama_context * ctx, std::vector<int> & tokens, int n_past, int n_batch, int n_vocab
 ) {
     std::vector<float> result;
     result.reserve(tokens.size() * n_vocab);
@@ -428,7 +431,7 @@ static std::vector<float> hellaswag_evaluate_tokens(
     for (size_t i_chunk = 0; i_chunk < n_chunk; ++i_chunk) {
         size_t n_tokens = tokens.size() - i_chunk * n_batch;
         n_tokens = std::min(n_tokens, size_t(n_batch));
-        if (llama_decode(ctx, llama_batch_get_one(tokens.data() + i_chunk * n_batch, n_tokens, n_past, 0), n_thread)) {
+        if (llama_decode(ctx, llama_batch_get_one(tokens.data() + i_chunk * n_batch, n_tokens, n_past, 0))) {
             fprintf(stderr, "%s : failed to eval\n", __func__);
             return {};
         }
@@ -475,7 +478,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
     size_t hs_task_count = prompt_lines.size()/6;
     fprintf(stderr, "%s : loaded %zu tasks from prompt.\n", __func__, hs_task_count);
 
-    const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
+    const bool is_spm = llama_vocab_type(llama_get_model(ctx)) == LLAMA_VOCAB_TYPE_SPM;
     fprintf(stderr, "================================= is_spm = %d\n", is_spm);
 
     // This is needed as usual for LLaMA models
@@ -530,7 +533,8 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
     printf("\ntask\tacc_norm\n");
 
     double acc = 0.0f;
-    const int n_vocab = llama_n_vocab(ctx);
+    const int n_vocab = llama_n_vocab(llama_get_model(ctx));
+    const int n_ctx = llama_n_ctx(ctx);
 
     std::vector<std::vector<int>> ending_tokens(4);
 
@@ -558,7 +562,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
         auto query_size = query_embd.size();
 
         // Stop if query wont fit the ctx window
-        if (query_size > (size_t)params.n_ctx) {
+        if (query_size > (size_t)n_ctx) {
             fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size);
             return;
         }
@@ -571,7 +575,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
         // clear the KV cache
         llama_kv_cache_tokens_rm(ctx, -1, -1);
 
-        auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab, params.n_threads);
+        auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab);
         if (logits.empty()) {
             fprintf(stderr, "%s : failed to eval\n", __func__);
             return;
@@ -608,7 +612,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
             query_size = query_embd.size();
 
             // Stop if query wont fit the ctx window
-            if (context_size + query_size > (size_t)params.n_ctx) {
+            if (context_size + query_size > (size_t)n_ctx) {
                 fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size);
                 return;
             }
@@ -620,7 +624,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
             //}
 
             // Evaluate the query
-            logits = hellaswag_evaluate_tokens(ctx, query_embd, context_size, params.n_batch, n_vocab, params.n_threads);
+            logits = hellaswag_evaluate_tokens(ctx, query_embd, context_size, params.n_batch, n_vocab);
             if (logits.empty()) {
                 fprintf(stderr, "%s : failed to eval\n", __func__);
                 return;
@@ -716,7 +720,7 @@ int main(int argc, char ** argv) {
         return 1;
     }
 
-    const int n_ctx_train = llama_n_ctx_train(ctx);
+    const int n_ctx_train = llama_n_ctx_train(model);
     if (params.n_ctx > n_ctx_train) {
         fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
                 __func__, n_ctx_train, params.n_ctx);
@@ -725,8 +729,7 @@ int main(int argc, char ** argv) {
     // print system information
     {
         fprintf(stderr, "\n");
-        fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
-                params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
+        fprintf(stderr, "%s\n", get_system_info(params).c_str());
     }
 
     struct results_perplexity results;
index 94edb94d958d06d14e18436423a977f04b86b49f..dd76b1ceef134d2cdafe01c9c458a6bd32ee2abb 100644 (file)
@@ -309,21 +309,22 @@ int main(int argc, char ** argv) {
     llama_context * ctx;
 
     {
-        auto lparams = llama_context_default_params();
+        auto mparams = llama_model_default_params();
+        mparams.use_mlock  = false;
 
-        lparams.n_ctx      = 256;
-        lparams.seed       = 1;
-        lparams.f16_kv     = false;
-        lparams.use_mlock  = false;
-
-        model = llama_load_model_from_file(params.model.c_str(), lparams);
+        model = llama_load_model_from_file(params.model.c_str(), mparams);
 
         if (model == NULL) {
             fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
             return 1;
         }
 
-        ctx = llama_new_context_with_model(model, lparams);
+        auto cparams = llama_context_default_params();
+        cparams.n_ctx      = 256;
+        cparams.seed       = 1;
+        cparams.f16_kv     = false;
+
+        ctx = llama_new_context_with_model(model, cparams);
 
         if (ctx == NULL) {
             fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str());
index 6e4d40b9e1d6d3ec9db190d20dc76994dcec1362..acc6dbdfd07d05f7a7462078daf8cbd54dd0965b 100644 (file)
@@ -23,23 +23,17 @@ int main(int argc, char ** argv) {
         params.n_predict = 16;
     }
 
-    auto lparams = llama_context_default_params();
-
-    lparams.n_ctx     = params.n_ctx;
-    lparams.seed      = params.seed;
-    lparams.f16_kv    = params.memory_f16;
-    lparams.use_mmap  = params.use_mmap;
-    lparams.use_mlock = params.use_mlock;
-
     auto n_past = 0;
     auto last_n_tokens_data = std::vector<llama_token>(params.repeat_last_n, 0);
 
     // init
-    auto * model = llama_load_model_from_file(params.model.c_str(), lparams);
+    llama_model * model;
+    llama_context * ctx;
+
+    std::tie(model, ctx) = llama_init_from_gpt_params( params );
     if (model == nullptr) {
         return 1;
     }
-    auto * ctx = llama_new_context_with_model(model, lparams);
     if (ctx == nullptr) {
         llama_free_model(model);
         return 1;
@@ -54,7 +48,7 @@ int main(int argc, char ** argv) {
     }
 
     // evaluate prompt
-    llama_decode(ctx, llama_batch_get_one(tokens.data(), n_prompt_tokens, n_past, 0), params.n_threads);
+    llama_decode(ctx, llama_batch_get_one(tokens.data(), n_prompt_tokens, n_past, 0));
 
     last_n_tokens_data.insert(last_n_tokens_data.end(), tokens.data(), tokens.data() + n_prompt_tokens);
     n_past += n_prompt_tokens;
@@ -79,7 +73,7 @@ int main(int argc, char ** argv) {
 
     for (auto i = 0; i < params.n_predict; i++) {
         auto * logits = llama_get_logits(ctx);
-        auto n_vocab = llama_n_vocab(ctx);
+        auto n_vocab = llama_n_vocab(model);
         std::vector<llama_token_data> candidates;
         candidates.reserve(n_vocab);
         for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
@@ -91,7 +85,7 @@ int main(int argc, char ** argv) {
         last_n_tokens_data.push_back(next_token);
 
         printf("%s", next_token_str.c_str());
-        if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0), params.n_threads)) {
+        if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0))) {
             fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
             llama_free(ctx);
             llama_free_model(model);
@@ -106,7 +100,7 @@ int main(int argc, char ** argv) {
     llama_free(ctx);
 
     // make new context
-    auto * ctx2 = llama_new_context_with_model(model, lparams);
+    auto * ctx2 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
 
     // Load state (rng, logits, embedding and kv_cache) from file
     {
@@ -139,7 +133,7 @@ int main(int argc, char ** argv) {
     // second run
     for (auto i = 0; i < params.n_predict; i++) {
         auto * logits = llama_get_logits(ctx2);
-        auto n_vocab = llama_n_vocab(ctx2);
+        auto n_vocab = llama_n_vocab(model);
         std::vector<llama_token_data> candidates;
         candidates.reserve(n_vocab);
         for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
@@ -151,7 +145,7 @@ int main(int argc, char ** argv) {
         last_n_tokens_data.push_back(next_token);
 
         printf("%s", next_token_str.c_str());
-        if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0), params.n_threads)) {
+        if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0))) {
             fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
             llama_free(ctx2);
             llama_free_model(model);
index 5176080463839618087ff6dd90bfd6ef7d6ba907..d409e8408f192df6c5ce7f331a73103e3bdf06fe 100644 (file)
@@ -4,14 +4,14 @@ This example demonstrates a simple HTTP API server and a simple web front end to
 
 Command line options:
 
--   `--threads N`, `-t N`: Set the number of threads to use during computation.
+-   `--threads N`, `-t N`: Set the number of threads to use during generation.
+-   `-tb N, --threads-batch N`: Set the number of threads to use during batch and prompt processing. If not specified, the number of threads will be set to the number of threads used for generation.
 -   `-m FNAME`, `--model FNAME`: Specify the path to the LLaMA model file (e.g., `models/7B/ggml-model.gguf`).
 -   `-m ALIAS`, `--alias ALIAS`: Set an alias for the model. The alias will be returned in API responses.
 -   `-c N`, `--ctx-size N`: Set the size of the prompt context. The default is 512, but LLaMA models were built with a context of 2048, which will provide better results for longer input/inference. The size may differ in other models, for example, baichuan models were build with a context of 4096.
 -   `-ngl N`, `--n-gpu-layers N`: When compiled with appropriate support (currently CLBlast or cuBLAS), this option allows offloading some layers to the GPU for computation. Generally results in increased performance.
 -   `-mg i, --main-gpu i`: When using multiple GPUs this option controls which GPU is used for small tensors for which the overhead of splitting the computation across all GPUs is not worthwhile. The GPU in question will use slightly more VRAM to store a scratch buffer for temporary results. By default GPU 0 is used. Requires cuBLAS.
 -   `-ts SPLIT, --tensor-split SPLIT`: When using multiple GPUs this option controls how large tensors should be split across all GPUs. `SPLIT` is a comma-separated list of non-negative values that assigns the proportion of data that each GPU should get in order. For example, "3,2" will assign 60% of the data to GPU 0 and 40% to GPU 1. By default the data is split in proportion to VRAM but this may not be optimal for performance. Requires cuBLAS.
--   `-lv, --low-vram`: Do not allocate a VRAM scratch buffer for holding temporary results. Reduces VRAM usage at the cost of performance, particularly prompt processing speed. Requires cuBLAS.
 -   `-b N`, `--batch-size N`: Set the batch size for prompt processing. Default: `512`.
 -   `--memory-f32`: Use 32-bit floats instead of 16-bit floats for memory key+value. Not recommended.
 -   `--mlock`: Lock the model in memory, preventing it from being swapped out when memory-mapped.
index 9b96248324a6eb35f4878f979bf914e39695dc57..fe9a4255e768006e5abd036edbb529205e190cf5 100644 (file)
@@ -200,6 +200,7 @@ struct llama_server_context
     llama_model *model = nullptr;
     llama_context *ctx = nullptr;
     gpt_params params;
+    int n_ctx;
 
     grammar_parser::parse_state parsed_grammar;
     llama_grammar *grammar = nullptr;
@@ -239,7 +240,7 @@ struct llama_server_context
         num_prompt_tokens = 0;
         num_tokens_predicted = 0;
         generated_text = "";
-        generated_text.reserve(params.n_ctx);
+        generated_text.reserve(n_ctx);
         generated_token_probs.clear();
         truncated = false;
         stopped_eos = false;
@@ -265,8 +266,8 @@ struct llama_server_context
             LOG_ERROR("unable to load model", {{"model", params_.model}});
             return false;
         }
-
-        last_n_tokens.resize(params.n_ctx);
+        n_ctx = llama_n_ctx(ctx);
+        last_n_tokens.resize(n_ctx);
         std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
         return true;
     }
@@ -351,19 +352,19 @@ struct llama_server_context
         {
             params.n_keep = (int)num_prompt_tokens;
         }
-        params.n_keep = std::min(params.n_ctx - 4, params.n_keep);
+        params.n_keep = std::min(n_ctx - 4, params.n_keep);
 
         // if input prompt is too big, truncate like normal
-        if (num_prompt_tokens >= (size_t)params.n_ctx)
+        if (num_prompt_tokens >= (size_t)n_ctx)
         {
-            const int n_left = (params.n_ctx - params.n_keep) / 2;
+            const int n_left = (n_ctx - params.n_keep) / 2;
             std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
             const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
             new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
-            std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin());
+            std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), last_n_tokens.begin());
 
             LOG_VERBOSE("input truncated", {
-                                               {"n_ctx", params.n_ctx},
+                                               {"n_ctx", n_ctx},
                                                {"n_keep", params.n_keep},
                                                {"n_left", n_left},
                                                {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
@@ -413,7 +414,7 @@ struct llama_server_context
         completion_token_output result;
         result.tok = -1;
 
-        if (embd.size() >= (size_t)params.n_ctx)
+        if (embd.size() >= (size_t)n_ctx)
         {
             // Shift context
 
@@ -433,7 +434,7 @@ struct llama_server_context
 
             truncated = true;
             LOG_VERBOSE("input truncated", {
-                                               {"n_ctx", params.n_ctx},
+                                               {"n_ctx", n_ctx},
                                                {"n_keep", params.n_keep},
                                                {"n_left", n_left},
                                            });
@@ -447,12 +448,11 @@ struct llama_server_context
                 n_eval = params.n_batch;
             }
 
-            if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval, n_past, 0), params.n_threads))
+            if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval, n_past, 0)))
             {
                 LOG_ERROR("failed to eval", {
                                                 {"n_eval", n_eval},
                                                 {"n_past", n_past},
-                                                {"n_threads", params.n_threads},
                                                 {"embd", tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend())},
                                             });
                 has_next_token = false;
@@ -470,11 +470,11 @@ struct llama_server_context
 
         // out of user input, sample next token
         const float temp = params.temp;
-        const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
+        const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(model) : params.top_k;
         const float top_p = params.top_p;
         const float tfs_z = params.tfs_z;
         const float typical_p = params.typical_p;
-        const int32_t repeat_last_n = params.repeat_last_n < 0 ? params.n_ctx : params.repeat_last_n;
+        const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
         const float repeat_penalty = params.repeat_penalty;
         const float alpha_presence = params.presence_penalty;
         const float alpha_frequency = params.frequency_penalty;
@@ -486,7 +486,7 @@ struct llama_server_context
 
         {
             auto *logits = llama_get_logits(ctx);
-            auto n_vocab = llama_n_vocab(ctx);
+            auto n_vocab = llama_n_vocab(model);
 
             // Apply params.logit_bias map
             for (const auto &it : params.logit_bias)
@@ -505,7 +505,7 @@ struct llama_server_context
 
             // Apply penalties
             float nl_logit = logits[llama_token_nl(ctx)];
-            auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), params.n_ctx);
+            auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
             llama_sample_repetition_penalty(ctx, &candidates_p,
                                             last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
                                             last_n_repeat, repeat_penalty);
@@ -690,7 +690,7 @@ struct llama_server_context
 
     std::vector<float> getEmbedding()
     {
-        static const int n_embd = llama_n_embd(ctx);
+        static const int n_embd = llama_n_embd(model);
         if (!params.embedding)
         {
             LOG_WARNING("embedding disabled", {
@@ -734,7 +734,6 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
     printf("  -ts SPLIT --tensor-split SPLIT\n");
     printf("                        how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
     printf("  -mg i, --main-gpu i   the GPU to use for scratch and small tensors\n");
-    printf("  -lv, --low-vram       don't allocate VRAM scratch buffer\n");
     printf("  -nommq, --no-mul-mat-q\n");
     printf("                        use cuBLAS instead of custom mul_mat_q CUDA kernels.\n");
     printf("                        Not recommended since this is both slower and uses more VRAM.\n");
@@ -918,14 +917,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
             }
 #else
             LOG_WARNING("llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n", {});
-#endif // GGML_USE_CUBLAS
-        }
-        else if (arg == "--low-vram" || arg == "-lv")
-        {
-#ifdef GGML_USE_CUBLAS
-            params.low_vram = true;
-#else
-            LOG_WARNING("warning: llama.cpp was compiled without cuBLAS. It is not possible to set lower vram usage.\n", {});
 #endif // GGML_USE_CUBLAS
         }
         else if (arg == "--no-mul-mat-q" || arg == "-nommq")
@@ -1031,7 +1022,7 @@ static json format_generation_settings(llama_server_context &llama)
                             eos_bias->second < 0.0f && std::isinf(eos_bias->second);
 
     return json{
-        {"n_ctx", llama.params.n_ctx},
+        {"n_ctx", llama.n_ctx},
         {"model", llama.params.model_alias},
         {"seed", llama.params.seed},
         {"temp", llama.params.temp},
@@ -1191,7 +1182,7 @@ static void parse_options_completion(const json &body, llama_server_context &lla
     const auto &logit_bias = body.find("logit_bias");
     if (logit_bias != body.end() && logit_bias->is_array())
     {
-        const int n_vocab = llama_n_vocab(llama.ctx);
+        const int n_vocab = llama_n_vocab(llama.model);
         for (const auto &el : *logit_bias)
         {
             if (el.is_array() && el.size() == 2 && el[0].is_number_integer())
@@ -1324,6 +1315,7 @@ int main(int argc, char **argv)
                             {"commit", BUILD_COMMIT}});
     LOG_INFO("system info", {
                                 {"n_threads", params.n_threads},
+                                {"n_threads_batch", params.n_threads_batch},
                                 {"total_threads", std::thread::hardware_concurrency()},
                                 {"system_info", llama_print_system_info()},
                             });
@@ -1387,7 +1379,7 @@ int main(int argc, char **argv)
             if (llama.params.n_beams) {
                 // Fill llama.generated_token_probs vector with final beam.
                 llama_beam_search(llama.ctx, beam_search_callback, &llama, llama.params.n_beams,
-                                  llama.n_past, llama.n_remain, llama.params.n_threads);
+                                  llama.n_past, llama.n_remain);
                 // Translate llama.generated_token_probs to llama.generated_text.
                 append_to_generated_text_from_generated_token_probs(llama);
             } else {
index 1616a4a7581a30067d720b64bdc1884d64f9b8b2..24fb16b78d0581e92f6b3ec04fa77b2d44ba7208 100644 (file)
@@ -33,18 +33,28 @@ int main(int argc, char ** argv) {
 
     llama_backend_init(params.numa);
 
-    llama_context_params ctx_params = llama_context_default_params();
+    // initialize the model
 
-    ctx_params.seed  = 1234;
-    ctx_params.n_ctx = 2048;
+    llama_model_params model_params = llama_model_default_params();
 
-    llama_model * model = llama_load_model_from_file(params.model.c_str(), ctx_params);
+    // model_params.n_gpu_layers = 99; // offload all layers to the GPU
+
+    llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);
 
     if (model == NULL) {
         fprintf(stderr , "%s: error: unable to load model\n" , __func__);
         return 1;
     }
 
+    // initialize the context
+
+    llama_context_params ctx_params = llama_context_default_params();
+
+    ctx_params.seed  = 1234;
+    ctx_params.n_ctx = 2048;
+    ctx_params.n_threads = params.n_threads;
+    ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
+
     llama_context * ctx = llama_new_context_with_model(model, ctx_params);
 
     if (ctx == NULL) {
@@ -97,7 +107,7 @@ int main(int argc, char ** argv) {
     // llama_decode will output logits only for the last token of the prompt
     batch.logits[batch.n_tokens - 1] = true;
 
-    if (llama_decode(ctx, batch, params.n_threads) != 0) {
+    if (llama_decode(ctx, batch) != 0) {
         LOG_TEE("%s: llama_decode() failed\n", __func__);
         return 1;
     }
@@ -112,7 +122,7 @@ int main(int argc, char ** argv) {
     while (n_cur <= n_len) {
         // sample the next token
         {
-            auto   n_vocab = llama_n_vocab(ctx);
+            auto   n_vocab = llama_n_vocab(model);
             auto * logits  = llama_get_logits_ith(ctx, batch.n_tokens - 1);
 
             std::vector<llama_token_data> candidates;
@@ -154,7 +164,7 @@ int main(int argc, char ** argv) {
         n_cur += 1;
 
         // evaluate the current batch with the transformer model
-        if (llama_decode(ctx, batch, params.n_threads)) {
+        if (llama_decode(ctx, batch)) {
             fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
             return 1;
         }
index 2445d78dc978889fa83bc3dab44ebd3931246400..c5e5b234f0f5cc45037c041a69e81912e1c7ce4c 100644 (file)
@@ -70,16 +70,16 @@ int main(int argc, char ** argv) {
     const auto t_enc_start = ggml_time_us();
 
     // eval the prompt with both models
-    llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1, 0,           0), params.n_threads);
-    llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(),           1, n_input - 1, 0), params.n_threads);
-    llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input,     0,           0), params.n_threads);
+    llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1, 0,           0));
+    llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(),           1, n_input - 1, 0));
+    llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input,     0,           0));
 
     const auto t_enc_end = ggml_time_us();
 
     // the 2 models should have the same vocab
     const int n_ctx   = llama_n_ctx(ctx_tgt);
-    const int n_vocab = llama_n_vocab(ctx_tgt);
-    //GGML_ASSERT(n_vocab == llama_n_vocab(ctx_dft));
+    const int n_vocab = llama_n_vocab(model_tgt);
+    //GGML_ASSERT(n_vocab == llama_n_vocab(model_dft));
 
     // how many tokens to draft each time
     int n_draft = params.n_draft;
@@ -173,7 +173,7 @@ int main(int argc, char ** argv) {
             }
 
             llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, n_ctx);
-            llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0), params.n_threads);
+            llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0));
             ++n_past_dft;
 
             // heuristic for n_draft
@@ -258,7 +258,7 @@ int main(int argc, char ** argv) {
 
             // evaluate the drafted token on the draft model
             llama_kv_cache_seq_rm(ctx_dft, 0, n_past_cur, n_ctx);
-            llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0), params.n_threads);
+            llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0));
             ++n_past_cur;
 
             if (grammar_dft != NULL) {
@@ -268,7 +268,7 @@ int main(int argc, char ** argv) {
 
         // evaluate the target model on the drafted tokens
         llama_kv_cache_seq_rm(ctx_tgt, 0, n_past_tgt, n_ctx);
-        llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0), params.n_threads);
+        llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0));
         ++n_past_tgt;
 
         // the first token is always proposed by the traget model before the speculation loop
index d5205aff6dda51dfd9fb9bd81439651bc3043570..a9cf8a38139e36c79129de72c9758bf9e4f76d32 100644 (file)
@@ -976,14 +976,16 @@ int main(int argc, char ** argv) {
     printf("%s: seed: %u\n", __func__, params.common.seed);
     srand(params.common.seed);
 
-    struct llama_context_params llama_params = llama_context_default_params();
-    llama_params.vocab_only = true;
+    struct llama_model_params mparams = llama_model_default_params();
+    mparams.vocab_only = true;
 
-    struct llama_model * lmodel = llama_load_model_from_file(params.fn_vocab_model, llama_params);
-    struct llama_context * lctx = llama_new_context_with_model(lmodel, llama_params);
+    struct llama_context_params cparams = llama_context_default_params();
+
+    struct llama_model * lmodel = llama_load_model_from_file(params.fn_vocab_model, mparams);
+    struct llama_context * lctx = llama_new_context_with_model(lmodel, cparams);
 
     struct my_llama_model model;
-    model.hparams.n_vocab = llama_n_vocab(lctx);
+    model.hparams.n_vocab = llama_n_vocab(lmodel);
     model.hparams.n_ctx   = params.common.n_ctx;
     model.hparams.n_embd  = params.n_embd;
     model.hparams.n_head  = params.n_head;
index 29fb7abd4296aa1b225bf53c32afa5334bf2360c..86d1fe203a4653d7c5ea1697f4cf7de85ea4b240 100644 (file)
@@ -1,3 +1,4 @@
+#include <algorithm>
 #include <cstddef>
 #include <cstdint>
 #include <limits>
@@ -467,7 +468,7 @@ static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
 static bool g_mul_mat_q = true;
 
 static void * g_scratch_buffer = nullptr;
-static size_t g_scratch_size = 1024*1024*1024; // 1 GB by default
+static size_t g_scratch_size = 0; // disabled by default
 static size_t g_scratch_offset = 0;
 
 static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
@@ -6738,14 +6739,10 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te
     const int64_t ne1 = dst->ne[1];
 
     // TODO: find the optimal values for these
-    if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
-        src1->type == GGML_TYPE_F32 &&
-        dst->type == GGML_TYPE_F32 &&
-        (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
-        return true;
-    }
-
-    return false;
+    return (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
+            src1->type == GGML_TYPE_F32 &&
+             dst->type == GGML_TYPE_F32 &&
+            (ne0 >= 32 && ne1 >= 32 && ne10 >= 32);
 }
 
 static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
@@ -6901,6 +6898,8 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
         ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
                               ne10, ne11, nb10, nb11, nb12, main_stream);
     } else {
+        fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
+                ggml_type_name(src0->type), ggml_type_name(src1->type));
         GGML_ASSERT(false);
     }
 
@@ -7198,7 +7197,12 @@ void ggml_cuda_set_mul_mat_q(const bool mul_mat_q) {
 }
 
 void ggml_cuda_set_scratch_size(const size_t scratch_size) {
-    g_scratch_size = scratch_size;
+    // this is a hack to not completely break llama.cpp when using multiple models or contexts simultaneously
+    // it still won't always work as expected, but it's better than nothing
+    if (scratch_size > g_scratch_size) {
+        ggml_cuda_free_scratch();
+    }
+    g_scratch_size = std::max(g_scratch_size, scratch_size);
 }
 
 void ggml_cuda_free_scratch() {
index 7668cb1a7edbbddc3c26551234164a1bdc3ebc41..685712d172666d42e6200b166e5cd1d7836c3f76 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -887,10 +887,10 @@ static void llama_nop(struct ggml_tensor * tensor) { // don't offload by default
 
 static std::string llama_token_to_str(const struct llama_context * ctx, llama_token token) {
     std::vector<char> result(8, 0);
-    const int n_tokens = llama_token_to_piece(ctx, token, result.data(), result.size());
+    const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
     if (n_tokens < 0) {
         result.resize(-n_tokens);
-        int check = llama_token_to_piece(ctx, token, result.data(), result.size());
+        int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
         GGML_ASSERT(check == -n_tokens);
     } else {
         result.resize(n_tokens);
@@ -931,9 +931,9 @@ static const size_t MB = kB*kB;
 static const size_t GB = kB*kB*kB;
 
 struct llama_hparams {
+    bool     vocab_only;
     uint32_t n_vocab;
     uint32_t n_ctx_train; // context size the model was trained on
-    uint32_t n_ctx;       // context size used during inference
     uint32_t n_embd;
     uint32_t n_head;
     uint32_t n_head_kv;
@@ -944,8 +944,8 @@ struct llama_hparams {
     float f_norm_eps;
     float f_norm_rms_eps;
 
-    float rope_freq_base;
-    float rope_freq_scale;
+    float rope_freq_base_train;
+    float rope_freq_scale_train;
 
     bool operator!=(const llama_hparams & other) const {
         return static_cast<bool>(memcmp(this, &other, sizeof(llama_hparams))); // NOLINT
@@ -962,15 +962,18 @@ struct llama_hparams {
     uint32_t n_embd_gqa() const {
         return n_embd/n_gqa();
     }
+};
 
-    size_t kv_size() const {
-        size_t result = 2ull;
-        result *= (size_t) n_embd_gqa();
-        result *= (size_t) n_ctx;
-        result *= (size_t) n_layer;
-        result *= sizeof(ggml_fp16_t);
-        return result;
-    }
+struct llama_cparams {
+    uint32_t n_ctx;       // context size used during inference
+    uint32_t n_batch;
+    uint32_t n_threads;       // number of threads to use for generation
+    uint32_t n_threads_batch; // number of threads to use for batch processing
+
+    float rope_freq_base;
+    float rope_freq_scale;
+
+    bool mul_mat_q;
 };
 
 struct llama_layer {
@@ -1148,11 +1151,8 @@ struct llama_model {
 };
 
 struct llama_context {
-    llama_context(const llama_model & model) : model(model), t_load_us(model.t_load_us), t_start_us(model.t_start_us) {}
+    llama_context(const llama_model & model) : model(model), t_start_us(model.t_start_us), t_load_us(model.t_load_us) {}
     ~llama_context() {
-        if (model_owner) {
-            delete &model;
-        }
 #ifdef GGML_USE_METAL
         if (ctx_metal) {
             ggml_metal_free(ctx_metal);
@@ -1163,27 +1163,26 @@ struct llama_context {
         }
     }
 
+    llama_cparams cparams;
+
+    const llama_model & model;
+
+    // key + value cache for the self attention
+    struct llama_kv_cache kv_self;
+
     std::mt19937 rng;
 
     bool has_evaluated_once = false;
 
+    int64_t t_start_us;
+    int64_t t_load_us;
     int64_t t_sample_us = 0;
-    int64_t t_eval_us   = 0;
     int64_t t_p_eval_us = 0;
+    int64_t t_eval_us   = 0;
 
     int32_t n_sample = 0; // number of tokens sampled
-    int32_t n_eval   = 0; // number of eval calls
     int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
-
-    const llama_model & model;
-
-    bool model_owner = false;
-
-    int64_t t_load_us;
-    int64_t t_start_us;
-
-    // key + value cache for the self attention
-    struct llama_kv_cache kv_self;
+    int32_t n_eval   = 0; // number of eval calls
 
     // decode output (2-dimensional array: [n_tokens][n_vocab])
     std::vector<float> logits;
@@ -1218,10 +1217,10 @@ static bool llama_kv_cache_init(
         const struct llama_hparams & hparams,
              struct llama_kv_cache & cache,
                          ggml_type   wtype,
+                          uint32_t   n_ctx,
                                int   n_gpu_layers) {
     const uint32_t n_embd  = hparams.n_embd_gqa();
     const uint32_t n_layer = hparams.n_layer;
-    const uint32_t n_ctx   = hparams.n_ctx;
 
     const int64_t n_mem      = n_layer*n_ctx;
     const int64_t n_elements = n_embd*n_mem;
@@ -1255,11 +1254,20 @@ static bool llama_kv_cache_init(
 
     (void) n_gpu_layers;
 #ifdef GGML_USE_CUBLAS
+    size_t vram_kv_cache = 0;
+
     if (n_gpu_layers > (int)n_layer + 1) {
         ggml_cuda_assign_buffers_no_scratch(cache.v);
+        LLAMA_LOG_INFO("%s: offloading v cache to GPU\n", __func__);
+        vram_kv_cache += ggml_nbytes(cache.v);
     }
     if (n_gpu_layers > (int)n_layer + 2) {
         ggml_cuda_assign_buffers_no_scratch(cache.k);
+        LLAMA_LOG_INFO("%s: offloading k cache to GPU\n", __func__);
+        vram_kv_cache += ggml_nbytes(cache.k);
+    }
+    if (vram_kv_cache > 0) {
+        LLAMA_LOG_INFO("%s: VRAM kv self = %.2f MB\n", __func__, vram_kv_cache / 1024.0 / 1024.0);
     }
 #endif // GGML_USE_CUBLAS
 
@@ -1715,7 +1723,7 @@ struct llama_model_loader {
                         lmlock->grow_to(size_lock);
                     }
                     break;
-#if defined(GGML_USE_CUBLAS)
+#ifdef GGML_USE_CUBLAS
                 case GGML_BACKEND_GPU:
                 case GGML_BACKEND_GPU_SPLIT:
                     // old code:
@@ -1748,7 +1756,15 @@ struct llama_model_loader {
 // load LLaMA models
 //
 
-static std::string llama_model_ftype_name(enum llama_ftype ftype) {
+static std::string llama_model_arch_name(llm_arch arch) {
+    auto it = LLM_ARCH_NAMES.find(arch);
+    if (it == LLM_ARCH_NAMES.end()) {
+        return "unknown";
+    }
+    return it->second;
+}
+
+static std::string llama_model_ftype_name(llama_ftype ftype) {
     if (ftype & LLAMA_FTYPE_GUESSED) {
         return llama_model_ftype_name((enum llama_ftype) (ftype & ~LLAMA_FTYPE_GUESSED)) + " (guessed)";
     }
@@ -1804,10 +1820,7 @@ static void llm_load_arch(llama_model_loader & ml, llama_model & model) {
 
 static void llm_load_hparams(
         llama_model_loader & ml,
-        llama_model & model,
-        int n_ctx,
-        float rope_freq_base,
-        float rope_freq_scale) {
+        llama_model & model) {
     struct gguf_context * ctx = ml.ctx_gguf;
 
     const auto kv = LLM_KV(model.arch);
@@ -1818,29 +1831,25 @@ static void llm_load_hparams(
     GGUF_GET_KEY(ctx, model.name, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_GENERAL_NAME));
 
     // get hparams kv
-    GGUF_GET_KEY(ctx, hparams.n_vocab,        gguf_get_arr_n,   GGUF_TYPE_ARRAY,   true, kv(LLM_KV_TOKENIZER_LIST));
-    GGUF_GET_KEY(ctx, hparams.n_ctx_train,    gguf_get_val_u32, GGUF_TYPE_UINT32,  true, kv(LLM_KV_CONTEXT_LENGTH));
-    GGUF_GET_KEY(ctx, hparams.n_embd,         gguf_get_val_u32, GGUF_TYPE_UINT32,  true, kv(LLM_KV_EMBEDDING_LENGTH));
-    GGUF_GET_KEY(ctx, hparams.n_ff,           gguf_get_val_u32, GGUF_TYPE_UINT32,  true, kv(LLM_KV_FEED_FORWARD_LENGTH));
-    GGUF_GET_KEY(ctx, hparams.n_head,         gguf_get_val_u32, GGUF_TYPE_UINT32,  true, kv(LLM_KV_ATTENTION_HEAD_COUNT));
-    GGUF_GET_KEY(ctx, hparams.n_layer,        gguf_get_val_u32, GGUF_TYPE_UINT32,  true, kv(LLM_KV_BLOCK_COUNT));
+    GGUF_GET_KEY(ctx, hparams.n_vocab,        gguf_get_arr_n,   GGUF_TYPE_ARRAY,  true, kv(LLM_KV_TOKENIZER_LIST));
+    GGUF_GET_KEY(ctx, hparams.n_ctx_train,    gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_CONTEXT_LENGTH));
+    GGUF_GET_KEY(ctx, hparams.n_embd,         gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_EMBEDDING_LENGTH));
+    GGUF_GET_KEY(ctx, hparams.n_ff,           gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_FEED_FORWARD_LENGTH));
+    GGUF_GET_KEY(ctx, hparams.n_head,         gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ATTENTION_HEAD_COUNT));
+    GGUF_GET_KEY(ctx, hparams.n_layer,        gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_BLOCK_COUNT));
 
     // n_head_kv is optional, default to n_head
     hparams.n_head_kv = hparams.n_head;
     GGUF_GET_KEY(ctx, hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV));
 
     // rope_freq_base (optional)
-    if (rope_freq_base == 0.0f) {
-        rope_freq_base = 10000.0f;
-        GGUF_GET_KEY(ctx, rope_freq_base, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE));
-    }
+    hparams.rope_freq_base_train = 10000.0f;
+    GGUF_GET_KEY(ctx, hparams.rope_freq_base_train, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE));
 
     // rope_freq_scale (inverse of the kv) is optional
-    if (rope_freq_scale == 0.0f) {
-        float ropescale = 1.0f;
-        GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
-        rope_freq_scale = 1.0f/ropescale;
-    }
+    float ropescale = 1.0f;
+    GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
+    hparams.rope_freq_scale_train = 1.0f/ropescale;
 
     // sanity check for n_rot (optional)
     {
@@ -1907,10 +1916,6 @@ static void llm_load_hparams(
     };
 
     model.ftype = ml.ftype;
-
-    hparams.n_ctx           = n_ctx;
-    hparams.rope_freq_base  = rope_freq_base;
-    hparams.rope_freq_scale = rope_freq_scale;
 }
 
 // TODO: This should probably be in llama.h
@@ -2034,31 +2039,30 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
     const auto & vocab   = model.vocab;
 
     // hparams
-    LLAMA_LOG_INFO("%s: format         = %s\n",     __func__, llama_file_version_name(ml.fver));
-    LLAMA_LOG_INFO("%s: arch           = %s\n",     __func__, LLM_ARCH_NAMES.at(model.arch).c_str());
-    LLAMA_LOG_INFO("%s: vocab type     = %s\n",     __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : "BPE"); // TODO: fix
-    LLAMA_LOG_INFO("%s: n_vocab        = %u\n",     __func__, hparams.n_vocab);
-    LLAMA_LOG_INFO("%s: n_merges       = %u\n",     __func__, (int) vocab.bpe_ranks.size());
-    LLAMA_LOG_INFO("%s: n_ctx_train    = %u\n",     __func__, hparams.n_ctx_train);
-    LLAMA_LOG_INFO("%s: n_ctx          = %u\n",     __func__, hparams.n_ctx);
-    LLAMA_LOG_INFO("%s: n_embd         = %u\n",     __func__, hparams.n_embd);
-    LLAMA_LOG_INFO("%s: n_head         = %u\n",     __func__, hparams.n_head);
-    LLAMA_LOG_INFO("%s: n_head_kv      = %u\n",     __func__, hparams.n_head_kv);
-    LLAMA_LOG_INFO("%s: n_layer        = %u\n",     __func__, hparams.n_layer);
-    LLAMA_LOG_INFO("%s: n_rot          = %u\n",     __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim
-    LLAMA_LOG_INFO("%s: n_gqa          = %u\n",     __func__, hparams.n_gqa());
-    LLAMA_LOG_INFO("%s: f_norm_eps     = %.1e\n",   __func__, hparams.f_norm_eps);
-    LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n",   __func__, hparams.f_norm_rms_eps);
-    LLAMA_LOG_INFO("%s: n_ff           = %u\n",     __func__, hparams.n_ff);
-    LLAMA_LOG_INFO("%s: freq_base      = %.1f\n",   __func__, hparams.rope_freq_base);
-    LLAMA_LOG_INFO("%s: freq_scale     = %g\n",     __func__, hparams.rope_freq_scale);
-    LLAMA_LOG_INFO("%s: model type     = %s\n",     __func__, llama_model_type_name(model.type));
-    LLAMA_LOG_INFO("%s: model ftype    = %s\n",     __func__, llama_model_ftype_name(model.ftype).c_str());
-    LLAMA_LOG_INFO("%s: model params   = %.2f B\n", __func__, ml.n_elements*1e-9);
+    LLAMA_LOG_INFO("%s: format           = %s\n",     __func__, llama_file_version_name(ml.fver));
+    LLAMA_LOG_INFO("%s: arch             = %s\n",     __func__, LLM_ARCH_NAMES.at(model.arch).c_str());
+    LLAMA_LOG_INFO("%s: vocab type       = %s\n",     __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : "BPE"); // TODO: fix
+    LLAMA_LOG_INFO("%s: n_vocab          = %u\n",     __func__, hparams.n_vocab);
+    LLAMA_LOG_INFO("%s: n_merges         = %u\n",     __func__, (int) vocab.bpe_ranks.size());
+    LLAMA_LOG_INFO("%s: n_ctx_train      = %u\n",     __func__, hparams.n_ctx_train);
+    LLAMA_LOG_INFO("%s: n_embd           = %u\n",     __func__, hparams.n_embd);
+    LLAMA_LOG_INFO("%s: n_head           = %u\n",     __func__, hparams.n_head);
+    LLAMA_LOG_INFO("%s: n_head_kv        = %u\n",     __func__, hparams.n_head_kv);
+    LLAMA_LOG_INFO("%s: n_layer          = %u\n",     __func__, hparams.n_layer);
+    LLAMA_LOG_INFO("%s: n_rot            = %u\n",     __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim
+    LLAMA_LOG_INFO("%s: n_gqa            = %u\n",     __func__, hparams.n_gqa());
+    LLAMA_LOG_INFO("%s: f_norm_eps       = %.1e\n",   __func__, hparams.f_norm_eps);
+    LLAMA_LOG_INFO("%s: f_norm_rms_eps   = %.1e\n",   __func__, hparams.f_norm_rms_eps);
+    LLAMA_LOG_INFO("%s: n_ff             = %u\n",     __func__, hparams.n_ff);
+    LLAMA_LOG_INFO("%s: freq_base_train  = %.1f\n",   __func__, hparams.rope_freq_base_train);
+    LLAMA_LOG_INFO("%s: freq_scale_train = %g\n",     __func__, hparams.rope_freq_scale_train);
+    LLAMA_LOG_INFO("%s: model type       = %s\n",     __func__, llama_model_type_name(model.type));
+    LLAMA_LOG_INFO("%s: model ftype      = %s\n",     __func__, llama_model_ftype_name(model.ftype).c_str());
+    LLAMA_LOG_INFO("%s: model params     = %.2f B\n", __func__, ml.n_elements*1e-9);
     if (ml.n_bytes < GB) {
-        LLAMA_LOG_INFO("%s: model size     = %.2f MiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements);
+        LLAMA_LOG_INFO("%s: model size       = %.2f MiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements);
     } else {
-        LLAMA_LOG_INFO("%s: model size     = %.2f GiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements);
+        LLAMA_LOG_INFO("%s: model size       = %.2f GiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements);
     }
 
     // general kv
@@ -2076,13 +2080,9 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
 static void llm_load_tensors(
         llama_model_loader & ml,
         llama_model & model,
-        int n_batch,
         int n_gpu_layers,
         int main_gpu,
         const float * tensor_split,
-        const bool mul_mat_q,
-        bool low_vram,
-        ggml_type memory_type,
         bool use_mlock,
         llama_progress_callback progress_callback,
         void * progress_callback_user_data) {
@@ -2121,11 +2121,9 @@ static void llm_load_tensors(
     }
 
     (void) main_gpu;
-    (void) mul_mat_q;
-#if defined(GGML_USE_CUBLAS)
+#ifdef GGML_USE_CUBLAS
     LLAMA_LOG_INFO("%s: using " GGML_CUDA_NAME " for GPU acceleration\n", __func__);
     ggml_cuda_set_main_device(main_gpu);
-    ggml_cuda_set_mul_mat_q(mul_mat_q);
 #define LLAMA_BACKEND_OFFLOAD       GGML_BACKEND_GPU
 #define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_GPU_SPLIT
 #elif defined(GGML_USE_CLBLAST)
@@ -2160,9 +2158,9 @@ static void llm_load_tensors(
                             // norm is not performance relevant on its own but keeping it in VRAM reduces data copying
                             // on Windows however this is detrimental unless everything is on the GPU
 #ifndef _WIN32
-                            backend_norm = low_vram ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
+                            backend_norm = LLAMA_BACKEND_OFFLOAD;
 #else
-                            backend_norm = low_vram || n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
+                            backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
 #endif // _WIN32
 
                             backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
@@ -2226,9 +2224,9 @@ static void llm_load_tensors(
                             // norm is not performance relevant on its own but keeping it in VRAM reduces data copying
                             // on Windows however this is detrimental unless everything is on the GPU
 #ifndef _WIN32
-                            backend_norm = low_vram ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
+                            backend_norm = LLAMA_BACKEND_OFFLOAD;
 #else
-                            backend_norm = low_vram || n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
+                            backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
 #endif // _WIN32
 
                             backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
@@ -2296,9 +2294,9 @@ static void llm_load_tensors(
                             // norm is not performance relevant on its own but keeping it in VRAM reduces data copying
                             // on Windows however this is detrimental unless everything is on the GPU
 #ifndef _WIN32
-                            backend_norm = low_vram ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
+                            backend_norm = LLAMA_BACKEND_OFFLOAD;
 #else
-                            backend_norm = low_vram || n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
+                            backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
 #endif // _WIN32
 
                             backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
@@ -2373,9 +2371,9 @@ static void llm_load_tensors(
                             // norm is not performance relevant on its own but keeping it in VRAM reduces data copying
                             // on Windows however this is detrimental unless everything is on the GPU
 #ifndef _WIN32
-                            backend_norm = low_vram ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
+                            backend_norm = LLAMA_BACKEND_OFFLOAD;
 #else
-                            backend_norm = low_vram || n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
+                            backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
 #endif // _WIN32
 
                             backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
@@ -2447,20 +2445,12 @@ static void llm_load_tensors(
 
     // print memory requirements
     {
-        const size_t scale = memory_type == GGML_TYPE_F32 ? 2 : 1;
-
         // this is the total memory required to run the inference
         size_t mem_required =
             ctx_size +
             mmapped_size - vram_weights; // weights in VRAM not in memory
 
-        // this is the memory required by one llama_state
-        const size_t mem_required_state = scale*hparams.kv_size();
-
-        LLAMA_LOG_INFO("%s: mem required  = %7.2f MB (+ %7.2f MB per state)\n", __func__,
-                mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0);
-
-        (void) n_batch;
+        LLAMA_LOG_INFO("%s: mem required  = %7.2f MB\n", __func__, mem_required / 1024.0 / 1024.0);
 
 #if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
         const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer));
@@ -2469,36 +2459,17 @@ static void llm_load_tensors(
         if (n_gpu_layers > (int) hparams.n_layer) {
             LLAMA_LOG_INFO("%s: offloading non-repeating layers to GPU\n", __func__);
         }
-        size_t vram_kv_cache = 0;
 
 #ifdef GGML_USE_CUBLAS
         const int max_backend_supported_layers = hparams.n_layer + 3;
-        const int max_offloadable_layers = low_vram ? hparams.n_layer + 1 : hparams.n_layer + 3;
-        if (n_gpu_layers > (int) hparams.n_layer + 1) {
-            if (low_vram) {
-                LLAMA_LOG_INFO("%s: cannot offload v cache to GPU due to low VRAM option\n", __func__);
-            } else {
-                LLAMA_LOG_INFO("%s: offloading v cache to GPU\n", __func__);
-                vram_kv_cache += hparams.kv_size() / 2;
-            }
-        }
-        if (n_gpu_layers > (int) hparams.n_layer + 2) {
-            if (low_vram) {
-                LLAMA_LOG_WARN("%s: cannot offload k cache to GPU due to low VRAM option\n", __func__);
-            } else {
-                LLAMA_LOG_INFO("%s: offloading k cache to GPU\n", __func__);
-                vram_kv_cache += hparams.kv_size() / 2;
-            }
-        }
+        const int max_offloadable_layers = hparams.n_layer + 3;
 #elif defined(GGML_USE_CLBLAST)
         const int max_backend_supported_layers = hparams.n_layer + 1;
         const int max_offloadable_layers = hparams.n_layer + 1;
 #endif // GGML_USE_CUBLAS
 
-        LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n",
-                __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers);
-        LLAMA_LOG_INFO("%s: VRAM used: %zu MB\n",
-                __func__, (vram_weights + vram_kv_cache + MB - 1) / MB); // round up
+        LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers);
+        LLAMA_LOG_INFO("%s: VRAM used: %.2f MB\n", __func__, vram_weights / 1024.0 / 1024.0);
 #else
         (void) n_gpu_layers;
 #endif // defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
@@ -2511,7 +2482,7 @@ static void llm_load_tensors(
     }
 
     (void) tensor_split;
-#if defined(GGML_USE_CUBLAS)
+#ifdef GGML_USE_CUBLAS
     {
         ggml_cuda_set_tensor_split(tensor_split);
     }
@@ -2533,29 +2504,24 @@ static void llm_load_tensors(
 static bool llama_model_load(
         const std::string & fname,
         llama_model & model,
-        int n_ctx,
-        int n_batch,
         int n_gpu_layers,
         int main_gpu,
         const float * tensor_split,
-        const bool mul_mat_q,
-        float rope_freq_base,
-        float rope_freq_scale,
-        bool low_vram,
-        ggml_type memory_type,
         bool use_mmap,
         bool use_mlock,
         bool vocab_only,
         llama_progress_callback progress_callback,
         void *progress_callback_user_data) {
     try {
-        std::unique_ptr<llama_model_loader> ml(new llama_model_loader(fname, use_mmap));
+        llama_model_loader ml(fname, use_mmap);
 
-        llm_load_arch   (*ml, model);
-        llm_load_hparams(*ml, model, n_ctx, rope_freq_base, rope_freq_scale);
-        llm_load_vocab  (*ml, model);
+        model.hparams.vocab_only = vocab_only;
 
-        llm_load_print_meta(*ml, model);
+        llm_load_arch   (ml, model);
+        llm_load_hparams(ml, model);
+        llm_load_vocab  (ml, model);
+
+        llm_load_print_meta(ml, model);
 
         if (model.hparams.n_vocab != model.vocab.id_to_token.size()) {
             throw std::runtime_error("vocab size mismatch");
@@ -2567,8 +2533,8 @@ static bool llama_model_load(
         }
 
         llm_load_tensors(
-                *ml, model, n_batch, n_gpu_layers,
-                main_gpu, tensor_split, mul_mat_q, low_vram, memory_type,
+                ml, model, n_gpu_layers,
+                main_gpu, tensor_split,
                 use_mlock, progress_callback, progress_callback_user_data);
     } catch (const std::exception & err) {
         LLAMA_LOG_ERROR("error loading model: %s\n", err.what());
@@ -2583,6 +2549,7 @@ static struct ggml_cgraph * llm_build_llama(
      const llama_batch & batch) {
     const auto & model   = lctx.model;
     const auto & hparams = model.hparams;
+    const auto & cparams = lctx.cparams;
 
     const auto & kv_self = lctx.kv_self;
 
@@ -2590,7 +2557,7 @@ static struct ggml_cgraph * llm_build_llama(
 
     const int64_t n_embd      = hparams.n_embd;
     const int64_t n_layer     = hparams.n_layer;
-    const int64_t n_ctx       = hparams.n_ctx;
+    const int64_t n_ctx       = cparams.n_ctx;
     const int64_t n_head      = hparams.n_head;
     const int64_t n_head_kv   = hparams.n_head_kv;
     const int64_t n_embd_head = hparams.n_embd_head();
@@ -2598,8 +2565,8 @@ static struct ggml_cgraph * llm_build_llama(
 
     GGML_ASSERT(n_embd_head == hparams.n_rot);
 
-    const float freq_base    = hparams.rope_freq_base;
-    const float freq_scale   = hparams.rope_freq_scale;
+    const float freq_base    = cparams.rope_freq_base;
+    const float freq_scale   = cparams.rope_freq_scale;
     const float norm_rms_eps = hparams.f_norm_rms_eps;
 
     const int n_gpu_layers = model.n_gpu_layers;
@@ -2657,9 +2624,6 @@ static struct ggml_cgraph * llm_build_llama(
 
     // offload functions set the tensor output backend to GPU
     // tensors are GPU-accelerated if any input or the output has been offloaded
-    //
-    // with the low VRAM option VRAM scratch is disabled in llama_load_model_internal
-    // in that case ggml_cuda_assign_buffers has no effect
     offload_func_t offload_func_nr = llama_nop; // nr = non-repeating
     offload_func_t offload_func_kq = llama_nop;
     offload_func_t offload_func_v  = llama_nop;
@@ -2975,6 +2939,7 @@ static struct ggml_cgraph * llm_build_baichaun(
      const llama_batch & batch) {
     const auto & model   = lctx.model;
     const auto & hparams = model.hparams;
+    const auto & cparams = lctx.cparams;
 
     const auto & kv_self = lctx.kv_self;
 
@@ -2982,7 +2947,7 @@ static struct ggml_cgraph * llm_build_baichaun(
 
     const int64_t n_embd      = hparams.n_embd;
     const int64_t n_layer     = hparams.n_layer;
-    const int64_t n_ctx       = hparams.n_ctx;
+    const int64_t n_ctx       = cparams.n_ctx;
     const int64_t n_head      = hparams.n_head;
     const int64_t n_head_kv   = hparams.n_head_kv;
     const int64_t n_embd_head = hparams.n_embd_head();
@@ -2990,8 +2955,8 @@ static struct ggml_cgraph * llm_build_baichaun(
 
     GGML_ASSERT(n_embd_head == hparams.n_rot);
 
-    const float freq_base    = hparams.rope_freq_base;
-    const float freq_scale   = hparams.rope_freq_scale;
+    const float freq_base    = cparams.rope_freq_base;
+    const float freq_scale   = cparams.rope_freq_scale;
     const float norm_rms_eps = hparams.f_norm_rms_eps;
 
     const int n_gpu_layers = model.n_gpu_layers;
@@ -3047,9 +3012,6 @@ static struct ggml_cgraph * llm_build_baichaun(
 
     // offload functions set the tensor output backend to GPU
     // tensors are GPU-accelerated if any input or the output has been offloaded
-    //
-    // with the low VRAM option VRAM scratch is disabled in llama_load_model_internal
-    // in that case ggml_cuda_assign_buffers has no effect
     offload_func_t offload_func_nr = llama_nop; // nr = non-repeating
     offload_func_t offload_func_kq = llama_nop;
     offload_func_t offload_func_v  = llama_nop;
@@ -3382,6 +3344,7 @@ static struct ggml_cgraph * llm_build_falcon(
      const llama_batch & batch) {
     const auto & model   = lctx.model;
     const auto & hparams = model.hparams;
+    const auto & cparams = lctx.cparams;
 
     const auto & kv_self = lctx.kv_self;
 
@@ -3389,7 +3352,7 @@ static struct ggml_cgraph * llm_build_falcon(
 
     const int64_t n_embd      = hparams.n_embd;
     const int64_t n_layer     = hparams.n_layer;
-    const int64_t n_ctx       = hparams.n_ctx;
+    const int64_t n_ctx       = cparams.n_ctx;
     const int64_t n_head      = hparams.n_head;
     const int64_t n_head_kv   = hparams.n_head_kv;
     const int64_t n_embd_head = hparams.n_embd_head();
@@ -3397,8 +3360,8 @@ static struct ggml_cgraph * llm_build_falcon(
 
     GGML_ASSERT(n_embd_head == hparams.n_rot);
 
-    const float freq_base  = hparams.rope_freq_base;
-    const float freq_scale = hparams.rope_freq_scale;
+    const float freq_base  = cparams.rope_freq_base;
+    const float freq_scale = cparams.rope_freq_scale;
     const float norm_eps   = hparams.f_norm_eps;
 
     const int n_gpu_layers = model.n_gpu_layers;
@@ -3457,9 +3420,6 @@ static struct ggml_cgraph * llm_build_falcon(
 
     // offload functions set the tensor output backend to GPU
     // tensors are GPU-accelerated if any input or the output has been offloaded
-    //
-    // with the low VRAM option VRAM scratch is disabled in llama_load_model_internal
-    // in that case ggml_cuda_assign_buffers has no effect
     offload_func_t offload_func_nr = llama_nop; // nr = non-repeating
     offload_func_t offload_func_kq = llama_nop;
     offload_func_t offload_func_v  = llama_nop;
@@ -3753,6 +3713,7 @@ static struct ggml_cgraph * llm_build_starcoder(
      const llama_batch & batch) {
     const auto & model   = lctx.model;
     const auto & hparams = model.hparams;
+    const auto & cparams = lctx.cparams;
 
     const auto & kv_self = lctx.kv_self;
 
@@ -3760,7 +3721,7 @@ static struct ggml_cgraph * llm_build_starcoder(
 
     const int64_t n_embd      = hparams.n_embd;
     const int64_t n_layer     = hparams.n_layer;
-    const int64_t n_ctx       = hparams.n_ctx;
+    const int64_t n_ctx       = cparams.n_ctx;
     const int64_t n_head      = hparams.n_head;
     const int64_t n_head_kv   = hparams.n_head_kv;
     const int64_t n_embd_head = hparams.n_embd_head();
@@ -4037,8 +3998,7 @@ static struct ggml_cgraph * llama_build_graph(
 //
 static int llama_decode_internal(
          llama_context & lctx,
-           llama_batch   batch,
-                   int   n_threads) {
+           llama_batch   batch) {
     const uint32_t n_tokens = batch.n_tokens;
 
     if (n_tokens == 0) {
@@ -4046,6 +4006,15 @@ static int llama_decode_internal(
         return -1;
     }
 
+    const auto & model   = lctx.model;
+    const auto & hparams = model.hparams;
+    const auto & cparams = lctx.cparams;
+
+    const auto n_batch = cparams.n_batch;
+
+    GGML_ASSERT(n_tokens <= n_batch);
+
+    int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
     GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
 
     const int64_t t_start_us = ggml_time_us();
@@ -4058,9 +4027,6 @@ static int llama_decode_internal(
 
     GGML_ASSERT(n_threads > 0);
 
-    const auto & model   = lctx.model;
-    const auto & hparams = model.hparams;
-
     auto & kv_self = lctx.kv_self;
 
     GGML_ASSERT(!!kv_self.ctx);
@@ -4103,7 +4069,7 @@ static int llama_decode_internal(
     // after enough generations, the benefit from this heuristic disappears
     // if we start defragmenting the cache, the benefit from this will be more important
     //kv_self.n = std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32));   // TODO: this might be better for CUDA?
-    kv_self.n = std::min((int32_t) hparams.n_ctx, std::max(32, llama_kv_cache_cell_max(kv_self)));
+    kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, llama_kv_cache_cell_max(kv_self)));
 
     //printf("kv_self.n = %d\n", kv_self.n);
 
@@ -4128,6 +4094,8 @@ static int llama_decode_internal(
             ggml_cuda_assign_scratch_offset(node, (char*)node->data - (char *) lctx.buf_alloc.data);
         }
     }
+
+    ggml_cuda_set_mul_mat_q(cparams.mul_mat_q);
 #endif
 
     // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
@@ -5416,7 +5384,7 @@ void llama_sample_classifier_free_guidance(
 
     GGML_ASSERT(ctx);
 
-    auto n_vocab = llama_n_vocab(ctx);
+    auto n_vocab = llama_n_vocab(llama_get_model(ctx));
 
     GGML_ASSERT(n_vocab == (int)candidates->size);
     GGML_ASSERT(!candidates->sorted);
@@ -5445,7 +5413,7 @@ void llama_sample_classifier_free_guidance(
 llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu) {
     GGML_ASSERT(ctx);
 
-    auto N = float(llama_n_vocab(ctx));
+    auto N = float(llama_n_vocab(llama_get_model(ctx)));
     int64_t t_start_sample_us;
     t_start_sample_us = ggml_time_us();
 
@@ -5632,7 +5600,7 @@ struct llama_logit_info {
     };
     llama_logit_info(llama_context * ctx)
       : logits(llama_get_logits(ctx))
-      , n_vocab(llama_n_vocab(ctx))
+      , n_vocab(llama_n_vocab(llama_get_model(ctx)))
       , max_l(*std::max_element(logits, logits + n_vocab))
       , normalizer(1.0f / std::accumulate(logits, logits + n_vocab, 0.0f, sum_exp{max_l}))
       { }
@@ -5670,7 +5638,6 @@ struct llama_beam_search_data {
     size_t n_beams;
     int n_past;
     int n_predict;
-    int n_threads;
     std::vector<llama_beam> beams;
     std::vector<llama_beam> next_beams;
 
@@ -5680,12 +5647,11 @@ struct llama_beam_search_data {
     // Used to communicate to/from callback on beams state.
     std::vector<llama_beam_view> beam_views;
 
-    llama_beam_search_data(llama_context * ctx, size_t n_beams, int n_past, int n_predict, int n_threads)
+    llama_beam_search_data(llama_context * ctx, size_t n_beams, int n_past, int n_predict)
       : ctx(ctx)
       , n_beams(n_beams)
       , n_past(n_past)
       , n_predict(n_predict)
-      , n_threads(n_threads)
       , beam_views(n_beams) {
         beams.reserve(n_beams);
         next_beams.reserve(n_beams);
@@ -5722,7 +5688,7 @@ struct llama_beam_search_data {
         } else {
             // beam is not at end-of-sentence, so branch with next top_k tokens.
             if (!beam.tokens.empty()) {
-                llama_decode(ctx, llama_batch_get_one(beam.tokens.data(), beam.tokens.size(), n_past, 0), n_threads);
+                llama_decode(ctx, llama_batch_get_one(beam.tokens.data(), beam.tokens.size(), n_past, 0));
             }
             llama_logit_info logit_info(ctx);
             std::vector<llama_token_data> next_tokens = logit_info.top_k(n_beams);
@@ -5796,7 +5762,7 @@ struct llama_beam_search_data {
             callback(callback_data, get_beams_state(false));  // Sets common_prefix_length
             update_beams_from_beam_views();   // Update values (p,eob) that callback may have changed.
             if (common_prefix_length) {
-                llama_decode(ctx, llama_batch_get_one(beams[0].tokens.data(), common_prefix_length, n_past, 0), n_threads);
+                llama_decode(ctx, llama_batch_get_one(beams[0].tokens.data(), common_prefix_length, n_past, 0));
                 n_past += common_prefix_length;
             }
             // Zero-out next_beam probabilities to place them last in following min-heap.
@@ -5837,11 +5803,11 @@ struct llama_beam_search_data {
 
 void llama_beam_search(llama_context * ctx,
                        llama_beam_search_callback_fn_t callback, void * callback_data,
-                       size_t n_beams, int n_past, int n_predict, int n_threads) {
+                       size_t n_beams, int n_past, int n_predict) {
     assert(ctx);
     const int64_t t_start_sample_us = ggml_time_us();
 
-    llama_beam_search_data beam_search_data(ctx, n_beams, n_past, n_predict, n_threads);
+    llama_beam_search_data beam_search_data(ctx, n_beams, n_past, n_predict);
 
     beam_search_data.loop(callback, callback_data);
 
@@ -6061,11 +6027,11 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
         nthread = std::thread::hardware_concurrency();
     }
 
-    std::unique_ptr<llama_model_loader> ml(new llama_model_loader(fname_inp, /*use_mmap*/ false));
+    llama_model_loader ml(fname_inp, /*use_mmap*/ false);
 
     llama_model model;
-    llm_load_arch(*ml, model);
-    llm_load_hparams(*ml, model, 0, 0, 0);
+    llm_load_arch(ml, model);
+    llm_load_hparams(ml, model);
 
     if (params->only_copy) {
         ftype = model.ftype;
@@ -6075,7 +6041,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
     struct gguf_context * ctx_out = gguf_init_empty();
 
     // copy the KV pairs from the input file
-    gguf_set_kv     (ctx_out, ml->ctx_gguf);
+    gguf_set_kv     (ctx_out, ml.ctx_gguf);
     gguf_set_val_u32(ctx_out, "general.quantization_version", GGML_QNT_VERSION);
     gguf_set_val_u32(ctx_out, "general.file_type", ftype);
 
@@ -6083,8 +6049,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
     int n_attention_wv    = 0;
     int n_feed_forward_w2 = 0;
 
-    for (int i = 0; i < ml->n_tensors; ++i) {
-        struct ggml_tensor * meta = ml->get_tensor_meta(i);
+    for (int i = 0; i < ml.n_tensors; ++i) {
+        struct ggml_tensor * meta = ml.get_tensor_meta(i);
 
         const std::string name = ggml_get_name(meta);
 
@@ -6120,8 +6086,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
     std::vector<no_init<float>> f32_conv_buf;
 
     // populate the original tensors so we get an initial meta data
-    for (int i = 0; i < ml->n_tensors; ++i) {
-        struct ggml_tensor * meta = ml->get_tensor_meta(i);
+    for (int i = 0; i < ml.n_tensors; ++i) {
+        struct ggml_tensor * meta = ml.get_tensor_meta(i);
         gguf_add_tensor(ctx_out, meta);
     }
 
@@ -6134,8 +6100,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
     // placeholder for the meta data
     ::zeros(fout, meta_size);
 
-    for (int i = 0; i < ml->n_tensors; ++i) {
-        struct ggml_tensor * tensor = ml->get_tensor_meta(i);
+    for (int i = 0; i < ml.n_tensors; ++i) {
+        struct ggml_tensor * tensor = ml.get_tensor_meta(i);
 
         const std::string name = ggml_get_name(tensor);
 
@@ -6143,10 +6109,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
             read_data.resize(ggml_nbytes(tensor));
         }
         tensor->data = read_data.data();
-        ml->load_data_for(tensor);
+        ml.load_data_for(tensor);
 
         LLAMA_LOG_INFO("[%4d/%4d] %36s - [%s], type = %6s, ",
-               ++idx, ml->n_tensors,
+               ++idx, ml.n_tensors,
                ggml_get_name(tensor),
                llama_format_tensor_shape(tensor).c_str(),
                ggml_type_name(tensor->type));
@@ -6296,7 +6262,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
     }
 }
 
-// TODO: after the GGUF PR, this likely won't work and needs to be updated
 static int llama_apply_lora_from_file_internal(
     const struct llama_model & model, const char * path_lora, float scale, const char * path_base_model, int n_threads
 ) {
@@ -6575,27 +6540,16 @@ static int llama_apply_lora_from_file_internal(
 //
 // interface implementation
 //
-
-struct llama_context_params llama_context_default_params() {
-    struct llama_context_params result = {
-        /*.seed                        =*/ LLAMA_DEFAULT_SEED,
-        /*.n_ctx                       =*/ 512,
-        /*.n_batch                     =*/ 512,
+struct llama_model_params llama_model_default_params() {
+    struct llama_model_params result = {
         /*.n_gpu_layers                =*/ 0,
         /*.main_gpu                    =*/ 0,
         /*.tensor_split                =*/ nullptr,
-        /*.rope_freq_base              =*/ 0.0f,
-        /*.rope_freq_scale             =*/ 0.0f,
         /*.progress_callback           =*/ nullptr,
         /*.progress_callback_user_data =*/ nullptr,
-        /*.low_vram                    =*/ false,
-        /*.mul_mat_q                   =*/ true,
-        /*.f16_kv                      =*/ true,
-        /*.logits_all                  =*/ false,
         /*.vocab_only                  =*/ false,
         /*.use_mmap                    =*/ true,
         /*.use_mlock                   =*/ false,
-        /*.embedding                   =*/ false,
     };
 
 #ifdef GGML_USE_METAL
@@ -6605,6 +6559,24 @@ struct llama_context_params llama_context_default_params() {
     return result;
 }
 
+struct llama_context_params llama_context_default_params() {
+    struct llama_context_params result = {
+        /*.seed                        =*/ LLAMA_DEFAULT_SEED,
+        /*.n_ctx                       =*/ 512,
+        /*.n_batch                     =*/ 512,
+        /*.n_threads                   =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
+        /*.n_threads_batch             =*/ GGML_DEFAULT_N_THREADS,
+        /*.rope_freq_base              =*/ 0.0f,
+        /*.rope_freq_scale             =*/ 0.0f,
+        /*.mul_mat_q                   =*/ true,
+        /*.f16_kv                      =*/ true,
+        /*.logits_all                  =*/ false,
+        /*.embedding                   =*/ false,
+    };
+
+    return result;
+}
+
 struct llama_model_quantize_params llama_model_quantize_default_params() {
     struct llama_model_quantize_params result = {
         /*.nthread                     =*/ 0,
@@ -6660,13 +6632,11 @@ int64_t llama_time_us(void) {
 
 struct llama_model * llama_load_model_from_file(
                              const char * path_model,
-            struct llama_context_params   params) {
+              struct llama_model_params   params) {
     ggml_time_init();
 
     llama_model * model = new llama_model;
 
-    ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
-
     unsigned cur_percentage = 0;
     if (params.progress_callback == NULL) {
         params.progress_callback_user_data = &cur_percentage;
@@ -6683,9 +6653,9 @@ struct llama_model * llama_load_model_from_file(
         };
     }
 
-    if (!llama_model_load(path_model, *model, params.n_ctx, params.n_batch, params.n_gpu_layers,
-                params.main_gpu, params.tensor_split, params.mul_mat_q, params.rope_freq_base, params.rope_freq_scale,
-                params.low_vram, memory_type, params.use_mmap, params.use_mlock, params.vocab_only,
+    if (!llama_model_load(path_model, *model, params.n_gpu_layers,
+                params.main_gpu, params.tensor_split,
+                params.use_mmap, params.use_mlock, params.vocab_only,
                 params.progress_callback, params.progress_callback_user_data)) {
         LLAMA_LOG_ERROR("%s: failed to load model\n", __func__);
         delete model;
@@ -6709,18 +6679,33 @@ struct llama_context * llama_new_context_with_model(
 
     llama_context * ctx = new llama_context(*model);
 
+    const auto & hparams = model->hparams;
+    auto       & cparams = ctx->cparams;
+
+    cparams.n_batch         = params.n_batch;
+    cparams.n_ctx           = params.n_ctx == 0           ? hparams.n_ctx_train           : params.n_ctx;
+    cparams.rope_freq_base  = params.rope_freq_base == 0  ? hparams.rope_freq_base_train  : params.rope_freq_base;
+    cparams.rope_freq_scale = params.rope_freq_scale == 0 ? hparams.rope_freq_scale_train : params.rope_freq_scale;
+    cparams.n_threads       = params.n_threads;
+    cparams.n_threads_batch = params.n_threads_batch;
+    cparams.mul_mat_q       = params.mul_mat_q;
+
     if (params.seed == LLAMA_DEFAULT_SEED) {
         params.seed = time(NULL);
     }
 
+    LLAMA_LOG_INFO("%s: n_ctx      = %u\n",     __func__, cparams.n_ctx);
+    LLAMA_LOG_INFO("%s: freq_base  = %.1f\n",   __func__, cparams.rope_freq_base);
+    LLAMA_LOG_INFO("%s: freq_scale = %g\n",     __func__, cparams.rope_freq_scale);
+
     ctx->rng = std::mt19937(params.seed);
     ctx->logits_all = params.logits_all;
 
     ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
 
     // reserve memory for context buffers
-    if (!params.vocab_only) {
-        if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, memory_type, params.n_gpu_layers)) {
+    if (!hparams.vocab_only) {
+        if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, memory_type, cparams.n_ctx, model->n_gpu_layers)) {
             LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
             llama_free(ctx);
             return nullptr;
@@ -6731,11 +6716,9 @@ struct llama_context * llama_new_context_with_model(
             LLAMA_LOG_INFO("%s: kv self size  = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
         }
 
-        const auto & hparams = ctx->model.hparams;
-
         // resized during inference
         if (params.logits_all) {
-            ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab);
+            ctx->logits.reserve(cparams.n_ctx*hparams.n_vocab);
         } else {
             ctx->logits.reserve(hparams.n_vocab);
         }
@@ -6753,12 +6736,13 @@ struct llama_context * llama_new_context_with_model(
             ctx->alloc = ggml_allocr_new_measure(tensor_alignment);
 
             // build worst-case graph
-            const uint32_t n_tokens = std::min((int) hparams.n_ctx, params.n_batch);
+            int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_batch);
+            int n_past = cparams.n_ctx - n_tokens;
             llama_token token = llama_token_bos(ctx); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
-            ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, hparams.n_ctx - n_tokens, 0));
+            ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, n_past, 0));
 
 #ifdef GGML_USE_METAL
-            if (params.n_gpu_layers > 0) {
+            if (model->n_gpu_layers > 0) {
                 ctx->ctx_metal = ggml_metal_init(1);
                 if (!ctx->ctx_metal) {
                     LLAMA_LOG_ERROR("%s: ggml_metal_init() failed\n", __func__);
@@ -6773,7 +6757,7 @@ struct llama_context * llama_new_context_with_model(
             // measure memory requirements for the graph
             size_t alloc_size = ggml_allocr_alloc_graph(ctx->alloc, gf) + tensor_alignment;
 
-            LLAMA_LOG_INFO("%s: compute buffer total size = %7.2f MB\n", __func__, (ctx->buf_compute.size + alloc_size) / 1024.0 / 1024.0);
+            LLAMA_LOG_INFO("%s: compute buffer total size = %.2f MB\n", __func__, (ctx->buf_compute.size + alloc_size) / 1024.0 / 1024.0);
 
             // recreate allocator with exact memory requirements
             ggml_allocr_free(ctx->alloc);
@@ -6786,24 +6770,42 @@ struct llama_context * llama_new_context_with_model(
             }
 #endif
 #ifdef GGML_USE_CUBLAS
-            if (params.low_vram) {
-                LLAMA_LOG_INFO("%s: not allocating a VRAM scratch buffer due to low VRAM option\n", __func__);
-                ggml_cuda_set_scratch_size(0); // disable scratch
-            } else {
-                ggml_cuda_set_scratch_size(alloc_size);
-                LLAMA_LOG_INFO("%s: VRAM scratch buffer: %.2f MB\n", __func__, alloc_size / 1024.0 / 1024.0);
+            ggml_cuda_set_scratch_size(alloc_size);
+            LLAMA_LOG_INFO("%s: VRAM scratch buffer: %.2f MB\n", __func__, alloc_size / 1024.0 / 1024.0);
+
+            // calculate total VRAM usage
+            auto add_tensor = [](const ggml_tensor * t, size_t & size) {
+                if (t->backend == GGML_BACKEND_GPU || t->backend == GGML_BACKEND_GPU_SPLIT) {
+                    size += ggml_nbytes(t);
+                }
+            };
+            size_t model_vram_size = 0;
+            for (const auto & kv : model->tensors_by_name) {
+                add_tensor(kv.second, model_vram_size);
             }
+
+            size_t kv_vram_size = 0;
+            add_tensor(ctx->kv_self.k, kv_vram_size);
+            add_tensor(ctx->kv_self.v, kv_vram_size);
+
+            size_t ctx_vram_size = alloc_size + kv_vram_size;
+            size_t total_vram_size = model_vram_size + ctx_vram_size;
+
+            LLAMA_LOG_INFO("%s: total VRAM used: %.2f MB (model: %.2f MB, context: %.2f MB)\n", __func__,
+                    total_vram_size / 1024.0 / 1024.0,
+                    model_vram_size / 1024.0 / 1024.0,
+                    ctx_vram_size / 1024.0 / 1024.0);
 #endif
         }
 
 #ifdef GGML_USE_METAL
-        if (params.n_gpu_layers > 0) {
+        if (model->n_gpu_layers > 0) {
             // this allocates all Metal resources and memory buffers
 
             void * data_ptr  = NULL;
             size_t data_size = 0;
 
-            if (params.use_mmap) {
+            if (ctx->model.mapping) {
                 data_ptr  = ctx->model.mapping->addr;
                 data_size = ctx->model.mapping->size;
             } else {
@@ -6822,11 +6824,8 @@ struct llama_context * llama_new_context_with_model(
                 return NULL;                                             \
             }
 
-            LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "data", data_ptr, data_size, max_size));
-
-            LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "eval", ctx->buf_compute.data, ctx->buf_compute.size, 0));
-            LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "kv",   ctx->kv_self.buf.data, ctx->kv_self.buf.size, 0));
-
+            LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "data",  data_ptr, data_size, max_size));
+            LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "kv",    ctx->kv_self.buf.data, ctx->kv_self.buf.size, 0));
             LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "alloc", ctx->buf_alloc.data, ctx->buf_alloc.size, 0));
 #undef LLAMA_METAL_CHECK_BUF
         }
@@ -6850,63 +6849,37 @@ struct llama_context * llama_new_context_with_model(
     return ctx;
 }
 
-static struct llama_context * llama_init_from_file(
-                             const char * path_model,
-            struct llama_context_params   params) {
-    struct llama_model * model = llama_load_model_from_file(path_model, params);
-    if (!model) {
-        return nullptr;
-    }
-
-    struct llama_context * ctx = llama_new_context_with_model(model, params);
-    ctx->model_owner = true;
-
-    return ctx;
-}
-
 void llama_free(struct llama_context * ctx) {
     delete ctx;
 }
 
-int llama_n_vocab(const struct llama_context * ctx) {
-    return llama_model_n_vocab(&ctx->model);
+const llama_model * llama_get_model(const struct llama_context * ctx) {
+    return &ctx->model;
 }
 
 int llama_n_ctx(const struct llama_context * ctx) {
-    return llama_model_n_ctx(&ctx->model);
-}
-
-int llama_n_ctx_train(const struct llama_context * ctx) {
-    return llama_model_n_ctx_train(&ctx->model);
+    return ctx->cparams.n_ctx;
 }
 
-int llama_n_embd(const struct llama_context * ctx) {
-    return llama_model_n_embd(&ctx->model);
+enum llama_vocab_type llama_vocab_type(const struct llama_model * model) {
+    return model->vocab.type;
 }
 
-enum llama_vocab_type llama_vocab_type(const struct llama_context * ctx) {
-    return ctx->model.vocab.type;
-}
-
-int llama_model_n_vocab(const struct llama_model * model) {
+int llama_n_vocab(const struct llama_model * model) {
     return model->vocab.id_to_token.size();
 }
 
-int llama_model_n_ctx(const struct llama_model * model) {
-    return model->hparams.n_ctx;
-}
-
-int llama_model_n_ctx_train(const struct llama_model * model) {
+int llama_n_ctx_train(const struct llama_model * model) {
     return model->hparams.n_ctx_train;
 }
 
-int llama_model_n_embd(const struct llama_model * model) {
+int llama_n_embd(const struct llama_model * model) {
     return model->hparams.n_embd;
 }
 
 int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size) {
     return snprintf(buf, buf_size, "%s %s %s",
-            model->name.c_str(),
+            llama_model_arch_name(model->arch).c_str(),
             llama_model_type_name(model->type),
             llama_model_ftype_name(model->ftype).c_str());
 }
@@ -7131,9 +7104,11 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
     {
         const auto & kv_self = ctx->kv_self;
         const auto & hparams = ctx->model.hparams;
+        const auto & cparams = ctx->cparams;
+
         const int    n_layer = hparams.n_layer;
         const int    n_embd  = hparams.n_embd_gqa();
-        const int    n_ctx   = hparams.n_ctx;
+        const int    n_ctx   = cparams.n_ctx;
 
         const size_t kv_size = kv_self.buf.size;
         const int    kv_ntok = kv_self.head;
@@ -7239,9 +7214,11 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
     {
         const auto & kv_self = ctx->kv_self;
         const auto & hparams = ctx->model.hparams;
+        const auto & cparams = ctx->cparams;
+
         const int    n_layer = hparams.n_layer;
         const int    n_embd  = hparams.n_embd_gqa();
-        const int    n_ctx   = hparams.n_ctx;
+        const int    n_ctx   = cparams.n_ctx;
 
         size_t kv_size;
         int kv_ntok;
@@ -7378,11 +7355,10 @@ int llama_eval(
         struct llama_context * ctx,
                  llama_token * tokens,
                      int32_t   n_tokens,
-                         int   n_past,
-                         int   n_threads) {
+                         int   n_past) {
     llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1);
 
-    const int ret = llama_decode_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0), n_threads);
+    const int ret = llama_decode_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0));
     if (ret < 0) {
         LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
     }
@@ -7394,13 +7370,12 @@ int llama_eval_embd(
             struct llama_context * ctx,
                            float * embd,
                          int32_t   n_tokens,
-                             int   n_past,
-                             int   n_threads) {
+                             int   n_past) {
     llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1);
 
     llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, n_past, 1, 0, };
 
-    const int ret = llama_decode_internal(*ctx, batch, n_threads);
+    const int ret = llama_decode_internal(*ctx, batch);
     if (ret < 0) {
         LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
     }
@@ -7408,6 +7383,11 @@ int llama_eval_embd(
     return ret;
 }
 
+void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch) {
+    ctx->cparams.n_threads       = n_threads;
+    ctx->cparams.n_threads_batch = n_threads_batch;
+}
+
 struct llama_batch llama_batch_get_one(
              llama_token * tokens,
                  int32_t   n_tokens,
@@ -7452,9 +7432,8 @@ void llama_batch_free(struct llama_batch batch) {
 
 int llama_decode(
         struct llama_context * ctx,
-          struct llama_batch   batch,
-                         int   n_threads) {
-    const int ret = llama_decode_internal(*ctx, batch, n_threads);
+          struct llama_batch   batch) {
+    const int ret = llama_decode_internal(*ctx, batch);
     if (ret < 0) {
         LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
     }
@@ -7499,16 +7478,6 @@ llama_token llama_token_nl(const struct llama_context * ctx) {
 }
 
 int llama_tokenize(
-        struct llama_context * ctx,
-                  const char * text,
-                         int   text_len,
-                 llama_token * tokens,
-                         int   n_max_tokens,
-                        bool   add_bos) {
-    return llama_tokenize_with_model(&ctx->model, text, text_len, tokens, n_max_tokens, add_bos);
-}
-
-int llama_tokenize_with_model(
     const struct llama_model * model,
                   const char * text,
                          int   text_len,
@@ -7529,13 +7498,9 @@ int llama_tokenize_with_model(
     return res.size();
 }
 
-int llama_token_to_piece(const struct llama_context * ctx, llama_token token, char * buf, int length) {
-    return llama_token_to_piece_with_model(&ctx->model, token, buf, length);
-}
-
 // does not write null-terminator to buf
-int llama_token_to_piece_with_model(const struct llama_model * model, llama_token token, char * buf, int length) {
-    if (0 <= token && token < llama_model_n_vocab(model)) {
+int llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int length) {
+    if (0 <= token && token < llama_n_vocab(model)) {
         if (llama_is_normal_token(model->vocab, token)) {
             std::string result = model->vocab.id_to_token[token].text;
             if (llama_vocab_get_type(model->vocab) == LLAMA_VOCAB_TYPE_SPM) {
diff --git a/llama.h b/llama.h
index 046284d744895b1bbb1406cf07672e757b804ac4..96ff1f09c76dbd88838fb8398cbe0f13f269d463 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -149,32 +149,37 @@ extern "C" {
         llama_seq_id all_seq_id; // used if seq_id == NULL
     } llama_batch;
 
-    struct llama_context_params {
-        uint32_t seed;         // RNG seed, -1 for random
-        int32_t  n_ctx;        // text context
-        int32_t  n_batch;      // prompt processing batch size
-        int32_t  n_gpu_layers; // number of layers to store in VRAM
-        int32_t  main_gpu;     // the GPU that is used for scratch and small tensors
-
+    struct llama_model_params {
+        int32_t n_gpu_layers; // number of layers to store in VRAM
+        int32_t main_gpu;     // the GPU that is used for scratch and small tensors
         const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)
 
-        // ref: https://github.com/ggerganov/llama.cpp/pull/2054
-        float    rope_freq_base;  // RoPE base frequency
-        float    rope_freq_scale; // RoPE frequency scaling factor
-
         // called with a progress value between 0 and 1, pass NULL to disable
         llama_progress_callback progress_callback;
         // context pointer passed to the progress callback
         void * progress_callback_user_data;
 
         // Keep the booleans together to avoid misalignment during copy-by-value.
-        bool low_vram;   // if true, reduce VRAM usage at the cost of performance
-        bool mul_mat_q;  // if true, use experimental mul_mat_q kernels
-        bool f16_kv;     // use fp16 for KV cache
-        bool logits_all; // the llama_eval() call computes all logits, not just the last one
         bool vocab_only; // only load the vocabulary, no weights
         bool use_mmap;   // use mmap if possible
         bool use_mlock;  // force system to keep model in RAM
+    };
+
+    struct llama_context_params {
+        uint32_t seed;            // RNG seed, -1 for random
+        uint32_t n_ctx;           // text context
+        uint32_t n_batch;         // prompt processing batch size
+        uint32_t n_threads;       // number of threads to use for generation
+        uint32_t n_threads_batch; // number of threads to use for batch processing
+
+        // ref: https://github.com/ggerganov/llama.cpp/pull/2054
+        float rope_freq_base;  // RoPE base frequency
+        float rope_freq_scale; // RoPE frequency scaling factor
+
+        // Keep the booleans together to avoid misalignment during copy-by-value.
+        bool mul_mat_q;  // if true, use experimental mul_mat_q kernels
+        bool f16_kv;     // use fp16 for KV cache
+        bool logits_all; // the llama_eval() call computes all logits, not just the last one
         bool embedding;  // embedding mode only
     };
 
@@ -236,6 +241,7 @@ extern "C" {
     };
 
     // Helpers for getting default parameters
+    LLAMA_API struct llama_model_params llama_model_default_params(void);
     LLAMA_API struct llama_context_params llama_context_default_params(void);
     LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void);
 
@@ -249,7 +255,7 @@ extern "C" {
 
     LLAMA_API struct llama_model * llama_load_model_from_file(
                              const char * path_model,
-            struct llama_context_params   params);
+            struct llama_model_params     params);
 
     LLAMA_API void llama_free_model(struct llama_model * model);
 
@@ -266,17 +272,15 @@ extern "C" {
     LLAMA_API bool llama_mmap_supported (void);
     LLAMA_API bool llama_mlock_supported(void);
 
-    LLAMA_API int llama_n_vocab    (const struct llama_context * ctx);
+    LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
+
     LLAMA_API int llama_n_ctx      (const struct llama_context * ctx);
-    LLAMA_API int llama_n_ctx_train(const struct llama_context * ctx);
-    LLAMA_API int llama_n_embd     (const struct llama_context * ctx);
 
-    LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_context * ctx);
+    LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
 
-    LLAMA_API int llama_model_n_vocab    (const struct llama_model * model);
-    LLAMA_API int llama_model_n_ctx      (const struct llama_model * model);
-    LLAMA_API int llama_model_n_ctx_train(const struct llama_model * model);
-    LLAMA_API int llama_model_n_embd     (const struct llama_model * model);
+    LLAMA_API int llama_n_vocab    (const struct llama_model * model);
+    LLAMA_API int llama_n_ctx_train(const struct llama_model * model);
+    LLAMA_API int llama_n_embd     (const struct llama_model * model);
 
     // Get a string describing the model type
     LLAMA_API int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size);
@@ -409,8 +413,7 @@ extern "C" {
             struct llama_context * ctx,
                      llama_token * tokens,
                          int32_t   n_tokens,
-                             int   n_past,
-                             int   n_threads),
+                             int   n_past),
             "use llama_decode() instead");
 
     // Same as llama_eval, but use float matrix input directly.
@@ -419,8 +422,7 @@ extern "C" {
             struct llama_context * ctx,
                            float * embd,
                          int32_t   n_tokens,
-                             int   n_past,
-                             int   n_threads),
+                             int   n_past),
             "use llama_decode() instead");
 
     // Return batch for single sequence of tokens starting at pos_0
@@ -452,8 +454,12 @@ extern "C" {
     // < 0 - error
     LLAMA_API int llama_decode(
             struct llama_context * ctx,
-              struct llama_batch   batch,
-                             int   n_threads);
+              struct llama_batch   batch);
+
+    // Set the number of threads used for decoding
+    // n_threads is the number of threads used for generation (single token)
+    // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
+    LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch);
 
     // Token logits obtained from the last call to llama_eval()
     // The logits for the last token are stored in the last row
@@ -494,14 +500,6 @@ extern "C" {
     // Returns the number of tokens on success, no more than n_max_tokens
     // Returns a negative number on failure - the number of tokens that would have been returned
     LLAMA_API int llama_tokenize(
-            struct llama_context * ctx,
-                      const char * text,
-                             int   text_len,
-                     llama_token * tokens,
-                             int   n_max_tokens,
-                            bool   add_bos);
-
-    LLAMA_API int llama_tokenize_with_model(
         const struct llama_model * model,
                       const char * text,
                              int   text_len,
@@ -514,12 +512,6 @@ extern "C" {
     // Does not write null terminator to the buffer.
     // User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens.
     LLAMA_API int llama_token_to_piece(
-            const struct llama_context * ctx,
-                           llama_token   token,
-                                  char * buf,
-                                  int    length);
-
-    LLAMA_API int llama_token_to_piece_with_model(
               const struct llama_model * model,
                            llama_token   token,
                                   char * buf,
@@ -700,15 +692,13 @@ extern "C" {
     /// @param n_beams Number of beams to use.
     /// @param n_past Number of tokens already evaluated.
     /// @param n_predict Maximum number of tokens to predict. EOS may occur earlier.
-    /// @param n_threads Number of threads as passed to llama_eval().
     LLAMA_API void llama_beam_search(
                    struct llama_context * ctx,
         llama_beam_search_callback_fn_t   callback,
                                    void * callback_data,
                                  size_t   n_beams,
                                     int   n_past,
-                                    int   n_predict,
-                                    int   n_threads);
+                                    int   n_predict);
 
     // Performance information
     LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);
index 836fb8ad271092e9e91dfface487546fe8492ecc..d51851e20822e4c4d615415dc4e159849fc01964 100644 (file)
@@ -62,18 +62,20 @@ int main(int argc, char **argv) {
 
     // load the vocab
     {
-        auto lparams = llama_context_default_params();
+        auto mparams = llama_model_default_params();
 
-        lparams.vocab_only = true;
+        mparams.vocab_only = true;
 
-        model = llama_load_model_from_file(fname.c_str(), lparams);
+        model = llama_load_model_from_file(fname.c_str(), mparams);
 
         if (model == NULL) {
             fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
             return 1;
         }
 
-        ctx = llama_new_context_with_model(model, lparams);
+        auto cparams = llama_context_default_params();
+
+        ctx = llama_new_context_with_model(model, cparams);
 
         if (ctx == NULL) {
             fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
@@ -82,7 +84,7 @@ int main(int argc, char **argv) {
         }
     }
 
-    if (llama_vocab_type(ctx) != LLAMA_VOCAB_TYPE_BPE) {
+    if (llama_vocab_type(model) != LLAMA_VOCAB_TYPE_BPE) {
         fprintf(stderr, "%s : error: vocab type is not SPM\n", __func__);
         llama_free_model(model);
         llama_free(ctx);
index dfb2e81a9bc6ff19368c4aa4a3e5d61dc7056ea7..91c841f7bba8f690e2b75bf1ecd272526376bd6c 100644 (file)
@@ -64,18 +64,20 @@ int main(int argc, char **argv) {
 
     // load the vocab
     {
-        auto lparams = llama_context_default_params();
+        auto mparams = llama_model_default_params();
 
-        lparams.vocab_only = true;
+        mparams.vocab_only = true;
 
-        model = llama_load_model_from_file(fname.c_str(), lparams);
+        model = llama_load_model_from_file(fname.c_str(), mparams);
 
         if (model == NULL) {
             fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
             return 1;
         }
 
-        ctx = llama_new_context_with_model(model, lparams);
+        auto cparams = llama_context_default_params();
+
+        ctx = llama_new_context_with_model(model, cparams);
 
         if (ctx == NULL) {
             fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
@@ -84,7 +86,7 @@ int main(int argc, char **argv) {
         }
     }
 
-    if (llama_vocab_type(ctx) != LLAMA_VOCAB_TYPE_SPM) {
+    if (llama_vocab_type(model) != LLAMA_VOCAB_TYPE_SPM) {
         fprintf(stderr, "%s : error: vocab type is not SPM\n", __func__);
         llama_free_model(model);
         llama_free(ctx);
index a95d462cfcd0b181c54950ff5697c3ba5c4456f0..3b2fc87ac48d8cbb5726c8ea02a5e6ae376a3def 100644 (file)
@@ -52,18 +52,20 @@ int main(int argc, char **argv) {
 
     // load the vocab
     {
-        auto lparams = llama_context_default_params();
+        auto mparams = llama_model_default_params();
 
-        lparams.vocab_only = true;
+        mparams.vocab_only = true;
 
-        model = llama_load_model_from_file(fname.c_str(), lparams);
+        model = llama_load_model_from_file(fname.c_str(), mparams);
 
         if (model == NULL) {
             fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
             return 1;
         }
 
-        ctx = llama_new_context_with_model(model, lparams);
+        auto cparams = llama_context_default_params();
+
+        ctx = llama_new_context_with_model(model, cparams);
 
         if (ctx == NULL) {
             fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
@@ -72,7 +74,7 @@ int main(int argc, char **argv) {
         }
     }
 
-    GGML_ASSERT(llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM);
+    GGML_ASSERT(llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM);
 
 #ifdef _WIN32
     // We need this for unicode console support
@@ -80,7 +82,7 @@ int main(int argc, char **argv) {
     atexit([]() { console::cleanup(); });
 #endif
 
-    const int n_vocab = llama_n_vocab(ctx);
+    const int n_vocab = llama_n_vocab(model);
 
     for (int i = 0; i < n_vocab; ++i) {
         std::string str = llama_detokenize_spm(ctx, std::vector<int>(1, i));