]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
model: EmbeddingGemma Adding Support for SentenceTransformers Dense Modules (#16367)
authorSaba Fallah <redacted>
Thu, 9 Oct 2025 06:39:18 +0000 (08:39 +0200)
committerGitHub <redacted>
Thu, 9 Oct 2025 06:39:18 +0000 (09:39 +0300)
* model: EmbeddingGemma sentence-transformers dense linear projections support

* model: add support for EmbeddingGemma SentenceTransformers dense linear projections

Adding support for the Dense modules used in EmbeddingGemma models.
EmbeddingGemma is a SentenceTransformers model with additional modules beyond the base Transformer backbone.

See: https://developers.googleblog.com/en/gemma-explained-embeddinggemma-architecture-and-recipe/

* model: add support for EmbeddingGemma SentenceTransformers dense linear projections

- converting model with dense-layers is optional
- introduced dense config params

* Update convert_hf_to_gguf.py

Co-authored-by: Daniel Bevenius <redacted>
* fixed formatting issues

* Update src/llama-graph.cpp

Co-authored-by: Georgi Gerganov <redacted>
* - removed pooling_type_opt, always allow overriding pooling_type
- asserts checking dense features dims

* fix python lint

* fix ubuntu gcc build warning

* - fixed thread-safety test
- moved asserts to load_hparams

* - tidying up code
- simplifying graph-context expecting both dense weights

* minor : add TODO

---------

Co-authored-by: Daniel Bevenius <redacted>
Co-authored-by: Georgi Gerganov <redacted>
12 files changed:
convert_hf_to_gguf.py
gguf-py/gguf/constants.py
gguf-py/gguf/gguf_writer.py
gguf-py/gguf/tensor_mapping.py
src/llama-arch.cpp
src/llama-arch.h
src/llama-context.cpp
src/llama-graph.cpp
src/llama-graph.h
src/llama-hparams.h
src/llama-model.cpp
src/llama-model.h

index a59ebfc0da7766ac5a6ce3ceb6a80d5b762fc6b8..43d345bcb480c9266e4b101b74153a4e694d5de5 100755 (executable)
@@ -93,13 +93,15 @@ class ModelBase:
     # Mistral format specifics
     is_mistral_format: bool = False
     disable_mistral_community_chat_template: bool = False
+    sentence_transformers_dense_modules: bool = False
 
     def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False,
                  use_temp_file: bool = False, eager: bool = False,
                  metadata_override: Path | None = None, model_name: str | None = None,
                  split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
                  small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None,
-                 disable_mistral_community_chat_template: bool = False):
+                 disable_mistral_community_chat_template: bool = False,
+                 sentence_transformers_dense_modules: bool = False):
         if type(self) is ModelBase or \
                 type(self) is TextModel or \
                 type(self) is MmprojModel:
@@ -114,6 +116,7 @@ class ModelBase:
         self.lazy = not eager or (remote_hf_model_id is not None)
         self.dry_run = dry_run
         self.remote_hf_model_id = remote_hf_model_id
+        self.sentence_transformers_dense_modules = sentence_transformers_dense_modules
         if remote_hf_model_id is not None:
             self.is_safetensors = True
 
@@ -5269,6 +5272,53 @@ class Gemma3Model(TextModel):
 @ModelBase.register("Gemma3TextModel")
 class EmbeddingGemma(Gemma3Model):
     model_arch = gguf.MODEL_ARCH.GEMMA_EMBEDDING
+    module_paths = []
+    dense_features_dims = {}
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        if self.sentence_transformers_dense_modules:
+            # read modules.json to determine if model has Dense layers
+            modules_file = self.dir_model / "modules.json"
+            if modules_file.is_file():
+                with open(modules_file, encoding="utf-8") as modules_json_file:
+                    mods = json.load(modules_json_file)
+                for mod in mods:
+                    if mod["type"] == "sentence_transformers.models.Dense":
+                        mod_path = mod["path"]
+                        # check if model.safetensors file for Dense layer exists
+                        model_tensors_file = self.dir_model / mod_path / "model.safetensors"
+                        if model_tensors_file.is_file():
+                            self.module_paths.append(mod_path)
+                            # read config.json of the Dense layer to get in/out features
+                            mod_conf_file = self.dir_model / mod_path / "config.json"
+                            if mod_conf_file.is_file():
+                                with open(mod_conf_file, encoding="utf-8") as mod_conf_json_file:
+                                    mod_conf = json.load(mod_conf_json_file)
+                                    # hparams dense_2_feat_out and dense_3_feat_in are required when loading model's dense weights
+                                    prefix = self._get_dense_prefix(mod_path)
+                                    if mod_conf["in_features"] is not None and mod_conf["out_features"] is not None:
+                                        self.dense_features_dims[prefix] = (mod_conf["in_features"], mod_conf["out_features"])
+
+    def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
+        from safetensors.torch import load_file
+        module_paths = list(self.module_paths)
+        for i, module_path in enumerate(module_paths):
+            tensors_file = self.dir_model / module_path / "model.safetensors"
+            local_tensors = load_file(tensors_file)
+            tensor_name = self._get_dense_prefix(module_path)
+            for name, local_tensor in local_tensors.items():
+                if not name.endswith(".weight"):
+                    continue
+                orig_name = name.replace("linear", tensor_name)
+                name = self.map_tensor_name(orig_name)
+                yield name, local_tensor.clone()
+
+    @staticmethod
+    def _get_dense_prefix(module_path) -> str:
+        """Get the tensor name prefix for the Dense layer from module path."""
+        tensor_name = "dense_2" if module_path == "2_Dense" else "dense_3"
+        return tensor_name
 
     def set_gguf_parameters(self):
         super().set_gguf_parameters()
@@ -5285,6 +5335,10 @@ class EmbeddingGemma(Gemma3Model):
             logger.info(f"Using original sliding_window from config: {orig_sliding_window} "
                         f"instead of {self.hparams['sliding_window']}")
             self.gguf_writer.add_sliding_window(orig_sliding_window)
+        if self.sentence_transformers_dense_modules:
+            for dense, dims in self.dense_features_dims.items():
+                logger.info(f"Setting dense layer {dense} in/out features to {dims}")
+                self.gguf_writer.add_dense_features_dims(dense, dims[0], dims[1])
 
         self._try_set_pooling_type()
 
@@ -9335,6 +9389,13 @@ def parse_args() -> argparse.Namespace:
         )
     )
 
+    parser.add_argument(
+        "--sentence-transformers-dense-modules", action="store_true",
+        help=("Whether to include sentence-transformers dense modules."
+              "It can be used for sentence-transformers models, like google/embeddinggemma-300m"
+              "Default these modules are not included.")
+    )
+
     args = parser.parse_args()
     if not args.print_supported_models and args.model is None:
         parser.error("the following arguments are required: model")
@@ -9397,9 +9458,13 @@ def main() -> None:
     if args.remote:
         hf_repo_id = args.model
         from huggingface_hub import snapshot_download
+        allowed_patterns = ["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"]
+        if args.sentence_transformers_dense_modules:
+            # include sentence-transformers dense modules safetensors files
+            allowed_patterns.append("*.safetensors")
         local_dir = snapshot_download(
             repo_id=hf_repo_id,
-            allow_patterns=["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"])
+            allow_patterns=allowed_patterns)
         dir_model = Path(local_dir)
         logger.info(f"Downloaded config and tokenizer to {local_dir}")
     else:
@@ -9467,7 +9532,8 @@ def main() -> None:
                                      split_max_tensors=args.split_max_tensors,
                                      split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
                                      small_first_shard=args.no_tensor_first_split,
-                                     remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template
+                                     remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template,
+                                     sentence_transformers_dense_modules=args.sentence_transformers_dense_modules
                                      )
 
         if args.vocab_only:
index 9c99b90faace83d6dc65522563ee112ed96fb99d..f5e5fba8008bd0e64baf231b3d773098a0c9010f 100644 (file)
@@ -128,6 +128,8 @@ class Keys:
         ALTUP_ACTIVE_IDX                  = "{arch}.altup.active_idx"
         ALTUP_NUM_INPUTS                  = "{arch}.altup.num_inputs"
         EMBD_LENGTH_PER_LAYER_INP         = "{arch}.embedding_length_per_layer_input"
+        DENSE_FEAT_IN_SIZE                = "{arch}.{dense}_feat_in"
+        DENSE_FEAT_OUT_SIZE               = "{arch}.{dense}_feat_out"
 
     class Attention:
         HEAD_COUNT                   = "{arch}.attention.head_count"
@@ -433,6 +435,8 @@ class MODEL_TENSOR(IntEnum):
     TOKEN_TYPES          = auto()
     POS_EMBD             = auto()
     OUTPUT               = auto()
+    DENSE_2_OUT          = auto() # embeddinggemma 2_Dense
+    DENSE_3_OUT          = auto() # embeddinggemma 3_Dense
     OUTPUT_NORM          = auto()
     ROPE_FREQS           = auto()
     ROPE_FACTORS_LONG    = auto()
@@ -777,6 +781,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
     MODEL_TENSOR.POS_EMBD:                  "position_embd",
     MODEL_TENSOR.OUTPUT_NORM:               "output_norm",
     MODEL_TENSOR.OUTPUT:                    "output",
+    MODEL_TENSOR.DENSE_2_OUT:                "dense_2", # embeddinggemma 2_Dense
+    MODEL_TENSOR.DENSE_3_OUT:                "dense_3", # embeddinggemma 2_Dense
     MODEL_TENSOR.ROPE_FREQS:                "rope_freqs",
     MODEL_TENSOR.ROPE_FACTORS_LONG:         "rope_factors_long",
     MODEL_TENSOR.ROPE_FACTORS_SHORT:        "rope_factors_short",
@@ -1759,6 +1765,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
     MODEL_ARCH.GEMMA_EMBEDDING: [
         MODEL_TENSOR.TOKEN_EMBD,
         MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.DENSE_2_OUT,
+        MODEL_TENSOR.DENSE_3_OUT,
         MODEL_TENSOR.OUTPUT_NORM,
         MODEL_TENSOR.ATTN_Q,
         MODEL_TENSOR.ATTN_Q_NORM,
index dfe4bfd490519a70eee338deaa8c067d5ea2e94d..306679e21834b88ab09dc070da0a51a8bc712177 100644 (file)
@@ -730,6 +730,10 @@ class GGUFWriter:
     def add_sliding_window_pattern(self, value: Sequence[bool]) -> None:
         self.add_array(Keys.Attention.SLIDING_WINDOW_PATTERN.format(arch=self.arch), value)
 
+    def add_dense_features_dims(self, dense:str, in_f:int, out_f:int) -> None:
+        self.add_uint32(Keys.LLM.DENSE_FEAT_IN_SIZE.format(arch=self.arch, dense=dense), in_f)
+        self.add_uint32(Keys.LLM.DENSE_FEAT_OUT_SIZE.format(arch=self.arch, dense=dense), out_f)
+
     def add_logit_scale(self, value: float) -> None:
         self.add_float32(Keys.LLM.LOGIT_SCALE.format(arch=self.arch), value)
 
index 3e9a2dd8f8cc912e2dffc4ae00bba25d3fd15d24..c05aa6cc488de6489a6c33d05de8fb4f79ab645b 100644 (file)
@@ -76,7 +76,12 @@ class TensorNameMap:
             "lm_head",                   # llama4
             "model.transformer.ff_out",  # llada
         ),
-
+        MODEL_TENSOR.DENSE_2_OUT: (
+            "dense_2_out",  # embeddinggemma
+        ),
+        MODEL_TENSOR.DENSE_3_OUT: (
+            "dense_3_out",  # embeddinggemma
+        ),
         # Output norm
         MODEL_TENSOR.OUTPUT_NORM: (
             "gpt_neox.final_layer_norm",               # gptneox
index 45f0d0e2cbbd4a763b469d52dfc65e352fc255ce..869e4dccf0dc9922bfb41f9476108b2d3cbbb6d5 100644 (file)
@@ -219,6 +219,11 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" },
 
     { LLM_KV_SHORTCONV_L_CACHE, "%s.shortconv.l_cache" },
+    // sentence-transformers dense modules feature dims
+    { LLM_KV_DENSE_2_FEAT_IN,        "%s.dense_2_feat_in"  },
+    { LLM_KV_DENSE_2_FEAT_OUT,       "%s.dense_2_feat_out"  },
+    { LLM_KV_DENSE_3_FEAT_IN,        "%s.dense_3_feat_in"   },
+    { LLM_KV_DENSE_3_FEAT_OUT,       "%s.dense_3_feat_out"  },
 
     { LLM_KV_TOKENIZER_MODEL,                "tokenizer.ggml.model"                    },
     { LLM_KV_TOKENIZER_PRE,                  "tokenizer.ggml.pre"                      },
@@ -1071,6 +1076,8 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
             { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
             { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_DENSE_2_OUT,     "dense_2" },
+            { LLM_TENSOR_DENSE_3_OUT,     "dense_3" },
             { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
             { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
             { LLM_TENSOR_ATTN_Q_NORM,     "blk.%d.attn_q_norm" },
@@ -2281,6 +2288,8 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
     {LLM_TENSOR_OUTPUT,                     {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
     {LLM_TENSOR_CLS,                        {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
     {LLM_TENSOR_CLS_OUT,                    {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_DENSE_2_OUT,                {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output
+    {LLM_TENSOR_DENSE_3_OUT,                {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output
     {LLM_TENSOR_OUTPUT_NORM,                {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
     {LLM_TENSOR_DEC_OUTPUT_NORM,            {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
     {LLM_TENSOR_ENC_OUTPUT_NORM,            {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
index 507fe5f3793e0b62666997b2f35e478767f26ad5..c3ae71655b17b417acce9f052925296802ba21a2 100644 (file)
@@ -271,6 +271,12 @@ enum llm_kv {
     LLM_KV_TOKENIZER_PREFIX_ID,
     LLM_KV_TOKENIZER_SUFFIX_ID,
     LLM_KV_TOKENIZER_MIDDLE_ID,
+
+    // sentence-transformers dense layers in and out features
+    LLM_KV_DENSE_2_FEAT_IN,
+    LLM_KV_DENSE_2_FEAT_OUT,
+    LLM_KV_DENSE_3_FEAT_IN,
+    LLM_KV_DENSE_3_FEAT_OUT,
 };
 
 enum llm_tensor {
@@ -278,6 +284,8 @@ enum llm_tensor {
     LLM_TENSOR_TOKEN_EMBD_NORM,
     LLM_TENSOR_TOKEN_TYPES,
     LLM_TENSOR_POS_EMBD,
+    LLM_TENSOR_DENSE_2_OUT,
+    LLM_TENSOR_DENSE_3_OUT,
     LLM_TENSOR_OUTPUT,
     LLM_TENSOR_OUTPUT_NORM,
     LLM_TENSOR_ROPE_FREQS,
index d8a8b5e647a8508303f2ff548db816141e0e781a..e7526e7d0a55791630066cad4ce0df075a9c8e64 100644 (file)
@@ -2346,6 +2346,12 @@ llama_context * llama_init_from_model(
         return nullptr;
     }
 
+    if (params.pooling_type != model->hparams.pooling_type) {
+        //user-specified pooling-type is different from the model default
+        LLAMA_LOG_WARN("%s: model default pooling_type is [%d], but [%d] was specified\n", __func__,
+                       model->hparams.pooling_type, params.pooling_type);
+    }
+
     try {
         auto * ctx = new llama_context(*model, params);
         return ctx;
index 90cd885a60a4f6801d61c7abf4ad610ead4f81a7..a24853c63ada4006f10b059a71e5bc22750ed2c0 100644 (file)
@@ -1853,6 +1853,23 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
     return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
 }
 
+void llm_graph_context::build_dense_out(
+    ggml_tensor * dense_2,
+    ggml_tensor * dense_3) const {
+    if (!cparams.embeddings || dense_2 == nullptr || dense_3 == nullptr) {
+        return;
+    }
+    ggml_tensor * cur = res->t_embd_pooled != nullptr ? res->t_embd_pooled : res->t_embd;
+    GGML_ASSERT(cur != nullptr && "missing t_embd_pooled/t_embd");
+
+    cur = ggml_mul_mat(ctx0, dense_2, cur);
+    cur = ggml_mul_mat(ctx0, dense_3, cur);
+    cb(cur, "result_embd_pooled", -1);
+    res->t_embd_pooled = cur;
+    ggml_build_forward_expand(gf, cur);
+}
+
+
 void llm_graph_context::build_pooling(
         ggml_tensor * cls,
         ggml_tensor * cls_b,
index 34b984afeb04379e1e7b5aa4c792931fd3b03b9e..dc84b7942893a10d4084950c70d529ebcd051c67 100644 (file)
@@ -814,6 +814,14 @@ struct llm_graph_context {
             ggml_tensor * cls_b,
             ggml_tensor * cls_out,
             ggml_tensor * cls_out_b) const;
+
+    //
+    // dense (out)
+    //
+
+    void build_dense_out(
+            ggml_tensor * dense_2,
+            ggml_tensor * dense_3) const;
 };
 
 // TODO: better name
index f29b23eeffe56bec6883abe2db2fc3740edb29ea..4e7f73ec234c33f1164b2e191273436f17bda2fd 100644 (file)
@@ -169,6 +169,12 @@ struct llama_hparams {
     uint32_t laurel_rank  = 64;
     uint32_t n_embd_altup = 256;
 
+    // needed for sentence-transformers dense layers
+    uint32_t dense_2_feat_in  = 0;  // in_features of the 2_Dense
+    uint32_t dense_2_feat_out = 0;  // out_features of the 2_Dense
+    uint32_t dense_3_feat_in  = 0;  // in_features of the 3_Dense
+    uint32_t dense_3_feat_out = 0;  // out_features of the 3_Dense
+
     // xIELU
     std::array<float, LLAMA_MAX_LAYERS> xielu_alpha_n;
     std::array<float, LLAMA_MAX_LAYERS> xielu_alpha_p;
index 03c2f49d7826766aee44c431979b0834da4141a7..a5fe5b749c355875fbb453e27fbed7707afe6185 100644 (file)
@@ -1218,12 +1218,21 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                 hparams.set_swa_pattern(6);
 
                 hparams.causal_attn = false; // embeddings do not use causal attention
-                hparams.rope_freq_base_train_swa  = 10000.0f;
+                hparams.rope_freq_base_train_swa = 10000.0f;
                 hparams.rope_freq_scale_train_swa = 1.0f;
 
-                ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW,    hparams.n_swa);
+                ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-                ml.get_key(LLM_KV_POOLING_TYPE,                hparams.pooling_type);
+                ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);
+
+                //applied only if model converted with --sentence-transformers-dense-modules
+                ml.get_key(LLM_KV_DENSE_2_FEAT_IN, hparams.dense_2_feat_in, false);
+                ml.get_key(LLM_KV_DENSE_2_FEAT_OUT, hparams.dense_2_feat_out, false);
+                ml.get_key(LLM_KV_DENSE_3_FEAT_IN, hparams.dense_3_feat_in, false);
+                ml.get_key(LLM_KV_DENSE_3_FEAT_OUT, hparams.dense_3_feat_out, false);
+
+                GGML_ASSERT((hparams.dense_2_feat_in == 0 || hparams.dense_2_feat_in == hparams.n_embd) && "dense_2_feat_in must be equal to n_embd");
+                GGML_ASSERT((hparams.dense_3_feat_out == 0 || hparams.dense_3_feat_out == hparams.n_embd) && "dense_3_feat_out must be equal to n_embd");
 
                 switch (hparams.n_layer) {
                     case 24: type = LLM_TYPE_0_3B; break;
@@ -3686,6 +3695,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,   "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
                     }
 
+                    // Dense linear weights
+                    dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.dense_2_feat_out}, TENSOR_NOT_REQUIRED);
+                    dense_3_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_3_OUT, "weight"), {hparams.dense_3_feat_in, n_embd}, TENSOR_NOT_REQUIRED);
+
+
                     for (int i = 0; i < n_layer; ++i) {
                         auto & layer = layers[i];
 
@@ -19893,6 +19907,12 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
     // add on pooling layer
     llm->build_pooling(cls, cls_b, cls_out, cls_out_b);
 
+    // if the gguf model was converted with --sentence-transformers-dense-modules
+    // there will be two additional dense projection layers
+    // dense linear projections are applied after pooling
+    // TODO: move reranking logic here and generalize
+    llm->build_dense_out(dense_2_out_layers, dense_3_out_layers);
+
     return llm->res->get_gf();
 }
 
index 20b59d952bf909e433c979e8d6df508af8a3aca4..7f48662f2807ac46e6b24335e850e3585b1575ed 100644 (file)
@@ -438,6 +438,12 @@ struct llama_model {
 
     std::vector<llama_layer> layers;
 
+    //Dense linear projections for SentenceTransformers models like embeddinggemma
+    // For Sentence Transformers models structure see
+    // https://sbert.net/docs/sentence_transformer/usage/custom_models.html#structure-of-sentence-transformer-models
+    struct ggml_tensor * dense_2_out_layers = nullptr;
+    struct ggml_tensor * dense_3_out_layers = nullptr;
+
     llama_model_params params;
 
     // gguf metadata