]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
model : jina-embeddings-v3 support (#13693)
authorSigbjørn Skjæret <redacted>
Thu, 28 Aug 2025 13:49:50 +0000 (15:49 +0200)
committerGitHub <redacted>
Thu, 28 Aug 2025 13:49:50 +0000 (15:49 +0200)
* initial jina-embeddings-v3 support

* initial jina-embeddings-v3 support

* initial jina-embeddings-v3 support

* fix vocab parsing with only tokenizer.json

* set mask token lstrip attribute

* additional unk_token_id fallback just in case [no ci]

* revert vocab_size() change [no ci]

* merge tensor loading into general bert

* rope

* add lora embedding and loading (non-functional)

* export separate lora ggufs instead

* add adapter metadata api

* use std::string

* convert_hf_to_lora compatibility

* fix assert

* apply suggestions from review

* apply suggestion from review

14 files changed:
common/arg.cpp
common/common.cpp
common/common.h
convert_hf_to_gguf.py
gguf-py/gguf/constants.py
include/llama.h
src/llama-adapter.cpp
src/llama-adapter.h
src/llama-arch.cpp
src/llama-arch.h
src/llama-model.cpp
src/llama-model.h
src/llama-vocab.cpp
tools/server/server.cpp

index d82f55890dee5398f9fa063e53f01c53ac8a2540..93f0108b2b93bc646fb26822a65cd2459d73c9b7 100644 (file)
@@ -2555,7 +2555,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         {"--lora"}, "FNAME",
         "path to LoRA adapter (can be repeated to use multiple adapters)",
         [](common_params & params, const std::string & value) {
-            params.lora_adapters.push_back({ std::string(value), 1.0, nullptr });
+            params.lora_adapters.push_back({ std::string(value), 1.0, "", "", nullptr });
         }
         // we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg
     ).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}));
@@ -2563,7 +2563,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         {"--lora-scaled"}, "FNAME", "SCALE",
         "path to LoRA adapter with user defined scaling (can be repeated to use multiple adapters)",
         [](common_params & params, const std::string & fname, const std::string & scale) {
-            params.lora_adapters.push_back({ fname, std::stof(scale), nullptr });
+            params.lora_adapters.push_back({ fname, std::stof(scale), "", "", nullptr });
         }
         // we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg
     ).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}));
index fdce1dcdec19b613efe0c73e8fcc26eb1d4c968e..054b43be770da67bc7442ebbc5cb816ce4e52cb7 100644 (file)
@@ -988,7 +988,12 @@ struct common_init_result common_init_from_params(common_params & params) {
             return iparams;
         }
 
+        char buf[1024];
         la.ptr = lora.get();
+        llama_adapter_meta_val_str(la.ptr, "adapter.lora.task_name", buf, sizeof(buf));
+        la.task_name = buf;
+        llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf));
+        la.prompt_prefix = buf;
         iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters
     }
 
index 390dda5e531bedb275504d72fa3751b292e9c303..87ea0606954a359293c07668f0af1ad4d92d9e8b 100644 (file)
@@ -34,6 +34,9 @@ struct common_adapter_lora_info {
     std::string path;
     float scale;
 
+    std::string task_name;
+    std::string prompt_prefix;
+
     struct llama_adapter_lora * ptr;
 };
 
index 31a11cbec0baafcc69a9a6a9d2660cbce4d4bd01..6c8a034060536c2ebb8ddada21e76f0d36f14d76 100755 (executable)
@@ -72,6 +72,7 @@ class ModelBase:
     endianess: gguf.GGUFEndian
     use_temp_file: bool
     lazy: bool
+    dry_run: bool
     part_names: list[str]
     is_safetensors: bool
     hparams: dict[str, Any]
@@ -111,6 +112,7 @@ class ModelBase:
         self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
         self.use_temp_file = use_temp_file
         self.lazy = not eager or (remote_hf_model_id is not None)
+        self.dry_run = dry_run
         self.remote_hf_model_id = remote_hf_model_id
         if remote_hf_model_id is not None:
             self.is_safetensors = True
@@ -4871,11 +4873,35 @@ class NeoBert(BertModel):
 @ModelBase.register("XLMRobertaModel", "XLMRobertaForSequenceClassification")
 class XLMRobertaModel(BertModel):
     model_arch = gguf.MODEL_ARCH.BERT
+    _lora_files = {}
+    _lora_names = []
 
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
+    def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, **kwargs: Any):
+        hparams = kwargs.pop("hparams", None)
+        if hparams is None:
+            hparams = ModelBase.load_hparams(dir_model, False)
+
+        if lora_names := hparams.get("lora_adaptations"):
+            self._lora_names = lora_names
+            self.model_arch = gguf.MODEL_ARCH.JINA_BERT_V3
+
+        super().__init__(dir_model, ftype, fname_out, hparams=hparams, **kwargs)
         self._xlmroberta_tokenizer_init()
 
+    def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
+        if self._lora_names:
+            for name in self._lora_names:
+                fname = self.add_prefix_to_filename(self.fname_out, f"lora-{name}-")
+                self._lora_files[name] = gguf.GGUFWriter(fname, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file, dry_run=self.dry_run)
+
+        return super().generate_extra_tensors()
+
+    def set_type(self):
+        for lora_writer in self._lora_files.values():
+            lora_writer.add_type(gguf.GGUFType.ADAPTER)
+            lora_writer.add_string(gguf.Keys.Adapter.TYPE, "lora")
+        super().set_type()
+
     def set_vocab(self):
         self._xlmroberta_set_vocab()
 
@@ -4885,13 +4911,62 @@ class XLMRobertaModel(BertModel):
         if name.startswith("roberta."):
             name = name[8:]
 
+        # jina-embeddings-v3
+        if ".parametrizations." in name:
+            name = name.replace(".parametrizations.", ".")
+            if name.endswith(".original"):
+                name = name[:-9]
+
         # position embeddings start at pad_token_id + 1, so just chop down the weight tensor
         if name == "embeddings.position_embeddings.weight":
             if self._position_offset is not None:
                 data_torch = data_torch[self._position_offset:,:]
 
+        if name.endswith(".0.lora_A") or name.endswith(".0.lora_B"):
+            if name.startswith("pooler.dense"):
+                return []
+
+            num_loras = data_torch.size(0)
+            assert num_loras == len(self._lora_names)
+
+            # Split out each LoRA in their own GGUF
+            for i, lora_writer in enumerate(self._lora_files.values()):
+                new_name = self.map_tensor_name(name[:-9]) + name[-7:].lower()
+                data = data_torch[i, :, :]
+                # Transpose/flip token_embd/types into correct shape
+                if new_name == "token_embd.weight.lora_b":
+                    data = data.T
+                elif new_name.startswith("token_types.weight."):
+                    new_name = new_name[:-1] + ("a" if new_name[-1:] == "b" else "b")
+                lora_writer.add_tensor(new_name, data.float().numpy(), raw_dtype=gguf.GGMLQuantizationType.F32)
+
+            return []
+
         return super().modify_tensors(data_torch, name, bid)
 
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+
+        # jina-embeddings-v3
+        if rotary_emb_base := self.hparams.get("rotary_emb_base"):
+            self.gguf_writer.add_rope_freq_base(rotary_emb_base)
+        lora_alpha = self.hparams.get("lora_alpha")
+        if lora_prompt_prefixes := self.hparams.get("task_instructions"):
+            assert self._lora_files and all(lora_name in lora_prompt_prefixes for lora_name in self._lora_files.keys())
+        for lora_name, lora_writer in self._lora_files.items():
+            lora_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, lora_alpha if lora_alpha is not None else 1.0)
+            lora_writer.add_string(gguf.Keys.Adapter.LORA_TASK_NAME, lora_name)
+            if lora_prompt_prefixes:
+                lora_writer.add_string(gguf.Keys.Adapter.LORA_PROMPT_PREFIX, lora_prompt_prefixes[lora_name])
+
+    def write(self):
+        super().write()
+        for lora_writer in self._lora_files.values():
+            lora_writer.write_header_to_file()
+            lora_writer.write_kv_data_to_file()
+            lora_writer.write_tensors_to_file(progress=True)
+            lora_writer.close()
+
 
 @ModelBase.register("GemmaForCausalLM")
 class GemmaModel(TextModel):
index b9d1235d1706d1985737081ef53b03cc168198e4..a581f9601f0fd1e85730fdaa87a68c6f5b8af30b 100644 (file)
@@ -231,8 +231,10 @@ class Keys:
         MIDDLE_ID            = "tokenizer.ggml.middle_token_id"
 
     class Adapter:
-        TYPE       = "adapter.type"
-        LORA_ALPHA = "adapter.lora.alpha"
+        TYPE               = "adapter.type"
+        LORA_ALPHA         = "adapter.lora.alpha"
+        LORA_TASK_NAME     = "adapter.lora.task_name"
+        LORA_PROMPT_PREFIX = "adapter.lora.prompt_prefix"
 
     class IMatrix:
         CHUNK_COUNT = "imatrix.chunk_count"
@@ -315,6 +317,7 @@ class MODEL_ARCH(IntEnum):
     NOMIC_BERT_MOE   = auto()
     NEO_BERT         = auto()
     JINA_BERT_V2     = auto()
+    JINA_BERT_V3     = auto()
     BLOOM            = auto()
     STABLELM         = auto()
     QWEN             = auto()
@@ -647,6 +650,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.NOMIC_BERT_MOE:   "nomic-bert-moe",
     MODEL_ARCH.NEO_BERT:         "neo-bert",
     MODEL_ARCH.JINA_BERT_V2:     "jina-bert-v2",
+    MODEL_ARCH.JINA_BERT_V3:     "jina-bert-v3",
     MODEL_ARCH.BLOOM:            "bloom",
     MODEL_ARCH.STABLELM:         "stablelm",
     MODEL_ARCH.QWEN:             "qwen",
@@ -1234,6 +1238,18 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.LAYER_OUT_NORM,
         MODEL_TENSOR.CLS,
     ],
+    MODEL_ARCH.JINA_BERT_V3: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.TOKEN_EMBD_NORM,
+        MODEL_TENSOR.TOKEN_TYPES,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.ATTN_OUT_NORM,
+        MODEL_TENSOR.ATTN_QKV,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+        MODEL_TENSOR.LAYER_OUT_NORM,
+    ],
     MODEL_ARCH.MPT: [
         MODEL_TENSOR.TOKEN_EMBD,
         MODEL_TENSOR.OUTPUT_NORM,
index c5622cc16b4c2181bc1dd2aad07bfa409add0681..702535385057fcb2529cbfe05363cdee1c85919a 100644 (file)
@@ -553,6 +553,24 @@ extern "C" {
             struct llama_model * model,
             const char * path_lora);
 
+    // Functions to access the adapter's GGUF metadata scalar values
+    // - The functions return the length of the string on success, or -1 on failure
+    // - The output string is always null-terminated and cleared on failure
+    // - When retrieving a string, an extra byte must be allocated to account for the null terminator
+    // - GGUF array values are not supported by these functions
+
+    // Get metadata value as a string by key name
+    LLAMA_API int32_t llama_adapter_meta_val_str(const struct llama_adapter_lora * adapter, const char * key, char * buf, size_t buf_size);
+
+    // Get the number of metadata key/value pairs
+    LLAMA_API int32_t llama_adapter_meta_count(const struct llama_adapter_lora * adapter);
+
+    // Get metadata key name by index
+    LLAMA_API int32_t llama_adapter_meta_key_by_index(const struct llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size);
+
+    // Get metadata value as a string by index
+    LLAMA_API int32_t llama_adapter_meta_val_str_by_index(const struct llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size);
+
     // Manually free a LoRA adapter
     // Note: loaded adapters will be free when the associated model is deleted
     LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter);
index 8d94034aed95debd4b1ea9269996578744ba809b..772ce1b4480886d280a5806b4da1c94bb8b1f73d 100644 (file)
@@ -163,13 +163,38 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
 
     // check metadata
     {
+        const gguf_context * gguf_ctx = ctx_gguf.get();
+
+        LLAMA_LOG_INFO("%s: Dumping metadata keys/values.\n", __func__);
+
+        // get metadata as string
+        for (int i = 0; i < gguf_get_n_kv(gguf_ctx); i++) {
+            gguf_type type = gguf_get_kv_type(gguf_ctx, i);
+            const std::string type_name =
+                type == GGUF_TYPE_ARRAY
+                ? format("%s[%s,%zu]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(gguf_ctx, i)), gguf_get_arr_n(gguf_ctx, i))
+                : gguf_type_name(type);
+            const char * name = gguf_get_key(gguf_ctx, i);
+            const std::string value = gguf_kv_to_str(gguf_ctx, i);
+
+            if (type != GGUF_TYPE_ARRAY) {
+                adapter.gguf_kv.emplace(name, value);
+            }
+
+            const size_t MAX_VALUE_LEN = 40;
+            std::string print_value = value.size() > MAX_VALUE_LEN ? format("%s...", value.substr(0, MAX_VALUE_LEN - 3).c_str()) : value;
+            replace_all(print_value, "\n", "\\n");
+
+            LLAMA_LOG_INFO("%s: - kv %3d: %42s %-16s = %s\n", __func__, i, name, type_name.c_str(), print_value.c_str());
+        }
+
         auto get_kv_str = [&](const std::string & key) -> std::string {
-            int id = gguf_find_key(ctx_gguf.get(), key.c_str());
-            return id < 0 ? "" : std::string(gguf_get_val_str(ctx_gguf.get(), id));
+            int id = gguf_find_key(gguf_ctx, key.c_str());
+            return id < 0 ? "" : std::string(gguf_get_val_str(gguf_ctx, id));
         };
         auto get_kv_f32 = [&](const std::string & key) -> float {
-            int id = gguf_find_key(ctx_gguf.get(), key.c_str());
-            return id < 0 ? 0.0f : gguf_get_val_f32(ctx_gguf.get(), id);
+            int id = gguf_find_key(gguf_ctx, key.c_str());
+            return id < 0 ? 0.0f : gguf_get_val_f32(gguf_ctx, id);
         };
         LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN);
 
@@ -383,6 +408,45 @@ llama_adapter_lora * llama_adapter_lora_init(llama_model * model, const char * p
     return nullptr;
 }
 
+int32_t llama_adapter_meta_val_str(const llama_adapter_lora * adapter, const char * key, char * buf, size_t buf_size) {
+    const auto & it = adapter->gguf_kv.find(key);
+    if (it == adapter->gguf_kv.end()) {
+        if (buf_size > 0) {
+            buf[0] = '\0';
+        }
+        return -1;
+    }
+    return snprintf(buf, buf_size, "%s", it->second.c_str());
+}
+
+int32_t llama_adapter_meta_count(const llama_adapter_lora * adapter) {
+    return (int)adapter->gguf_kv.size();
+}
+
+int32_t llama_adapter_meta_key_by_index(const llama_adapter_lora * adapter, int i, char * buf, size_t buf_size) {
+    if (i < 0 || i >= (int)adapter->gguf_kv.size()) {
+        if (buf_size > 0) {
+            buf[0] = '\0';
+        }
+        return -1;
+    }
+    auto it = adapter->gguf_kv.begin();
+    std::advance(it, i);
+    return snprintf(buf, buf_size, "%s", it->first.c_str());
+}
+
+int32_t llama_adapter_meta_val_str_by_index(const llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size) {
+    if (i < 0 || i >= (int)adapter->gguf_kv.size()) {
+        if (buf_size > 0) {
+            buf[0] = '\0';
+        }
+        return -1;
+    }
+    auto it = adapter->gguf_kv.begin();
+    std::advance(it, i);
+    return snprintf(buf, buf_size, "%s", it->second.c_str());
+}
+
 void llama_adapter_lora_free(llama_adapter_lora * adapter) {
     delete adapter;
 }
index 65824e972765bd5b8b845e0c41ec46dc81657a4f..9084e7cab08fdfe63f9c1ca012d48dec47c42054 100644 (file)
@@ -67,6 +67,9 @@ struct llama_adapter_lora {
 
     float alpha;
 
+    // gguf metadata
+    std::unordered_map<std::string, std::string> gguf_kv;
+
     llama_adapter_lora() = default;
     ~llama_adapter_lora() = default;
 
index 0ca0a4c22f8149dda9c0af94dd4970c2b4d18284..a61dc177ac9b8f741f738628096d9deef813a23e 100644 (file)
@@ -22,6 +22,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_NOMIC_BERT_MOE,   "nomic-bert-moe"   },
     { LLM_ARCH_NEO_BERT,         "neo-bert"         },
     { LLM_ARCH_JINA_BERT_V2,     "jina-bert-v2"     },
+    { LLM_ARCH_JINA_BERT_V3,     "jina-bert-v3"     },
     { LLM_ARCH_BLOOM,            "bloom"            },
     { LLM_ARCH_STABLELM,         "stablelm"         },
     { LLM_ARCH_QWEN,             "qwen"             },
@@ -234,8 +235,10 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_TOKENIZER_FIM_REP_ID,           "tokenizer.ggml.fim_rep_token_id"         },
     { LLM_KV_TOKENIZER_FIM_SEP_ID,           "tokenizer.ggml.fim_sep_token_id"         },
 
-    { LLM_KV_ADAPTER_TYPE,       "adapter.type"       },
-    { LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" },
+    { LLM_KV_ADAPTER_TYPE,               "adapter.type"               },
+    { LLM_KV_ADAPTER_LORA_ALPHA,         "adapter.lora.alpha"         },
+    { LLM_KV_ADAPTER_LORA_TASK_NAME,     "adapter.lora.task_name"     },
+    { LLM_KV_ADAPTER_LORA_PROMPT_PREFIX, "adapter.lora.prompt_prefix" },
 
     // deprecated
     { LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" },
@@ -575,6 +578,20 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_CLS,             "cls" },
         },
     },
+    {
+        LLM_ARCH_JINA_BERT_V3,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
+            { LLM_TENSOR_TOKEN_TYPES,     "token_types" },
+            { LLM_TENSOR_ATTN_OUT_NORM,   "blk.%d.attn_output_norm" },
+            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+            { LLM_TENSOR_LAYER_OUT_NORM,  "blk.%d.layer_output_norm" },
+        },
+    },
     {
         LLM_ARCH_BLOOM,
         {
index 7008c2514c5d4debb99263f77d914ed671751f5c..94b0bef719bed00cb8b2e054d66a0717bc5c4b9b 100644 (file)
@@ -26,6 +26,7 @@ enum llm_arch {
     LLM_ARCH_NOMIC_BERT_MOE,
     LLM_ARCH_NEO_BERT,
     LLM_ARCH_JINA_BERT_V2,
+    LLM_ARCH_JINA_BERT_V3,
     LLM_ARCH_BLOOM,
     LLM_ARCH_STABLELM,
     LLM_ARCH_QWEN,
@@ -230,6 +231,8 @@ enum llm_kv {
 
     LLM_KV_ADAPTER_TYPE,
     LLM_KV_ADAPTER_LORA_ALPHA,
+    LLM_KV_ADAPTER_LORA_TASK_NAME,
+    LLM_KV_ADAPTER_LORA_PROMPT_PREFIX,
 
     LLM_KV_POSNET_EMBEDDING_LENGTH,
     LLM_KV_POSNET_BLOCK_COUNT,
index 7d3429617bef98d2ccaa679351c6fb85f72517f9..30974a723fda1a628e69229e3735f5535eefc236 100644 (file)
@@ -47,6 +47,7 @@ const char * llm_type_name(llm_type type) {
         case LLM_TYPE_410M:          return "410M";
         case LLM_TYPE_450M:          return "450M";
         case LLM_TYPE_475M:          return "475M";
+        case LLM_TYPE_558M:          return "558M";
         case LLM_TYPE_700M:          return "700M";
         case LLM_TYPE_770M:          return "770M";
         case LLM_TYPE_780M:          return "780M";
@@ -772,6 +773,18 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     default: type = LLM_TYPE_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_JINA_BERT_V3:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS,    hparams.f_norm_eps);
+                ml.get_key(LLM_KV_ATTENTION_CAUSAL,           hparams.causal_attn);
+                ml.get_key(LLM_KV_POOLING_TYPE,               hparams.pooling_type, false);
+
+                switch (hparams.n_layer) {
+                    case 24:
+                        type = LLM_TYPE_558M; break;
+                    default: type = LLM_TYPE_UNKNOWN;
+                }
+            } break;
         case LLM_ARCH_NOMIC_BERT:
         case LLM_ARCH_NOMIC_BERT_MOE:
             {
@@ -2631,6 +2644,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
             case LLM_ARCH_BERT:
             case LLM_ARCH_NOMIC_BERT:
             case LLM_ARCH_NOMIC_BERT_MOE:
+            case LLM_ARCH_JINA_BERT_V3:
                 {
                     tok_embd     = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}, 0);
                     type_embd    = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED);
@@ -2666,24 +2680,22 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         }
 
                         layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT,      "weight", i), {n_embd, n_embd}, 0);
+                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT,      "bias", i),   {n_embd}, TENSOR_NOT_REQUIRED);
 
                         layer.attn_out_norm   = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0);
                         layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i),   {n_embd}, 0);
 
                         if (hparams.moe_every_n_layers > 0 && i % hparams.moe_every_n_layers == 1) {
-                            layer.bo         = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
                             layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {  n_embd, n_ff,   n_expert}, 0);
                             layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {  n_ff,   n_embd, n_expert}, 0);
                             layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP,   "weight", i), {n_embd, n_expert}, 0);
                         } else {
-                            layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,        "weight", i), {n_embd, n_ff}, 0);
-                            layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN,      "weight", i), {n_ff, n_embd}, 0);
-
-                            if (arch == LLM_ARCH_BERT || arch == LLM_ARCH_NOMIC_BERT_MOE) {
-                                layer.bo         = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
-                                layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i), {n_ff}, 0);
-                                layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
-                            } else {
+                            layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
+                            layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff}, TENSOR_NOT_REQUIRED);
+                            layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                            layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd}, TENSOR_NOT_REQUIRED);
+
+                            if (arch == LLM_ARCH_NOMIC_BERT) {
                                 layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
                             }
                         }
@@ -7461,7 +7473,7 @@ struct llm_build_bert : public llm_graph_context {
                 }
 
                 // RoPE
-                if (model.arch == LLM_ARCH_NOMIC_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE) {
+                if (model.arch == LLM_ARCH_NOMIC_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE || model.arch == LLM_ARCH_JINA_BERT_V3) {
                     Qcur = ggml_rope_ext(
                             ctx0, Qcur, inp_pos, nullptr,
                             n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@@ -7520,7 +7532,7 @@ struct llm_build_bert : public llm_graph_context {
                         0.0f,
                         LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il);
                 cb(cur, "ffn_moe_out", il);
-            } else if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE) {
+            } else if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE || model.arch == LLM_ARCH_JINA_BERT_V3) {
                 cur = build_ffn(cur,
                         model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
                         NULL,                      NULL,                        NULL,
@@ -18241,6 +18253,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
         // switch statement
         case LLM_ARCH_BERT:
         case LLM_ARCH_JINA_BERT_V2:
+        case LLM_ARCH_JINA_BERT_V3:
         case LLM_ARCH_NOMIC_BERT:
         case LLM_ARCH_NOMIC_BERT_MOE:
         case LLM_ARCH_NEO_BERT:
@@ -18395,6 +18408,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
             } break;
         case LLM_ARCH_BERT:
         case LLM_ARCH_JINA_BERT_V2:
+        case LLM_ARCH_JINA_BERT_V3:
         case LLM_ARCH_NOMIC_BERT:
         case LLM_ARCH_NOMIC_BERT_MOE:
             {
@@ -18885,6 +18899,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_GROK:
         case LLM_ARCH_DBRX:
         case LLM_ARCH_BERT:
+        case LLM_ARCH_JINA_BERT_V3:
         case LLM_ARCH_NOMIC_BERT:
         case LLM_ARCH_NOMIC_BERT_MOE:
         case LLM_ARCH_STABLELM:
index af4460cc01eb0b231d51a5ae75a395e98c854117..fa44d800d527766c4b6e40bb5ce3e579ccad29c0 100644 (file)
@@ -40,6 +40,7 @@ enum llm_type {
     LLM_TYPE_450M,
     LLM_TYPE_475M,
     LLM_TYPE_537M,
+    LLM_TYPE_558M,
     LLM_TYPE_700M,
     LLM_TYPE_770M,
     LLM_TYPE_780M,
index de5d1681dff8544b50893f88812c89ecc8fa0e1b..ca02b63a58407951c534e4544b608ddcf7fd1169 100644 (file)
@@ -2470,7 +2470,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
         // set attributes by model/tokenizer/architecture name
         if (false
                 || _contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})
-                || _contains_any(general_arch, {"nomic-bert-moe"})
+                || _contains_any(general_arch, {"nomic-bert-moe", "jina-bert-v3"})
            ) {
             if (token_to_id.count("<mask>") == 0) {
                 LLAMA_LOG_WARN("%s: Mask token is missing in vocab, please reconvert model!\n", __func__);
index 6eb5aeb582b3a9cf03ef2affb6f018af43ae2029..6aa319d2f112157a0e99de5ad999ea87874b9219 100644 (file)
@@ -4898,6 +4898,8 @@ int main(int argc, char ** argv) {
                 {"id", i},
                 {"path", lora.path},
                 {"scale", lora.scale},
+                {"task_name", lora.task_name},
+                {"prompt_prefix", lora.prompt_prefix},
             });
         }
         res_ok(res, result);