From: Vitali Lovich Date: Tue, 25 Feb 2025 09:29:33 +0000 (-0800) Subject: llama : expose llama_model_n_head_kv in the API (#11997) X-Git-Tag: upstream/0.0.4853~82 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=3e9a2860e996657fc10db8393cf65adc40703082;p=pkg%2Fggml%2Fsources%2Fllama.cpp llama : expose llama_model_n_head_kv in the API (#11997) It's useful to be able to have this from the library layer as it's a key parameter of the model (e.g. to figure out how much KV cache memory is needed). --- diff --git a/include/llama.h b/include/llama.h index b0726cbe..47919602 100644 --- a/include/llama.h +++ b/include/llama.h @@ -477,6 +477,7 @@ extern "C" { LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model); LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model); LLAMA_API int32_t llama_model_n_head (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model); // Get the model's RoPE frequency scaling factor LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index f64c3afa..36a0a009 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -3838,6 +3838,10 @@ int32_t llama_model_n_head(const struct llama_model * model) { return model->hparams.n_head(); } +int32_t llama_model_n_head_kv(const struct llama_model * model) { + return model->hparams.n_head_kv(); +} + // deprecated int32_t llama_n_ctx_train(const struct llama_model * model) { return llama_model_n_ctx_train(model);