]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
model : add support for Falcon-H1 family (#14534)
authoribrahim khadraoui <redacted>
Wed, 9 Jul 2025 08:03:49 +0000 (12:03 +0400)
committerGitHub <redacted>
Wed, 9 Jul 2025 08:03:49 +0000 (10:03 +0200)
* v1

* push more fixes

* another fix

* fix

* more fixes

* minor fix

* more cleaning on python code

* python fixes

* changed precision for multipliers float 32->64

* fixes

* another fix

* fix

* pre-norm -> norm

* fix

* Revert "fix"

This reverts commit 243e4d1a50bd73467d99f6b289b9a1826f83b94b.

* fix

* small fix ffn_norm

* try

* mix instead of max

* fix vocab size

* conflict solve

* fixed multipliers

* falcon-h1 specefic vocab resolved

* read arch from gguf.MODEL_ARCH

* mamba_d_ssm added to d_inner find_hparam

* remove unused functions from gguf_writer.py

* override modify_tensors instead of get_tensors

* fix conversion and d_inner

* added some cb functions for debugging puposes

* inp_out_ids moved outside of layers loop

* mup_vec create as float64

* fix rope_theta

* injected mup

* clean ups

* rm extra space

* rm unused MAMBA_CHUNK_SIZE

* rm unused key

* add bos False

* changed ROPE_TYPE

* cleaning debugging stuff

* cleaning debug quant

* fix comment

* some cleanups

* some cleanups

* Update src/llama-model-loader.cpp

* more cleanups

* moe cleanuips

* d_ssm -> d_inner;

* cleaning unused hparams

* cleanup

* more cleanups

* more cleanups on python conversion;

* minor cleanups

* Apply suggestions from code review

Co-authored-by: Georgi Gerganov <redacted>
* remove todo

* added falcon-h1

* tensor not required

* clean

* remove unneeded attributes

* more cleanups and fixed conversion

* remove final_norm

* flake8 fixes

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* flake8 fixes

* Update src/llama-hparams.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-arch.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <redacted>
* added hashes

* Update src/llama-arch.cpp

Co-authored-by: Georgi Gerganov <redacted>
* Update src/llama-vocab.cpp

Co-authored-by: Georgi Gerganov <redacted>
* update the update file

* Revert "update the update file"

This reverts commit 082ab4ad2a3927384d878666a5f8cae4eb15f577.

* fix: address suggestions

* fix: update convert_hf_to_gguf.py

* Update gguf-py/gguf/constants.py

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-model-loader.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* d_inner fixed

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* reshaping ssm_norm for 34B

* removing generate_mup

* remove duplicates metadata keys

* rm comment

* final comment

* fix unused args

* fix constants

* fix bad merge

* Update src/llama-model.cpp

Co-authored-by: compilade <redacted>
* falcon-h1: remove unused ssm_in_b and bad merge

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* falcon-h1: fix last comment

* Update convert_hf_to_gguf.py

Co-authored-by: compilade <redacted>
* falcon-h1: revert add_add_bos(False)

* falcon-h1: fix tied weights

* falcon-h1: remove whitespace

* falcon-h1: fix wrong size param

* falcon-h1: fix whitespace issues

---------

Co-authored-by: younesbelkada <redacted>
Co-authored-by: Younes B <redacted>
Co-authored-by: Georgi Gerganov <redacted>
Co-authored-by: Sigbjørn Skjæret <redacted>
Co-authored-by: compilade <redacted>
convert_hf_to_gguf.py
convert_hf_to_gguf_update.py
gguf-py/gguf/constants.py
gguf-py/gguf/tensor_mapping.py
src/llama-arch.cpp
src/llama-arch.h
src/llama-model.cpp
src/llama-vocab.cpp

index 5d12d4799fafcb90230cb2455d50e0d3c48aa7ef..4dedc020b61e158bd768bd5477184addac0de832 100755 (executable)
@@ -818,6 +818,18 @@ class TextModel(ModelBase):
         if chkhsh == "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664":
             # ref: https://huggingface.co/tencent/Hunyuan-A13B-Instruct
             res = "hunyuan"
+        if chkhsh == "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6":
+            # ref: https://huggingface.co/tiiuae/Falcon-H1-0.5B-Base
+            res = "falcon-h1"
+        if chkhsh == "60476e1243776c4fb1b993dbd7a5f15ac22f83c80afdf425fa5ae01c8d44ef86":
+            # ref: https://huggingface.co/tiiuae/Falcon-H1-1B-Base
+            res = "falcon-h1"
+        if chkhsh == "3eda48b4c4dc7de733d1a8b3e3b4a85243dbbf704da2ee9d42c6beced8897896":
+            # ref: https://huggingface.co/tiiuae/Falcon-H1-7B-Base
+            res = "falcon-h1"
+        if chkhsh == "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b":
+            # ref: https://huggingface.co/tiiuae/Falcon-H1-34B-Base
+            res = "falcon-h1"
 
         if res is None:
             logger.warning("\n")
@@ -4899,17 +4911,19 @@ class Mamba2Model(TextModel):
     def set_gguf_parameters(self):
         d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
         d_conv  = self.find_hparam(["conv_kernel",       "d_conv"],  optional=True) or 4
-        d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
+        d_inner = self.find_hparam(["mamba_d_ssm", "intermediate_size", "d_inner"], optional=True) or 2 * d_model
         d_state = self.find_hparam(["state_size",        "d_state"], optional=True) or 128
-        head_dim = self.find_hparam(["head_dim"],                    optional=True) or 64
+        head_dim = self.find_hparam(["mamba_d_head", "head_dim"],    optional=True) or 64
         n_group = self.find_hparam(["n_groups"],                     optional=True) or 1
 
         rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
 
         # Fail early for models which don't have a block expansion factor of 2
         # TODO: does this really matter?
-        assert d_inner == 2 * d_model
-        assert d_inner % head_dim == 0
+        # skip the assertion for FalconH1 Model
+        if self.model_arch != gguf.MODEL_ARCH.FALCON_H1:
+            assert d_inner == 2 * d_model
+            assert d_inner % head_dim == 0
 
         self.gguf_writer.add_context_length(2**20)  # arbitrary value; for those who use the default
         self.gguf_writer.add_embedding_length(d_model)
@@ -4946,7 +4960,7 @@ class Mamba2Model(TextModel):
             data_torch = data_torch.reshape((*data_torch.shape, 1))
         elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid):
             d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
-            d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
+            d_inner = self.find_hparam(["mamba_d_ssm", "intermediate_size", "d_inner"], optional=True) or 2 * d_model
             n_group = self.hparams.get("n_groups", 1)
             data_torch = data_torch.reshape((n_group, d_inner // n_group))
 
@@ -6539,6 +6553,113 @@ class UltravoxWhisperEncoderModel(WhisperEncoderModel):
         self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"])
 
 
+@ModelBase.register("FalconH1ForCausalLM")
+class FalconH1Model(Mamba2Model):
+    model_arch = gguf.MODEL_ARCH.FALCON_H1
+
+    def __init__(self, *args, **kwargs):
+        # Set the hparam prefixes for Falcon Mamba2
+        self.hparam_prefixes = ["mamba"]
+
+        # Initialize the base Mamba2Model
+        super().__init__(*args, **kwargs)
+
+        # Use Llama conversion for attention
+        self._transformer_model_class = LlamaModel
+
+        # n_group and d_inner are used during reshape_tensors for mamaba2
+        self.n_group = self.find_hparam(["n_groups"])
+        self.d_inner = self.find_hparam(["mamba_d_ssm"])
+        self.d_head = self.find_hparam(["d_head"])
+
+        # Initialize any Falcon Mamba2 specific attributes
+        self.has_attention = True  # Falcon Mamba2 has attention components
+
+        # Load Falcon-H1 multipliers from hyperparameters
+        self.attention_in_multiplier = self.find_hparam(["attention_in_multiplier"], optional=True)
+        self.attention_out_multiplier = self.find_hparam(["attention_out_multiplier"], optional=True)
+        self.ssm_in_multiplier = self.find_hparam(["ssm_in_multiplier"], optional=True)
+        self.ssm_out_multiplier = self.find_hparam(["ssm_out_multiplier"], optional=True)
+        self.mlp_multipliers = self.find_hparam(["mlp_multipliers"], optional=True)
+        self.ssm_multipliers = self.find_hparam(["ssm_multipliers"], optional=True)
+        self.intermediate_size = self.find_hparam(["intermediate_size"])
+        self.key_multiplier = self.find_hparam(["key_multiplier"], optional=True)
+
+    def find_hparam(self, keys: Iterable[str], *args, **kwargs) -> Any:
+        prefixed = []
+        for pfx in self.hparam_prefixes:
+            prefixed.extend(
+                "_".join([pfx, k])
+                for k in keys
+            )
+        keys = list(keys) + prefixed
+        return super().find_hparam(keys, *args, **kwargs)
+
+    def set_vocab(self):
+        self._set_vocab_gpt2()
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        tensors = list(super().modify_tensors(data_torch, name, bid))
+        tensor = tensors[0][1]
+
+        if "down_proj" in name:
+            tensor = tensor  * self.mlp_multipliers[1]
+        elif "gate_proj" in name:
+            tensor = tensor * self.mlp_multipliers[0]
+        elif "k_proj" in name:
+            tensor = tensor * self.key_multiplier * self.attention_in_multiplier
+        elif "q_proj" in name:
+            tensor = tensor * self.attention_in_multiplier
+        elif "v_proj" in name:
+            tensor = tensor * self.attention_in_multiplier
+        elif "o_proj" in name:
+            tensor = tensor * self.attention_out_multiplier
+        elif "out_proj" in name:
+            tensor = tensor * self.ssm_out_multiplier
+        elif "in_proj" in name:
+            tensor = tensor * self.ssm_in_multiplier
+            zxbcdt_multipliers = self.hparams["ssm_multipliers"]
+            intermediate_size = self.hparams["mamba_d_ssm"]
+            groups_time_state_size = self.hparams["mamba_n_groups"] * self.hparams["mamba_d_state"]
+            tensor[:intermediate_size, :] *= zxbcdt_multipliers[0]
+            tensor[intermediate_size:2 * intermediate_size, :] *= zxbcdt_multipliers[1]
+            tensor[2 * intermediate_size:2 * intermediate_size + groups_time_state_size, :] *= zxbcdt_multipliers[2]
+            tensor[2 * intermediate_size + groups_time_state_size:2 * intermediate_size + 2 * groups_time_state_size, :] *= zxbcdt_multipliers[3]
+            tensor[2 * intermediate_size + 2 * groups_time_state_size:, :] *= zxbcdt_multipliers[4]
+        elif "lm_head" in name:
+            tensor = tensor * self.hparams["lm_head_multiplier"]
+        elif "embed_tokens" in name:
+            tensor = tensor * self.hparams["embedding_multiplier"]
+        elif "mamba.norm" in name:
+            tensor = tensor.reshape(self.n_group, self.d_inner // self.n_group)
+
+        tensors = [(tensors[0][0], tensor)]
+        return tensors
+
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+
+        ## General Params ##
+        self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
+        # Override some Mamba2 defaults
+        self.gguf_writer.add_block_count(self.block_count)
+        self.gguf_writer.add_context_length(self.hparams.get("max_position_embeddings", 0))
+        self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
+
+        ## Attention params ##
+        self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) # Override value 0 from Mamba2
+        self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"])
+        self.gguf_writer.add_key_length(self.hparams["head_dim"])
+        self.gguf_writer.add_value_length(self.hparams["head_dim"])
+
+        ## Validation ##
+        assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported"
+        assert self.d_inner % self.d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {self.d_head}"
+
+        # Add any other Falcon Mamba2 specific configuration
+        self.gguf_writer.add_rope_freq_base(self.find_hparam(["rope_theta"]))
+
+
 @ModelBase.register("HunYuanMoEV1ForCausalLM")
 class HunYuanMoEModel(TextModel):
     model_arch = gguf.MODEL_ARCH.HUNYUAN_MOE
index 96a2b692a86c10d16a3920bb66b9c89484f1601c..15a326e695dd5ec56c6fb125cdcbea3fb502bdb6 100755 (executable)
@@ -138,6 +138,11 @@ pre_computed_hashes = [
     {"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"},
     {"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"},
     {"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"},
+    # falcon-h1 series uses 4 different tokenizers across model sizes (0.5b - 34b), hence we need to define 4 different hashes
+    {"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-0.5B-Base", "chkhsh": "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6"},
+    {"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-1B-Base", "chkhsh": "60476e1243776c4fb1b993dbd7a5f15ac22f83c80afdf425fa5ae01c8d44ef86"},
+    {"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-7B-Base", "chkhsh": "3eda48b4c4dc7de733d1a8b3e3b4a85243dbbf704da2ee9d42c6beced8897896"},
+    {"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-34B-Base", "chkhsh": "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b"},
 ]
 
 
index e938f8fa664dfea1ed5d7ad91724017bebff30a8..93eec43556c744cefcc953e3ea5231c04a922bc3 100644 (file)
@@ -288,6 +288,7 @@ class MODEL_ARCH(IntEnum):
     LLAMA4           = auto()
     DECI             = auto()
     FALCON           = auto()
+    FALCON_H1        = auto()
     BAICHUAN         = auto()
     GROK             = auto()
     GPT2             = auto()
@@ -662,6 +663,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.DOTS1:            "dots1",
     MODEL_ARCH.ARCEE:            "arcee",
     MODEL_ARCH.ERNIE4_5:         "ernie4_5",
+    MODEL_ARCH.FALCON_H1:        "falcon-h1",
     MODEL_ARCH.HUNYUAN_MOE:      "hunyuan-moe",
     MODEL_ARCH.SMOLLM3:          "smollm3",
 }
@@ -2215,6 +2217,40 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_DOWN,
         MODEL_TENSOR.FFN_UP,
     ],
+    MODEL_ARCH.FALCON_H1: [
+        # Token embedding
+        MODEL_TENSOR.TOKEN_EMBD,
+
+        # Input layernorm
+        MODEL_TENSOR.ATTN_NORM,
+
+        # Attention components
+        MODEL_TENSOR.ATTN_Q,         # Query projection
+        MODEL_TENSOR.ATTN_K,         # Key projection
+        MODEL_TENSOR.ATTN_V,         # Value projection
+        MODEL_TENSOR.ATTN_OUT,       # Output projection
+
+        # SSM components (Mamba2 specific)
+        MODEL_TENSOR.SSM_IN,         # Input projection for SSM
+        MODEL_TENSOR.SSM_CONV1D,     # Convolution layer
+        MODEL_TENSOR.SSM_DT,         # Delta time projection
+        MODEL_TENSOR.SSM_A,          # A parameter (log form)
+        MODEL_TENSOR.SSM_D,          # D parameter
+        MODEL_TENSOR.SSM_NORM,       # Normalization in SSM
+        MODEL_TENSOR.SSM_OUT,        # Output projection
+
+        # Pre-feedforward layernorm
+        MODEL_TENSOR.FFN_PRE_NORM,
+
+        # Feed-forward network components
+        MODEL_TENSOR.FFN_GATE,       # Gate projection (SwiGLU)
+        MODEL_TENSOR.FFN_DOWN,       # Down projection
+        MODEL_TENSOR.FFN_UP,         # Up projection
+
+        # Post-feedforward layernorm
+        MODEL_TENSOR.OUTPUT_NORM,    # Final layer norm
+        MODEL_TENSOR.OUTPUT,         # Output projection (lm_head)
+    ],
     MODEL_ARCH.HUNYUAN_MOE: [
         MODEL_TENSOR.TOKEN_EMBD,
         MODEL_TENSOR.OUTPUT_NORM,
index 7c2877f56c64451e1634006f6232eaecad226444..6bddbec23d74ac75cba06c1a7913567caf352115 100644 (file)
@@ -286,12 +286,14 @@ class TensorNameMap:
         # Post feed-forward norm
         MODEL_TENSOR.FFN_PRE_NORM: (
             "model.layers.{bid}.pre_feedforward_layernorm", # gemma2
+            "model.layers.{bid}.pre_ff_layernorm.weight",
         ),
 
         # Post feed-forward norm
         MODEL_TENSOR.FFN_POST_NORM: (
             "model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2
             "model.layers.{bid}.post_mlp_layernorm", # glm-4-0414
+            "model.layers.{bid}.feed_forward.up_proj",
         ),
 
         MODEL_TENSOR.FFN_GATE_INP: (
@@ -363,6 +365,7 @@ class TensorNameMap:
             "model.layers.{bid}.mlp.shared_expert.up_proj",          # qwen2moe
             "model.layers.{bid}.mlp.shared_experts.up_proj",         # deepseek deepseek2
             "model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4
+            "model.layers.{bid}.feed_forward.down_proj",
             "model.layers.{bid}.mlp.shared_mlp.up_proj",             # hunyuan
         ),
 
@@ -553,11 +556,13 @@ class TensorNameMap:
         MODEL_TENSOR.SSM_IN: (
             "model.layers.{bid}.in_proj",
             "backbone.layers.{bid}.mixer.in_proj",
+            "model.layers.{bid}.mamba.in_proj",
         ),
 
         MODEL_TENSOR.SSM_CONV1D: (
             "model.layers.{bid}.conv1d",
             "backbone.layers.{bid}.mixer.conv1d",
+            "model.layers.{bid}.mamba.conv1d",
         ),
 
         MODEL_TENSOR.SSM_X: (
@@ -568,25 +573,30 @@ class TensorNameMap:
         MODEL_TENSOR.SSM_DT: (
             "model.layers.{bid}.dt_proj",
             "backbone.layers.{bid}.mixer.dt_proj",
+            "model.layers.{bid}.mamba.dt_proj",
         ),
 
         MODEL_TENSOR.SSM_A: (
             "model.layers.{bid}.A_log",
             "backbone.layers.{bid}.mixer.A_log",
+            "model.layers.{bid}.mamba.A_log",
         ),
 
         MODEL_TENSOR.SSM_D: (
             "model.layers.{bid}.D",
             "backbone.layers.{bid}.mixer.D",
+            "model.layers.{bid}.mamba.D",
         ),
 
         MODEL_TENSOR.SSM_NORM: (
+            "model.layers.{bid}.mamba.norm", # falcon-h1
             "backbone.layers.{bid}.mixer.norm",  # mamba2
         ),
 
         MODEL_TENSOR.SSM_OUT: (
             "model.layers.{bid}.out_proj",
             "backbone.layers.{bid}.mixer.out_proj",
+            "model.layers.{bid}.mamba.out_proj",          # falcon-h1
         ),
 
         MODEL_TENSOR.TIME_MIX_W0: (
index 9af9c2ad604d56aa42844be8ba79b6e58a497aaa..8f4f2df088f70b5b425d8b83246241bead3fa981 100644 (file)
@@ -46,6 +46,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_STARCODER2,       "starcoder2"       },
     { LLM_ARCH_MAMBA,            "mamba"            },
     { LLM_ARCH_MAMBA2,           "mamba2"           },
+    { LLM_ARCH_FALCON_H1,        "falcon-h1"        },
     { LLM_ARCH_XVERSE,           "xverse"           },
     { LLM_ARCH_COMMAND_R,        "command-r"        },
     { LLM_ARCH_COHERE2,          "cohere2"          },
@@ -1024,6 +1025,30 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_SSM_OUT,         "blk.%d.ssm_out" },
         },
     },
+    {
+        LLM_ARCH_FALCON_H1,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_SSM_IN,          "blk.%d.ssm_in" },
+            { LLM_TENSOR_SSM_CONV1D,      "blk.%d.ssm_conv1d" },
+            { LLM_TENSOR_SSM_DT,          "blk.%d.ssm_dt" },
+            { LLM_TENSOR_SSM_A,           "blk.%d.ssm_a" },
+            { LLM_TENSOR_SSM_D,           "blk.%d.ssm_d" },
+            { LLM_TENSOR_SSM_NORM,        "blk.%d.ssm_norm" },
+            { LLM_TENSOR_SSM_OUT,         "blk.%d.ssm_out" },
+            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+        },
+    },
     {
         LLM_ARCH_XVERSE,
         {
@@ -1967,9 +1992,10 @@ bool llm_arch_is_recurrent(const llm_arch & arch) {
 }
 
 bool llm_arch_is_hybrid(const llm_arch & arch) {
-    // TODO: There are currently no hybrid models! Once there are, this will be
-    //  the place to identify them
+    // List all mamba-attention hybrid models here
     switch (arch) {
+        case LLM_ARCH_FALCON_H1:
+            return true;
         default:
             return false;
     }
index ba5d03fa24ebe2c5a45237e8c6d095432107409a..deb3bcd5bc0e3e195b8b1c072dc33d71c90d2f97 100644 (file)
@@ -50,6 +50,7 @@ enum llm_arch {
     LLM_ARCH_STARCODER2,
     LLM_ARCH_MAMBA,
     LLM_ARCH_MAMBA2,
+    LLM_ARCH_FALCON_H1,
     LLM_ARCH_XVERSE,
     LLM_ARCH_COMMAND_R,
     LLM_ARCH_COHERE2,
index fc4e9a5af004d235153a9bf4263ec4fa894b7637..e424350bdd78341f83b293400badacd930e5710d 100644 (file)
@@ -1550,6 +1550,37 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     default: type = LLM_TYPE_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_FALCON_H1:
+            {
+                // Common parameters
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+
+                // SSM parameters
+                ml.get_key(LLM_KV_SSM_CONV_KERNEL,    hparams.ssm_d_conv);
+                ml.get_key(LLM_KV_SSM_INNER_SIZE,     hparams.ssm_d_inner);
+                ml.get_key(LLM_KV_SSM_STATE_SIZE,     hparams.ssm_d_state);
+                ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
+                ml.get_key(LLM_KV_SSM_GROUP_COUNT,    hparams.ssm_n_group);
+
+                std::fill(hparams.recurrent_layer_arr.begin(), hparams.recurrent_layer_arr.end(), true);
+
+                switch (hparams.n_layer) {
+                    case 36:
+                        type = LLM_TYPE_0_5B; break;
+                    case 24:
+                        type = LLM_TYPE_1_5B; break;
+                    case 66:
+                        type = LLM_TYPE_1B; break;
+                    case 32:
+                        type = LLM_TYPE_3B; break;
+                    case 44:
+                        type = LLM_TYPE_7B; break;
+                    case 72:
+                        type = LLM_TYPE_34B; break;
+                    default:
+                        type = LLM_TYPE_UNKNOWN;
+                }
+            } break;
         case LLM_ARCH_HUNYUAN_MOE:
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,       hparams.f_norm_rms_eps);
@@ -4497,6 +4528,83 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
                     }
                 } break;
+            case LLM_ARCH_FALCON_H1:
+                {
+                    // Common
+                    const int64_t hidden_size = hparams.n_embd; // hidden_size
+
+                    // mamba2 Mixer SSM params
+                    const int64_t ssm_conv_kernel_size  = hparams.ssm_d_conv; // ssm_conv_kernel_size
+                    const int64_t ssm_n_groups          = hparams.ssm_n_group; // ssm_n_groups
+                    const int64_t ssm_state_size        = hparams.ssm_d_state; // ssm_state_size
+                    const int64_t ssm_intermediate_size = hparams.ssm_d_inner; // TODO expand
+                    const int64_t ssm_num_heads         = hparams.ssm_dt_rank; // ssm_num_heads
+                    const int64_t ssm_conv_dim          = ssm_intermediate_size + 2 * ssm_n_groups * ssm_state_size;
+                    const int64_t ssm_projection_size   = ssm_intermediate_size + ssm_conv_dim + ssm_num_heads;
+
+                    // attn params
+                    const int64_t attn_num_attention_head = hparams.n_head(0); // rename to: attn_num_attention_head
+                    const int64_t attn_num_key_value_head = hparams.n_head_kv(0);
+
+                    // ffn params
+                    const int64_t ffn_intermediate_size = hparams.n_ff(0);
+
+                    // embeddings
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hidden_size, n_vocab}, 0);
+
+                    // output
+                    output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hidden_size, n_vocab}, TENSOR_NOT_REQUIRED);
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {hidden_size}, 0);
+                    
+                    // if output is NULL, init from the input tok embed
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hidden_size, n_vocab}, TENSOR_DUPLICATED);
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        /*SSM LAYERS*/
+                        // ssm in
+                        layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {hidden_size, ssm_projection_size}, 0);
+                        // ssm 1d conv
+                        layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {ssm_conv_kernel_size, ssm_conv_dim}, 0);
+                        layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {ssm_conv_dim}, TENSOR_NOT_REQUIRED);
+                        // ssm_dt
+                        layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {ssm_num_heads}, 0);
+                        // no "weight" suffix for these
+                        layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, ssm_num_heads}, 0);
+                        layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, ssm_num_heads}, 0);
+                        // ssm_norm
+                        layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {ssm_intermediate_size / ssm_n_groups, ssm_n_groups}, TENSOR_NOT_REQUIRED);
+                        // out_proj
+                        layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {ssm_intermediate_size, hidden_size}, 0);
+
+                        /*ATTENTION LAYERS*/
+                        // attention layers (with optional bias)
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {hidden_size, n_embd_head_k * attn_num_attention_head}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {hidden_size, attn_num_key_value_head * n_embd_head_k}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {hidden_size, attn_num_key_value_head * n_embd_head_v}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * attn_num_attention_head, hidden_size}, 0);
+                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {hidden_size}, TENSOR_NOT_REQUIRED);
+                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {attn_num_key_value_head * n_embd_head_k}, TENSOR_NOT_REQUIRED);
+                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {attn_num_key_value_head * n_embd_head_v}, TENSOR_NOT_REQUIRED);
+                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {hidden_size}, TENSOR_NOT_REQUIRED);
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {hidden_size}, 0);
+
+
+                        // feed forward (w/ optional biases)
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, i), {hidden_size}, 0);
+                        layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {hidden_size,   ffn_intermediate_size}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  ffn_intermediate_size, hidden_size}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {hidden_size,   ffn_intermediate_size}, 0);
+
+                        layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {ffn_intermediate_size}, TENSOR_NOT_REQUIRED);
+                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {hidden_size}, TENSOR_NOT_REQUIRED);
+                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i), {ffn_intermediate_size}, TENSOR_NOT_REQUIRED);
+                    }
+                } break;
             case LLM_ARCH_HUNYUAN_MOE:
                 {
                     tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -10147,7 +10255,7 @@ struct llm_build_mamba : public llm_graph_context {
 
         // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
         cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
-        // cb(cur, "mamba_out", il);
+        cb(cur, "mamba_out", il);
 
         return cur;
     }
@@ -14598,6 +14706,267 @@ struct llm_build_ernie4_5 : public llm_graph_context {
     }
 };
 
+struct llm_build_falcon_h1 : public llm_graph_context {
+    const llama_model & model;
+
+    llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params), model(model)  {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        // Build the inputs in the recurrent & kv cache
+        auto * inp = build_inp_mem_hybrid();
+
+        const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
+
+        ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+            cb(Qcur, "Qcur", il);
+
+            ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+            cb(Kcur, "Kcur", il);
+
+            ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+            cb(Vcur, "Vcur", il);
+
+            Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+            Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+
+            Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+            Qcur = ggml_rope_ext(
+                    ctx0, Qcur, inp_pos, nullptr,
+                    n_rot, hparams.rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow);
+
+            Kcur = ggml_rope_ext(
+                    ctx0, Kcur, inp_pos, nullptr,
+                    n_rot, hparams.rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                    );
+
+            cb(Qcur, "Qcur-post-rope", il);
+            cb(Kcur, "Kcur-post-rope", il);
+            cb(Vcur, "Vcur-post-rope", il);
+
+            ggml_tensor * attn_out = build_attn(inp, gf,
+                    model.layers[il].wo, NULL,
+                    Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
+            cb(attn_out, "attn_out", il);
+
+            cur = build_norm(inpL,
+                model.layers[il].attn_norm, NULL,
+                LLM_NORM_RMS, il);
+            // Mamba2 layer
+            cb(cur, "ssm_in", il);
+
+            ggml_tensor * ssm_out = build_mamba2_layer(inp, gf, cur, ubatch, il);
+            cb(ssm_out, "ssm_out", il);
+
+            // // Aggregation
+            cur = ggml_add(ctx0, attn_out, ssm_out);
+            inpSA = ggml_add(ctx0, cur, inpSA);
+            cb(cur, "layer_out", il);
+
+            if (il == n_layer - 1 && inp_out_ids) {
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = inpSA;
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            cur = build_norm(ffn_inp,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up,   model.layers[il].ffn_up_b, NULL,
+                    model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
+                    model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+
+            cur = ggml_add(ctx0, cur, inpSA);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+
+    ggml_tensor * build_mamba2_layer(
+        llm_graph_input_mem_hybrid * inp,
+                                ggml_cgraph * gf,
+                                ggml_tensor * cur,
+                         const llama_ubatch & ubatch,
+                                      int   il) const {
+        const auto * kv_state = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
+
+        const auto kv_head = kv_state->get_head();
+
+        const int64_t d_conv  = hparams.ssm_d_conv;
+        const int64_t d_inner = hparams.ssm_d_inner;
+        const int64_t d_state = hparams.ssm_d_state;
+        const int64_t n_head  = hparams.ssm_dt_rank;
+        const int64_t head_dim = d_inner / n_head;
+        const int64_t n_group = hparams.ssm_n_group;
+        const int64_t n_seqs  = ubatch.n_seqs;
+
+        const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+
+        GGML_ASSERT(n_seqs != 0);
+        GGML_ASSERT(ubatch.equal_seqs);
+        GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
+
+        ggml_tensor * conv_states_all = kv_state->get_r_l(il);
+        ggml_tensor * ssm_states_all  = kv_state->get_s_l(il);
+
+        ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs);
+        conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs);
+
+        // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
+        cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
+
+        // d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
+
+        // {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs}
+        ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur);
+        cb(zxBCdt, "zxBCdt", il);
+
+        // split the above in three
+        ggml_tensor * z = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*zxBCdt->nb[0], zxBCdt->nb[1], zxBCdt->nb[2], 0);
+        ggml_tensor * xBC = ggml_view_3d(ctx0, zxBCdt, d_inner + 2*n_group*d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], d_inner*ggml_element_size(zxBCdt));
+        ggml_tensor * dt = ggml_view_3d(ctx0, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], (2*d_inner + 2*n_group*d_state)*ggml_element_size(zxBCdt));
+
+        // conv
+        {
+            // => {d_conv - 1 + n_seq_tokens, d_inner + 2*n_group*d_state, n_seqs}
+            ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, xBC), 0);
+
+            // copy last (d_conv - 1) columns back into the state cache
+            ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0]));
+
+            ggml_build_forward_expand(gf,
+                ggml_cpy(ctx0, last_conv,
+                    ggml_view_1d(ctx0, conv_states_all,
+                        (d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs),
+                        kv_head*(d_conv - 1)*(d_inner + 2*n_group*d_state)*ggml_element_size(conv_states_all))));
+
+            // 1D convolution
+            // The equivalent is to make a self-overlapping view of conv_x
+            // over d_conv columns at each stride in the 3rd dimension,
+            // then element-wise multiply that with the conv1d weight,
+            // then sum the elements of each row,
+            // (the last two steps are a dot product over rows (also doable with mul_mat))
+            // then permute away the ne[0] dimension,
+            // and then you're left with the resulting x tensor.
+            // For simultaneous sequences, all sequences need to have the same length.
+            xBC = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d);
+
+            // bias
+            xBC = ggml_add(ctx0, xBC, model.layers[il].ssm_conv1d_b);
+
+            xBC = ggml_silu(ctx0, xBC);
+        }
+
+        // ssm
+        {
+            // These correspond to V K Q in SSM/attention duality
+            ggml_tensor * x = ggml_view_4d(ctx0, xBC, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*xBC->nb[0], xBC->nb[1], xBC->nb[2], 0);
+
+            ggml_tensor * B = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], d_inner*ggml_element_size(xBC));
+
+            ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], (d_inner + n_group*d_state)*ggml_element_size(xBC));
+
+            // {n_head, n_seq_tokens, n_seqs}
+            dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b);
+
+            ggml_tensor * A = model.layers[il].ssm_a;
+
+            // use the states and the indices provided by build_rs
+            // (this is necessary in order to properly use the states before they are overwritten,
+            //  while avoiding to make unnecessary copies of the states)
+            auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
+                ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, kv_state->get_size());
+
+                // TODO: use semistructured matrices to implement state-space duality
+                // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
+                return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
+            };
+
+            ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);
+
+            // store last states
+            ggml_build_forward_expand(gf,
+                ggml_cpy(ctx0,
+                    ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, ggml_nelements(x)*x->nb[0]),
+                    ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all))));
+
+            ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_head, n_seq_tokens, n_seqs, x->nb[1], n_head*x->nb[1], n_seq_tokens*n_head*x->nb[1], 0);
+
+            // TODO: skip computing output earlier for unused tokens
+
+            y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
+            y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y);
+
+            // grouped RMS norm
+            if (model.layers[il].ssm_norm) {
+                y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs);
+                y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
+            }
+
+            y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs);
+
+            // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
+            cur = build_lora_mm(model.layers[il].ssm_out, y);
+        }
+
+        // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
+        cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
+        cb(cur, "mamba_out", il);
+        return cur;
+    }
+};
+
 struct llm_build_arcee : public llm_graph_context {
     llm_build_arcee(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -15077,7 +15446,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                         /* recurrent_type_v  */ GGML_TYPE_F32,
                         /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
                         /* n_seq_max         */ cparams.n_seq_max,
-                        /* offload           */ cparams.offload_kqv);
+                        /* offload           */ cparams.offload_kqv,
+                        /* filter_attn       */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr,
+                        /* filter_recr       */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr);
                 } else {
                     const auto padding = llama_kv_cache_unified::get_padding(cparams);
 
@@ -15419,6 +15790,10 @@ llm_graph_result_ptr llama_model::build_graph(
             {
                 llm = std::make_unique<llm_build_smollm3>(*this, params, gf);
             } break;
+        case LLM_ARCH_FALCON_H1:
+            {
+                llm = std::make_unique<llm_build_falcon_h1>(*this, params, gf);
+            } break;
         default:
             GGML_ABORT("fatal error");
     }
@@ -15577,6 +15952,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
 
         // the pairs of head values are offset by n_rot/2
         case LLM_ARCH_FALCON:
+        case LLM_ARCH_FALCON_H1:
         case LLM_ARCH_GROK:
         case LLM_ARCH_DBRX:
         case LLM_ARCH_BERT:
index 551bba171c0e0ea26b9363d42f7c1e8219b6d836..b7f14dc07b60936d748c01f294ad51e2852c73c5 100644 (file)
@@ -1523,6 +1523,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                     tokenizer_pre == "llama-v3" ||
                     tokenizer_pre == "llama-bpe"||
                     tokenizer_pre == "falcon3"  ||
+                    tokenizer_pre == "falcon-h1" ||
                     tokenizer_pre == "pixtral") {
                 pre_type = LLAMA_VOCAB_PRE_TYPE_LLAMA3;
                 ignore_merges = true;