]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : add Jina Embeddings architecture (#6826)
authorJoan Fontanals <redacted>
Sat, 11 May 2024 07:46:09 +0000 (09:46 +0200)
committerGitHub <redacted>
Sat, 11 May 2024 07:46:09 +0000 (10:46 +0300)
* feat: first things to do

* feat: create tensors for Jina architecture

* fix: use other tensors

* feat: embedding gets results

* fix: fix usage of ALIBI

* fix: clean prints

* fix: do some cleanup unused vars

* fix: revert changes to Makefile and CMakeLists

* fix: revert some changes

* fix: fix small detail

* fix: fix convert formatting

* fix: fix linting and editor

* feat: set proper vocab settings

* fix: JinaBertForMaskedLM registration

* feat: support q_normalization and k_normalization in Jina arch

* feat: handle gpt2 tokenizer with Jina architecture

* feat: example comments in embedding

* feat: rename Jina Bert to Jina Bert V2

* fix: add some changes as per review

* feat: proper KQ_pos for Jina embeddings

* feat: add capacity to load models ES and DE for Spanish

* llama : fix pre-tokenizers

* ggml : full ALiBi support

* ggml : update ggml_soft_max_ext() CUDA, SYCL

* ggml : ggml_flash_attn_ext() support ALiBi (CPU)

* ggml : ggml_flash_attn_ext() support ALiBi (Metal)

* ggml : fix warning

* ggml : ggml_flash_attn_ext() support ALiBi (CUDA)

ggml-ci

* minor : clean-up

* embedding : add warning about missing SEP

---------

Co-authored-by: Georgi Gerganov <redacted>
convert-hf-to-gguf-update.py
convert-hf-to-gguf.py
examples/embedding/embedding.cpp
gguf-py/gguf/constants.py
gguf-py/gguf/tensor_mapping.py
llama.cpp

index b5eb41eacdab74830c3102141446cd17fcccfe08..e757d5ccbc0b4f86dd09adaf93d91ac46320a9bc 100755 (executable)
@@ -74,6 +74,9 @@ models = [
     {"name": "qwen2",          "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen1.5-7B", },
     {"name": "olmo",           "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/allenai/OLMo-1.7-7B-hf", },
     {"name": "dbrx",           "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/databricks/dbrx-base", },
+    {"name": "jina-en",        "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-en", }, # WPM!
+    {"name": "jina-es",        "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-es", },
+    {"name": "jina-de",        "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-de", },
 ]
 
 # make directory "models/tokenizers" if it doesn't exist
index 3315ca74b044be3a491b3f3e40f5194eed1fe5a7..fbaed64da1cac4d12569f35701357df927e2078d 100755 (executable)
@@ -404,8 +404,17 @@ class Model:
             # ref: https://huggingface.co/allenai/OLMo-1.7-7B-hf
             res = "olmo"
         if chkhsh == "a8594e3edff7c29c003940395316294b2c623e09894deebbc65f33f1515df79e":
-            # ref: https://huggingface.co/databricks/dbrx-instruct
+            # ref: https://huggingface.co/databricks/dbrx-base
             res = "dbrx"
+        if chkhsh == "0876d13b50744004aa9aeae05e7b0647eac9d801b5ba4668afc01e709c15e19f":
+            # ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-en
+            res = "jina-en"
+        if chkhsh == "171aeeedd6fb548d418a7461d053f11b6f1f1fc9b387bd66640d28a4b9f5c643":
+            # ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-es
+            res = "jina-es"
+        if chkhsh == "27949a2493fc4a9f53f5b9b029c82689cfbe5d3a1929bb25e043089e28466de6":
+            # ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-de
+            res = "jina-de"
 
         if res is None:
             logger.warning("\n")
@@ -2289,6 +2298,43 @@ class OlmoModel(Model):
         return [(self.map_tensor_name(name), data_torch)]
 
 
+@Model.register("JinaBertModel", "JinaBertForMaskedLM")
+class JinaBertV2Model(BertModel):
+    model_arch = gguf.MODEL_ARCH.JINA_BERT_V2
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.intermediate_size = self.hparams["intermediate_size"]
+
+    def get_tensors(self):
+        for name, data in super().get_tensors():
+            if 'gated_layers' in name:
+                d1 = data[:self.intermediate_size, :]
+                name1 = name.replace('gated_layers', 'gated_layers_w')
+                d2 = data[self.intermediate_size:, :]
+                name2 = name.replace('gated_layers', 'gated_layers_v')
+                yield name1, d1
+                yield name2, d2
+                continue
+
+            yield name, data
+
+    def set_vocab(self, *args, **kwargs):
+        tokenizer_class = 'BertTokenizer'
+        with open(self.dir_model / "tokenizer_config.json", "r", encoding="utf-8") as f:
+            tokenizer_class = json.load(f)['tokenizer_class']
+
+        if tokenizer_class == 'BertTokenizer':
+            super().set_vocab()
+        elif tokenizer_class == 'RobertaTokenizer':
+            self._set_vocab_gpt2()
+            self.gguf_writer.add_token_type_count(2)
+        else:
+            raise NotImplementedError(f'Tokenizer {tokenizer_class} is not supported for JinaBertModel')
+        self.gguf_writer.add_add_bos_token(True)
+        self.gguf_writer.add_add_eos_token(True)
+
+
 ###### CONVERSION LOGIC ######
 
 
index 6a93147d70e88959f66766bcadd88a08578ead6d..c85a2da53d129fd6402f433e2a7623c4bd31a4dd 100644 (file)
@@ -49,6 +49,12 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
         }
 
         float * out = output + batch.seq_id[i][0] * n_embd;
+        //TODO: I would also add a parameter here to enable normalization or not.
+        /*fprintf(stdout, "unnormalized_embedding:");
+        for (int hh = 0; hh < n_embd; hh++) {
+            fprintf(stdout, "%9.6f ", embd[hh]);
+        }
+        fprintf(stdout, "\n");*/
         llama_embd_normalize(embd, out, n_embd);
     }
 }
@@ -123,10 +129,12 @@ int main(int argc, char ** argv) {
         inputs.push_back(inp);
     }
 
-    // add SEP if not present
+    // check if the last token is SEP
+    // it should be automatically added by the tokenizer when 'tokenizer.ggml.add_eos_token' is set to 'true'
     for (auto & inp : inputs) {
         if (inp.empty() || inp.back() != llama_token_sep(model)) {
-            inp.push_back(llama_token_sep(model));
+            fprintf(stderr, "%s: warning: last token in the prompt is not SEP\n", __func__);
+            fprintf(stderr, "%s:          'tokenizer.ggml.add_eos_token' should be set to 'true' in the GGUF header\n", __func__);
         }
     }
 
index 5951c0bb0fb5ead7949b6d8008052fa83e015ba1..a4fbfc5e09d0632840c42fad1af76609b300069a 100644 (file)
@@ -118,6 +118,7 @@ class MODEL_ARCH(IntEnum):
     REFACT     = auto()
     BERT       = auto()
     NOMIC_BERT = auto()
+    JINA_BERT_V2 = auto()
     BLOOM      = auto()
     STABLELM   = auto()
     QWEN       = auto()
@@ -195,6 +196,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.REFACT:         "refact",
     MODEL_ARCH.BERT:           "bert",
     MODEL_ARCH.NOMIC_BERT:     "nomic-bert",
+    MODEL_ARCH.JINA_BERT_V2:   "jina-bert-v2",
     MODEL_ARCH.BLOOM:          "bloom",
     MODEL_ARCH.STABLELM:       "stablelm",
     MODEL_ARCH.QWEN:           "qwen",
@@ -380,6 +382,22 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_UP,
         MODEL_TENSOR.LAYER_OUT_NORM,
     ],
+    MODEL_ARCH.JINA_BERT_V2: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.TOKEN_EMBD_NORM,
+        MODEL_TENSOR.TOKEN_TYPES,
+        MODEL_TENSOR.ATTN_OUT_NORM,
+        MODEL_TENSOR.ATTN_Q,
+        MODEL_TENSOR.ATTN_Q_NORM,
+        MODEL_TENSOR.ATTN_K,
+        MODEL_TENSOR.ATTN_K_NORM,
+        MODEL_TENSOR.ATTN_V,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.FFN_UP,
+        MODEL_TENSOR.FFN_GATE,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.LAYER_OUT_NORM,
+    ],
     MODEL_ARCH.MPT: [
         MODEL_TENSOR.TOKEN_EMBD,
         MODEL_TENSOR.OUTPUT_NORM,
index 990fe63c2acd19cddd5c4cff6fa60b62df7c0153..8e1cac9152f55ea2491e9dc31d3c7b78a53c042d 100644 (file)
@@ -243,6 +243,7 @@ class TensorNameMap:
             "model.layers.{bid}.feed_forward.w3",                     # internlm2
             "encoder.layers.{bid}.mlp.fc11",                          # nomic-bert
             "model.layers.{bid}.mlp.c_fc",                            # starcoder2
+            "encoder.layer.{bid}.mlp.gated_layers_v",                 # jina-bert-v2
         ),
 
         MODEL_TENSOR.FFN_UP_EXP: (
@@ -269,6 +270,7 @@ class TensorNameMap:
             "model.layers.layers.{bid}.mlp.gate_proj",    # plamo
             "model.layers.{bid}.feed_forward.w1",         # internlm2
             "encoder.layers.{bid}.mlp.fc12",              # nomic-bert
+            "encoder.layer.{bid}.mlp.gated_layers_w",     # jina-bert-v2
             "transformer.h.{bid}.mlp.linear_1",           # refact
         ),
 
@@ -303,6 +305,7 @@ class TensorNameMap:
             "model.layers.{bid}.feed_forward.w2",                     # internlm2
             "encoder.layers.{bid}.mlp.fc2",                           # nomic-bert
             "model.layers.{bid}.mlp.c_proj",                          # starcoder2
+            "encoder.layer.{bid}.mlp.wo",                             # jina-bert-v2
         ),
 
         MODEL_TENSOR.FFN_DOWN_EXP: (
@@ -321,6 +324,7 @@ class TensorNameMap:
             "model.layers.{bid}.self_attn.q_layernorm",                       # persimmon
             "model.layers.{bid}.self_attn.q_norm",                            # cohere
             "transformer.blocks.{bid}.attn.q_ln",                             # sea-lion
+            "encoder.layer.{bid}.attention.self.layer_norm_q"                 # jina-bert-v2
         ),
 
         MODEL_TENSOR.ATTN_K_NORM: (
@@ -328,6 +332,7 @@ class TensorNameMap:
             "model.layers.{bid}.self_attn.k_layernorm",                       # persimmon
             "model.layers.{bid}.self_attn.k_norm",                            # cohere
             "transformer.blocks.{bid}.attn.k_ln",                             # sea-lion
+            "encoder.layer.{bid}.attention.self.layer_norm_k"                 # jina-bert-v2
         ),
 
         MODEL_TENSOR.ROPE_FREQS: (
@@ -338,6 +343,7 @@ class TensorNameMap:
             "encoder.layer.{bid}.output.LayerNorm",         # bert
             "encoder.layers.{bid}.norm2",                   # nomic-bert
             "transformer.decoder_layer.{bid}.rms_norm_3",   # Grok
+            "encoder.layer.{bid}.mlp.layernorm",            # jina-bert-v2
         ),
 
         MODEL_TENSOR.SSM_IN: (
index dede68cb5a3f9bf4c55897b9fb21ccd9fba15c11..cdff28cdaa7737b2bda610400d0d8b1ae20f5b27 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -205,6 +205,7 @@ enum llm_arch {
     LLM_ARCH_REFACT,
     LLM_ARCH_BERT,
     LLM_ARCH_NOMIC_BERT,
+    LLM_ARCH_JINA_BERT_V2,
     LLM_ARCH_BLOOM,
     LLM_ARCH_STABLELM,
     LLM_ARCH_QWEN,
@@ -228,39 +229,40 @@ enum llm_arch {
 };
 
 static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
-    { LLM_ARCH_LLAMA,           "llama"      },
-    { LLM_ARCH_FALCON,          "falcon"     },
-    { LLM_ARCH_GROK,            "grok"       },
-    { LLM_ARCH_GPT2,            "gpt2"       },
-    { LLM_ARCH_GPTJ,            "gptj"       },
-    { LLM_ARCH_GPTNEOX,         "gptneox"    },
-    { LLM_ARCH_MPT,             "mpt"        },
-    { LLM_ARCH_BAICHUAN,        "baichuan"   },
-    { LLM_ARCH_STARCODER,       "starcoder"  },
-    { LLM_ARCH_PERSIMMON,       "persimmon"  },
-    { LLM_ARCH_REFACT,          "refact"     },
-    { LLM_ARCH_BERT,            "bert"       },
-    { LLM_ARCH_NOMIC_BERT,      "nomic-bert" },
-    { LLM_ARCH_BLOOM,           "bloom"      },
-    { LLM_ARCH_STABLELM,        "stablelm"   },
-    { LLM_ARCH_QWEN,            "qwen"       },
-    { LLM_ARCH_QWEN2,           "qwen2"      },
-    { LLM_ARCH_QWEN2MOE,        "qwen2moe"   },
-    { LLM_ARCH_PHI2,            "phi2"       },
-    { LLM_ARCH_PHI3,            "phi3"       },
-    { LLM_ARCH_PLAMO,           "plamo"      },
-    { LLM_ARCH_CODESHELL,       "codeshell"  },
-    { LLM_ARCH_ORION,           "orion"      },
-    { LLM_ARCH_INTERNLM2,       "internlm2"  },
-    { LLM_ARCH_MINICPM,         "minicpm"    },
-    { LLM_ARCH_GEMMA,           "gemma"      },
-    { LLM_ARCH_STARCODER2,      "starcoder2" },
-    { LLM_ARCH_MAMBA,           "mamba"      },
-    { LLM_ARCH_XVERSE,          "xverse"     },
-    { LLM_ARCH_COMMAND_R,       "command-r"  },
-    { LLM_ARCH_DBRX,            "dbrx"       },
-    { LLM_ARCH_OLMO,            "olmo"       },
-    { LLM_ARCH_UNKNOWN,         "(unknown)"  },
+    { LLM_ARCH_LLAMA,           "llama"        },
+    { LLM_ARCH_FALCON,          "falcon"       },
+    { LLM_ARCH_GROK,            "grok"         },
+    { LLM_ARCH_GPT2,            "gpt2"         },
+    { LLM_ARCH_GPTJ,            "gptj"         },
+    { LLM_ARCH_GPTNEOX,         "gptneox"      },
+    { LLM_ARCH_MPT,             "mpt"          },
+    { LLM_ARCH_BAICHUAN,        "baichuan"     },
+    { LLM_ARCH_STARCODER,       "starcoder"    },
+    { LLM_ARCH_PERSIMMON,       "persimmon"    },
+    { LLM_ARCH_REFACT,          "refact"       },
+    { LLM_ARCH_BERT,            "bert"         },
+    { LLM_ARCH_NOMIC_BERT,      "nomic-bert"   },
+    { LLM_ARCH_JINA_BERT_V2,    "jina-bert-v2" },
+    { LLM_ARCH_BLOOM,           "bloom"        },
+    { LLM_ARCH_STABLELM,        "stablelm"     },
+    { LLM_ARCH_QWEN,            "qwen"         },
+    { LLM_ARCH_QWEN2,           "qwen2"        },
+    { LLM_ARCH_QWEN2MOE,        "qwen2moe"     },
+    { LLM_ARCH_PHI2,            "phi2"         },
+    { LLM_ARCH_PHI3,            "phi3"         },
+    { LLM_ARCH_PLAMO,           "plamo"        },
+    { LLM_ARCH_CODESHELL,       "codeshell"    },
+    { LLM_ARCH_ORION,           "orion"        },
+    { LLM_ARCH_INTERNLM2,       "internlm2"    },
+    { LLM_ARCH_MINICPM,         "minicpm"      },
+    { LLM_ARCH_GEMMA,           "gemma"        },
+    { LLM_ARCH_STARCODER2,      "starcoder2"   },
+    { LLM_ARCH_MAMBA,           "mamba"        },
+    { LLM_ARCH_XVERSE,          "xverse"       },
+    { LLM_ARCH_COMMAND_R,       "command-r"    },
+    { LLM_ARCH_DBRX,            "dbrx"         },
+    { LLM_ARCH_OLMO,            "olmo"         },
+    { LLM_ARCH_UNKNOWN,         "(unknown)"    },
 };
 
 enum llm_kv {
@@ -691,6 +693,25 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
             { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
         },
     },
+    {
+        LLM_ARCH_JINA_BERT_V2,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
+            { LLM_TENSOR_TOKEN_TYPES,     "token_types" },
+            { LLM_TENSOR_ATTN_OUT_NORM,   "blk.%d.attn_output_norm" },
+            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_Q_NORM,     "blk.%d.attn_q_norm" },
+            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_K_NORM,     "blk.%d.attn_k_norm" },
+            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_LAYER_OUT_NORM,  "blk.%d.layer_output_norm" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+        },
+    },
     {
         LLM_ARCH_BLOOM,
         {
@@ -3778,6 +3799,12 @@ static void llm_load_hparams(
 
     // get hparams kv
     ml.get_key(LLM_KV_VOCAB_SIZE,           hparams.n_vocab,       false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab);
+
+    // everything past this point is not vocab-related
+    if (hparams.vocab_only) {
+        return;
+    }
+
     ml.get_key(LLM_KV_CONTEXT_LENGTH,       hparams.n_ctx_train);
     ml.get_key(LLM_KV_EMBEDDING_LENGTH,     hparams.n_embd);
     ml.get_key(LLM_KV_FEED_FORWARD_LENGTH,  hparams.n_ff);
@@ -3961,6 +3988,19 @@ static void llm_load_hparams(
                         model.type = e_model::MODEL_335M; break; // bge-large
                 }
             } break;
+        case LLM_ARCH_JINA_BERT_V2:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS,    hparams.f_norm_eps);
+                ml.get_key(LLM_KV_ATTENTION_CAUSAL,           hparams.causal_attn);
+                ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type);
+                ml.get_key(LLM_KV_POOLING_TYPE,               hparams.pooling_type);
+                hparams.f_max_alibi_bias = 8.0f;
+
+                switch (hparams.n_layer) {
+                    case 4: model.type = e_model::MODEL_33M; break; // jina-embeddings-small
+                    case 12: model.type = e_model::MODEL_137M; break; // jina-embeddings-base
+                }
+            } break;
         case LLM_ARCH_NOMIC_BERT:
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS,    hparams.f_norm_eps);
@@ -4382,7 +4422,9 @@ static void llm_load_vocab(
                     tokenizer_pre == "starcoder") {
                 vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_STARCODER;
             } else if (
-                    tokenizer_pre == "gpt-2") {
+                    tokenizer_pre == "gpt-2"   ||
+                    tokenizer_pre == "jina-es" ||
+                    tokenizer_pre == "jina-de") {
                 vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_GPT2;
             } else if (
                     tokenizer_pre == "refact") {
@@ -5241,6 +5283,50 @@ static bool llm_load_tensors(
                         layer.layer_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i),   {n_embd});
                     }
                 } break;
+            case LLM_ARCH_JINA_BERT_V2:
+                {
+                    model.tok_embd     = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}); // word_embeddings
+                    model.type_embd    = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}); //token_type_embeddings
+                    model.tok_norm   = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); // LayerNorm
+                    model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"),   {n_embd}); //LayerNorm bias
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        ggml_context * ctx_layer = ctx_for_layer(i);
+                        ggml_context * ctx_split = ctx_for_layer_split(i);
+
+                        auto & layer = model.layers[i]; // JinaBertLayer
+
+                        layer.wq   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
+                        layer.bq   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q,   "bias", i),   {n_embd});
+
+                        layer.attn_q_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, false);
+                        layer.attn_q_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, false);
+
+                        layer.wk   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
+                        layer.bk   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K,   "bias", i),   {n_embd_gqa});
+
+                        layer.attn_k_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, false);
+                        layer.attn_k_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, false);
+
+                        layer.wv   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
+                        layer.bv   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V,   "bias", i),   {n_embd_gqa});
+
+                        layer.wo              = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT,      "weight", i), {n_embd, n_embd}); //output_dens
+                        layer.bo              = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT,      "bias", i), {n_embd}); //output_dens
+
+                        layer.attn_out_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}); //output_norm
+                        layer.attn_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i),   {n_embd});
+
+                        layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,        "weight", i), {n_embd, n_ff});
+                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE,    "weight", i), {n_embd, n_ff});
+
+                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN,        "weight", i), {n_ff, n_embd});
+                        layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN,      "bias", i), {n_embd});
+
+                        layer.layer_out_norm = ml.create_tensor(ctx_split, tn(LLM_TENSOR_LAYER_OUT_NORM,        "weight", i), {n_embd});
+                        layer.layer_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM,        "bias", i), {n_embd});
+                    }
+                } break;
             case LLM_ARCH_BLOOM:
                 {
                     model.tok_embd   = ml.create_tensor(ctx_input,  tn(LLM_TENSOR_TOKEN_EMBD,      "weight"), {n_embd, n_vocab});
@@ -6317,7 +6403,7 @@ static struct ggml_tensor * llm_build_ffn(
           llm_ffn_gate_type   type_gate,
          const llm_build_cb & cb,
                         int   il) {
-    struct ggml_tensor * tmp = ggml_mul_mat(ctx, up, cur);
+    struct ggml_tensor * tmp = up ? ggml_mul_mat(ctx, up, cur) : cur;
     cb(tmp, "ffn_up", il);
 
     if (up_b) {
@@ -8118,8 +8204,11 @@ struct llm_build_context {
 
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
+        struct ggml_tensor * inp_pos = nullptr;
 
-        struct ggml_tensor * inp_pos  = build_inp_pos();
+        if (model.arch != LLM_ARCH_JINA_BERT_V2) {
+            inp_pos = build_inp_pos();
+        }
         struct ggml_tensor * inp_mean = build_inp_mean();
         struct ggml_tensor * inp_cls  = build_inp_cls();
 
@@ -8150,13 +8239,26 @@ struct llm_build_context {
             struct ggml_tensor * Vcur;
 
             // self-attention
-            if (model.arch == LLM_ARCH_BERT) {
+            if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_JINA_BERT_V2) {
                 Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), model.layers[il].bq);
                 cb(Qcur, "Qcur", il);
 
+                if (model.layers[il].attn_q_norm) {
+                    Qcur = llm_build_norm(ctx0, Qcur, hparams,
+                            model.layers[il].attn_q_norm,
+                            model.layers[il].attn_q_norm_b,
+                            LLM_NORM, cb, il);
+                }
+
                 Kcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), model.layers[il].bk);
                 cb(Kcur, "Kcur", il);
 
+                if (model.layers[il].attn_k_norm) {
+                    Kcur = llm_build_norm(ctx0, Kcur, hparams,
+                            model.layers[il].attn_k_norm,
+                            model.layers[il].attn_k_norm_b,
+                            LLM_NORM, cb, il);
+                }
                 Vcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), model.layers[il].bv);
                 cb(Vcur, "Vcur", il);
 
@@ -8247,6 +8349,13 @@ struct llm_build_context {
                         model.layers[il].ffn_down, model.layers[il].ffn_down_b,
                         NULL,
                         LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
+            } else if (model.arch == LLM_ARCH_JINA_BERT_V2) {
+                cur = llm_build_ffn(ctx0, cur,
+                        model.layers[il].ffn_up,   NULL,
+                        model.layers[il].ffn_gate, NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+                        NULL,
+                        LLM_FFN_GELU, LLM_FFN_PAR, cb, il);
             } else {
                 cur = llm_build_ffn(ctx0, cur,
                         model.layers[il].ffn_up,   NULL,
@@ -10769,6 +10878,7 @@ static struct ggml_cgraph * llama_build_graph(
                 result = llm.build_refact();
             } break;
         case LLM_ARCH_BERT:
+        case LLM_ARCH_JINA_BERT_V2:
         case LLM_ARCH_NOMIC_BERT:
             {
                 result = llm.build_bert();
@@ -12695,7 +12805,10 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
                     }
                 }
 
-                GGML_ASSERT(vocab.special_add_eos != 1);
+                if (add_special && vocab.special_add_eos == 1) {
+                    GGML_ASSERT(vocab.special_add_eos != -1);
+                    output.push_back(vocab.special_eos_id);
+                }
             } break;
         case LLAMA_VOCAB_TYPE_WPM:
             {
@@ -15746,6 +15859,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_REFACT:
         case LLM_ARCH_BLOOM:
         case LLM_ARCH_MAMBA:
+        case LLM_ARCH_JINA_BERT_V2:
             return LLAMA_ROPE_TYPE_NONE;
 
         // use what we call a normal RoPE, operating on pairs of consecutive head values