]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
model : add PLaMo-2 support (#14560)
authorShunta Saito <redacted>
Tue, 15 Jul 2025 16:11:42 +0000 (01:11 +0900)
committerGitHub <redacted>
Tue, 15 Jul 2025 16:11:42 +0000 (18:11 +0200)
* Add PLaMo-2 model using hybrid memory module

* Fix z shape

* Add cmath to include from llama-vocab.h

* Explicitly dequantize normalization weights before RoPE apply

* Revert unnecessary cast because the problem can be solved by excluding attn_k, attn_q when quantizing

* Use ATTN_K/Q_NORM for k,q weights to prevent quantization

* Remove SSM_BCDT that is not used from anywhere

* Do not duplicate embedding weights for output.weight

* Fix tokenizer encoding problem for multibyte strings

* Apply suggestion from @CISC

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Use LLM_FFN_SWIGLU instead of splitting ffn_gate and ffn_up

* Remove unnecessary part for Grouped Query Attention

* Fix how to load special token id to gguf

* Remove unused tensor mapping

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Remove llama_vocab_plamo2 class and replace it with llm_tokenizer_plamo2_session to follow the other tokenizer implementations

* Update src/llama-vocab.cpp

Co-authored-by: Georgi Gerganov <redacted>
* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <redacted>
* Fix plamo2 tokenizer session to prevent multiple calls of build()

---------

Co-authored-by: Francis Couture-Harpin <redacted>
Co-authored-by: Sigbjørn Skjæret <redacted>
Co-authored-by: Georgi Gerganov <redacted>
convert_hf_to_gguf.py
gguf-py/gguf/constants.py
gguf-py/gguf/tensor_mapping.py
include/llama.h
src/llama-arch.cpp
src/llama-arch.h
src/llama-model.cpp
src/llama-vocab.cpp

index c201883509ceb13e09cfc0feec0a5d5fcea87120..ba7dff355f49c8ff5f9a8b7900ea01bed2554fe5 100755 (executable)
@@ -3508,6 +3508,175 @@ class PlamoModel(TextModel):
         return [(new_name, data_torch)]
 
 
+@ModelBase.register("Plamo2ForCausalLM", "PLaMo2ForCausalLM")
+class Plamo2Model(TextModel):
+    model_arch = gguf.MODEL_ARCH.PLAMO2
+
+    def set_vocab(self):
+        # PLaMo 2 uses a custom tokenizer with a .jsonl file
+        # We need to handle this specially
+        tokenizer_jsonl_path = self.dir_model / "tokenizer.jsonl"
+        tokenizer_config_path = self.dir_model / "tokenizer_config.json"
+
+        if not tokenizer_jsonl_path.is_file():
+            raise FileNotFoundError(f"PLaMo 2 tokenizer file not found: {tokenizer_jsonl_path}")
+
+        # Load tokenizer config
+        with open(tokenizer_config_path, 'r', encoding='utf-8') as f:
+            tokenizer_config = json.load(f)
+
+        # Load tokens from JSONL file (actually a list format)
+        tokens = []
+        scores = []
+        toktypes = []
+
+        with open(tokenizer_jsonl_path, 'r', encoding='utf-8') as f:
+            for line_num, line in enumerate(f):
+                if line.strip():
+                    token_data = json.loads(line)
+                    # Format: [token, score, type, ?, ?, ?, ?]
+                    token = token_data[0].encode("utf-8")
+                    score = float(token_data[1])
+                    token_type_str = token_data[2] if len(token_data) > 2 else "NORMAL"
+
+                    tokens.append(token)
+                    scores.append(score)
+
+                    # Map token type strings to GGUF token types
+                    if token_type_str == "UNKNOWN":
+                        toktypes.append(gguf.TokenType.UNKNOWN)
+                    elif token_type_str == "CONTROL":
+                        toktypes.append(gguf.TokenType.CONTROL)
+                    elif token_type_str == "BYTE":
+                        toktypes.append(gguf.TokenType.BYTE)
+                    else:
+                        # Check for PLaMo-2 special tokens
+                        token_str = token_data[0]
+                        if token_str.startswith("<|plamo:") and token_str.endswith("|>"):
+                            toktypes.append(gguf.TokenType.CONTROL)
+                        else:
+                            toktypes.append(gguf.TokenType.NORMAL)
+
+        vocab_size = self.hparams["vocab_size"]
+        if vocab_size > len(tokens):
+            pad_count = vocab_size - len(tokens)
+            logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]")
+            for i in range(1, pad_count + 1):
+                tokens.append(bytes(f"[PAD{i}]", encoding="utf-8"))
+                scores.append(-1000.0)
+                toktypes.append(gguf.TokenType.UNUSED)
+
+        # Use "plamo2" tokenizer type for PLaMo-2's custom Aho-Corasick tokenizer
+        self.gguf_writer.add_tokenizer_model("plamo2")
+        self.gguf_writer.add_tokenizer_pre("default")
+        self.gguf_writer.add_token_list(tokens)
+        self.gguf_writer.add_token_scores(scores)
+        self.gguf_writer.add_token_types(toktypes)
+
+        # Add special tokens from config
+        if "bos_token" in tokenizer_config and tokenizer_config["bos_token"] is not None:
+            token_id = tokens.index(tokenizer_config["bos_token"].encode("utf-8"))
+            self.gguf_writer.add_bos_token_id(token_id)
+        if "eos_token" in tokenizer_config and tokenizer_config["eos_token"] is not None:
+            token_id = tokens.index(tokenizer_config["eos_token"].encode("utf-8"))
+            self.gguf_writer.add_eos_token_id(token_id)
+        if "pad_token" in tokenizer_config and tokenizer_config["pad_token"] is not None:
+            token_id = tokens.index(tokenizer_config["pad_token"].encode("utf-8"))
+            self.gguf_writer.add_pad_token_id(token_id)
+        if "sep_token" in tokenizer_config and tokenizer_config["sep_token"] is not None:
+            token_id = tokens.index(tokenizer_config["sep_token"].encode("utf-8"))
+            self.gguf_writer.add_sep_token_id(token_id)
+        if "unk_token" in tokenizer_config and tokenizer_config["unk_token"] is not None:
+            token_id = tokens.index(tokenizer_config["unk_token"].encode("utf-8"))
+            self.gguf_writer.add_unk_token_id(token_id)
+
+        # Add <|plamo:op|> as EOT to ensure appropriate end of generation
+        self.gguf_writer.add_eot_token_id(4)
+
+        self.gguf_writer.add_add_space_prefix(False)
+
+    def set_gguf_parameters(self):
+        hparams = self.hparams
+        block_count = hparams["num_hidden_layers"]
+        self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
+
+        # Which layers are Mamba layers
+        # PLaMo 2 uses mamba_step to indicate the pattern (e.g., 2 means every other layer)
+        # This logic matches modeling_plamo.py's is_mamba function
+        mamba_step = hparams.get("mamba_step", 2)
+        mamba_enabled = hparams.get("mamba_enabled", True)
+        mamba_layers = []
+
+        if mamba_enabled:
+            for i in range(block_count):
+                if block_count <= (mamba_step // 2):
+                    # use attention in last layer
+                    is_mamba = (i != block_count - 1)
+                else:
+                    is_mamba = (i % mamba_step) != (mamba_step // 2)
+                if is_mamba:
+                    mamba_layers.append(0)
+                else:
+                    mamba_layers.append(hparams.get("num_key_value_heads", 4))
+
+        if mamba_layers:
+            self.gguf_writer.add_head_count_kv(mamba_layers)
+
+        self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 2048))
+        self.gguf_writer.add_embedding_length(hparams.get("hidden_size", 4096))
+        self.gguf_writer.add_block_count(block_count)
+        self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 32))
+        self.gguf_writer.add_layer_norm_rms_eps(hparams.get("rms_norm_eps", 1e-06))
+        self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1000000.0))
+
+        # Mamba parameters
+        self.gguf_writer.add_ssm_state_size(hparams.get("mamba_d_state", 64))
+        self.gguf_writer.add_ssm_conv_kernel(hparams.get("mamba_d_conv", 4))
+        self.gguf_writer.add_ssm_time_step_rank(hparams.get("mamba_num_heads", 64))
+        intermediate_size = hparams.get("mamba_num_heads", 64) * hparams.get("hidden_size_per_head", 128)
+        self.gguf_writer.add_ssm_inner_size(intermediate_size)
+        self.gguf_writer.add_ssm_group_count(0)
+
+        # MLP feed forward parameters (for attention layers)
+        self.gguf_writer.add_feed_forward_length(hparams.get("intermediate_size", 16384))
+        self.gguf_writer.add_file_type(self.ftype)
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        del bid  # unused
+
+        if name.endswith(".A_log"):
+            data_torch = -torch.exp(data_torch)
+        elif name.endswith(".dt_bias"):
+            name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"
+        elif name.endswith(".dt_norm_weight"):
+            name = name.rpartition(".dt_norm_weight")[0] + ".dt_norm.weight"
+        elif name.endswith(".B_norm_weight"):
+            name = name.rpartition(".B_norm_weight")[0] + ".B_norm.weight"
+        elif name.endswith(".C_norm_weight"):
+            name = name.rpartition(".C_norm_weight")[0] + ".C_norm.weight"
+        elif name.endswith(".k_weight"):
+            name = name.rpartition(".k_weight")[0] + ".k.weight"
+        elif name.endswith(".q_weight"):
+            name = name.rpartition(".q_weight")[0] + ".q.weight"
+        elif name.endswith(".conv1d.weight"):
+            data_torch = torch.squeeze(data_torch)  # remove (, 1, )
+            assert data_torch.ndim == 2
+        elif name.endswith(".pre_mixer_norm.weight"):
+            data_torch += 1.0
+        elif name.endswith(".post_mixer_norm.weight"):
+            data_torch += 1.0 / 5
+        elif name.endswith(".pre_mlp_norm.weight"):
+            data_torch += 1.0
+        elif name.endswith(".post_mlp_norm.weight"):
+            data_torch += 1.0 / (5**1.5)
+        elif name.endswith(".norm.weight"):
+            data_torch += 1.0
+
+        new_name = self.map_tensor_name(name)
+
+        return [(new_name, data_torch)]
+
+
 @ModelBase.register("CodeShellForCausalLM")
 class CodeShellModel(TextModel):
     model_arch = gguf.MODEL_ARCH.CODESHELL
index 4e2b878e189c6102d305b6793b5d86389361b8d4..486a165b68b72a49df00f39190e58f158bd0f845 100644 (file)
@@ -317,6 +317,7 @@ class MODEL_ARCH(IntEnum):
     PHI3             = auto()
     PHIMOE           = auto()
     PLAMO            = auto()
+    PLAMO2           = auto()
     CODESHELL        = auto()
     ORION            = auto()
     INTERNLM2        = auto()
@@ -631,6 +632,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.PHI3:             "phi3",
     MODEL_ARCH.PHIMOE:           "phimoe",
     MODEL_ARCH.PLAMO:            "plamo",
+    MODEL_ARCH.PLAMO2:           "plamo2",
     MODEL_ARCH.CODESHELL:        "codeshell",
     MODEL_ARCH.ORION:            "orion",
     MODEL_ARCH.INTERNLM2:        "internlm2",
@@ -1369,6 +1371,36 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_DOWN,
         MODEL_TENSOR.FFN_UP,
     ],
+    MODEL_ARCH.PLAMO2: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ROPE_FREQS,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_QKV,
+        MODEL_TENSOR.ATTN_Q,
+        MODEL_TENSOR.ATTN_K,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.ATTN_ROT_EMBD,
+        MODEL_TENSOR.ATTN_Q_NORM,
+        MODEL_TENSOR.ATTN_K_NORM,
+        MODEL_TENSOR.ATTN_POST_NORM,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_GATE,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+        MODEL_TENSOR.FFN_POST_NORM,
+        MODEL_TENSOR.SSM_IN,
+        MODEL_TENSOR.SSM_CONV1D,
+        MODEL_TENSOR.SSM_X,
+        MODEL_TENSOR.SSM_DT,
+        MODEL_TENSOR.SSM_A,
+        MODEL_TENSOR.SSM_D,
+        MODEL_TENSOR.SSM_OUT,
+        MODEL_TENSOR.SSM_DT_NORM,
+        MODEL_TENSOR.SSM_B_NORM,
+        MODEL_TENSOR.SSM_C_NORM,
+    ],
     MODEL_ARCH.GPT2: [
         MODEL_TENSOR.TOKEN_EMBD,
         MODEL_TENSOR.POS_EMBD,
index 75855eba52c3c9cbb9873bf6df5fbb8017b1e502..2a675044f9d99276d4e0b8c203698fe5f751b8ef 100644 (file)
@@ -13,7 +13,7 @@ class TensorNameMap:
             "transformer.wte",                           # gpt2 gpt-j mpt refact qwen dbrx jais exaone
             "transformer.word_embeddings",               # falcon
             "word_embeddings",                           # bloom
-            "model.embed_tokens",                        # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414 granite-hybrid
+            "model.embed_tokens",                        # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414 plamo2 granite-hybrid
             "tok_embeddings",                            # llama-pth
             "embeddings.word_embeddings",                # bert nomic-bert
             "language_model.embedding.word_embeddings",  # persimmon
@@ -63,7 +63,7 @@ class TensorNameMap:
         # Output
         MODEL_TENSOR.OUTPUT: (
             "embed_out",                 # gptneox
-            "lm_head",                   # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe olmo2 phimoe
+            "lm_head",                   # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe olmo2 phimoe plamo2
             "output",                    # llama-pth bloom internlm2
             "word_embeddings_for_head",  # persimmon
             "lm_head.linear",            # phi2
@@ -77,7 +77,7 @@ class TensorNameMap:
         MODEL_TENSOR.OUTPUT_NORM: (
             "gpt_neox.final_layer_norm",               # gptneox
             "transformer.ln_f",                        # gpt2 gpt-j falcon jais exaone
-            "model.norm",                              # llama-hf baichuan internlm2 olmoe olmo2 phimoe
+            "model.norm",                              # llama-hf baichuan internlm2 olmoe olmo2 phimoe plamo2
             "norm",                                    # llama-pth
             "transformer.norm_f",                      # mpt dbrx
             "ln_f",                                    # refact bloom qwen gpt2
@@ -126,6 +126,7 @@ class TensorNameMap:
             "h.{bid}.ln_1",                                         # gpt2
             "transformer.h.{bid}.ln",                               # phi2
             "model.layers.layers.{bid}.norm",                       # plamo
+            "model.layers.layers.{bid}.pre_mixer_norm",             # plamo2
             "model.layers.{bid}.attention_norm",                    # internlm2
             "model.layers.{bid}.norm",                              # mamba-qbert
             "backbone.layers.{bid}.norm",                           # mamba
@@ -163,6 +164,7 @@ class TensorNameMap:
             "encoder.layers.{bid}.attn.Wqkv",                                      # nomic-bert
             "encoder.layers.{bid}.mixer.Wqkv",                                     # jina
             "model.layers.{bid}.self_attn.qkv_proj",                               # phi3
+            "model.layers.layers.{bid}.mixer.qkv_proj",                            # plamo2
             "encoder.layers.{bid}.self_attention.query_key_value",                 # chatglm
             "transformer.layers.{bid}.attn.qkv_proj",                              # openelm
             "transformer_encoder.{bid}.qkv",                                       # neobert
@@ -233,6 +235,7 @@ class TensorNameMap:
             "h.{bid}.attn.c_proj",                                          # gpt2
             "transformer.h.{bid}.mixer.out_proj",                           # phi2
             "model.layers.layers.{bid}.self_attn.o_proj",                   # plamo
+            "model.layers.layers.{bid}.mixer.o_proj",                       # plamo2
             "model.layers.{bid}.attention.wo",                              # internlm2
             "encoder.layers.{bid}.attn.out_proj",                           # nomic-bert
             "encoder.layers.{bid}.mixer.out_proj",                          # jina
@@ -255,8 +258,9 @@ class TensorNameMap:
         ),
 
         MODEL_TENSOR.ATTN_POST_NORM: (
-            "model.layers.{bid}.post_attention_layernorm",     # gemma2 olmo2    # ge
-            "model.layers.{bid}.post_self_attn_layernorm",     # glm-4-0414
+            "model.layers.{bid}.post_attention_layernorm",       # gemma2 olmo2    # ge
+            "model.layers.{bid}.post_self_attn_layernorm",       # glm-4-0414
+            "model.layers.layers.{bid}.post_mixer_norm.weight",  # plamo2
         ),
 
         # Rotary embeddings
@@ -286,6 +290,7 @@ class TensorNameMap:
             "model.layers.{bid}.pre_moe_layernorm",                          # mini-jamba
             "model.layers.{bid}.post_attention_layernorm",                   # llama4
             "transformer_encoder.{bid}.ffn_norm",                            # neobert
+            "model.layers.layers.{bid}.pre_mlp_norm",                        # plamo2
         ),
 
         # Post feed-forward norm
@@ -298,6 +303,7 @@ class TensorNameMap:
         MODEL_TENSOR.FFN_POST_NORM: (
             "model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2
             "model.layers.{bid}.post_mlp_layernorm", # glm-4-0414
+            "model.layers.layers.{bid}.post_mlp_norm.weight", # plamo2
             "model.layers.{bid}.feed_forward.up_proj",
         ),
 
@@ -342,6 +348,7 @@ class TensorNameMap:
             "model.layers.{bid}.mlp.fc1",                             # phi2
             "model.layers.{bid}.mlp.gate_up_proj",                    # phi3 glm-4-0414
             "model.layers.layers.{bid}.mlp.up_proj",                  # plamo
+            "model.layers.layers.{bid}.mlp.gate_up_proj",             # plamo2
             "model.layers.{bid}.feed_forward.w3",                     # internlm2
             "encoder.layers.{bid}.mlp.fc11",                          # nomic-bert
             "encoder.layers.{bid}.mlp.fc1",                           # nomic-bert-moe
@@ -469,6 +476,7 @@ class TensorNameMap:
             "transformer.blocks.{bid}.attn.q_ln",                             # sea-lion
             "encoder.layer.{bid}.attention.self.layer_norm_q",                # jina-bert-v2
             "transformer.layers.{bid}.attn.q_norm",                           # openelm
+            "model.layers.layers.{bid}.mixer.q",                              # plamo2
         ),
 
         MODEL_TENSOR.ATTN_K_NORM: (
@@ -479,6 +487,7 @@ class TensorNameMap:
             "transformer.blocks.{bid}.attn.k_ln",                             # sea-lion
             "encoder.layer.{bid}.attention.self.layer_norm_k",                # jina-bert-v2
             "transformer.layers.{bid}.attn.k_norm",                           # openelm
+            "model.layers.layers.{bid}.mixer.k",                              # plamo2
         ),
 
         MODEL_TENSOR.ROPE_FREQS: (
@@ -559,27 +568,31 @@ class TensorNameMap:
         ),
 
         MODEL_TENSOR.SSM_IN: (
-            "model.layers.{bid}.in_proj",           # mamba-hf
-            "backbone.layers.{bid}.mixer.in_proj",  # mamba
-            "model.layers.{bid}.mamba.in_proj",     # jamba falcon-h1 granite-hybrid
+            "model.layers.{bid}.in_proj",               # mamba-hf
+            "backbone.layers.{bid}.mixer.in_proj",      # mamba
+            "model.layers.{bid}.mamba.in_proj",         # jamba falcon-h1 granite-hybrid
+            "model.layers.layers.{bid}.mixer.in_proj",  # plamo2
         ),
 
         MODEL_TENSOR.SSM_CONV1D: (
-            "model.layers.{bid}.conv1d",           # mamba-hf
-            "backbone.layers.{bid}.mixer.conv1d",  # mamba
-            "model.layers.{bid}.mamba.conv1d",     # jamba falcon-h1 granite-hybrid
+            "model.layers.{bid}.conv1d",               # mamba-hf
+            "backbone.layers.{bid}.mixer.conv1d",      # mamba
+            "model.layers.{bid}.mamba.conv1d",         # jamba falcon-h1 granite-hybrid
+            "model.layers.layers.{bid}.mixer.conv1d",  # plamo2
         ),
 
         MODEL_TENSOR.SSM_X: (
-            "model.layers.{bid}.x_proj",           # mamba-hf
-            "backbone.layers.{bid}.mixer.x_proj",  # mamba
-            "model.layers.{bid}.mamba.x_proj",     # jamba
+            "model.layers.{bid}.x_proj",                  # mamba-hf
+            "backbone.layers.{bid}.mixer.x_proj",         # mamba
+            "model.layers.{bid}.mamba.x_proj",            # jamba
+            "model.layers.layers.{bid}.mixer.bcdt_proj",  # plamo2
         ),
 
         MODEL_TENSOR.SSM_DT: (
-            "model.layers.{bid}.dt_proj",           # mamba-hf
-            "backbone.layers.{bid}.mixer.dt_proj",  # mamba
-            "model.layers.{bid}.mamba.dt_proj",     # jamba falcon-h1 granite-hybrid
+            "model.layers.{bid}.dt_proj",               # mamba-hf
+            "backbone.layers.{bid}.mixer.dt_proj",      # mamba
+            "model.layers.{bid}.mamba.dt_proj",         # jamba falcon-h1 granite-hybrid
+            "model.layers.layers.{bid}.mixer.dt_proj",  # plamo2
         ),
 
         MODEL_TENSOR.SSM_DT_NORM: (
@@ -587,25 +600,33 @@ class TensorNameMap:
         ),
 
         MODEL_TENSOR.SSM_A: (
-            "model.layers.{bid}.A_log",           # mamba-hf
-            "backbone.layers.{bid}.mixer.A_log",  # mamba
-            "model.layers.{bid}.mamba.A_log",     # jamba falcon-h1 granite-hybrid
+            "model.layers.{bid}.A_log",               # mamba-hf
+            "backbone.layers.{bid}.mixer.A_log",      # mamba
+            "model.layers.{bid}.mamba.A_log",         # jamba falcon-h1 granite-hybrid
+            "model.layers.layers.{bid}.mixer.A_log",  # plamo2
         ),
 
         MODEL_TENSOR.SSM_B_NORM: (
-            "model.layers.{bid}.mamba.b_layernorm",  # jamba
-            "model.layers.{bid}.mamba.B_layernorm",  # mini-jamba
+            "model.layers.{bid}.mamba.b_layernorm",           # jamba
+            "model.layers.{bid}.mamba.B_layernorm",           # mini-jamba
+            "model.layers.layers.{bid}.mixer.B_norm.weight",  # plamo2
         ),
 
         MODEL_TENSOR.SSM_C_NORM: (
-            "model.layers.{bid}.mamba.c_layernorm",  # jamba
-            "model.layers.{bid}.mamba.C_layernorm",  # mini-jamba
+            "model.layers.{bid}.mamba.c_layernorm",           # jamba
+            "model.layers.{bid}.mamba.C_layernorm",           # mini-jamba
+            "model.layers.layers.{bid}.mixer.C_norm.weight",  # plamo2
         ),
 
         MODEL_TENSOR.SSM_D: (
-            "model.layers.{bid}.D",           # mamba-hf
-            "backbone.layers.{bid}.mixer.D",  # mamba
-            "model.layers.{bid}.mamba.D",     # jamba falcon-h1 granite-hybrid
+            "model.layers.{bid}.D",               # mamba-hf
+            "backbone.layers.{bid}.mixer.D",      # mamba
+            "model.layers.{bid}.mamba.D",         # jamba falcon-h1 granite-hybrid
+            "model.layers.layers.{bid}.mixer.D",  # plamo2
+        ),
+
+        MODEL_TENSOR.SSM_DT_NORM: (
+            "model.layers.layers.{bid}.mixer.dt_norm.weight",  # plamo2
         ),
 
         MODEL_TENSOR.SSM_NORM: (
@@ -614,9 +635,10 @@ class TensorNameMap:
         ),
 
         MODEL_TENSOR.SSM_OUT: (
-            "model.layers.{bid}.out_proj",           # mamba-hf
-            "backbone.layers.{bid}.mixer.out_proj",  # mamba
-            "model.layers.{bid}.mamba.out_proj",     # jamba falcon-h1 granite-hybrid
+            "model.layers.{bid}.out_proj",               # mamba-hf
+            "backbone.layers.{bid}.mixer.out_proj",      # mamba
+            "model.layers.{bid}.mamba.out_proj",         # jamba falcon-h1 granite-hybrid
+            "model.layers.layers.{bid}.mixer.out_proj",  # plamo2
         ),
 
         MODEL_TENSOR.TIME_MIX_W0: (
index f73b1ab65fe6fcad6ea7afa37ab443246fa10c3e..c83b759150bfe8868f10f973baea37208bc2ab3d 100644 (file)
@@ -71,12 +71,13 @@ extern "C" {
     typedef int32_t llama_seq_id;
 
     enum llama_vocab_type {
-        LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab
-        LLAMA_VOCAB_TYPE_SPM  = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback
-        LLAMA_VOCAB_TYPE_BPE  = 2, // GPT-2 tokenizer based on byte-level BPE
-        LLAMA_VOCAB_TYPE_WPM  = 3, // BERT tokenizer based on WordPiece
-        LLAMA_VOCAB_TYPE_UGM  = 4, // T5 tokenizer based on Unigram
-        LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization
+        LLAMA_VOCAB_TYPE_NONE   = 0, // For models without vocab
+        LLAMA_VOCAB_TYPE_SPM    = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback
+        LLAMA_VOCAB_TYPE_BPE    = 2, // GPT-2 tokenizer based on byte-level BPE
+        LLAMA_VOCAB_TYPE_WPM    = 3, // BERT tokenizer based on WordPiece
+        LLAMA_VOCAB_TYPE_UGM    = 4, // T5 tokenizer based on Unigram
+        LLAMA_VOCAB_TYPE_RWKV   = 5, // RWKV tokenizer based on greedy tokenization
+        LLAMA_VOCAB_TYPE_PLAMO2 = 6, // PLaMo-2 tokenizer based on Aho-Corasick with dynamic programming
     };
 
     enum llama_rope_type {
index e63ab284bc3b59bca499add2f387bc119ec39294..5c7a0d087ce528f124b55393eb9179e771dcf576 100644 (file)
@@ -34,6 +34,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_PHI3,             "phi3"             },
     { LLM_ARCH_PHIMOE,           "phimoe"           },
     { LLM_ARCH_PLAMO,            "plamo"            },
+    { LLM_ARCH_PLAMO2,           "plamo2"           },
     { LLM_ARCH_CODESHELL,        "codeshell"        },
     { LLM_ARCH_ORION,            "orion"            },
     { LLM_ARCH_INTERNLM2,        "internlm2"        },
@@ -784,6 +785,36 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
         },
     },
+    {
+        LLM_ARCH_PLAMO2,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
+            { LLM_TENSOR_ATTN_Q_NORM,     "blk.%d.attn_q_norm" },
+            { LLM_TENSOR_ATTN_K_NORM,     "blk.%d.attn_k_norm" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
+            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+            { LLM_TENSOR_SSM_IN,          "blk.%d.ssm_in" },
+            { LLM_TENSOR_SSM_CONV1D,      "blk.%d.ssm_conv1d" },
+            { LLM_TENSOR_SSM_X,           "blk.%d.ssm_x" },
+            { LLM_TENSOR_SSM_DT,          "blk.%d.ssm_dt" },
+            { LLM_TENSOR_SSM_A,           "blk.%d.ssm_a" },
+            { LLM_TENSOR_SSM_D,           "blk.%d.ssm_d" },
+            { LLM_TENSOR_SSM_OUT,         "blk.%d.ssm_out" },
+            { LLM_TENSOR_SSM_DT_NORM,     "blk.%d.ssm_dt_norm" },
+            { LLM_TENSOR_SSM_B_NORM,      "blk.%d.ssm_b_norm" },
+            { LLM_TENSOR_SSM_C_NORM,      "blk.%d.ssm_c_norm" },
+            { LLM_TENSOR_ATTN_POST_NORM,  "blk.%d.post_attention_norm" },
+            { LLM_TENSOR_FFN_POST_NORM,   "blk.%d.post_ffw_norm" },
+        },
+    },
     {
         LLM_ARCH_CODESHELL,
         {
@@ -2094,6 +2125,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
     switch (arch) {
         case LLM_ARCH_JAMBA:
         case LLM_ARCH_FALCON_H1:
+        case LLM_ARCH_PLAMO2:
         case LLM_ARCH_GRANITE_HYBRID:
         case LLM_ARCH_LFM2:
             return true;
index 1f97325952411182d875af8e52a1416f95324016..d4a2dea9ec33dcdc41be2ebcde9879e12a3177de 100644 (file)
@@ -38,6 +38,7 @@ enum llm_arch {
     LLM_ARCH_PHI3,
     LLM_ARCH_PHIMOE,
     LLM_ARCH_PLAMO,
+    LLM_ARCH_PLAMO2,
     LLM_ARCH_CODESHELL,
     LLM_ARCH_ORION,
     LLM_ARCH_INTERNLM2,
index a322fc39352e7f2a1138f4c053df224399f90e69..ffee997b8f494772f21c9cd1ba378d03ae7e779d 100644 (file)
@@ -935,6 +935,33 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     default: type = LLM_TYPE_UNKNOWN;
                }
             } break;
+        case LLM_ARCH_PLAMO2:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+
+                // Load Mamba SSM parameters
+                ml.get_key(LLM_KV_SSM_CONV_KERNEL,    hparams.ssm_d_conv);
+                ml.get_key(LLM_KV_SSM_INNER_SIZE,     hparams.ssm_d_inner);
+                ml.get_key(LLM_KV_SSM_STATE_SIZE,     hparams.ssm_d_state);
+                ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
+                ml.get_key(LLM_KV_SSM_GROUP_COUNT,    hparams.ssm_n_group);
+
+                for (uint32_t i = 0; i < hparams.n_layer; ++i) {
+                    hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0;
+                }
+
+                switch (hparams.n_layer) {
+                    case 16: type = LLM_TYPE_1B; break;
+                    case 32:
+                        if (hparams.n_embd == 2048) {
+                            type = LLM_TYPE_2B;
+                        } else if (hparams.n_embd == 4096) {
+                            type = LLM_TYPE_8B;
+                        }
+                        break;
+                    default: type = LLM_TYPE_UNKNOWN;
+               }
+            } break;
         case LLM_ARCH_GPT2:
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -2938,6 +2965,73 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
                     }
                 } break;
+            case LLM_ARCH_PLAMO2:
+                {
+                    const uint32_t d_conv             = hparams.ssm_d_conv;
+                    const uint32_t d_state            = hparams.ssm_d_state;
+                    const uint32_t num_heads          = hparams.ssm_dt_rank;
+                    const uint32_t intermediate_size  = hparams.ssm_d_inner;
+                    const uint32_t head_dim           = intermediate_size / num_heads;
+                    const uint32_t qk_dim             = head_dim;
+                    const uint32_t v_dim              = head_dim;
+                    const int64_t num_attention_heads = hparams.n_head();
+                    const int64_t q_num_heads         = num_attention_heads;
+                    const int64_t dt_dim              = std::max(64, int(hparams.n_embd / 16));
+
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+                    // if output is NULL, init from the input tok embed
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+                        bool is_mamba_layer = hparams.is_recurrent(i);
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        if (is_mamba_layer) {
+                            layer.ssm_in       = create_tensor(tn(LLM_TENSOR_SSM_IN,     "weight", i), {n_embd, 2 * intermediate_size}, 0);
+                            layer.ssm_conv1d   = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, intermediate_size}, 0);
+
+                            layer.ssm_x    = create_tensor(tn(LLM_TENSOR_SSM_X,  "weight", i), {intermediate_size, dt_dim + 2*d_state}, 0);
+                            layer.ssm_dt   = create_tensor(tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_dim, num_heads}, 0);
+                            layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {num_heads}, 0);
+
+                            layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {num_heads}, 0);
+                            layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {num_heads}, 0);
+
+                            layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {intermediate_size, n_embd}, 0);
+
+                            layer.ssm_dt_norm = create_tensor(tn(LLM_TENSOR_SSM_DT_NORM, i), {dt_dim}, 0);
+                            layer.ssm_b_norm = create_tensor(tn(LLM_TENSOR_SSM_B_NORM, i), {d_state}, 0);
+                            layer.ssm_c_norm = create_tensor(tn(LLM_TENSOR_SSM_C_NORM, i), {d_state}, 0);
+                        } else {
+                            const int64_t num_key_value_heads = hparams.n_head_kv(i);
+                            const int64_t k_num_heads         = num_key_value_heads;
+                            const int64_t v_num_heads         = num_key_value_heads;
+                            const int64_t q_proj_dim          = q_num_heads * qk_dim;
+                            const int64_t k_proj_dim          = k_num_heads * qk_dim;
+                            const int64_t v_proj_dim          = v_num_heads * v_dim;
+
+                            layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, q_proj_dim + k_proj_dim + v_proj_dim}, 0);
+                            layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {head_dim, num_attention_heads}, 0);
+                            layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {head_dim, k_num_heads}, 0);
+                            layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {q_num_heads * v_dim, n_embd}, 0);
+                        }
+
+                        // All layers have post-attention norm, FFN norm, and FFN tensors
+                        layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, i), {n_embd}, 0);
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff * 2}, 0);
+                        layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, i), {n_embd}, 0);
+                    }
+                } break;
             case LLM_ARCH_GPT2:
                 {
                     tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -5209,6 +5303,7 @@ void llama_model::print_info() const {
         arch == LLM_ARCH_MAMBA2 ||
         arch == LLM_ARCH_JAMBA ||
         arch == LLM_ARCH_FALCON_H1 ||
+        arch == LLM_ARCH_PLAMO2 ||
         arch == LLM_ARCH_GRANITE_HYBRID) {
         LLAMA_LOG_INFO("%s: ssm_d_conv       = %u\n",     __func__, hparams.ssm_d_conv);
         LLAMA_LOG_INFO("%s: ssm_d_inner      = %u\n",     __func__, hparams.ssm_d_inner);
@@ -15476,6 +15571,320 @@ struct llm_build_falcon_h1 : public llm_graph_context_mamba {
     }
 };
 
+struct llm_build_plamo2 : public llm_graph_context_mamba {
+    llm_build_plamo2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context_mamba(params) {
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        // {n_embd, n_tokens}
+        inpL = build_inp_embd(model.tok_embd);
+        cb(inpL, "embedding_output", -1);
+
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_hybrid = build_inp_mem_hybrid();
+
+        ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * residual = inpL;
+
+            // ggml_graph_add_node(gf, model.layers[il].attn_norm);
+            // cb(model.layers[il].attn_norm, "attn_norm", il);
+
+            // pre_mixer_norm
+            cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
+
+            // check if this layer is Mamba or Attention
+            bool is_mamba_layer = hparams.is_recurrent(il);
+
+            if (is_mamba_layer) {
+                // PLaMo-2 Mamba layer
+                cur = build_plamo2_mamba_layer(inp_hybrid->get_recr(), gf, cur, model, ubatch, il);
+            } else {
+                // PLaMo-2 Attention layer
+                cur = build_plamo2_attn_layer(inp_hybrid->get_attn(), inp_pos, gf, cur, model, il);
+            }
+
+            // post_mixer_norm
+            cur = build_norm(cur, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il);
+            cb(cur, "attn_post_norm", il);
+
+            // residual connection
+            cur = ggml_add(ctx0, cur, residual);
+            cb(cur, "attn_residual", il);
+            residual = cur;
+
+            // pre-ffn norm
+            cur = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
+            cb(cur, "ffn_pre_norm", il);
+
+            // feed-forward network
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    NULL,                      NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
+                    NULL,
+                    LLM_FFN_SWIGLU, LLM_FFN_SEQ, il);
+            cb(cur, "ffn_out", il);
+
+            // post ffn norm
+            cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, il);
+            cb(cur, "ffn_post_norm", il);
+
+            if (il == n_layer - 1 && inp_out_ids) {
+                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+                residual = ggml_get_rows(ctx0, residual, inp_out_ids);
+            }
+
+            // residual connection
+            cur = ggml_add(ctx0, cur, residual);
+            cb(cur, "ffn_residual", il);
+
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        // final norm
+        cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
+        cb(cur, "result_norm", -1);
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+        cb(cur, "result_output", -1);
+
+        // Explicitly mark as output tensor to ensure proper backend assignment
+        ggml_set_output(cur);
+
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+
+private:
+    ggml_tensor * build_plamo2_attn_layer(
+            llm_graph_input_attn_kv_unified * inp,
+            ggml_tensor * inp_pos,
+            ggml_cgraph * gf,
+            ggml_tensor * cur,
+            const llama_model & model,
+            int il) {
+
+        // self-attention
+        {
+            // PLaMo-2 uses combined QKV tensor
+            ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur);
+            cb(qkv, "qkv", il);
+
+            // split QKV tensor into Q, K, V
+            const int64_t n_embd_head_q = hparams.n_embd_head_k;
+            const int64_t n_embd_head_k = hparams.n_embd_head_k;
+            const int64_t n_embd_head_v = hparams.n_embd_head_v;
+            int32_t n_head_kv = hparams.n_head_kv(il);
+
+            const int64_t q_offset = 0;
+            const int64_t k_offset = n_embd_head_q * n_head;
+            const int64_t v_offset = k_offset + n_embd_head_k * n_head_kv;
+
+            ggml_tensor * Qcur = ggml_view_3d(ctx0, qkv, n_embd_head_q, n_head, n_tokens, n_embd_head_q * sizeof(float), qkv->nb[1], q_offset * ggml_element_size(qkv));
+            ggml_tensor * Kcur = ggml_view_3d(ctx0, qkv, n_embd_head_k, n_head_kv, n_tokens, n_embd_head_k * sizeof(float), qkv->nb[1], k_offset * ggml_element_size(qkv));
+            ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_head_v * n_head_kv, n_tokens, qkv->nb[1], v_offset * ggml_element_size(qkv)));
+
+            cb(Qcur, "Qcur", il);
+            cb(Kcur, "Kcur", il);
+            cb(Vcur, "Vcur", il);
+
+            Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv, n_tokens);
+
+            Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
+            cb(Qcur, "Qcur_normed", il);
+
+            Qcur = ggml_rope_ext(
+                    ctx0, Qcur, inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                    );
+
+            Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
+            cb(Kcur, "Kcur_normed", il);
+
+            Kcur = ggml_rope_ext(
+                    ctx0, Kcur, inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                    );
+
+            cur = build_attn(inp, gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f, il);
+        }
+
+        cb(cur, "attn_out", il);
+
+        return cur;
+    }
+
+    ggml_tensor * build_plamo2_mamba_layer(
+         llm_graph_input_rs * inp,
+               ggml_cgraph * gf,
+               ggml_tensor * cur,
+         const llama_model & model,
+        const llama_ubatch & ubatch,
+                       int   il) {
+
+        const auto * mctx_cur = inp->mctx;
+
+        const auto kv_head = mctx_cur->get_head();
+
+        const int64_t d_conv   = hparams.ssm_d_conv;
+        const int64_t d_inner  = hparams.ssm_d_inner;
+        const int64_t d_state  = hparams.ssm_d_state;
+        const int64_t n_heads  = hparams.ssm_dt_rank;
+        const int64_t head_dim = d_inner / n_heads;
+        const int64_t n_group  = hparams.ssm_n_group;
+        const int64_t n_seqs   = ubatch.n_seqs;
+
+        const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+
+        GGML_ASSERT(n_seqs != 0);
+        GGML_ASSERT(ubatch.equal_seqs);
+        GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
+
+        ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
+        ggml_tensor * ssm_states_all  = mctx_cur->get_s_l(il);
+
+        ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs);
+        conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs);
+
+        // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
+        cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
+
+        // in_proj: {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs}
+        ggml_tensor * zx = build_lora_mm(model.layers[il].ssm_in, cur);
+        cb(zx, "mamba_in_proj", il);
+        // {8192, 5, 1, 1} -> {8192, 1, 5, 1}
+        zx = ggml_permute(ctx0, zx, 0, 2, 1, 3);
+        zx = ggml_reshape_4d(ctx0, zx, head_dim * 2, n_heads, n_seq_tokens, n_seqs);
+        cb(zx, "mamba_in_proj_out", il);
+
+        // split into z and x
+        // => {head_dim * n_heads, n_seq_tokens, n_seqs}
+        ggml_tensor * x = ggml_view_4d(ctx0, zx, head_dim, n_heads, n_seq_tokens, n_seqs, zx->nb[1], zx->nb[2], zx->nb[3], head_dim*ggml_element_size(zx));
+        x = ggml_cont(ctx0, x);
+        x = ggml_reshape_3d(ctx0, x, head_dim * n_heads, n_seq_tokens, n_seqs);
+        // x = ggml_permute(ctx0, x, 0, 2, 1, 3);
+        cb(x, "mamba_x_split", il);
+
+        ggml_tensor * z = ggml_view_4d(ctx0, zx, head_dim, n_heads, n_seq_tokens, n_seqs, zx->nb[1], zx->nb[2], zx->nb[3], 0);
+        cb(z, "mamba_z_split", il);
+
+        // conv1d
+        {
+            // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs}
+            x = ggml_view_2d(ctx0, x, d_inner, n_seq_tokens * n_seqs, d_inner * x->nb[0], 0);
+            ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, x), 0);
+            cb(conv_x, "mamba_conv1d_input", il);
+
+            // copy last (d_conv - 1) columns back into the state cache
+            ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner, n_seqs,
+                    conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0]));
+
+            ggml_build_forward_expand(gf,
+                ggml_cpy(ctx0, last_conv,
+                    ggml_view_1d(ctx0, conv_states_all,
+                        (d_conv - 1)*(d_inner)*(n_seqs),
+                        kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all))));
+
+            // 1D convolution
+            x = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d);
+            cb(x, "mamba_conv1d", il);
+
+            x = ggml_silu(ctx0, x);
+            cb(x, "mamba_conv1d_silu", il);
+        }
+
+        // SSM
+        {
+            // bcdt_proj: {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs}
+            ggml_tensor * x_bcdt = build_lora_mm(model.layers[il].ssm_x, x);
+            cb(x_bcdt, "mamba_bcdt_proj", il);
+
+            // split into dt, B, C
+            const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16));
+            ggml_tensor * B = ggml_view_3d(ctx0, x_bcdt, d_state, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], 0);
+            ggml_tensor * C  = ggml_view_3d(ctx0, x_bcdt, d_state, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], ggml_element_size(x_bcdt)*d_state);
+            ggml_tensor * dt  = ggml_view_3d(ctx0, x_bcdt, dt_dim, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], ggml_element_size(x_bcdt)*(2*d_state));
+            cb(B, "mamba_B_raw", il);
+            cb(C, "mamba_C_raw", il);
+            cb(dt, "mamba_dt_raw", il);
+
+            // Apply RMS norm to dt, B, C (PLaMo-2 specific)
+            B = build_norm(B, model.layers[il].ssm_b_norm, NULL, LLM_NORM_RMS, il);
+            C = build_norm(C, model.layers[il].ssm_c_norm, NULL, LLM_NORM_RMS, il);
+            dt = build_norm(dt, model.layers[il].ssm_dt_norm, NULL, LLM_NORM_RMS, il);
+            cb(B, "mamba_B_normed", il);
+            cb(C, "mamba_C_normed", il);
+            cb(dt, "mamba_dt_normed", il);
+
+            // dt_proj: {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs}
+            dt = build_lora_mm(model.layers[il].ssm_dt, dt);
+            dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
+            cb(dt, "mamba_dt_proj", il);
+
+            ggml_tensor * A = ggml_reshape_2d(ctx0, model.layers[il].ssm_a, 1, n_heads);
+            cb(A, "mamba_A", il);
+
+            x = ggml_view_4d(ctx0, x, head_dim, n_heads, n_seq_tokens, n_seqs, head_dim * ggml_element_size(x), head_dim * n_heads * ggml_element_size(x), head_dim * n_heads * n_seq_tokens * ggml_element_size(x), 0);
+            B = ggml_view_4d(ctx0, B, d_state, 1, n_seq_tokens, n_seqs, d_state * B->nb[0], B->nb[1], B->nb[2], 0);
+            C = ggml_view_4d(ctx0, C, d_state, 1, n_seq_tokens, n_seqs, d_state * C->nb[0], C->nb[1], C->nb[2], 0);
+
+            // use the states and the indices provided by build_recurrent_state
+            // (this is necessary in order to properly use the states before they are overwritten,
+            //  while avoiding to make unnecessary copies of the states)
+            auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
+                ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_heads, mctx_cur->get_size());
+
+                // Custom operator to optimize the parallel associative scan
+                // as described in the Annex D of the Mamba paper.
+                // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
+                return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
+            };
+
+            ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);
+            cb(y_ssm, "mamba_ssm_scan", il);
+
+            // store last states
+            ggml_build_forward_expand(gf,
+                ggml_cpy(ctx0,
+                    ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, x->nb[3]*x->ne[3]),
+                    ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs,
+                            kv_head*d_state*d_inner*ggml_element_size(ssm_states_all))));
+
+            ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_heads, n_seq_tokens, n_seqs, head_dim * ggml_element_size(x), head_dim * n_heads * ggml_element_size(x), head_dim * n_heads * n_seq_tokens * ggml_element_size(x), 0);
+            cb(y, "mamba_y_view", il);
+
+            // Add D parameter and apply gating with z
+            // {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs}
+            ggml_tensor * D = ggml_reshape_2d(ctx0, model.layers[il].ssm_d, 1, n_heads);
+            y = ggml_add(ctx0, y, ggml_mul(ctx0, x, D));
+            cb(y, "mamba_y_add_d", il);
+
+            y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y);
+            cb(y, "mamba_y_swiglu_z", il);
+
+            // out_proj: {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
+            y = ggml_view_3d(ctx0, y, head_dim * n_heads, n_seq_tokens, n_seqs, y->nb[2], y->nb[3], 0);
+            cur = build_lora_mm(model.layers[il].ssm_out, y);
+            cb(cur, "mamba_out_proj", il);
+        }
+
+        // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
+        cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
+        cb(cur, "mamba_out", il);
+
+        return cur;
+    }
+};
+
 struct llm_build_arcee : public llm_graph_context {
     llm_build_arcee(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -16262,6 +16671,10 @@ llm_graph_result_ptr llama_model::build_graph(
             {
                 llm = std::make_unique<llm_build_plamo>(*this, params, gf);
             } break;
+        case LLM_ARCH_PLAMO2:
+            {
+                llm = std::make_unique<llm_build_plamo2>(*this, params, gf);
+            } break;
         case LLM_ARCH_GPT2:
             {
                 llm = std::make_unique<llm_build_gpt2>(*this, params, gf);
@@ -16651,6 +17064,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_PHI3:
         case LLM_ARCH_PHIMOE:
         case LLM_ARCH_PLAMO:
+        case LLM_ARCH_PLAMO2:
         case LLM_ARCH_GEMMA:
         case LLM_ARCH_GEMMA2:
         case LLM_ARCH_GEMMA3:
index e0e578d6394d822a4c927b241ff71d16c7140aa2..55e6813c248444881bc32dc94f1c214885f5847d 100644 (file)
@@ -11,6 +11,7 @@
 #include <cassert>
 #include <cctype>
 #include <cfloat>
+#include <cmath>
 #include <cstdarg>
 #include <cstring>
 #include <forward_list>
@@ -1196,6 +1197,284 @@ private:
     const llm_tokenizer_rwkv & tokenizer;
 };
 
+struct llm_tokenizer_plamo2 : llm_tokenizer {
+    llm_tokenizer_plamo2(const llama_vocab & vocab) {
+        build(vocab);
+    }
+
+    void build(const llama_vocab & vocab) {
+        // Reset internal structures
+        tokens_.clear();
+        bytes_.assign(256, 0);
+        to_suffix_id_.clear();
+        table_.clear();
+
+        // Build token list and byte mapping
+        std::unordered_map<std::string, float> suffix_to_score;
+        std::unordered_map<std::string, llama_token> token_to_id;
+
+        for (size_t token_id = 0; token_id < vocab.n_tokens(); ++token_id) {
+            const auto & entry = vocab.get_token_data(token_id);
+            tokens_.push_back(entry.text);
+            token_to_id[entry.text] = static_cast<llama_token>(token_id);
+
+            // Handle byte tokens
+            if (vocab.is_byte(token_id)) {
+                if (entry.text.length() == 6 && entry.text.substr(0, 3) == "<0x" && entry.text.back() == '>') {
+                    std::string hex_str = entry.text.substr(3, 2);
+                    int byte_val = std::stoi(hex_str, nullptr, 16);
+                    bytes_[byte_val] = static_cast<llama_token>(token_id);
+                }
+                continue;
+            }
+
+            // Add token and all its suffixes to suffix_to_score
+            suffix_to_score[entry.text] = entry.score;
+
+            // Extract suffixes character by character (UTF-8 aware)
+            std::vector<uint32_t> cpts = unicode_cpts_from_utf8(entry.text);
+            for (size_t i = 1; i < cpts.size(); ++i) {
+                std::string suffix;
+                for (size_t j = i; j < cpts.size(); ++j) {
+                    suffix += unicode_cpt_to_utf8(cpts[j]);
+                }
+                if (suffix_to_score.find(suffix) == suffix_to_score.end()) {
+                    suffix_to_score[suffix] = std::numeric_limits<float>::quiet_NaN();
+                }
+            }
+        }
+
+        // Check that all byte tokens are set
+        for (int i = 0; i < 256; ++i) {
+            if (bytes_[i] == 0) {
+                throw std::runtime_error("Byte token for <0x" + std::to_string(i) + "> is not set");
+            }
+        }
+
+        // Build suffix list in lexicographical order of reversed strings
+        std::vector<std::string> suffixes;
+        for (const auto & pair : suffix_to_score) {
+            suffixes.push_back(pair.first);
+        }
+        suffixes.push_back("");  // Empty suffix
+
+        std::sort(suffixes.begin(), suffixes.end(), [](const std::string & a, const std::string & b) {
+            std::string rev_a(a.rbegin(), a.rend());
+            std::string rev_b(b.rbegin(), b.rend());
+            return rev_a < rev_b;
+        });
+
+        // Build suffix_to_id and to_suffix_id_
+        std::unordered_map<std::string, int32_t> suffix_to_id;
+        int32_t num_pieces = 0;
+
+        for (const auto & suffix : suffixes) {
+            suffix_to_id[suffix] = num_pieces;
+            if (!suffix.empty()) {
+                std::vector<uint32_t> cpts = unicode_cpts_from_utf8(suffix);
+
+                std::string remaining;
+                for (size_t i = 1; i < cpts.size(); ++i) {
+                    remaining += unicode_cpt_to_utf8(cpts[i]);
+                }
+
+                int64_t piece_code = (static_cast<int64_t>(cpts[0]) << 32) | suffix_to_id[remaining];
+                to_suffix_id_[piece_code] = num_pieces;
+
+                // Count number of pieces for this suffix
+                int32_t pieces_for_suffix = 1; // sentinel row
+                for (int32_t piece_length = static_cast<int32_t>(cpts.size()); piece_length > 0; --piece_length) {
+                    std::string piece;
+                    for (int32_t i = 0; i < piece_length; ++i) {
+                        piece += unicode_cpt_to_utf8(cpts[i]);
+                    }
+                    if (suffix_to_score.find(piece) != suffix_to_score.end()) {
+                        pieces_for_suffix++;
+                    }
+                }
+                num_pieces += pieces_for_suffix;
+            } else {
+                num_pieces++;  // Empty suffix contributes one piece (sentinel row)
+            }
+        }
+
+        // Build flattened table
+        table_.resize(num_pieces, std::vector<int32_t>(4, 0));
+        int32_t table_idx = 0;
+
+        for (const auto & suffix : suffixes) {
+            // Add all prefixes of the suffix to the table (in decreasing order of length)
+            std::vector<uint32_t> cpts = unicode_cpts_from_utf8(suffix);
+            for (int32_t piece_length = static_cast<int32_t>(cpts.size()); piece_length > 0; --piece_length) {
+                std::string piece;
+                for (int32_t i = 0; i < piece_length; ++i) {
+                    piece += unicode_cpt_to_utf8(cpts[i]);
+                }
+
+                auto score_it = suffix_to_score.find(piece);
+                if (score_it == suffix_to_score.end()) {
+                    continue;
+                }
+
+                table_[table_idx][TABLE_PIECE_LENGTH] = piece_length;
+                auto token_it = token_to_id.find(piece);
+                table_[table_idx][TABLE_TOKEN_ID] = (token_it != token_to_id.end()) ? token_it->second : -1;
+
+                float score = score_it->second;
+                table_[table_idx][TABLE_SCORE] = std::isfinite(score) ?
+                    static_cast<int32_t>(std::round(score * 1e4)) : INVALID_SCORE;
+                table_[table_idx][TABLE_PIECE_ID] = suffix_to_id[piece];
+
+                table_idx++;
+            }
+
+            // Add sentinel row
+            table_[table_idx][TABLE_PIECE_LENGTH] = 1;
+            table_[table_idx][TABLE_TOKEN_ID] = -1;
+            table_[table_idx][TABLE_SCORE] = UNKNOWN_SCORE;
+            table_idx++;
+        }
+    }
+
+    std::vector<llama_token> encode(const std::string & text) const {
+        std::vector<uint32_t> unicode_data = unicode_cpts_from_utf8(text);
+        // Skip the first code point if it is a BOM (Byte Order Mark)
+        if (!unicode_data.empty() && unicode_data[0] == 0xFEFF) {
+            unicode_data.erase(unicode_data.begin());
+        }
+
+        if (unicode_data.empty()) {
+            return {};
+        }
+
+        const size_t data_len = unicode_data.size();
+
+        // Initialize scores array (dynamic programming)
+        std::vector<int64_t> scores(data_len + 1, static_cast<int64_t>(1) << 60);
+        scores[data_len] = 0;
+
+        // Path array to track best tokenization
+        std::vector<std::vector<int32_t>> path(data_len + 1, std::vector<int32_t>(3, 0));
+
+        int32_t suffix_id = 0;
+
+        // Process from end to beginning
+        for (int i = static_cast<int>(data_len) - 1; i >= 0; --i) {
+            uint32_t c = unicode_data[i];
+
+            // Find next suffix ID
+            for (size_t p = suffix_id; p < table_.size(); ++p) {
+                int64_t piece_code = (static_cast<int64_t>(c) << 32) | table_[p][TABLE_PIECE_ID];
+                auto it = to_suffix_id_.find(piece_code);
+                suffix_id = (it != to_suffix_id_.end()) ? it->second : 0;
+
+                if (suffix_id > 0 || table_[p][TABLE_SCORE] == UNKNOWN_SCORE) {
+                    break;
+                }
+            }
+
+            // Update best path
+            for (size_t p = suffix_id; p < table_.size(); ++p) {
+                int32_t score = table_[p][TABLE_SCORE];
+                if (score > INVALID_SCORE) {
+                    int32_t piece_length = table_[p][TABLE_PIECE_LENGTH];
+                    int64_t s = scores[i + piece_length] - score;
+
+                    if (s < scores[i]) {
+                        scores[i] = s;
+                        path[i][PATH_TOKEN_LENGTH] = piece_length;
+                        path[i][PATH_TOKEN_ID] = table_[p][TABLE_TOKEN_ID];
+                        path[i][PATH_NUM_TOKENS] = path[i + piece_length][PATH_NUM_TOKENS] + 1;
+
+                        if (score == UNKNOWN_SCORE) {
+                            // Add UTF-8 byte count
+                            path[i][PATH_NUM_TOKENS] += (c >= 0x80) + (c >= 0x800) + (c >= 0x10000);
+                        }
+                    }
+                }
+
+                if (score == UNKNOWN_SCORE) {
+                    break;
+                }
+            }
+        }
+
+        // Decode the best path
+        std::vector<llama_token> token_ids;
+        token_ids.reserve(path[0][PATH_NUM_TOKENS]);
+
+        int pos = 0;
+        while (pos < static_cast<int>(data_len)) {
+            if (path[pos][PATH_TOKEN_ID] >= 0) {
+                token_ids.push_back(path[pos][PATH_TOKEN_ID]);
+            } else {
+                // Fall back to byte tokens
+                uint32_t c = unicode_data[pos];
+                int s = 1 + (c >= 0x80) + (c >= 0x800) + (c >= 0x10000);
+
+                for (int i = 0; i < s; ++i) {
+                    uint8_t b;
+                    if (s == 1) {
+                        b = c;
+                    } else {
+                        if (i == 0) {
+                            b = (0xF00 >> s) & 0xFF;
+                        } else {
+                            b = 0x80;
+                        }
+                    }
+                    token_ids.push_back(bytes_[b | ((c >> ((s - i - 1) * 6)) & 0x3F)]);
+                }
+            }
+
+            assert(path[pos][PATH_TOKEN_LENGTH] > 0);
+            pos += path[pos][PATH_TOKEN_LENGTH];
+        }
+
+        return token_ids;
+    }
+private:
+    // Constants for table structure
+    static constexpr int32_t TABLE_PIECE_LENGTH = 0;
+    static constexpr int32_t TABLE_TOKEN_ID     = 1;
+    static constexpr int32_t TABLE_SCORE        = 2;
+    static constexpr int32_t TABLE_PIECE_ID     = 3;
+
+    // Constants for path array
+    static constexpr int32_t PATH_TOKEN_LENGTH  = 0;
+    static constexpr int32_t PATH_TOKEN_ID      = 1;
+    static constexpr int32_t PATH_NUM_TOKENS    = 2;
+
+    // Score constants
+    static constexpr int32_t INVALID_SCORE = -20000000;
+    static constexpr int32_t UNKNOWN_SCORE = -10000000;
+
+    // List of tokens in the vocabulary
+    std::vector<std::string> tokens_;
+
+    // Mapping from byte code point to token ID (for byte fallback)
+    std::vector<llama_token> bytes_;
+
+    // Mapping from piece code to suffix ID
+    std::unordered_map<int64_t, int32_t> to_suffix_id_;
+
+    // Flattened table representing the Trie structure
+    // Each row contains: [piece_length, token_id, score, piece_id]
+    std::vector<std::vector<int32_t>> table_;
+};
+
+struct llm_tokenizer_plamo2_session {
+    llm_tokenizer_plamo2_session(const llm_tokenizer_plamo2 & tokenizer) : tokenizer(tokenizer) {}
+
+    void tokenize(const std::string & text, std::vector<llama_token> & output) {
+        std::vector<llama_token> tokens = tokenizer.encode(text);
+        output.insert(output.end(), tokens.begin(), tokens.end());
+    }
+
+private:
+    const llm_tokenizer_plamo2 & tokenizer;
+};
+
 //
 // impl
 //
@@ -1499,6 +1778,16 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
             special_unk_id = LLAMA_TOKEN_NULL;
             special_sep_id = LLAMA_TOKEN_NULL;
             special_pad_id = LLAMA_TOKEN_NULL;
+        } else if (tokenizer_model == "plamo2") {
+            type = LLAMA_VOCAB_TYPE_PLAMO2;
+
+            // PLaMo-2 default special tokens (these will be overridden by model config)
+            special_bos_id = 1;  // <|plamo:bos|>
+            special_eos_id = 2;  // <|plamo:eos|>
+            special_unk_id = 0;  // <|plamo:unk|>
+            special_sep_id = LLAMA_TOKEN_NULL;
+            special_pad_id = 3;  // <|plamo:pad|>
+            special_mask_id = LLAMA_TOKEN_NULL;
         } else {
             throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str()));
         }
@@ -2145,13 +2434,14 @@ enum llama_vocab_type llama_vocab::impl::get_type() const {
 
 std::string llama_vocab::impl::type_name() const{
     switch (type) {
-        case LLAMA_VOCAB_TYPE_NONE: return "no vocab";
-        case LLAMA_VOCAB_TYPE_SPM:  return "SPM";
-        case LLAMA_VOCAB_TYPE_BPE:  return "BPE";
-        case LLAMA_VOCAB_TYPE_WPM:  return "WPM";
-        case LLAMA_VOCAB_TYPE_UGM:  return "UGM";
-        case LLAMA_VOCAB_TYPE_RWKV: return "RWKV";
-        default:                    return "unknown";
+        case LLAMA_VOCAB_TYPE_NONE:   return "no vocab";
+        case LLAMA_VOCAB_TYPE_SPM:    return "SPM";
+        case LLAMA_VOCAB_TYPE_BPE:    return "BPE";
+        case LLAMA_VOCAB_TYPE_WPM:    return "WPM";
+        case LLAMA_VOCAB_TYPE_UGM:    return "UGM";
+        case LLAMA_VOCAB_TYPE_RWKV:   return "RWKV";
+        case LLAMA_VOCAB_TYPE_PLAMO2: return "PLaMo2";
+        default:                      return "unknown";
     }
 }
 
@@ -2234,6 +2524,9 @@ void llama_vocab::impl::init_tokenizer(enum llama_vocab_type type) {
         case LLAMA_VOCAB_TYPE_RWKV:
             tokenizer = std::make_unique<llm_tokenizer_rwkv>(vocab);
             break;
+        case LLAMA_VOCAB_TYPE_PLAMO2:
+            tokenizer = std::make_unique<llm_tokenizer_plamo2>(vocab);
+            break;
         default:
             GGML_ABORT("unsupported vocab type");
     }
@@ -2566,6 +2859,23 @@ std::vector<llama_token> llama_vocab::impl::tokenize(
                     if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
                         std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
 
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
+#endif
+
+                        session.tokenize(text, output);
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        output.push_back(fragment.token);
+                    }
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_PLAMO2:
+            {
+                llm_tokenizer_plamo2_session session(*static_cast<const llm_tokenizer_plamo2 *>(tokenizer.get()));
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
+
 #ifdef PRETOKENIZERDEBUG
                         LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
 #endif
@@ -2664,6 +2974,24 @@ int32_t llama_vocab::impl::token_to_piece(llama_token token, char * buf, int32_t
                 memcpy(buf, result.data(), result.size());
                 return (int)result.size();
             }
+            case LLAMA_VOCAB_TYPE_PLAMO2: {
+                // PLaMo-2 uses similar token handling as BPE/SPM
+                if (vocab.is_byte(token)) {
+                    // Handle byte tokens like <0xXX>
+                    if (token_text.length() == 6 && token_text.substr(0, 3) == "<0x" && token_text.back() == '>') {
+                        int hex_val = std::stoi(token_text.substr(3, 2), nullptr, 16);
+                        if (length < 1) {
+                            return -1;
+                        }
+                        buf[0] = static_cast<char>(hex_val);
+                        return 1;
+                    }
+                }
+
+                // Normal token - just copy the text
+                std::string result = token_text;
+                return _try_copy(result.data(), result.size());
+            }
             default:
                 GGML_ABORT("fatal error");
         }
@@ -2908,6 +3236,12 @@ llama_token llama_vocab::byte_to_token(uint8_t ch) const {
         case LLAMA_VOCAB_TYPE_BPE: {
             return pimpl->token_to_id.at(unicode_byte_to_utf8(ch));
         }
+        case LLAMA_VOCAB_TYPE_PLAMO2: {
+            // PLaMo-2 uses byte tokens in format <0xXX>
+            char hex_str[8];
+            snprintf(hex_str, sizeof(hex_str), "<0x%02X>", ch);
+            return pimpl->token_to_id.at(hex_str);
+        }
         default:
             GGML_ABORT("fatal error");
     }
@@ -3385,4 +3719,3 @@ int32_t llama_detokenize(
                         bool   unparse_special) {
     return vocab->detokenize(tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
 }
-