]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
model : support for LlamaBidirectionalModel architecture (#18220)
authorSaba Fallah <redacted>
Wed, 24 Dec 2025 13:02:36 +0000 (14:02 +0100)
committerGitHub <redacted>
Wed, 24 Dec 2025 13:02:36 +0000 (14:02 +0100)
* model: llama-embed-nemotron

* minor: python lint

* changed arch-name

* templated llm_build_llama to be used for both llama and llama-embed arch

convert_hf_to_gguf.py
gguf-py/gguf/constants.py
src/llama-arch.cpp
src/llama-arch.h
src/llama-model.cpp
src/models/llama.cpp
src/models/models.h

index 16c5acf346d5d862ac54e3b5c1aa667dfcb4cf20..c638e3398665ea914aa4a2edbac39c7dbb4ef210 100755 (executable)
@@ -8695,6 +8695,11 @@ class NemotronHModel(GraniteHybridModel):
                 raise ValueError(f"Unprocessed experts: {experts}")
 
 
+@ModelBase.register("LlamaBidirectionalModel")
+class LlamaEmbedNemotronModel(LlamaModel):
+    model_arch = gguf.MODEL_ARCH.LLAMA_EMBED
+
+
 @ModelBase.register("BailingMoeForCausalLM")
 class BailingMoeModel(TextModel):
     model_arch = gguf.MODEL_ARCH.BAILINGMOE
index 41d3bd4faf2389e1e2171f573feb118da041a3f7..baff8547abe101485c6a56229814b4aecdf15907 100644 (file)
@@ -449,6 +449,7 @@ class MODEL_ARCH(IntEnum):
     RND1             = auto()
     PANGU_EMBED      = auto()
     MISTRAL3         = auto()
+    LLAMA_EMBED      = auto()
 
 
 class VISION_PROJECTOR_TYPE(IntEnum):
@@ -844,6 +845,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.RND1:             "rnd1",
     MODEL_ARCH.PANGU_EMBED:      "pangu-embedded",
     MODEL_ARCH.MISTRAL3:         "mistral3",
+    MODEL_ARCH.LLAMA_EMBED:      "llama-embed",
 }
 
 VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@@ -3196,6 +3198,26 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_DOWN_EXP,
         MODEL_TENSOR.FFN_UP_EXP,
     ],
+    MODEL_ARCH.LLAMA_EMBED: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ROPE_FREQS,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_Q,
+        MODEL_TENSOR.ATTN_K,
+        MODEL_TENSOR.ATTN_V,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.ATTN_ROT_EMBD,
+        MODEL_TENSOR.FFN_GATE_INP,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_GATE,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+        MODEL_TENSOR.FFN_GATE_EXP,
+        MODEL_TENSOR.FFN_DOWN_EXP,
+        MODEL_TENSOR.FFN_UP_EXP,
+    ]
     # TODO
 }
 
index 80f44ae1bfe78f2c3461aed0eb23076f4dfaacba..73420d3c9e16d882ba2325e4c0515f4e33f90925 100644 (file)
@@ -115,6 +115,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_RND1,             "rnd1"             },
     { LLM_ARCH_PANGU_EMBED,      "pangu-embedded"   },
     { LLM_ARCH_MISTRAL3,         "mistral3"         },
+    { LLM_ARCH_LLAMA_EMBED,      "llama-embed"      },
     { LLM_ARCH_UNKNOWN,          "(unknown)"        },
 };
 
@@ -500,6 +501,7 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
         case LLM_ARCH_LLAMA:
         case LLM_ARCH_DECI:
         case LLM_ARCH_MISTRAL3:
+        case LLM_ARCH_LLAMA_EMBED:
             return {
                 LLM_TENSOR_TOKEN_EMBD,
                 LLM_TENSOR_OUTPUT_NORM,
index a53bc39d1833914cf85985c99387b70efa3db9f6..433ee4bc18fb9d307fc23ffd6e191dfe945d26e2 100644 (file)
@@ -119,6 +119,7 @@ enum llm_arch {
     LLM_ARCH_RND1,
     LLM_ARCH_PANGU_EMBED,
     LLM_ARCH_MISTRAL3,
+    LLM_ARCH_LLAMA_EMBED,
     LLM_ARCH_UNKNOWN,
 };
 
index 0d5bcc64fe553d44f27d94745062bb226add4e9b..9fada915d7d1729e296690ff604abb3c075c166d 100644 (file)
@@ -606,7 +606,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
 
         ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
 
-        if (arch == LLM_ARCH_LLAMA || arch == LLM_ARCH_DECI || arch == LLM_ARCH_FALCON) {
+        if (arch == LLM_ARCH_LLAMA || arch == LLM_ARCH_DECI || arch == LLM_ARCH_FALCON || arch == LLM_ARCH_LLAMA_EMBED) {
             if (hparams.n_rot != hparams.n_embd_head_k) {
                 throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
             }
@@ -630,6 +630,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
     // arch-specific KVs
     switch (arch) {
         case LLM_ARCH_LLAMA:
+        case LLM_ARCH_LLAMA_EMBED:
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
 
@@ -2652,6 +2653,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
             case LLM_ARCH_GRANITE:
             case LLM_ARCH_GRANITE_MOE:
             case LLM_ARCH_MISTRAL3:
+            case LLM_ARCH_LLAMA_EMBED:
                 {
                     tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
@@ -7269,16 +7271,20 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
     switch (arch) {
         case LLM_ARCH_LLAMA:
             {
-                llm = std::make_unique<llm_build_llama>(*this, params);
+                llm = std::make_unique<llm_build_llama<false>>(*this, params);
             } break;
         case LLM_ARCH_LLAMA4:
             {
                 if (hparams.swa_type == LLAMA_SWA_TYPE_NONE) {
-                    llm = std::make_unique<llm_build_llama>(*this, params);
+                    llm = std::make_unique<llm_build_llama<false>>(*this, params);
                 } else {
                     llm = std::make_unique<llm_build_llama_iswa>(*this, params);
                 }
             } break;
+        case LLM_ARCH_LLAMA_EMBED:
+            {
+                llm = std::make_unique<llm_build_llama<true>>(*this, params);
+            } break;
         case LLM_ARCH_DECI:
             {
                 llm = std::make_unique<llm_build_deci>(*this, params);
@@ -7874,6 +7880,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_ERNIE4_5:
         case LLM_ARCH_ERNIE4_5_MOE:
         case LLM_ARCH_MISTRAL3:
+        case LLM_ARCH_LLAMA_EMBED:
             return LLAMA_ROPE_TYPE_NORM;
 
         // the pairs of head values are offset by n_rot/2
index ab7fd5d0508665e7f9465ecda75c9feb5ebaee04..42b5fcdf42eb80fd5c705bffe0644de229d05cf5 100644 (file)
@@ -1,6 +1,7 @@
 #include "models.h"
 
-llm_build_llama::llm_build_llama(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
+template <bool embed>
+llm_build_llama<embed>::llm_build_llama(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
     const int64_t n_embd_head = hparams.n_embd_head_v;
 
     GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -14,7 +15,14 @@ llm_build_llama::llm_build_llama(const llama_model & model, const llm_graph_para
     // inp_pos - contains the positions
     ggml_tensor * inp_pos = build_inp_pos();
 
-    auto * inp_attn = build_attn_inp_kv();
+    using inp_attn_type = std::conditional_t<embed, llm_graph_input_attn_no_cache, llm_graph_input_attn_kv>;
+
+    inp_attn_type * inp_attn = nullptr;
+    if constexpr (embed) {
+        inp_attn = build_attn_inp_no_cache();
+    } else {
+        inp_attn = build_attn_inp_kv();
+    }
 
     const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
 
@@ -145,11 +153,16 @@ llm_build_llama::llm_build_llama(const llama_model & model, const llm_graph_para
     cb(cur, "result_norm", -1);
     res->t_embd = cur;
 
-    // lm_head
-    cur = build_lora_mm(model.output, cur);
+    if constexpr (!embed) {
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
 
-    cb(cur, "result_output", -1);
-    res->t_logits = cur;
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+    }
 
     ggml_build_forward_expand(gf, cur);
 }
+
+template struct llm_build_llama<false>;
+template struct llm_build_llama<true>;
index 53a5810659adcc1b37bc8871512fefa377411924..fca505b30ab98acbb639c7612947e3e30dac43b9 100644 (file)
@@ -303,6 +303,7 @@ struct llm_build_llada_moe : public llm_graph_context {
     llm_build_llada_moe(const llama_model & model, const llm_graph_params & params);
 };
 
+template <bool embed>
 struct llm_build_llama : public llm_graph_context {
     llm_build_llama(const llama_model & model, const llm_graph_params & params);
 };