]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : expose llama_model_n_head_kv in the API (#11997)
authorVitali Lovich <redacted>
Tue, 25 Feb 2025 09:29:33 +0000 (01:29 -0800)
committerGitHub <redacted>
Tue, 25 Feb 2025 09:29:33 +0000 (11:29 +0200)
It's useful to be able to have this from the library layer as it's a key
parameter of the model (e.g. to figure out how much KV cache memory is
needed).

include/llama.h
src/llama-model.cpp

index b0726cbe63ea6b1480ab4008a6c97d85c18c4a0e..479196026b93bf5f3a7600f5a4c100326ebad152 100644 (file)
@@ -477,6 +477,7 @@ extern "C" {
     LLAMA_API int32_t llama_model_n_embd     (const struct llama_model * model);
     LLAMA_API int32_t llama_model_n_layer    (const struct llama_model * model);
     LLAMA_API int32_t llama_model_n_head     (const struct llama_model * model);
+    LLAMA_API int32_t llama_model_n_head_kv  (const struct llama_model * model);
 
     // Get the model's RoPE frequency scaling factor
     LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
index f64c3afa029830da7ef77980460af2cb10f990e6..36a0a009c45672d1c4b17f2489b48c3759589810 100644 (file)
@@ -3838,6 +3838,10 @@ int32_t llama_model_n_head(const struct llama_model * model) {
     return model->hparams.n_head();
 }
 
+int32_t llama_model_n_head_kv(const struct llama_model * model) {
+    return model->hparams.n_head_kv();
+}
+
 // deprecated
 int32_t llama_n_ctx_train(const struct llama_model * model) {
     return llama_model_n_ctx_train(model);