]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Add support for BERT embedding models (#5423)
authorDouglas Hanley <redacted>
Sun, 11 Feb 2024 16:21:38 +0000 (10:21 -0600)
committerGitHub <redacted>
Sun, 11 Feb 2024 16:21:38 +0000 (11:21 -0500)
* BERT model graph construction (build_bert)
* WordPiece tokenizer (llm_tokenize_wpm)
* Add flag for non-causal attention models
* Allow for models that only output embeddings
* Support conversion of BERT models to GGUF
* Based on prior work by @xyzhang626 and @skeskinen

---------

Co-authored-by: Jared Van Bortel <redacted>
Co-authored-by: Jared Van Bortel <redacted>
Co-authored-by: Georgi Gerganov <redacted>
.flake8
convert-hf-to-gguf.py
examples/embedding/embedding.cpp
gguf-py/gguf/constants.py
gguf-py/gguf/gguf_writer.py
gguf-py/gguf/tensor_mapping.py
llama.cpp
llama.h

diff --git a/.flake8 b/.flake8
index 113ca5fd3cb17dec32f0a3b42874a4b02120d5b8..18fba2c1574a6952001ea1003bd0bb96ce7c53ea 100644 (file)
--- a/.flake8
+++ b/.flake8
@@ -1,2 +1,3 @@
 [flake8]
 max-line-length = 125
+ignore = W503
index 0d4ea03b44051c731d7cfc12df79e0c4a09234db..cae1551a236b06798768224aab3de9da3c88a3b5 100755 (executable)
@@ -209,6 +209,8 @@ class Model:
             return InternLM2Model
         if model_architecture == "MiniCPMForCausalLM":
             return MiniCPMModel
+        if model_architecture == "BertModel":
+            return BertModel
         return Model
 
     def _is_model_safetensors(self) -> bool:
@@ -264,6 +266,8 @@ class Model:
             return gguf.MODEL_ARCH.INTERNLM2
         if arch == "MiniCPMForCausalLM":
             return gguf.MODEL_ARCH.MINICPM
+        if arch == "BertModel":
+            return gguf.MODEL_ARCH.BERT
 
         raise NotImplementedError(f'Architecture "{arch}" not supported!')
 
@@ -1629,6 +1633,96 @@ in chat mode so that the conversation can end normally.")
                 self.post_write_tensors(tensor_map, name, data_torch)
 
 
+class BertModel(Model):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.block_count = self.hparams["num_hidden_layers"]
+
+    def set_gguf_parameters(self):
+        # TODO(cebtenzzre): merge with parent class
+        self.gguf_writer.add_name(self.dir_model.name)
+        self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
+        self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
+        self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
+        self.gguf_writer.add_block_count(self.block_count)
+        self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
+        self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
+        self.gguf_writer.add_causal_attention(False)
+        self.gguf_writer.add_file_type(self.ftype)
+
+    def set_vocab(self):
+        path = self.dir_model
+        added_tokens_path = self.dir_model if self.dir_model.exists() else None
+
+        # use huggingface vocab to get all tokens
+        vocab = HfVocab(path, added_tokens_path)
+        tokens, scores, toktypes = zip(*vocab.all_tokens())
+        assert len(tokens) == vocab.vocab_size
+
+        # we need this to validate the size of the token_type embeddings
+        # though currently we are passing all zeros to the token_type embeddings
+        n_token_types = len(set(toktypes))
+        self.gguf_writer.add_token_type_count(n_token_types)
+
+        # convert to phantom space vocab
+        def phantom(tok, typ):
+            if tok.startswith(b"[") and tok.endswith(b"]"):
+                return tok
+            if tok.startswith(b"##"):
+                return tok[2:]
+            return b"\xe2\x96\x81" + tok
+        tokens = [phantom(t, y) for t, y in zip(tokens, toktypes)]
+
+        # set up bos and eos tokens (cls and sep)
+        self.gguf_writer.add_bos_token_id(vocab.tokenizer.cls_token_id)
+        self.gguf_writer.add_eos_token_id(vocab.tokenizer.sep_token_id)
+
+        # add vocab to gguf
+        self.gguf_writer.add_tokenizer_model("bert")
+        self.gguf_writer.add_token_list(tokens)
+        self.gguf_writer.add_token_scores(scores)
+        self.gguf_writer.add_token_types(toktypes)
+
+        # handle special tokens
+        special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
+        special_vocab.add_to_gguf(self.gguf_writer)
+
+    def write_tensors(self):
+        tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
+        tensors = dict(self.get_tensors())
+        for name, data_torch in tensors.items():
+            # we are only using BERT for embeddings so we don't need the pooling layer
+            if name in ("embeddings.position_ids", "pooler.dense.weight", "pooler.dense.bias"):
+                continue  # we don't need these
+
+            # map tensor names
+            new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
+            if new_name is None:
+                print(f"Can not map tensor {name!r}")
+                sys.exit()
+
+            data = data_torch.squeeze().numpy()
+            n_dims = len(data.shape)
+            new_dtype: type[np.floating[Any]]
+
+            if (
+                self.ftype == 1 and name.endswith(".weight") and n_dims == 2
+                and name != "embeddings.token_type_embeddings.weight"  # not used with get_rows, must be F32
+            ):
+                # if f16 desired, convert any float32 2-dim weight tensors to float16
+                new_dtype = np.float16
+            else:
+                # if f32 desired, convert any float16 to float32
+                new_dtype = np.float32
+
+            print(f"{new_name}, n_dims = {n_dims}, {data_torch.dtype} --> {new_dtype}")
+
+            if data.dtype != new_dtype:
+                data = data.astype(new_dtype)
+
+            self.gguf_writer.add_tensor(new_name, data)
+
+
 ###### CONVERSION LOGIC ######
 
 
index 3295cd2400ac3d25280dc61f450b693c1c4bfb20..27376c8f09fdcc6624c34ce1e40a36c02437e12f 100644 (file)
@@ -87,7 +87,17 @@ int main(int argc, char ** argv) {
     }
 
     const int n_embd = llama_n_embd(model);
-    const auto * embeddings = llama_get_embeddings(ctx);
+    auto * embeddings = llama_get_embeddings(ctx);
+
+    // l2-normalize embeddings
+    float norm = 0;
+    for (int i = 0; i < n_embd; i++) {
+        norm += embeddings[i] * embeddings[i];
+    }
+    norm = sqrt(norm);
+    for (int i = 0; i < n_embd; i++) {
+        embeddings[i] /= norm;
+    }
 
     for (int i = 0; i < n_embd; i++) {
         printf("%f ", embeddings[i]);
index 1cfd41c0be1fef408997cb9c51dfc63c5be23a6f..a9c13dd3826b872adf7cf8d7fac11e01cd1fa965 100644 (file)
@@ -50,6 +50,7 @@ class Keys:
         VALUE_LENGTH      = "{arch}.attention.value_length"
         LAYERNORM_EPS     = "{arch}.attention.layer_norm_epsilon"
         LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
+        CAUSAL            = "{arch}.attention.causal"
 
     class Rope:
         DIMENSION_COUNT      = "{arch}.rope.dimension_count"
@@ -60,22 +61,23 @@ class Keys:
         SCALING_FINETUNED    = "{arch}.rope.scaling.finetuned"
 
     class Tokenizer:
-        MODEL         = "tokenizer.ggml.model"
-        LIST          = "tokenizer.ggml.tokens"
-        TOKEN_TYPE    = "tokenizer.ggml.token_type"
-        SCORES        = "tokenizer.ggml.scores"
-        MERGES        = "tokenizer.ggml.merges"
-        BOS_ID        = "tokenizer.ggml.bos_token_id"
-        EOS_ID        = "tokenizer.ggml.eos_token_id"
-        UNK_ID        = "tokenizer.ggml.unknown_token_id"
-        SEP_ID        = "tokenizer.ggml.seperator_token_id"
-        PAD_ID        = "tokenizer.ggml.padding_token_id"
-        ADD_BOS       = "tokenizer.ggml.add_bos_token"
-        ADD_EOS       = "tokenizer.ggml.add_eos_token"
-        ADD_PREFIX    = "tokenizer.ggml.add_space_prefix"
-        HF_JSON       = "tokenizer.huggingface.json"
-        RWKV          = "tokenizer.rwkv.world"
-        CHAT_TEMPLATE = "tokenizer.chat_template"
+        MODEL            = "tokenizer.ggml.model"
+        LIST             = "tokenizer.ggml.tokens"
+        TOKEN_TYPE       = "tokenizer.ggml.token_type"
+        TOKEN_TYPE_COUNT = "tokenizer.ggml.token_type_count"  # for BERT-style token types
+        SCORES           = "tokenizer.ggml.scores"
+        MERGES           = "tokenizer.ggml.merges"
+        BOS_ID           = "tokenizer.ggml.bos_token_id"
+        EOS_ID           = "tokenizer.ggml.eos_token_id"
+        UNK_ID           = "tokenizer.ggml.unknown_token_id"
+        SEP_ID           = "tokenizer.ggml.seperator_token_id"
+        PAD_ID           = "tokenizer.ggml.padding_token_id"
+        ADD_BOS          = "tokenizer.ggml.add_bos_token"
+        ADD_EOS          = "tokenizer.ggml.add_eos_token"
+        ADD_PREFIX       = "tokenizer.ggml.add_space_prefix"
+        HF_JSON          = "tokenizer.huggingface.json"
+        RWKV             = "tokenizer.rwkv.world"
+        CHAT_TEMPLATE    = "tokenizer.chat_template"
 
 
 #
@@ -122,6 +124,7 @@ class MODEL_TENSOR(IntEnum):
     ATTN_OUT        = auto()
     ATTN_NORM       = auto()
     ATTN_NORM_2     = auto()
+    ATTN_OUT_NORM   = auto()
     ATTN_ROT_EMBD   = auto()
     FFN_GATE_INP    = auto()
     FFN_NORM        = auto()
@@ -134,6 +137,7 @@ class MODEL_TENSOR(IntEnum):
     FFN_UP_EXP      = auto()
     ATTN_Q_NORM     = auto()
     ATTN_K_NORM     = auto()
+    LAYER_OUT_NORM  = auto()
 
 
 MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
@@ -178,6 +182,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
     MODEL_TENSOR.ATTN_ROT_EMBD:   "blk.{bid}.attn_rot_embd",
     MODEL_TENSOR.ATTN_Q_NORM:     "blk.{bid}.attn_q_norm",
     MODEL_TENSOR.ATTN_K_NORM:     "blk.{bid}.attn_k_norm",
+    MODEL_TENSOR.ATTN_OUT_NORM:   "blk.{bid}.attn_output_norm",
     MODEL_TENSOR.FFN_GATE_INP:    "blk.{bid}.ffn_gate_inp",
     MODEL_TENSOR.FFN_NORM:        "blk.{bid}.ffn_norm",
     MODEL_TENSOR.FFN_GATE:        "blk.{bid}.ffn_gate",
@@ -187,6 +192,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
     MODEL_TENSOR.FFN_GATE_EXP:    "blk.{bid}.ffn_gate.{xid}",
     MODEL_TENSOR.FFN_DOWN_EXP:    "blk.{bid}.ffn_down.{xid}",
     MODEL_TENSOR.FFN_UP_EXP:      "blk.{bid}.ffn_up.{xid}",
+    MODEL_TENSOR.LAYER_OUT_NORM:  "blk.{bid}.layer_output_norm",
 }
 
 MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
@@ -262,17 +268,18 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
     ],
     MODEL_ARCH.BERT: [
         MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.TOKEN_EMBD_NORM,
         MODEL_TENSOR.TOKEN_TYPES,
         MODEL_TENSOR.POS_EMBD,
         MODEL_TENSOR.OUTPUT_NORM,
-        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_OUT_NORM,
         MODEL_TENSOR.ATTN_Q,
         MODEL_TENSOR.ATTN_K,
         MODEL_TENSOR.ATTN_V,
         MODEL_TENSOR.ATTN_OUT,
-        MODEL_TENSOR.FFN_NORM,
         MODEL_TENSOR.FFN_DOWN,
         MODEL_TENSOR.FFN_UP,
+        MODEL_TENSOR.LAYER_OUT_NORM,
     ],
     MODEL_ARCH.MPT: [
         MODEL_TENSOR.TOKEN_EMBD,
index 16808196e769d11ee9041bdf5f6fca6f7f6a5ba9..7af58a46c2cb73332d8301e0ecab1c642582a154 100644 (file)
@@ -357,6 +357,9 @@ class GGUFWriter:
     def add_layer_norm_rms_eps(self, value: float) -> None:
         self.add_float32(Keys.Attention.LAYERNORM_RMS_EPS.format(arch=self.arch), value)
 
+    def add_causal_attention(self, value: bool) -> None:
+        self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value)
+
     def add_rope_dimension_count(self, count: int) -> None:
         self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count)
 
@@ -387,6 +390,9 @@ class GGUFWriter:
     def add_token_types(self, types: Sequence[TokenType] | Sequence[int]) -> None:
         self.add_array(Keys.Tokenizer.TOKEN_TYPE, types)
 
+    def add_token_type_count(self, value: int) -> None:
+        self.add_uint32(Keys.Tokenizer.TOKEN_TYPE_COUNT, value)
+
     def add_token_scores(self, scores: Sequence[float]) -> None:
         self.add_array(Keys.Tokenizer.SCORES, scores)
 
index 4f16d85044693910bc459a78dc22e49dc882213e..c7ba1420e0453ea9630b5fa4b0508b2658c7205f 100644 (file)
@@ -30,6 +30,7 @@ class TensorNameMap:
         # Normalization of token embeddings
         MODEL_TENSOR.TOKEN_EMBD_NORM: (
             "word_embeddings_layernorm",  # bloom
+            "embeddings.LayerNorm",       # bert
         ),
 
         # Position embeddings
@@ -54,7 +55,6 @@ class TensorNameMap:
             "transformer.ln_f",                        # gpt2 gpt-j falcon
             "model.norm",                              # llama-hf baichuan internlm2
             "norm",                                    # llama-pth
-            "embeddings.LayerNorm",                    # bert
             "transformer.norm_f",                      # mpt
             "ln_f",                                    # refact bloom qwen gpt2
             "language_model.encoder.final_layernorm",  # persimmon
@@ -79,7 +79,6 @@ class TensorNameMap:
             "transformer.h.{bid}.ln_mlp",                           # falcon40b
             "model.layers.{bid}.input_layernorm",                   # llama-hf
             "layers.{bid}.attention_norm",                          # llama-pth
-            "encoder.layer.{bid}.attention.output.LayerNorm",       # bert
             "language_model.encoder.layers.{bid}.input_layernorm",  # persimmon
             "model.layers.{bid}.ln1",                               # yi
             "h.{bid}.ln_1",                                         # gpt2
@@ -155,6 +154,11 @@ class TensorNameMap:
             "model.layers.{bid}.attention.wo",                           # internlm2
         ),
 
+        # Attention output norm
+        MODEL_TENSOR.ATTN_OUT_NORM: (
+            "encoder.layer.{bid}.attention.output.LayerNorm",  # bert
+        ),
+
         # Rotary embeddings
         MODEL_TENSOR.ATTN_ROT_EMBD: (
             "model.layers.{bid}.self_attn.rotary_emb.inv_freq",        # llama-hf
@@ -171,7 +175,6 @@ class TensorNameMap:
             "transformer.blocks.{bid}.norm_2",                               # mpt
             "model.layers.{bid}.post_attention_layernorm",                   # llama-hf
             "layers.{bid}.ffn_norm",                                         # llama-pth
-            "encoder.layer.{bid}.output.LayerNorm",                          # bert
             "language_model.encoder.layers.{bid}.post_attention_layernorm",  # persimmon
             "model.layers.{bid}.ln2",                                        # yi
             "h.{bid}.ln_2",                                                  # gpt2
@@ -266,6 +269,10 @@ class TensorNameMap:
         MODEL_TENSOR.ROPE_FREQS: (
             "language_model.encoder.layers.{bid}.self_attention.rotary_emb.inv_freq",  # persimmon
         ),
+
+        MODEL_TENSOR.LAYER_OUT_NORM: (
+            "encoder.layer.{bid}.output.LayerNorm",  # bert
+        )
     }
 
     mapping: dict[str, tuple[MODEL_TENSOR, str]]
index 3f39a67fb109613269167e432d74a8c44eec6dc9..d1ee26ce2a9c89bce8bbc9531a78a70eaa70b950 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -196,6 +196,7 @@ enum llm_arch {
     LLM_ARCH_STARCODER,
     LLM_ARCH_PERSIMMON,
     LLM_ARCH_REFACT,
+    LLM_ARCH_BERT,
     LLM_ARCH_BLOOM,
     LLM_ARCH_STABLELM,
     LLM_ARCH_QWEN,
@@ -220,6 +221,7 @@ static std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_STARCODER,       "starcoder" },
     { LLM_ARCH_PERSIMMON,       "persimmon" },
     { LLM_ARCH_REFACT,          "refact"    },
+    { LLM_ARCH_BERT,            "bert"      },
     { LLM_ARCH_BLOOM,           "bloom"     },
     { LLM_ARCH_STABLELM,        "stablelm"  },
     { LLM_ARCH_QWEN,            "qwen"      },
@@ -261,6 +263,7 @@ enum llm_kv {
     LLM_KV_ATTENTION_VALUE_LENGTH,
     LLM_KV_ATTENTION_LAYERNORM_EPS,
     LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,
+    LLM_KV_ATTENTION_CAUSAL,
 
     LLM_KV_ROPE_DIMENSION_COUNT,
     LLM_KV_ROPE_FREQ_BASE,
@@ -273,6 +276,7 @@ enum llm_kv {
     LLM_KV_TOKENIZER_MODEL,
     LLM_KV_TOKENIZER_LIST,
     LLM_KV_TOKENIZER_TOKEN_TYPE,
+    LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT,
     LLM_KV_TOKENIZER_SCORES,
     LLM_KV_TOKENIZER_MERGES,
     LLM_KV_TOKENIZER_BOS_ID,
@@ -316,6 +320,7 @@ static std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_ATTENTION_VALUE_LENGTH,        "%s.attention.value_length"           },
     { LLM_KV_ATTENTION_LAYERNORM_EPS,       "%s.attention.layer_norm_epsilon"     },
     { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,   "%s.attention.layer_norm_rms_epsilon" },
+    { LLM_KV_ATTENTION_CAUSAL,              "%s.attention.causal"                 },
 
     { LLM_KV_ROPE_DIMENSION_COUNT,          "%s.rope.dimension_count"                 },
     { LLM_KV_ROPE_FREQ_BASE,                "%s.rope.freq_base"                       },
@@ -328,6 +333,7 @@ static std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_TOKENIZER_MODEL,               "tokenizer.ggml.model"              },
     { LLM_KV_TOKENIZER_LIST,                "tokenizer.ggml.tokens"             },
     { LLM_KV_TOKENIZER_TOKEN_TYPE,          "tokenizer.ggml.token_type"         },
+    { LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT,    "tokenizer.ggml.token_type_count"   },
     { LLM_KV_TOKENIZER_SCORES,              "tokenizer.ggml.scores"             },
     { LLM_KV_TOKENIZER_MERGES,              "tokenizer.ggml.merges"             },
     { LLM_KV_TOKENIZER_BOS_ID,              "tokenizer.ggml.bos_token_id"       },
@@ -355,6 +361,7 @@ struct LLM_KV {
 enum llm_tensor {
     LLM_TENSOR_TOKEN_EMBD,
     LLM_TENSOR_TOKEN_EMBD_NORM,
+    LLM_TENSOR_TOKEN_TYPES,
     LLM_TENSOR_POS_EMBD,
     LLM_TENSOR_OUTPUT,
     LLM_TENSOR_OUTPUT_NORM,
@@ -536,6 +543,23 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
             { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
         },
     },
+    {
+        LLM_ARCH_BERT,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
+            { LLM_TENSOR_TOKEN_TYPES,     "token_types" },
+            { LLM_TENSOR_POS_EMBD,        "position_embd" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_output_norm" },
+            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_FFN_NORM,        "blk.%d.layer_output_norm" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+        },
+    },
     {
         LLM_ARCH_BLOOM,
         {
@@ -1440,6 +1464,11 @@ static llama_state g_state;
 // available llama models
 enum e_model {
     MODEL_UNKNOWN,
+    MODEL_17M,
+    MODEL_22M,
+    MODEL_33M,
+    MODEL_109M,
+    MODEL_335M,
     MODEL_0_5B,
     MODEL_1B,
     MODEL_2B,
@@ -1481,6 +1510,7 @@ struct llama_hparams {
     uint32_t n_ff;
     uint32_t n_expert = 0;
     uint32_t n_expert_used = 0;
+    uint32_t n_vocab_type = 0; // for BERT-style token types
 
     float f_norm_eps;
     float f_norm_rms_eps;
@@ -1493,6 +1523,8 @@ struct llama_hparams {
     float f_clamp_kqv;
     float f_max_alibi_bias;
 
+    bool causal_attn = true;
+
 
     bool operator!=(const llama_hparams & other) const {
         if (this->vocab_only    != other.vocab_only)    return true;
@@ -1720,6 +1752,7 @@ struct llama_model {
     llama_vocab   vocab;
 
     struct ggml_tensor * tok_embd;
+    struct ggml_tensor * type_embd;
     struct ggml_tensor * pos_embd;
     struct ggml_tensor * tok_norm;
     struct ggml_tensor * tok_norm_b;
@@ -1850,6 +1883,7 @@ struct llama_context {
     struct ggml_tensor * inp_pos;       // I32 [n_batch]
     struct ggml_tensor * inp_KQ_mask;   // F32 [n_ctx, n_batch]
     struct ggml_tensor * inp_K_shift;   // I32 [n_ctx]
+    struct ggml_tensor * inp_sum;       // F32 [1, n_batch]
 
 #ifdef GGML_USE_MPI
     ggml_mpi_context * ctx_mpi = NULL;
@@ -2829,6 +2863,7 @@ static const char * llama_model_vocab_type_name(enum llama_vocab_type type){
     switch (type) {
         case LLAMA_VOCAB_TYPE_SPM:         return "SPM";
         case LLAMA_VOCAB_TYPE_BPE:         return "BPE";
+        case LLAMA_VOCAB_TYPE_WPM:         return "WPM";
         default:                           return "unknown";
     }
 }
@@ -3000,6 +3035,26 @@ static void llm_load_hparams(
                     default: model.type = e_model::MODEL_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_BERT:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
+                ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
+                ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type);
+
+                switch (hparams.n_layer) {
+                    case 3:
+                        model.type = e_model::MODEL_17M; break; // bge-micro
+                    case 6:
+                        model.type = e_model::MODEL_22M; break; // MiniLM-L6
+                    case 12:
+                        switch (hparams.n_embd) {
+                            case 384: model.type = e_model::MODEL_33M; break; // MiniLM-L12, bge-small
+                            case 768: model.type = e_model::MODEL_109M; break; // bge-base
+                        } break;
+                    case 24:
+                        model.type = e_model::MODEL_335M; break; // bge-large
+                }
+            } break;
         case LLM_ARCH_BLOOM:
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -3204,6 +3259,16 @@ static void llm_load_vocab(
             vocab.special_unk_id = -1;
             vocab.special_sep_id = -1;
             vocab.special_pad_id = -1;
+        } else if (tokenizer_name == "bert") {
+            vocab.type = LLAMA_VOCAB_TYPE_WPM;
+
+            // default special tokens
+            vocab.special_bos_id = 101;
+            vocab.special_eos_id = 102;
+            vocab.special_unk_id = 100;
+            vocab.special_sep_id = -1;
+            vocab.special_pad_id = -1;
+            vocab.add_space_prefix = false;
         } else {
             LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str());
             LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__);
@@ -3232,6 +3297,8 @@ static void llm_load_vocab(
     // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n'
     if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
         vocab.linefeed_id = llama_byte_to_token(vocab, '\n');
+    } else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) {
+        vocab.linefeed_id = vocab.special_pad_id;
     } else {
         const std::vector<int> ids = llama_tokenize_internal(vocab, "\u010A", false);
         GGML_ASSERT(!ids.empty() && "model vocab missing newline token");
@@ -3569,6 +3636,7 @@ static bool llm_load_tensors(
         const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
         const int64_t n_embd_gqa   = n_embd_v_gqa;
         const int64_t n_vocab      = hparams.n_vocab;
+        const int64_t n_vocab_type = hparams.n_vocab_type;
         const int64_t n_ff         = hparams.n_ff;
 
         GGML_ASSERT(n_embd_gqa == n_embd_k_gqa);
@@ -3783,11 +3851,50 @@ static bool llm_load_tensors(
                         layer.attn_k_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "bias",   i), {64});
                     }
                 } break;
-            case LLM_ARCH_BLOOM:
+            case LLM_ARCH_BERT:
                 {
                     model.tok_embd   = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD,      "weight"), {n_embd, n_vocab});
-                    model.tok_norm   = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd});
-                    model.tok_norm_b = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"),   {n_embd});
+                    model.type_embd  = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES,     "weight"), {n_embd, n_vocab_type});
+                    model.pos_embd   = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD,        "weight"), {n_embd, hparams.n_ctx_train});
+                    model.tok_norm   = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd});
+                    model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"),   {n_embd});
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        ggml_context * ctx_layer = ctx_for_layer(i);
+                        ggml_context * ctx_split = ctx_for_layer_split(i);
+
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd});
+
+                        layer.ffn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+                        layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd});
+
+                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
+                        layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i),   {n_embd});
+
+                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
+                        layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i),   {n_embd_gqa});
+
+                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
+                        layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i),   {n_embd_gqa});
+
+                        layer.wo   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+                        layer.bo   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd});
+
+                        layer.ffn_up     = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff});
+                        layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff});
+
+                        layer.ffn_down   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
+                        layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd});
+                    }
+                } break;
+            case LLM_ARCH_BLOOM:
+                {
+                    model.tok_embd   = ml.create_tensor(ctx_input,  tn(LLM_TENSOR_TOKEN_EMBD,      "weight"), {n_embd, n_vocab});
+                    model.tok_norm   = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd});
+                    model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"),   {n_embd});
 
                     // output
                     {
@@ -4739,6 +4846,7 @@ struct llm_build_context {
     const int32_t n_orig_ctx;
 
     const bool do_rope_shift;
+    const bool causal_attn;
 
     const llm_build_cb & cb;
 
@@ -4782,6 +4890,7 @@ struct llm_build_context {
         kv_head          (worst_case ? n_ctx - n_tokens : kv_self.head),
         n_orig_ctx       (cparams.n_yarn_orig_ctx),
         do_rope_shift    (worst_case || kv_self.has_shift),
+        causal_attn      (hparams.causal_attn),
         cb               (cb),
         buf_compute_meta (lctx.buf_compute_meta) {
             // all initializations should be done in init()
@@ -5625,6 +5734,100 @@ struct llm_build_context {
         return gf;
     }
 
+    struct ggml_cgraph * build_bert() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
+        // get input vectors with right size
+        struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
+        struct ggml_tensor * inp_sum = ggml_view_1d(ctx0, lctx.inp_sum, n_tokens, 0);
+
+        // construct input embeddings (token, type, position)
+        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
+        // token types are hardcoded to zero ("Sentence A")
+        struct ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0);
+        inpL = ggml_add(ctx0, inpL, type_row0);
+        inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.pos_embd, inp_pos), inpL);
+        cb(inpL, "inp_embd", -1);
+
+        // embed layer norm
+        inpL = llm_build_norm(ctx0, inpL, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1);
+        cb(inpL, "inp_norm", -1);
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
+        cb(KQ_mask, "KQ_mask", -1); // [n_kv, n_tokens]
+
+        // iterate layers
+        for (int il = 0; il < n_layer; ++il) {
+            struct ggml_tensor * cur = inpL;
+
+            // self-attention
+            {
+                struct ggml_tensor * Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), model.layers[il].bq);
+                cb(Qcur, "Qcur", il);
+
+                struct ggml_tensor * Kcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), model.layers[il].bk);
+                cb(Kcur, "Kcur", il);
+
+                struct ggml_tensor * Vcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), model.layers[il].bv);
+                cb(Vcur, "Vcur", il);
+
+                // seems like we just need to do this for Q?
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+
+                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                cb(cur, "kqv_out", il);
+            }
+
+            // re-add the layer input
+            cur = ggml_add(ctx0, cur, inpL);
+
+            // attention layer norm
+            cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_norm, model.layers[il].attn_norm_b, LLM_NORM, cb, il);
+
+            struct ggml_tensor * ffn_inp = cur;
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            cur = llm_build_ffn(ctx0, cur,
+                    model.layers[il].ffn_up,   model.layers[il].ffn_up_b,
+                    NULL,                      NULL,
+                    model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+                    NULL,
+                    LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
+            cb(cur, "ffn_out", il);
+
+            // attentions bypass the intermediate layer
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            // output layer norm
+            cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, LLM_NORM, cb, il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        // final output
+        cur = inpL;
+
+        // pooling
+        cur = ggml_mul_mat(ctx0, inp_sum, ggml_cont(ctx0, ggml_transpose(ctx0, cur)));
+        cb(cur, "result_embed", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
     struct ggml_cgraph * build_bloom() {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
 
@@ -7060,7 +7263,8 @@ static struct ggml_cgraph * llama_build_graph(
 
                     for (int i = 0; i < n_kv; ++i) {
                         float f;
-                        if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
+                        if (!lctx.kv_self.cells[i].has_seq_id(seq_id) ||
+                            (llm.causal_attn && lctx.kv_self.cells[i].pos > pos)) {
                             f = -INFINITY;
                         } else {
                             f = 0;
@@ -7081,6 +7285,15 @@ static struct ggml_cgraph * llama_build_graph(
                 data[i] = lctx.kv_self.cells[i].delta;
             }
         }
+
+        {
+            GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_sum->buffer));
+            float * data = (float *) lctx.inp_sum->data;
+
+            for (int i = 0; i < batch.n_tokens; ++i) {
+                data[i] = 1.0f/float(batch.n_tokens);
+            }
+        }
     }
 
     llm.init();
@@ -7110,6 +7323,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_refact();
             } break;
+        case LLM_ARCH_BERT:
+            {
+                result = llm.build_bert();
+            } break;
         case LLM_ARCH_BLOOM:
             {
                 result = llm.build_bloom();
@@ -7269,13 +7486,18 @@ static int llama_decode_internal(
 
     // the output is always the last tensor in the graph
     struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
-    GGML_ASSERT(strcmp(res->name, "result_output") == 0);
-
-    // the embeddings could be the second to last tensor, or the third to last tensor
     struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2];
-    if (strcmp(embeddings->name, "result_norm") != 0) {
-        embeddings = gf->nodes[gf->n_nodes - 3];
-        GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
+    if (strcmp(res->name, "result_output") == 0) {
+        // the embeddings could be the second to last tensor, or the third to last tensor
+        if (strcmp(embeddings->name, "result_norm") != 0) {
+            embeddings = gf->nodes[gf->n_nodes - 3];
+            GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
+        }
+    } else if (strcmp(res->name, "result_embed") == 0) {
+        embeddings = res;
+        res = nullptr;
+    } else {
+        GGML_ASSERT(false);
     }
 
     // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
@@ -7344,7 +7566,7 @@ static int llama_decode_internal(
     // extract logits
     // TODO: do not compute and extract logits if only embeddings are needed
     //       need to update the graphs to skip "result_output"
-    {
+    if (res) {
         auto & logits_out = lctx.logits;
 
 #ifndef NDEBUG
@@ -7388,9 +7610,11 @@ static int llama_decode_internal(
     if (!lctx.embedding.empty()) {
         auto & embedding_out = lctx.embedding;
 
+        const int64_t embed_pos = res ? n_embd * (n_tokens-1) : 0;
+
         embedding_out.resize(n_embd);
         ggml_backend_t embeddings_backend = ggml_backend_sched_get_node_backend(lctx.sched, embeddings);
-        ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), (n_embd*(n_tokens - 1))*sizeof(float), n_embd*sizeof(float));
+        ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), embed_pos*sizeof(float), n_embd*sizeof(float));
         ggml_backend_synchronize(embeddings_backend);
     }
 
@@ -7454,6 +7678,9 @@ static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) {
         GGML_ASSERT(false);
         return unicode_to_bytes_bpe(token_data.text);
     }
+    case LLAMA_VOCAB_TYPE_WPM: {
+        GGML_ASSERT(false);
+    }
     default:
         GGML_ASSERT(false);
     }
@@ -7466,6 +7693,7 @@ static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) {
             const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 };
             return vocab.token_to_id.at(buf);
         }
+        case LLAMA_VOCAB_TYPE_WPM:
         case LLAMA_VOCAB_TYPE_BPE: {
             return vocab.token_to_id.at(bytes_to_unicode_bpe(ch));
         }
@@ -7936,12 +8164,212 @@ private:
     llm_bigram_bpe::queue work_queue;
 };
 
-typedef enum FRAGMENT_BUFFER_VARIANT_TYPE{
+struct llm_tokenizer_wpm {
+    llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {}
+
+    void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
+        auto * token_map = &vocab.token_to_id;
+
+        // normalize and split by whitespace
+        std::vector<std::string> words = preprocess(text);
+
+        // bos token prepended already
+
+        // find the longest tokens that form the words
+        for (const std::string &word : words) {
+            // skip empty words
+            if (word.size() == 0) {
+                continue;
+            }
+
+            // prepend phantom space
+            std::string word1 = "\xe2\x96\x81" + word;
+            int n = word1.size();
+
+            // we're at the start of a new word
+            int i = 0;
+            bool match_any = false;
+
+            // move through character position in word
+            while (i < n) {
+                // loop through possible match length
+                bool match = false;
+                for (int j = n; j > i; j--) {
+                    auto it = token_map->find(word1.substr(i, j - i));
+                    if (it != token_map->end()) {
+                        output.push_back(it->second);
+                        match = true;
+                        match_any = true;
+                        i = j;
+                        break;
+                    }
+                }
+
+                // must be an unknown character
+                if (!match) {
+                    i++;
+                }
+            }
+
+            // we didn't find any matches for this word
+            if (!match_any) {
+                output.push_back(vocab.special_unk_id);
+            }
+        }
+
+        // append eos token
+        output.push_back(vocab.special_eos_id);
+    }
+
+    std::vector<std::string> preprocess(const std::string & text) {
+        std::string ori_str = normalize(text);
+        uint64_t ori_size = ori_str.size();
+
+        // single punct / single symbol / single digit
+        // baseline: add whitespace on the left and right of punct and chinese characters
+        std::vector<std::string> words;
+        std::string new_str = "";
+        uint64_t i = 0;
+        while (i < ori_size) {
+            int utf_char_len = utf8_len(ori_str[i]);
+            if ((utf_char_len == 1) && ispunct(ori_str[i])) {
+                new_str += " ";
+                new_str += ori_str[i];
+                new_str += " ";
+                i += 1;
+            }
+            else if ((utf_char_len == 3) && is_chinese_char(ori_str.substr(i, 3))) {
+                new_str += " ";
+                new_str += ori_str.substr(i, 3);
+                new_str += " ";
+                i += 3;
+            }
+            else {
+                new_str += ori_str[i];
+                i += 1;
+            }
+        }
+
+        // split by whitespace
+        uint64_t l = 0;
+        uint64_t r = 0;
+        while (r < new_str.size()) {
+            // if is whitespace
+            if (isspace(new_str[r])) {
+                if (r > l) words.push_back(new_str.substr(l, (r - l)));
+                l = r + 1;
+                r = l;
+            }
+            else {
+                r += 1;
+            }
+        }
+        if (r > l) {
+            words.push_back(new_str.substr(l, (r - l)));
+        }
+        return words;
+    }
+
+    std::string normalize(const std::string & text) {
+        // TODO: handle chinese characters? https://github.com/huggingface/tokenizers/blob/ef5f50605ddf9f8caef1598c0e4853862b9707a7/tokenizers/src/normalizers/bert.rs#L98
+        std::string text2 = strip_accents(text);
+        for (size_t i = 0; i < text2.size(); i += utf8_len(text2[i])) {
+            char c = text2[i];
+            if (c >= 'A' && c <= 'Z') {
+                text2[i] = c - 'A' + 'a';
+            }
+        }
+        return text2;
+    }
+
+    bool is_chinese_char(const std::string & str) {
+        int len = str.length();
+        unsigned int codepoint = 0;
+        int num_bytes = 0;
+        int i = 0;
+        unsigned char ch = static_cast<unsigned char>(str[i]);
+        if (ch <= 0x7f) {
+            codepoint = ch;
+            num_bytes = 1;
+        } else if ((ch >> 5) == 0x06) {
+            codepoint = ch & 0x1f;
+            num_bytes = 2;
+        } else if ((ch >> 4) == 0x0e) {
+            codepoint = ch & 0x0f;
+            num_bytes = 3;
+        } else if ((ch >> 3) == 0x1e) {
+            codepoint = ch & 0x07;
+            num_bytes = 4;
+        }
+        for (int j = 1; j < num_bytes; ++j) {
+            if (i + j >= len) {
+                return false; // incomplete UTF-8 character
+            }
+            unsigned char next_ch = static_cast<unsigned char>(str[i + j]);
+            if ((next_ch >> 6) != 0x02) {
+                return false; // invalid trailing byte
+            }
+            codepoint = (codepoint << 6) | (next_ch & 0x3f);
+        }
+        if ((codepoint >= 0x4E00  && codepoint <= 0x9FFF)  ||
+            (codepoint >= 0x3400  && codepoint <= 0x4DBF)  ||
+            (codepoint >= 0x20000 && codepoint <= 0x2A6DF) ||
+            (codepoint >= 0x2A700 && codepoint <= 0x2B73F) ||
+            (codepoint >= 0x2B740 && codepoint <= 0x2B81F) ||
+            (codepoint >= 0x2B920 && codepoint <= 0x2CEAF) || // this should be 0x2B820 but in hf rust code it is 0x2B920
+            (codepoint >= 0xF900  && codepoint <= 0xFAFF)  ||
+            (codepoint >= 0x2F800 && codepoint <= 0x2FA1F) ||
+            (codepoint >= 0x3000  && codepoint <= 0x303F)  ||
+            (codepoint >= 0xFF00  && codepoint <= 0xFFEF)) {
+            return true; // NOLINT
+        }
+        return false;
+    }
+
+    std::string strip_accents(const std::string & input_string) {
+        std::string resultString;
+        std::map<std::string, char> accent_map = {
+            {"À", 'A'}, {"Á", 'A'}, {"Â", 'A'}, {"Ã", 'A'}, {"Ä", 'A'}, {"Å", 'A'},
+            {"à", 'a'}, {"á", 'a'}, {"â", 'a'}, {"ã", 'a'}, {"ä", 'a'}, {"å", 'a'},
+            {"È", 'E'}, {"É", 'E'}, {"Ê", 'E'}, {"Ë", 'E'}, {"è", 'e'}, {"é", 'e'},
+            {"ê", 'e'}, {"ë", 'e'}, {"Ì", 'I'}, {"Í", 'I'}, {"Î", 'I'}, {"Ï", 'I'},
+            {"ì", 'i'}, {"í", 'i'}, {"î", 'i'}, {"ï", 'i'}, {"Ò", 'O'}, {"Ó", 'O'},
+            {"Ô", 'O'}, {"Õ", 'O'}, {"Ö", 'O'}, {"ò", 'o'}, {"ó", 'o'}, {"ô", 'o'},
+            {"õ", 'o'}, {"ö", 'o'}, {"Ù", 'U'}, {"Ú", 'U'}, {"Û", 'U'}, {"Ü", 'U'},
+            {"ù", 'u'}, {"ú", 'u'}, {"û", 'u'}, {"ü", 'u'}, {"Ý", 'Y'}, {"ý", 'y'},
+            {"Ç", 'C'}, {"ç", 'c'}, {"Ñ", 'N'}, {"ñ", 'n'},
+        };
+
+        for (size_t i = 0; i <  input_string.length();) {
+            int len = utf8_len(input_string[i]);
+            std::string curChar = input_string.substr(i, len);
+            auto iter = accent_map.find(curChar);
+            if (iter != accent_map.end()) {
+                resultString += iter->second;
+            } else {
+                resultString += curChar;
+            }
+            i += len;
+        }
+
+        return resultString;
+    }
+
+    static size_t utf8_len(char src) {
+        const size_t lookup[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4};
+        uint8_t highbits = static_cast<uint8_t>(src) >> 4;
+        return lookup[highbits];
+    }
+
+    const llama_vocab & vocab;
+};
+
+typedef enum FRAGMENT_BUFFER_VARIANT_TYPE {
     FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN,
     FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT
 } FRAGMENT_BUFFER_VARIANT_TYPE;
 
-struct fragment_buffer_variant{
+struct fragment_buffer_variant {
     fragment_buffer_variant(llama_vocab::id _token)
     :
         type(FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN),
@@ -7971,8 +8399,7 @@ struct fragment_buffer_variant{
 
 // #define PRETOKENIZERDEBUG
 
-static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer)
-{
+static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer) {
     // for each special token
     for (const auto & st: vocab.special_tokens_cache) {
         const auto & special_token = st.first;
@@ -8090,10 +8517,8 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
     switch (vocab.type) {
         case LLAMA_VOCAB_TYPE_SPM:
             {
-                for (const auto & fragment: fragment_buffer)
-                {
-                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT)
-                    {
+                for (const auto & fragment: fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
                         // without adding this leading whitespace, we do not get the same results as the original tokenizer
 
                         // TODO: It's likely possible to get rid of this string copy entirely
@@ -8113,19 +8538,15 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
                         llm_tokenizer_spm tokenizer(vocab);
                         llama_escape_whitespace(raw_text);
                         tokenizer.tokenize(raw_text, output);
-                    }
-                    else // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
-                    {
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
                         output.push_back(fragment.token);
                     }
                 }
             } break;
         case LLAMA_VOCAB_TYPE_BPE:
             {
-                for (const auto & fragment: fragment_buffer)
-                {
-                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT)
-                    {
+                for (const auto & fragment: fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
                         auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
 
 #ifdef PRETOKENIZERDEBUG
@@ -8133,9 +8554,23 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
 #endif
                         llm_tokenizer_bpe tokenizer(vocab);
                         tokenizer.tokenize(raw_text, output);
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        output.push_back(fragment.token);
                     }
-                    else // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
-                    {
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_WPM:
+            {
+                for (const auto & fragment: fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
+#endif
+                        llm_tokenizer_wpm tokenizer(vocab);
+                        tokenizer.tokenize(raw_text, output);
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
                         output.push_back(fragment.token);
                     }
                 }
@@ -10799,7 +11234,7 @@ struct llama_context * llama_new_context_with_model(
         // graph inputs
         {
             ggml_init_params init_params = {
-                /* .mem_size   */ ggml_tensor_overhead()*5,
+                /* .mem_size   */ ggml_tensor_overhead()*7,
                 /* .mem_buffer */ nullptr,
                 /* .no_alloc   */ true,
             };
@@ -10810,12 +11245,14 @@ struct llama_context * llama_new_context_with_model(
             ctx->inp_pos     = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
             ctx->inp_KQ_mask = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx, cparams.n_batch);
             ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_ctx);
+            ctx->inp_sum     = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, 1, cparams.n_batch);
 
             ggml_set_name(ctx->inp_tokens,  "inp_tokens");
             ggml_set_name(ctx->inp_embd,    "inp_embd");
             ggml_set_name(ctx->inp_pos,     "inp_pos");
             ggml_set_name(ctx->inp_KQ_mask, "inp_KQ_mask");
             ggml_set_name(ctx->inp_K_shift, "inp_K_shift");
+            ggml_set_name(ctx->inp_sum,     "inp_sum");
 
             ctx->buf_input = ggml_backend_alloc_ctx_tensors_from_buft(ctx->ctx_input, llama_default_buffer_type_cpu(true));
 
@@ -11746,6 +12183,7 @@ static std::string llama_decode_text(const std::string & text) {
 int32_t llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int32_t length) {
     if (0 <= token && token < llama_n_vocab(model)) {
         switch (llama_vocab_get_type(model->vocab)) {
+        case LLAMA_VOCAB_TYPE_WPM:
         case LLAMA_VOCAB_TYPE_SPM: {
             // NOTE: we accept all unsupported token types,
             // suppressing them like CONTROL tokens.
diff --git a/llama.h b/llama.h
index cec4158bc8e803840de2a242dd212c9f2e3b9d80..367e8f1a105a5f8cc9f7d8d046f29a467e6262a5 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -61,6 +61,7 @@ extern "C" {
     enum llama_vocab_type {
         LLAMA_VOCAB_TYPE_SPM = 0, // SentencePiece
         LLAMA_VOCAB_TYPE_BPE = 1, // Byte Pair Encoding
+        LLAMA_VOCAB_TYPE_WPM = 2, // WordPiece
     };
 
     enum llama_token_type {