* llama : introduce llama_set_warmup() API call that controls warmup mode; use all MoE experts during warmup
* common : use new API to enable warmup mode during model warmup
---------
Co-authored-by: Stanisław Szymczyk <redacted>
if (params.warmup) {
LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
+ llama_set_warmup(lctx, true);
+
std::vector<llama_token> tmp;
llama_token bos = llama_vocab_bos(vocab);
llama_token eos = llama_vocab_eos(vocab);
llama_kv_self_clear(lctx);
llama_synchronize(lctx);
llama_perf_context_reset(lctx);
+ llama_set_warmup(lctx, false);
}
iparams.model.reset(model);
// If set to true, the model will only attend to the past tokens
LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);
+ // Set whether the model is in warmup mode or not
+ // If true, all model tensors are activated during llama_decode() to load and cache their weights.
+ LLAMA_API void llama_set_warmup(struct llama_context * ctx, bool warmup);
+
// Set abort callback
LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);
cparams.flash_attn = params.flash_attn;
cparams.no_perf = params.no_perf;
cparams.pooling_type = params.pooling_type;
+ cparams.warmup = false;
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
cparams.causal_attn = value;
}
+void llama_context::set_warmup(bool value) {
+ LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
+
+ cparams.warmup = value;
+}
+
void llama_context::set_adapter_lora(
llama_adapter_lora * adapter,
float scale) {
//
int32_t llama_context::graph_max_nodes() const {
- return std::max<int32_t>(8192, 5*model.n_tensors());
+ return std::max<int32_t>(65536, 5*model.n_tensors());
}
ggml_cgraph * llama_context::graph_init() {
ctx->set_causal_attn(causal_attn);
}
+void llama_set_warmup(llama_context * ctx, bool warmup) {
+ ctx->set_warmup(warmup);
+}
+
void llama_synchronize(llama_context * ctx) {
ctx->synchronize();
}
void set_embeddings (bool value);
void set_causal_attn(bool value);
+ void set_warmup(bool value);
void set_adapter_lora(
llama_adapter_lora * adapter,
bool offload_kqv;
bool flash_attn;
bool no_perf;
+ bool warmup;
enum llama_pooling_type pooling_type;
n_embd_head_v (hparams.n_embd_head_v),
n_embd_v_gqa (hparams.n_embd_v_gqa()),
n_expert (hparams.n_expert),
- n_expert_used (hparams.n_expert_used),
+ n_expert_used (cparams.warmup ? hparams.n_expert : hparams.n_expert_used),
freq_base (cparams.rope_freq_base),
freq_scale (cparams.rope_freq_scale),
ext_factor (cparams.yarn_ext_factor),