]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Load all MoE experts during warmup (#11571)
authorfairydreaming <redacted>
Fri, 14 Mar 2025 12:47:05 +0000 (13:47 +0100)
committerGitHub <redacted>
Fri, 14 Mar 2025 12:47:05 +0000 (13:47 +0100)
* llama : introduce llama_set_warmup() API call that controls warmup mode; use all MoE experts during warmup

* common : use new API to enable warmup mode during model warmup

---------

Co-authored-by: Stanisław Szymczyk <redacted>
common/common.cpp
include/llama.h
src/llama-context.cpp
src/llama-context.h
src/llama-cparams.h
src/llama-graph.cpp

index 8487e3834bccb474c2961da607ed35101a6c245e..18ffb4e738aee3ddc415cb106b9b2cfcb34622a1 100644 (file)
@@ -1033,6 +1033,8 @@ struct common_init_result common_init_from_params(common_params & params) {
     if (params.warmup) {
         LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
 
+        llama_set_warmup(lctx, true);
+
         std::vector<llama_token> tmp;
         llama_token bos = llama_vocab_bos(vocab);
         llama_token eos = llama_vocab_eos(vocab);
@@ -1063,6 +1065,7 @@ struct common_init_result common_init_from_params(common_params & params) {
         llama_kv_self_clear(lctx);
         llama_synchronize(lctx);
         llama_perf_context_reset(lctx);
+        llama_set_warmup(lctx, false);
     }
 
     iparams.model.reset(model);
index e5286f06162ab5a90118b11a6a9ef783326bf632..6a44be404d9142b04219ac72de9955d26e58ef77 100644 (file)
@@ -945,6 +945,10 @@ extern "C" {
     // If set to true, the model will only attend to the past tokens
     LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);
 
+    // Set whether the model is in warmup mode or not
+    // If true, all model tensors are activated during llama_decode() to load and cache their weights.
+    LLAMA_API void llama_set_warmup(struct llama_context * ctx, bool warmup);
+
     // Set abort callback
     LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);
 
index 4df6b18ec1de33bb1eb0568b5f5c5eeea4d76ae8..c2fcce42a7d58bae98f898f8e5dd4c5e3271fb10 100644 (file)
@@ -39,6 +39,7 @@ llama_context::llama_context(
     cparams.flash_attn       = params.flash_attn;
     cparams.no_perf          = params.no_perf;
     cparams.pooling_type     = params.pooling_type;
+    cparams.warmup           = false;
 
     cparams.n_ctx            = params.n_ctx           == 0    ? hparams.n_ctx_train           : params.n_ctx;
     cparams.rope_freq_base   = params.rope_freq_base  == 0.0f ? hparams.rope_freq_base_train  : params.rope_freq_base;
@@ -948,6 +949,12 @@ void llama_context::set_causal_attn(bool value) {
     cparams.causal_attn = value;
 }
 
+void llama_context::set_warmup(bool value) {
+    LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
+
+    cparams.warmup = value;
+}
+
 void llama_context::set_adapter_lora(
             llama_adapter_lora * adapter,
             float scale) {
@@ -1594,7 +1601,7 @@ void llama_context::output_reorder() {
 //
 
 int32_t llama_context::graph_max_nodes() const {
-    return std::max<int32_t>(8192, 5*model.n_tensors());
+    return std::max<int32_t>(65536, 5*model.n_tensors());
 }
 
 ggml_cgraph * llama_context::graph_init() {
@@ -2372,6 +2379,10 @@ void llama_set_causal_attn(llama_context * ctx, bool causal_attn) {
     ctx->set_causal_attn(causal_attn);
 }
 
+void llama_set_warmup(llama_context * ctx, bool warmup) {
+    ctx->set_warmup(warmup);
+}
+
 void llama_synchronize(llama_context * ctx) {
     ctx->synchronize();
 }
index 88df8950e4cb0f6039ea50c87aa7f1b45d944f59..04facb544cb1a54b43a7e58eedd1ca4368df0ce4 100644 (file)
@@ -64,6 +64,7 @@ struct llama_context {
 
     void set_embeddings (bool value);
     void set_causal_attn(bool value);
+    void set_warmup(bool value);
 
     void set_adapter_lora(
             llama_adapter_lora * adapter,
index 252012f3d9405ac04022303528c0d0c082a46879..30e550f023a9e323fae37fd364d9fcb8a3745e41 100644 (file)
@@ -29,6 +29,7 @@ struct llama_cparams {
     bool offload_kqv;
     bool flash_attn;
     bool no_perf;
+    bool warmup;
 
     enum llama_pooling_type pooling_type;
 
index e4af507780aa1a4447e020378779d86221f91e6a..4e90873397ca4c230f3c79f82095a415c25e455f 100644 (file)
@@ -577,7 +577,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
     n_embd_head_v    (hparams.n_embd_head_v),
     n_embd_v_gqa     (hparams.n_embd_v_gqa()),
     n_expert         (hparams.n_expert),
-    n_expert_used    (hparams.n_expert_used),
+    n_expert_used    (cparams.warmup ? hparams.n_expert : hparams.n_expert_used),
     freq_base        (cparams.rope_freq_base),
     freq_scale       (cparams.rope_freq_scale),
     ext_factor       (cparams.yarn_ext_factor),