]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
model : add grok-2 support (#15539)
authorSigbjørn Skjæret <redacted>
Sun, 14 Sep 2025 21:00:59 +0000 (23:00 +0200)
committerGitHub <redacted>
Sun, 14 Sep 2025 21:00:59 +0000 (23:00 +0200)
* add grok-2 support

* type fix

* type fix

* type fix

* "fix" vocab for invalid sequences

* fix expert tensor mapping and spaces in vocab

* add chat template

* fix norm tensor mapping

* rename layer_out_norm to ffn_post_norm

* ensure ffn_post_norm is mapped

* fix experts merging

* remove erroneous FFN_GATE entry

* concatenate split tensors and add more metadata

* process all expert layers and try cat instead of hstack

* add support for community BPE vocab

* fix expert feed forward length and ffn_down concat

* commit this too

* add ffn_up/gate/down, unsure if sequence is right

* add ffn_gate/down/up to tensor names

* correct residual moe (still not working)

* mess--

* fix embedding scale being applied twice

* add built in chat template

* change beta fast for grok if default value

* remove spm vocab in favor of community bpe vocab

* change attention temp length metadata type to integer

* update attention temp length metadata

* remove comment

* replace M_SQRT2 with std::sqrt(2)

* add yarn metadata, move defaults to hparams

16 files changed:
common/common.h
convert_hf_to_gguf.py
convert_hf_to_gguf_update.py
gguf-py/gguf/constants.py
gguf-py/gguf/gguf_writer.py
gguf-py/gguf/tensor_mapping.py
src/llama-arch.cpp
src/llama-arch.h
src/llama-chat.cpp
src/llama-chat.h
src/llama-context.cpp
src/llama-graph.cpp
src/llama-hparams.h
src/llama-model.cpp
src/llama-vocab.cpp
src/llama-vocab.h

index cf57d48415bd1ea9d51b8dfa0b218ec0224438d0..5063d73f9636946f2bf398d159fbd7004a14b508 100644 (file)
@@ -288,9 +288,9 @@ struct common_params {
     float   rope_freq_base        =  0.0f; // RoPE base frequency
     float   rope_freq_scale       =  0.0f; // RoPE frequency scaling factor
     float   yarn_ext_factor       = -1.0f; // YaRN extrapolation mix factor
-    float   yarn_attn_factor      =  1.0f; // YaRN magnitude scaling factor
-    float   yarn_beta_fast        = 32.0f; // YaRN low correction dim
-    float   yarn_beta_slow        =  1.0f; // YaRN high correction dim
+    float   yarn_attn_factor      = -1.0f; // YaRN magnitude scaling factor
+    float   yarn_beta_fast        = -1.0f; // YaRN low correction dim
+    float   yarn_beta_slow        = -1.0f; // YaRN high correction dim
     int32_t yarn_orig_ctx         =     0; // YaRN original context length
 
     // offload params
index bbc21813f81ca171c9b01b509f200dc28e3cf575..855789f1ba1fad38a0ecde7605fa71dd6af6d328 100755 (executable)
@@ -735,6 +735,9 @@ class TextModel(ModelBase):
         if chkhsh == "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c":
             # ref: https://huggingface.co/Qwen/Qwen3-Embedding-0.6B
             res = "qwen2"
+        if chkhsh == "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273":
+            # ref: https://huggingface.co/alvarobartt/grok-2-tokenizer
+            res = "grok-2"
         if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5":
             # ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
             res = "llama-bpe"
@@ -2682,12 +2685,20 @@ class BitnetModel(TextModel):
         yield (new_name, data_torch)
 
 
-@ModelBase.register("GrokForCausalLM")
+@ModelBase.register("GrokForCausalLM", "Grok1ForCausalLM")
 class GrokModel(TextModel):
     model_arch = gguf.MODEL_ARCH.GROK
 
     def set_vocab(self):
-        self._set_vocab_sentencepiece()
+        if (self.dir_model / 'tokenizer.model').is_file():
+            self._set_vocab_sentencepiece()
+            return
+
+        if not (self.dir_model / 'tokenizer.json').is_file() or not (self.dir_model / 'chat_template.jinja').is_file():
+            logger.error('Error: Missing vocab and chat template, download files from https://huggingface.co/alvarobartt/grok-2-tokenizer')
+            sys.exit(1)
+
+        self._set_vocab_gpt2()
 
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
@@ -2695,11 +2706,46 @@ class GrokModel(TextModel):
     def set_gguf_parameters(self):
         super().set_gguf_parameters()
 
-    _experts: list[dict[str, Tensor]] | None = None
+        self.gguf_writer.add_attn_logit_softcapping(self.hparams.get("attn_logit_softcapping", 30.0))
+        self.gguf_writer.add_router_logit_softcapping(self.hparams.get("router_logit_softcapping", 30.0))
+        if (final_logit_softcap := self.hparams.get("final_logit_softcapping")):
+            self.gguf_writer.add_final_logit_softcapping(final_logit_softcap)
+
+        if (rope_dim := self.hparams.get("head_dim")) is None:
+            rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
+
+        if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
+            self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
+
+        # Treat "original" as "yarn", seems to have been a mistake
+        if self.hparams.get("rope_type") in ("yarn", "original"):
+            self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
+            self.gguf_writer.add_rope_scaling_factor(self.hparams["scaling_factor"])
+            self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["original_max_position_embeddings"])
+            self.gguf_writer.add_rope_scaling_yarn_ext_factor(self.hparams["extrapolation_factor"])
+            self.gguf_writer.add_rope_scaling_yarn_attn_factor(self.hparams["attn_factor"])
+            self.gguf_writer.add_rope_scaling_yarn_beta_fast(self.hparams["beta_fast"])
+            self.gguf_writer.add_rope_scaling_yarn_beta_slow(self.hparams["beta_slow"])
+
+        if temp_len := self.hparams.get("attn_temperature_len"):
+            self.gguf_writer.add_attn_temperature_length(temp_len)
+
+        self.gguf_writer.add_attn_output_scale(self.hparams.get("attn_output_multiplier", rope_dim**-0.5))
+        self.gguf_writer.add_embedding_scale(self.hparams["embedding_multiplier_scale"])
+        self.gguf_writer.add_logit_scale(self.hparams["output_multiplier_scale"])
+
+    _experts: list[dict[str, list[Tensor]]] | None = None
+    _cur_expert = ""
 
     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        tensors: list[tuple[str, Tensor]] = []
+        is_expert = ".moe." in name or ".block_sparse_moe.experts." in name
+
+        if not is_expert:
+            tensors.append((self.map_tensor_name(name), data_torch))
+
         # process the experts separately
-        if name.find(".moe.") != -1:
+        if is_expert or self._cur_expert:
             n_experts = self.hparams["num_local_experts"]
 
             assert bid is not None
@@ -2707,32 +2753,41 @@ class GrokModel(TextModel):
             if self._experts is None:
                 self._experts = [{} for _ in range(self.block_count)]
 
-            self._experts[bid][name] = data_torch
+            # concatenate split tensors
+            if name in self._experts[bid]:
+                self._cur_expert = name
+                self._experts[bid][name].append(data_torch)
+                return []
+            elif is_expert:
+                self._cur_expert = name
+                self._experts[bid][name] = [data_torch]
+                return []
+            else:
+                self._cur_expert = ""
 
-            if len(self._experts[bid]) >= n_experts * 3:
-                tensors: list[tuple[str, Tensor]] = []
+            for bid in range(self.block_count):
+                if len(self._experts[bid]) >= n_experts * 3:
+                    # merge the experts into a single 3d tensor
+                    for wid in [("linear", "w1", 0), ("linear_1", "w2", 1), ("linear_v", "w3", 0)]:
+                        datas: list[Tensor] = []
 
-                # merge the experts into a single 3d tensor
-                for wid in ["linear", "linear_1", "linear_v"]:
-                    datas: list[Tensor] = []
+                        for xid in range(n_experts):
+                            ename = f"transformer.decoder_layer.{bid}.moe.{xid}.{wid[0]}.weight"
+                            if ename not in self._experts[bid]:
+                                ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{wid[1]}.weight"
+                            tensor_list = self._experts[bid][ename]
+                            datas.append(torch.cat(tensor_list, dim=wid[2]) if len(tensor_list) > 1 else tensor_list[0])
+                            del self._experts[bid][ename]
 
-                    for xid in range(n_experts):
-                        ename = f"transformer.decoder_layer.{bid}.moe.{xid}.{wid}.weight"
-                        datas.append(self._experts[bid][ename])
-                        del self._experts[bid][ename]
+                        data_torch = torch.stack(datas, dim=0)
 
-                    data_torch = torch.stack(datas, dim=0)
+                        merged_name = f"transformer.decoder_layer.{bid}.moe.{wid[0]}.weight"
 
-                    merged_name = f"transformer.decoder_layer.{bid}.moe.{wid}.weight"
-
-                    new_name = self.map_tensor_name(merged_name)
+                        new_name = self.map_tensor_name(merged_name)
 
-                    tensors.append((new_name, data_torch))
-                return tensors
-            else:
-                return []
+                        yield (new_name, data_torch)
 
-        return [(self.map_tensor_name(name), data_torch)]
+        yield from tensors
 
 
 @ModelBase.register("DbrxForCausalLM")
index 575e05e193c2ea0a5cb6712cf429cda5ebab76c6..eb8fdfa7e1014af2017c5eb0c5c455c81d07d013 100755 (executable)
@@ -158,6 +158,7 @@ pre_computed_hashes = [
     {"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-34B-Base", "chkhsh": "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b"},
     {"name": "kimi-k2",   "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/moonshotai/Kimi-K2-Base",   "chkhsh": "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890"},
     {"name": "qwen2",     "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3-Embedding-0.6B", "chkhsh": "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c"},
+    {"name": "grok-2",    "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/alvarobartt/grok-2-tokenizer", "chkhsh": "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273"},
 ]
 
 
index 1e88b6505bae06be2f592160733ab38b967c1eb0..c7edef919b08e378150a3dea8903a22ebd901fc7 100644 (file)
@@ -111,6 +111,7 @@ class Keys:
         DECODER_START_TOKEN_ID            = "{arch}.decoder_start_token_id"
         DECODER_BLOCK_COUNT               = "{arch}.decoder_block_count"
         ATTN_LOGIT_SOFTCAPPING            = "{arch}.attn_logit_softcapping"
+        ROUTER_LOGIT_SOFTCAPPING          = "{arch}.router_logit_softcapping"
         FINAL_LOGIT_SOFTCAPPING           = "{arch}.final_logit_softcapping"
         SWIN_NORM                         = "{arch}.swin_norm"
         RESCALE_EVERY_N_LAYERS            = "{arch}.rescale_every_n_layers"
@@ -146,21 +147,27 @@ class Keys:
         REL_BUCKETS_COUNT            = "{arch}.attention.relative_buckets_count"
         SLIDING_WINDOW               = "{arch}.attention.sliding_window"
         SCALE                        = "{arch}.attention.scale"
+        OUTPUT_SCALE                 = "{arch}.attention.output_scale"
+        TEMPERATURE_LENGTH           = "{arch}.attention.temperature_length"
         KEY_LENGTH_MLA               = "{arch}.attention.key_length_mla"
         VALUE_LENGTH_MLA             = "{arch}.attention.value_length_mla"
         SHARED_KV_LAYERS             = "{arch}.attention.shared_kv_layers"
         SLIDING_WINDOW_PATTERN       = "{arch}.attention.sliding_window_pattern"
 
     class Rope:
-        DIMENSION_COUNT         = "{arch}.rope.dimension_count"
-        DIMENSION_SECTIONS      = "{arch}.rope.dimension_sections"
-        FREQ_BASE               = "{arch}.rope.freq_base"
-        SCALING_TYPE            = "{arch}.rope.scaling.type"
-        SCALING_FACTOR          = "{arch}.rope.scaling.factor"
-        SCALING_ATTN_FACTOR     = "{arch}.rope.scaling.attn_factor"
-        SCALING_ORIG_CTX_LEN    = "{arch}.rope.scaling.original_context_length"
-        SCALING_FINETUNED       = "{arch}.rope.scaling.finetuned"
-        SCALING_YARN_LOG_MUL    = "{arch}.rope.scaling.yarn_log_multiplier"
+        DIMENSION_COUNT          = "{arch}.rope.dimension_count"
+        DIMENSION_SECTIONS       = "{arch}.rope.dimension_sections"
+        FREQ_BASE                = "{arch}.rope.freq_base"
+        SCALING_TYPE             = "{arch}.rope.scaling.type"
+        SCALING_FACTOR           = "{arch}.rope.scaling.factor"
+        SCALING_ATTN_FACTOR      = "{arch}.rope.scaling.attn_factor"
+        SCALING_ORIG_CTX_LEN     = "{arch}.rope.scaling.original_context_length"
+        SCALING_FINETUNED        = "{arch}.rope.scaling.finetuned"
+        SCALING_YARN_LOG_MUL     = "{arch}.rope.scaling.yarn_log_multiplier"
+        SCALING_YARN_EXT_FACTOR  = "{arch}.rope.scaling.yarn_ext_factor"
+        SCALING_YARN_ATTN_FACTOR = "{arch}.rope.scaling.yarn_attn_factor"
+        SCALING_YARN_BETA_FAST   = "{arch}.rope.scaling.yarn_beta_fast"
+        SCALING_YARN_BETA_SLOW   = "{arch}.rope.scaling.yarn_beta_slow"
 
     class Split:
         LLM_KV_SPLIT_NO            = "split.no"
@@ -1114,6 +1121,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_GATE_EXP,
         MODEL_TENSOR.FFN_DOWN_EXP,
         MODEL_TENSOR.FFN_UP_EXP,
+        MODEL_TENSOR.FFN_POST_NORM,
         MODEL_TENSOR.LAYER_OUT_NORM,
     ],
     MODEL_ARCH.GPTNEOX: [
index 7ff12f7f5709ddcdb8cc1b616e1384e22386a53c..d925fca7e3e110e43e28a39e7427a0ebc62d06fa 100644 (file)
@@ -733,6 +733,9 @@ class GGUFWriter:
     def add_attn_logit_softcapping(self, value: float) -> None:
         self.add_float32(Keys.LLM.ATTN_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
 
+    def add_router_logit_softcapping(self, value: float) -> None:
+        self.add_float32(Keys.LLM.ROUTER_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
+
     def add_final_logit_softcapping(self, value: float) -> None:
         self.add_float32(Keys.LLM.FINAL_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
 
@@ -829,6 +832,12 @@ class GGUFWriter:
     def add_attention_scale(self, value: float) -> None:
         self.add_float32(Keys.Attention.SCALE.format(arch=self.arch), value)
 
+    def add_attn_output_scale(self, value: float) -> None:
+        self.add_float32(Keys.Attention.OUTPUT_SCALE.format(arch=self.arch), value)
+
+    def add_attn_temperature_length(self, value: int) -> None:
+        self.add_uint32(Keys.Attention.TEMPERATURE_LENGTH.format(arch=self.arch), value)
+
     def add_pooling_type(self, value: PoolingType) -> None:
         self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)
 
@@ -859,6 +868,18 @@ class GGUFWriter:
     def add_rope_scaling_yarn_log_mul(self, value: float) -> None:
         self.add_float32(Keys.Rope.SCALING_YARN_LOG_MUL.format(arch=self.arch), value)
 
+    def add_rope_scaling_yarn_ext_factor(self, value: float) -> None:
+        self.add_float32(Keys.Rope.SCALING_YARN_EXT_FACTOR.format(arch=self.arch), value)
+
+    def add_rope_scaling_yarn_attn_factor(self, value: float) -> None:
+        self.add_float32(Keys.Rope.SCALING_YARN_ATTN_FACTOR.format(arch=self.arch), value)
+
+    def add_rope_scaling_yarn_beta_fast(self, value: float) -> None:
+        self.add_float32(Keys.Rope.SCALING_YARN_BETA_FAST.format(arch=self.arch), value)
+
+    def add_rope_scaling_yarn_beta_slow(self, value: float) -> None:
+        self.add_float32(Keys.Rope.SCALING_YARN_BETA_SLOW.format(arch=self.arch), value)
+
     def add_ssm_conv_kernel(self, value: int) -> None:
         self.add_uint32(Keys.SSM.CONV_KERNEL.format(arch=self.arch), value)
 
index b0c3d65e95847d9aa87250d82c2e19e1dd9d3bd6..8fd9e454e0e814de27364b4d7383538b4e243490 100644 (file)
@@ -136,6 +136,7 @@ class TensorNameMap:
             "model.layers.{bid}.norm",                              # mamba-qbert
             "backbone.layers.{bid}.norm",                           # mamba
             "transformer.decoder_layer.{bid}.rms_norm",             # Grok
+            "model.layers.{bid}.pre_attn_norm",                     # grok-2
             "transformer.blocks.{bid}.norm_attn_norm.norm_1",       # dbrx
             "encoder.layers.{bid}.input_layernorm",                 # chatglm
             "transformer.layers.{bid}.attn_norm",                   # openelm
@@ -278,6 +279,7 @@ class TensorNameMap:
             "transformer.layer.{bid}.sa_layer_norm",           # distillbert
             "encoder.layers.{bid}.norm1",                      # nomic-bert
             "transformer.decoder_layer.{bid}.rms_norm_1",      # Grok
+            "model.layers.{bid}.post_attn_norm",               # grok-2
             "transformer.blocks.{bid}.norm_attn_norm.norm_2",  # dbrx
         ),
 
@@ -313,6 +315,7 @@ class TensorNameMap:
             "h.{bid}.ln_2",                                                  # gpt2
             "model.layers.{bid}.ffn_norm",                                   # internlm2
             "transformer.decoder_layer.{bid}.rms_norm_2",                    # Grok
+            "model.layers.{bid}.pre_moe_norm",                               # grok-2
             "encoder.layers.{bid}.post_attention_layernorm",                 # chatglm
             "transformer.layers.{bid}.ffn_norm",                             # openelm
             "model.layers.{bid}.pre_ff_layernorm",                           # jamba granite-hybrid
@@ -333,11 +336,12 @@ class TensorNameMap:
 
         # Post feed-forward norm
         MODEL_TENSOR.FFN_POST_NORM: (
-            "model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2
-            "layers.{bid}.post_feedforward_layernorm",       # embeddinggemma
-            "model.layers.{bid}.post_mlp_layernorm", # glm-4-0414
+            "model.layers.{bid}.post_feedforward_layernorm",  # gemma2 olmo2
+            "layers.{bid}.post_feedforward_layernorm",        # embeddinggemma
+            "model.layers.{bid}.post_mlp_layernorm",          # glm-4-0414
             "model.layers.layers.{bid}.post_mlp_norm.weight", # plamo2
             "model.layers.{bid}.feed_forward.up_proj",
+            "model.layers.{bid}.post_moe_norm",               # grok-2
         ),
 
         MODEL_TENSOR.FFN_GATE_INP: (
index 81f9746818d4aa520f136f7c586f76011b7762ca..3122331d8ed772aaf62f5116f0a81a881963a21e 100644 (file)
@@ -139,6 +139,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_DECODER_START_TOKEN_ID,            "%s.decoder_start_token_id"            },
     { LLM_KV_DECODER_BLOCK_COUNT,               "%s.decoder_block_count"               },
     { LLM_KV_ATTN_LOGIT_SOFTCAPPING,            "%s.attn_logit_softcapping"            },
+    { LLM_KV_ROUTER_LOGIT_SOFTCAPPING,          "%s.router_logit_softcapping"          },
     { LLM_KV_FINAL_LOGIT_SOFTCAPPING,           "%s.final_logit_softcapping"           },
     { LLM_KV_SWIN_NORM,                         "%s.swin_norm"                         },
     { LLM_KV_RESCALE_EVERY_N_LAYERS,            "%s.rescale_every_n_layers"            },
@@ -169,19 +170,25 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,       "%s.attention.relative_buckets_count"       },
     { LLM_KV_ATTENTION_SLIDING_WINDOW,               "%s.attention.sliding_window"               },
     { LLM_KV_ATTENTION_SCALE,                        "%s.attention.scale"                        },
+    { LLM_KV_ATTENTION_OUTPUT_SCALE,                 "%s.attention.output_scale"                 },
+    { LLM_KV_ATTENTION_TEMPERATURE_LENGTH,           "%s.attention.temperature_length"           },
     { LLM_KV_ATTENTION_KEY_LENGTH_MLA,               "%s.attention.key_length_mla"               },
     { LLM_KV_ATTENTION_VALUE_LENGTH_MLA,             "%s.attention.value_length_mla"             },
 
-    { LLM_KV_ROPE_DIMENSION_COUNT,      "%s.rope.dimension_count"                 },
-    { LLM_KV_ROPE_DIMENSION_SECTIONS,   "%s.rope.dimension_sections"              },
-    { LLM_KV_ROPE_FREQ_BASE,            "%s.rope.freq_base"                       },
-    { LLM_KV_ROPE_SCALE_LINEAR,         "%s.rope.scale_linear"                    },
-    { LLM_KV_ROPE_SCALING_TYPE,         "%s.rope.scaling.type"                    },
-    { LLM_KV_ROPE_SCALING_FACTOR,       "%s.rope.scaling.factor"                  },
-    { LLM_KV_ROPE_SCALING_ATTN_FACTOR,  "%s.rope.scaling.attn_factor"             },
-    { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" },
-    { LLM_KV_ROPE_SCALING_FINETUNED,    "%s.rope.scaling.finetuned"               },
-    { LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier"     },
+    { LLM_KV_ROPE_DIMENSION_COUNT,          "%s.rope.dimension_count"                 },
+    { LLM_KV_ROPE_DIMENSION_SECTIONS,       "%s.rope.dimension_sections"              },
+    { LLM_KV_ROPE_FREQ_BASE,                "%s.rope.freq_base"                       },
+    { LLM_KV_ROPE_SCALE_LINEAR,             "%s.rope.scale_linear"                    },
+    { LLM_KV_ROPE_SCALING_TYPE,             "%s.rope.scaling.type"                    },
+    { LLM_KV_ROPE_SCALING_FACTOR,           "%s.rope.scaling.factor"                  },
+    { LLM_KV_ROPE_SCALING_ATTN_FACTOR,      "%s.rope.scaling.attn_factor"             },
+    { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,     "%s.rope.scaling.original_context_length" },
+    { LLM_KV_ROPE_SCALING_FINETUNED,        "%s.rope.scaling.finetuned"               },
+    { LLM_KV_ROPE_SCALING_YARN_LOG_MUL,     "%s.rope.scaling.yarn_log_multiplier"     },
+    { LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR,  "%s.rope.scaling.yarn_ext_factor"         },
+    { LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, "%s.rope.scaling.yarn_attn_factor"        },
+    { LLM_KV_ROPE_SCALING_YARN_BETA_FAST,   "%s.rope.scaling.yarn_beta_fast"          },
+    { LLM_KV_ROPE_SCALING_YARN_BETA_SLOW,   "%s.rope.scaling.yarn_beta_slow"          },
 
     { LLM_KV_SPLIT_NO,            "split.no"            },
     { LLM_KV_SPLIT_COUNT,         "split.count"         },
@@ -398,12 +405,16 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
             { LLM_TENSOR_FFN_GATE_INP,    "blk.%d.ffn_gate_inp" },
             { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
             { LLM_TENSOR_FFN_GATE_EXP,    "blk.%d.ffn_gate.%d" },
             { LLM_TENSOR_FFN_DOWN_EXP,    "blk.%d.ffn_down.%d" },
             { LLM_TENSOR_FFN_UP_EXP,      "blk.%d.ffn_up.%d" },
             { LLM_TENSOR_FFN_GATE_EXPS,   "blk.%d.ffn_gate_exps" },
             { LLM_TENSOR_FFN_DOWN_EXPS,   "blk.%d.ffn_down_exps" },
             { LLM_TENSOR_FFN_UP_EXPS,     "blk.%d.ffn_up_exps" },
+            { LLM_TENSOR_FFN_POST_NORM,   "blk.%d.post_ffw_norm" },
             { LLM_TENSOR_LAYER_OUT_NORM,  "blk.%d.layer_output_norm" },
             { LLM_TENSOR_ATTN_OUT_NORM,   "blk.%d.attn_output_norm" },
         },
index 6ee3707dcfbf64deea1c2d112803d24b2bc3eda7..a4ac28b5252f9337d1122c8d14d907b50f7a55f3 100644 (file)
@@ -143,6 +143,7 @@ enum llm_kv {
     LLM_KV_DECODER_START_TOKEN_ID,
     LLM_KV_DECODER_BLOCK_COUNT,
     LLM_KV_ATTN_LOGIT_SOFTCAPPING,
+    LLM_KV_ROUTER_LOGIT_SOFTCAPPING,
     LLM_KV_FINAL_LOGIT_SOFTCAPPING,
     LLM_KV_SWIN_NORM,
     LLM_KV_RESCALE_EVERY_N_LAYERS,
@@ -173,6 +174,8 @@ enum llm_kv {
     LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
     LLM_KV_ATTENTION_SLIDING_WINDOW,
     LLM_KV_ATTENTION_SCALE,
+    LLM_KV_ATTENTION_OUTPUT_SCALE,
+    LLM_KV_ATTENTION_TEMPERATURE_LENGTH,
     LLM_KV_ATTENTION_KEY_LENGTH_MLA,
     LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
 
@@ -186,6 +189,10 @@ enum llm_kv {
     LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,
     LLM_KV_ROPE_SCALING_FINETUNED,
     LLM_KV_ROPE_SCALING_YARN_LOG_MUL,
+    LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR,
+    LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR,
+    LLM_KV_ROPE_SCALING_YARN_BETA_FAST,
+    LLM_KV_ROPE_SCALING_YARN_BETA_SLOW,
 
     LLM_KV_SPLIT_NO,
     LLM_KV_SPLIT_COUNT,
index 9d8e57eac1f69b04d960b9fa07334c9130c83ab7..66e6c6a38f1cd4ac29080aa41ea8fdd56dc4e494 100644 (file)
@@ -70,6 +70,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
     { "hunyuan-dense",     LLM_CHAT_TEMPLATE_HUNYUAN_DENSE     },
     { "kimi-k2",           LLM_CHAT_TEMPLATE_KIMI_K2           },
     { "seed_oss",          LLM_CHAT_TEMPLATE_SEED_OSS          },
+    { "grok-2",            LLM_CHAT_TEMPLATE_GROK_2            },
 };
 
 llm_chat_template llm_chat_template_from_str(const std::string & name) {
@@ -204,6 +205,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
         return LLM_CHAT_TEMPLATE_KIMI_K2;
     } else if (tmpl_contains("<seed:bos>")) {
         return LLM_CHAT_TEMPLATE_SEED_OSS;
+    } else if (tmpl_contains("'Assistant: '  + message['content'] + '<|separator|>")) {
+        return LLM_CHAT_TEMPLATE_GROK_2;
     }
     return LLM_CHAT_TEMPLATE_UNKNOWN;
 }
@@ -763,6 +766,20 @@ int32_t llm_chat_apply_template(
         if (add_ass) {
             ss << "<seed:bos>assistant\n";
         }
+    } else if (tmpl == LLM_CHAT_TEMPLATE_GROK_2) {
+        for (auto message : chat) {
+            std::string role(message->role);
+            if (role == "system") {
+                ss << "System: " << trim(message->content) << "<|separator|>\n\n";
+            } else if (role == "user") {
+                ss << "Human: " << trim(message->content) << "<|separator|>\n\n";
+            } else if (role == "assistant") {
+                ss << "Assistant: " << message->content << "<|separator|>\n\n";
+            }
+        }
+        if (add_ass) {
+            ss << "Assistant:";
+        }
     } else {
         // template not supported
         return -1;
index 21d53ed08b4c3e39146b8d5c728b5bd61502c21c..5a87d9ab627bcccd214c41b4c3260449357c1d7b 100644 (file)
@@ -50,6 +50,7 @@ enum llm_chat_template {
     LLM_CHAT_TEMPLATE_HUNYUAN_DENSE,
     LLM_CHAT_TEMPLATE_KIMI_K2,
     LLM_CHAT_TEMPLATE_SEED_OSS,
+    LLM_CHAT_TEMPLATE_GROK_2,
     LLM_CHAT_TEMPLATE_UNKNOWN,
 };
 
index 289a32b6d347319cf21e584945a18a0a5d15ddcb..e6f76421cf1319702d476e751db6e45b3ca498ae 100644 (file)
@@ -35,10 +35,10 @@ llama_context::llama_context(
 
     cparams.n_threads        = params.n_threads;
     cparams.n_threads_batch  = params.n_threads_batch;
-    cparams.yarn_ext_factor  = params.yarn_ext_factor;
-    cparams.yarn_attn_factor = params.yarn_attn_factor;
-    cparams.yarn_beta_fast   = params.yarn_beta_fast;
-    cparams.yarn_beta_slow   = params.yarn_beta_slow;
+    cparams.yarn_ext_factor  = params.yarn_ext_factor  >= 0.0f ? params.yarn_ext_factor  : hparams.yarn_ext_factor;
+    cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor;
+    cparams.yarn_beta_fast   = params.yarn_beta_fast   >= 0.0f ? params.yarn_beta_fast   : hparams.yarn_beta_fast;
+    cparams.yarn_beta_slow   = params.yarn_beta_slow   >= 0.0f ? params.yarn_beta_slow   : hparams.yarn_beta_slow;
     cparams.embeddings       = params.embeddings;
     cparams.offload_kqv      = params.offload_kqv;
     cparams.no_perf          = params.no_perf;
@@ -2263,9 +2263,9 @@ llama_context_params llama_context_default_params() {
         /*.rope_freq_base              =*/ 0.0f,
         /*.rope_freq_scale             =*/ 0.0f,
         /*.yarn_ext_factor             =*/ -1.0f,
-        /*.yarn_attn_factor            =*/ 1.0f,
-        /*.yarn_beta_fast              =*/ 32.0f,
-        /*.yarn_beta_slow              =*/ 1.0f,
+        /*.yarn_attn_factor            =*/ -1.0f,
+        /*.yarn_beta_fast              =*/ -1.0f,
+        /*.yarn_beta_slow              =*/ -1.0f,
         /*.yarn_orig_ctx               =*/ 0,
         /*.defrag_thold                =*/ -1.0f,
         /*.cb_eval                     =*/ nullptr,
index ddc772b179f7e713968533b0d863de8fb4b543d1..9f2e417f1ff4b19be80e6371ff2048b7bad29c7c 100644 (file)
@@ -1335,14 +1335,14 @@ ggml_tensor * llm_graph_context::build_attn_mha(
 
         if (arch == LLM_ARCH_GROK) {
             // need to do the following:
-            // multiply by attn_output_multiplyer of 0.08838834764831845
+            // multiply by attn_output_multiplier
             // and then :
             // kq = 30 * tanh(kq / 30)
             // before the softmax below
 
-            kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, 0.08838834764831845f/30.0f));
+            kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, hparams.f_attn_out_scale / hparams.f_attn_logit_softcapping));
             cb(kq, "kq_tanh", il);
-            kq = ggml_scale(ctx0, kq, 30);
+            kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
             cb(kq, "kq_scaled", il);
         }
 
index 4dca2ca41d095af5f7c2b66602d105e627f4fe9e..116d728e8c9f27b9991985055d322bb813573f8e 100644 (file)
@@ -82,8 +82,9 @@ struct llama_hparams {
     float f_norm_rms_eps;
     float f_norm_group_eps;
 
-    float f_attn_logit_softcapping  = 50.0f;
-    float f_final_logit_softcapping = 30.0f;
+    float f_attn_logit_softcapping   = 50.0f;
+    float f_router_logit_softcapping = 30.0f;
+    float f_final_logit_softcapping  = 30.0f;
 
     // for RWKV
     uint32_t rescale_every_n_layers = 0;
@@ -104,6 +105,11 @@ struct llama_hparams {
     uint32_t n_ctx_orig_yarn;
     float    rope_yarn_log_mul = 0.0f;
 
+    float    yarn_ext_factor  = -1.0f;
+    float    yarn_attn_factor =  1.0f;
+    float    yarn_beta_fast   = 32.0f;
+    float    yarn_beta_slow   =  1.0f;
+
     std::array<int, 4> rope_sections;
 
     // Sliding Window Attention (SWA)
@@ -136,6 +142,10 @@ struct llama_hparams {
     float f_embedding_scale = 0.0f;
     float f_attention_scale = 0.0f;
 
+    // grok-2
+    float    f_attn_out_scale = 0.0f;
+    uint32_t attn_temp_length = 0;
+
     bool causal_attn   = true;
     bool use_alibi     = false;
     bool attn_soft_cap = false;
index 818b209641a5a2c4a7f9da39065c2101d14dadc3..4864ed8e72abdda5f19c5183d7caa4edcc0e54ce 100644 (file)
@@ -685,7 +685,30 @@ void llama_model::load_hparams(llama_model_loader & ml) {
             } break;
         case LLM_ARCH_GROK:
             {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+                // defaults for old GGUFs
+                hparams.yarn_beta_fast = 8.0f;
+                hparams.f_logit_scale = 0.5773502691896257f;
+                hparams.f_embedding_scale = 78.38367176906169f;
+                hparams.f_attn_out_scale = 0.08838834764831845f;
+                hparams.f_attn_logit_softcapping = 30.0f;
+                hparams.f_router_logit_softcapping = 30.0f;
+                // no final_logit_softcapping in grok-1
+                hparams.f_final_logit_softcapping = 0.0f;
+
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,  hparams.f_norm_rms_eps);
+                ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH,   hparams.n_ff_exp, false);
+                ml.get_key(LLM_KV_LOGIT_SCALE,                  hparams.f_logit_scale, false);
+                ml.get_key(LLM_KV_EMBEDDING_SCALE,              hparams.f_embedding_scale, false);
+                ml.get_key(LLM_KV_ATTENTION_OUTPUT_SCALE,       hparams.f_attn_out_scale, false);
+                ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING,       hparams.f_attn_logit_softcapping, false);
+                ml.get_key(LLM_KV_ROUTER_LOGIT_SOFTCAPPING,     hparams.f_router_logit_softcapping, false);
+                ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING,      hparams.f_final_logit_softcapping, false);
+
+                ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH,  hparams.attn_temp_length, false);
+                ml.get_key(LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR,  hparams.yarn_ext_factor, false);
+                ml.get_key(LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, hparams.yarn_attn_factor, false);
+                ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST,   hparams.yarn_beta_fast, false);
+                ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW,   hparams.yarn_beta_slow, false);
 
                 switch (hparams.n_layer) {
                     case 64: type = LLM_TYPE_314B; break;
@@ -2540,6 +2563,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
                     }
 
+                    const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff/* / n_expert_used*/; // grok-1 n_ff_exp == n_ff
                     for (int i = 0; i < n_layer; ++i) {
                         auto & layer = layers[i];
 
@@ -2554,12 +2578,19 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
 
                         layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
 
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff,   n_embd}, TENSOR_NOT_REQUIRED);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED);
+
                         layer.ffn_gate_inp  = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP,  "weight", i), {n_embd, n_expert}, 0);
-                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED);
-                        layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {  n_ff, n_embd, n_expert}, 0);
-                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd,   n_ff, n_expert}, 0);
+                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd,   n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED);
+                        layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd,   n_expert}, 0);
+                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd,   n_ff_exp, n_expert}, 0);
 
-                        layer.layer_out_norm   = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED);
+                        if (!layer.ffn_post_norm) {
+                            layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
+                        }
                     }
                 } break;
             case LLM_ARCH_DBRX:
@@ -7028,9 +7059,6 @@ struct llm_build_grok : public llm_graph_context {
 
         inpL = build_inp_embd(model.tok_embd);
 
-        // multiply by embedding_multiplier_scale of 78.38367176906169
-        inpL = ggml_scale(ctx0, inpL, 78.38367176906169f);
-
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
@@ -7102,26 +7130,22 @@ struct llm_build_grok : public llm_graph_context {
                 inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
             }
 
-            // Grok
-            // if attn_out_norm is present then apply it before adding the input
-            if (model.layers[il].attn_out_norm) {
-                cur = build_norm(cur,
-                        model.layers[il].attn_out_norm, NULL,
-                        LLM_NORM_RMS, il);
-                cb(cur, "attn_out_norm", il);
-            }
+            cur = build_norm(cur,
+                    model.layers[il].attn_out_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_out_norm", il);
 
             ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
             cb(ffn_inp, "ffn_inp", il);
 
             // feed-forward network
-            // MoE branch
             cur = build_norm(ffn_inp,
                     model.layers[il].ffn_norm, NULL,
                     LLM_NORM_RMS, il);
             cb(cur, "ffn_norm", il);
 
-            cur = build_moe_ffn(cur,
+            // MoE branch
+            ggml_tensor * moe_out = build_moe_ffn(cur,
                     model.layers[il].ffn_gate_inp,
                     model.layers[il].ffn_up_exps,
                     model.layers[il].ffn_gate_exps,
@@ -7132,18 +7156,28 @@ struct llm_build_grok : public llm_graph_context {
                     false, 0.0,
                     LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
                     il);
-            cb(cur, "ffn_moe_out", il);
+            cb(moe_out, "ffn_moe_out", il);
 
-            // Grok
-            // if layer_out_norm is present then apply it before adding the input
-            // Idea: maybe ffn_out_norm is a better name
-            if (model.layers[il].layer_out_norm) {
-                cur = build_norm(cur,
-                        model.layers[il].layer_out_norm, NULL,
-                        LLM_NORM_RMS, il);
-                cb(cur, "layer_out_norm", il);
+            if (model.layers[il].ffn_up) {
+                ggml_tensor * ffn_out = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_GELU, LLM_FFN_PAR, il);
+                cb(ffn_out, "ffn_out", il);
+
+                cur = ggml_scale(ctx0, ggml_add(ctx0, ffn_out, moe_out), std::sqrt(2) / 2);
+                cb(cur, "ffn_out", il);
+            } else {
+                cur = moe_out;
             }
 
+            cur = build_norm(cur,
+                    model.layers[il].ffn_post_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_post_norm", il);
+
             cur = ggml_add(ctx0, cur, ffn_inp);
             cb(cur, "ffn_out", il);
 
@@ -7166,10 +7200,14 @@ struct llm_build_grok : public llm_graph_context {
         // lm_head
         cur = build_lora_mm(model.output, cur);
 
-        // Grok
-        // multiply logits by output_multiplier_scale of 0.5773502691896257
+        cur = ggml_scale(ctx0, cur, hparams.f_logit_scale);
 
-        cur = ggml_scale(ctx0, cur, 0.5773502691896257f);
+        // final logit soft-capping
+        if (hparams.f_final_logit_softcapping) {
+            cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
+            cur = ggml_tanh(ctx0, cur);
+            cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
+        }
 
         cb(cur, "result_output", -1);
         res->t_logits = cur;
index ca02b63a58407951c534e4544b608ddcf7fd1169..b551253afbe19da6f59122dd9b6838b10b3da19e 100644 (file)
@@ -434,6 +434,13 @@ struct llm_tokenizer_bpe : llm_tokenizer {
                     "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1}| ?[^\\s\\p{L}\\p{N}\\r\\n]+|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
                 };
                 break;
+            case LLAMA_VOCAB_PRE_TYPE_GROK_2:
+                regex_exprs = {
+                    // original regex from tokenizer.json
+                    // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
+                    "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
+                };
+                break;
             default:
                 // default regex for BPE tokenization pre-processing
                 regex_exprs = {
@@ -1974,6 +1981,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                 tokenizer_pre == "kimi-k2") {
                 pre_type = LLAMA_VOCAB_PRE_TYPE_KIMI_K2;
                 clean_spaces = false;
+            } else if (
+                tokenizer_pre == "grok-2") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_GROK_2;
+                clean_spaces = false;
             } else {
                 throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
             }
index 61b8124216847b2eb9d84586c8aae8c382a1c589..0d2f28c36c80dd714a6a2d01981872583dbbc7cd 100644 (file)
@@ -47,6 +47,7 @@ enum llama_vocab_pre_type {
     LLAMA_VOCAB_PRE_TYPE_HUNYUAN        = 36,
     LLAMA_VOCAB_PRE_TYPE_KIMI_K2        = 37,
     LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE  = 38,
+    LLAMA_VOCAB_PRE_TYPE_GROK_2         = 39,
 };
 
 struct LLM_KV;