]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : support Jamba hybrid Transformer-Mamba models (#7531)
authorcompilade <redacted>
Wed, 9 Jul 2025 18:59:57 +0000 (14:59 -0400)
committerGitHub <redacted>
Wed, 9 Jul 2025 18:59:57 +0000 (14:59 -0400)
* wip: llama : separate recurrent states from the KV cache

This will be necessary to support Jamba
(and other recurrent models mixed with Attention).

Doesn't compile yet, and finding a slot isn't yet done correctly for recurrent states.

* llama : use std::find for seq_nodes in llama_rs_cache

* llama : state checkpoints for recurrent models

* llama : correctly handle more edge cases for the rs cache

* llama : rename many llama_kv_cache_* functions

* llama : remove useless return value for some llama_cache_* functions

* llama : rethink recurrent state cell counts

* llama : begin work on support for variable GQA

This will also be useful for Jamba if we consider the Mamba layers
to have 0 KV heads.

* llama : gracefully fail when not finding hybrid slot

* llama : support Jamba

* llama : fix BERT inference without KV cache

* convert-hf : check for unprocessed Jamba experts

* convert-hf : support Mini-Jamba conversion

* llama : fix Jamba quantization sanity checks

* llama : sequence-length-aware batch splitting

* llama : use equal-sequence-length sub-batches for recurrent models

* ggml : simplify SSM-related operators

* llama : make recurrent state slot allocation contiguous

* llama : adapt internal uses of batches to llama_ubatch

* llama : fix batch split output count for embeddings

* llama : minimize swaps when reordering logits

This reduces overhead when running hellaswag
on thousands of sequences with very small 100k params Mamba models.

* llama : fix edge case finding batch seq_id of split recurrent cell

This otherwise was a problem when running the HellaSwag benchmark
with small batch sizes, making it crash.

* llama : avoid copies for simple batch splits

* ggml : make ggml_ssm_scan not modify its source tensors

* llama : fix shared recurrent tail cell count for small ubatch sizes

Otherwise it was impossible to run the 'parallel' example with '-ub 1'
with a Mamba or Jamba model.

* llama : fix .base() compilation error on Windows

* llama : allow doing the equivalent of SSM_CONV with SUM_ROWS and MUL

* ggml : allow GGML_OP_CONCAT to work on non-contiguous tensors

The implementation already supported it,
and this makes Mamba's conv step slightly faster.

* mamba : fix non-contiguous usage of ggml_silu

* llama : session saving and reloading for hybrid models

* convert_hf : fix Jamba conversion

* llama : fix mixed signedness comparison

* llama : use unused n_embd_k_gqa in k_shift

This also slightly reduces the diff from the master branch

* llama : begin renaming llama_past back to llama_kv_cache

* llama : remove implicit recurrent state rollbacks

* llama : partially apply clang-format style

* convert : fix jamba conv1d shape squeezing

* graph : add back hybrid memory graph input

But this time it contains the sub-cache graph inputs.
This *should* make it easier to handle updating the inputs
when caching the graph (eventually).

* model : add Jamba to Mamba-specific hparams printing

* jamba : remove redundant nullptr initializations

* model : remove unnecessary prefix for tensor loading constants

Co-authored-by: Sigbjørn Skjæret <redacted>
* model : use ggml_swiglu_split for Mamba

Co-authored-by: Sigbjørn Skjæret <redacted>
* model : make falcon-h1 use shared mamba2 layer builder

* memory : avoid referring to KV in recurrent cache logs

* gguf-py : avoid adding duplicate tensor mappings for Jamba

Some of the tensor names are common with Llama4

---------

Co-authored-by: Sigbjørn Skjæret <redacted>
convert_hf_to_gguf.py
gguf-py/gguf/constants.py
gguf-py/gguf/tensor_mapping.py
src/llama-arch.cpp
src/llama-arch.h
src/llama-graph.cpp
src/llama-graph.h
src/llama-memory-recurrent.cpp
src/llama-model.cpp
src/llama-model.h

index 702827f4d5d2e3d07a77e3a1662066ec808828d3..2419126ec4ea2d52943a95ef620304110f993a25 100755 (executable)
@@ -4974,6 +4974,123 @@ class Mamba2Model(TextModel):
         yield (new_name, data_torch)
 
 
+@ModelBase.register("JambaForCausalLM")
+class JambaModel(TextModel):
+    model_arch = gguf.MODEL_ARCH.JAMBA
+
+    def get_vocab_base_pre(self, tokenizer) -> str:
+        del tokenizer  # unused
+
+        return "gpt-2"
+
+    def set_vocab(self):
+        if (self.dir_model / "tokenizer.model").is_file():
+            # Using Jamba's tokenizer.json causes errors on model load
+            # (something about "byte not found in vocab"),
+            # but there's a working tokenizer.model
+            self._set_vocab_sentencepiece()
+        else:
+            # Some Jamba models only have a tokenizer.json, which works.
+            self._set_vocab_gpt2()
+
+    def set_gguf_parameters(self):
+        d_model = self.find_hparam(["hidden_size", "mamba_d_model"])
+        d_conv  = self.find_hparam(["mamba_d_conv"],  optional=True) or 4
+        d_inner = self.hparams["mamba_expand"] * d_model
+        d_state = self.find_hparam(["mamba_d_state"], optional=True) or 16
+        # ceiling division
+        # ref: https://stackoverflow.com/a/17511341/22827863
+        # ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
+        dt_rank      = self.find_hparam(["mamba_dt_rank"], optional=True) or -(d_model // -16)
+        rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-6
+        n_kv_head = self.hparams["num_key_value_heads"]
+        attn_offset = self.hparams["attn_layer_offset"]
+        attn_period = self.hparams["attn_layer_period"]
+        n_kv_vec = [0 for _ in range(attn_offset)] + [
+            n_kv_head if (i - attn_offset) % attn_period == 0 else 0 for i in range(attn_offset, self.block_count)
+        ]
+
+        self.gguf_writer.add_block_count(self.block_count)
+        self.gguf_writer.add_context_length(self.find_hparam(["max_position_embeddings", "n_ctx"]))
+        self.gguf_writer.add_embedding_length(d_model)
+        self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
+        self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
+        self.gguf_writer.add_head_count_kv(n_kv_vec)
+        self.gguf_writer.add_ssm_conv_kernel(d_conv)
+        self.gguf_writer.add_ssm_inner_size(d_inner)
+        self.gguf_writer.add_ssm_state_size(d_state)
+        self.gguf_writer.add_ssm_time_step_rank(dt_rank)
+        self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
+        self.gguf_writer.add_expert_count(self.hparams["num_experts"])
+        self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"])
+        self.gguf_writer.add_file_type(self.ftype)
+
+    _experts: list[dict[str, Tensor]] | None = None
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+
+        # Mini-Jamba
+        name = name.replace(".moe.", ".feed_forward.")
+        if bid is not None:
+            moe_offset = self.hparams["expert_layer_offset"]
+            moe_period = self.hparams["expert_layer_period"]
+
+            if not (bid >= moe_offset and (bid - moe_offset) % moe_period == 0):
+                name = name.replace(".experts.0.", ".")
+
+        # process the experts separately
+        if ".feed_forward.experts." in name:
+            n_experts = self.hparams["num_experts"]
+
+            assert bid is not None
+
+            if self._experts is None:
+                self._experts = [{} for _ in range(self.block_count)]
+
+            self._experts[bid][name] = data_torch
+
+            if len(self._experts[bid]) >= n_experts * 3:
+
+                # merge the experts into a single 3d tensor
+                for wid in ["down_proj", "gate_proj", "up_proj"]:
+                    datas: list[Tensor] = []
+
+                    for xid in range(n_experts):
+                        ename = f"model.layers.{bid}.feed_forward.experts.{xid}.{wid}.weight"
+                        datas.append(self._experts[bid][ename])
+                        del self._experts[bid][ename]
+
+                    data_torch = torch.stack(datas, dim=0)
+
+                    # using the same merged name as qwen2moe
+                    merged_name = f"model.layers.{bid}.mlp.experts.{wid}.weight"
+
+                    new_name = self.map_tensor_name(merged_name)
+
+                    yield new_name, data_torch
+            return
+
+        new_name = self.map_tensor_name(name)
+
+        if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_CONV1D, bid):
+            data_torch = data_torch.squeeze()
+
+        if name.endswith(".A_log"):
+            logger.debug("A_log --> A ==> " + new_name)
+            data_torch = -torch.exp(data_torch)
+
+        yield (new_name, data_torch)
+
+    def prepare_tensors(self):
+        super().prepare_tensors()
+
+        if self._experts is not None:
+            # flatten `list[dict[str, Tensor]]` into `list[str]`
+            experts = [k for d in self._experts for k in d.keys()]
+            if len(experts) > 0:
+                raise ValueError(f"Unprocessed experts: {experts}")
+
+
 @ModelBase.register("CohereForCausalLM")
 class CommandR2Model(TextModel):
     model_arch = gguf.MODEL_ARCH.COMMAND_R
index 93eec43556c744cefcc953e3ea5231c04a922bc3..fbe3f53273a355ddd8fbd5d400b6949a778bb40d 100644 (file)
@@ -330,6 +330,7 @@ class MODEL_ARCH(IntEnum):
     ARWKV7           = auto()
     MAMBA            = auto()
     MAMBA2           = auto()
+    JAMBA            = auto()
     XVERSE           = auto()
     COMMAND_R        = auto()
     COHERE2          = auto()
@@ -432,7 +433,10 @@ class MODEL_TENSOR(IntEnum):
     SSM_CONV1D           = auto()
     SSM_X                = auto()
     SSM_DT               = auto()
+    SSM_DT_NORM          = auto()
     SSM_A                = auto()
+    SSM_B_NORM           = auto()
+    SSM_C_NORM           = auto()
     SSM_D                = auto()
     SSM_NORM             = auto()
     SSM_OUT              = auto()
@@ -635,6 +639,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.ARWKV7:           "arwkv7",
     MODEL_ARCH.MAMBA:            "mamba",
     MODEL_ARCH.MAMBA2:           "mamba2",
+    MODEL_ARCH.JAMBA:            "jamba",
     MODEL_ARCH.XVERSE:           "xverse",
     MODEL_ARCH.COMMAND_R:        "command-r",
     MODEL_ARCH.COHERE2:          "cohere2",
@@ -738,7 +743,10 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
     MODEL_TENSOR.SSM_CONV1D:                "blk.{bid}.ssm_conv1d",
     MODEL_TENSOR.SSM_X:                     "blk.{bid}.ssm_x",
     MODEL_TENSOR.SSM_DT:                    "blk.{bid}.ssm_dt",
+    MODEL_TENSOR.SSM_DT_NORM:               "blk.{bid}.ssm_dt_norm",
     MODEL_TENSOR.SSM_A:                     "blk.{bid}.ssm_a",
+    MODEL_TENSOR.SSM_B_NORM:                "blk.{bid}.ssm_b_norm",
+    MODEL_TENSOR.SSM_C_NORM:                "blk.{bid}.ssm_c_norm",
     MODEL_TENSOR.SSM_D:                     "blk.{bid}.ssm_d",
     MODEL_TENSOR.SSM_NORM:                  "blk.{bid}.ssm_norm",
     MODEL_TENSOR.SSM_OUT:                   "blk.{bid}.ssm_out",
@@ -1738,6 +1746,34 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.SSM_NORM,
         MODEL_TENSOR.SSM_OUT,
     ],
+    MODEL_ARCH.JAMBA: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_Q,
+        MODEL_TENSOR.ATTN_K,
+        MODEL_TENSOR.ATTN_V,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.SSM_IN,
+        MODEL_TENSOR.SSM_CONV1D,
+        MODEL_TENSOR.SSM_X,
+        MODEL_TENSOR.SSM_DT,
+        MODEL_TENSOR.SSM_DT_NORM,
+        MODEL_TENSOR.SSM_A,
+        MODEL_TENSOR.SSM_B_NORM,
+        MODEL_TENSOR.SSM_C_NORM,
+        MODEL_TENSOR.SSM_D,
+        MODEL_TENSOR.SSM_OUT,
+        MODEL_TENSOR.FFN_GATE_INP,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_GATE,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+        MODEL_TENSOR.FFN_GATE_EXP,
+        MODEL_TENSOR.FFN_DOWN_EXP,
+        MODEL_TENSOR.FFN_UP_EXP,
+    ],
     MODEL_ARCH.XVERSE: [
         MODEL_TENSOR.TOKEN_EMBD,
         MODEL_TENSOR.OUTPUT_NORM,
index 6bddbec23d74ac75cba06c1a7913567caf352115..215eb297ebcc1fcec2a97017f62644884805a4a1 100644 (file)
@@ -279,6 +279,8 @@ class TensorNameMap:
             "transformer.decoder_layer.{bid}.rms_norm_2",                    # Grok
             "encoder.layers.{bid}.post_attention_layernorm",                 # chatglm
             "transformer.layers.{bid}.ffn_norm",                             # openelm
+            "model.layers.{bid}.pre_ff_layernorm",                           # jamba
+            "model.layers.{bid}.pre_moe_layernorm",                          # mini-jamba
             "model.layers.{bid}.post_attention_layernorm",                   # llama4
             "transformer_encoder.{bid}.ffn_norm",                            # neobert
         ),
@@ -303,7 +305,7 @@ class TensorNameMap:
             "transformer.decoder_layer.{bid}.router",           # Grok
             "transformer.blocks.{bid}.ffn.router.layer",        # dbrx
             "model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe
-            "model.layers.{bid}.feed_forward.router",           # llama4
+            "model.layers.{bid}.feed_forward.router",           # llama4 jamba
             "encoder.layers.{bid}.mlp.router.layer",            # nomic-bert-moe
             "model.layers.{bid}.mlp.gate.wg",                   # hunyuan
         ),
@@ -347,7 +349,7 @@ class TensorNameMap:
             "model.layers.{bid}.residual_mlp.w3",                     # arctic
             "encoder.layers.{bid}.mlp.dense_h_to_4h",                 # chatglm
             "transformer.h.{bid}.mlp.c_fc_1",                         # exaone
-            "model.layers.{bid}.feed_forward.up_proj",                # llama4
+            "model.layers.{bid}.feed_forward.up_proj",                # llama4 jamba
             "transformer_encoder.{bid}.ffn.w12",                      # neobert
         ),
 
@@ -387,7 +389,7 @@ class TensorNameMap:
             "transformer.h.{bid}.mlp.linear_1",           # refact
             "model.layers.{bid}.residual_mlp.w1",         # arctic
             "transformer.h.{bid}.mlp.c_fc_0",             # exaone
-            "model.layers.{bid}.feed_forward.gate_proj",  # llama4
+            "model.layers.{bid}.feed_forward.gate_proj",  # llama4 jamba
         ),
 
         MODEL_TENSOR.FFN_GATE_EXP: (
@@ -433,7 +435,7 @@ class TensorNameMap:
             "encoder.layer.{bid}.mlp.down_layer",                     # jina-bert-v2
             "encoder.layers.{bid}.mlp.dense_4h_to_h",                 # chatglm
             "model.layers.h.{bid}.mlp.c_proj",                        # exaone
-            "model.layers.{bid}.feed_forward.down_proj",              # llama4
+            "model.layers.{bid}.feed_forward.down_proj",              # llama4 jamba
             "transformer_encoder.{bid}.ffn.w3",                       # neobert
         ),
 
@@ -554,38 +556,53 @@ class TensorNameMap:
         ),
 
         MODEL_TENSOR.SSM_IN: (
-            "model.layers.{bid}.in_proj",
-            "backbone.layers.{bid}.mixer.in_proj",
-            "model.layers.{bid}.mamba.in_proj",
+            "model.layers.{bid}.in_proj",           # mamba-hf
+            "backbone.layers.{bid}.mixer.in_proj",  # mamba
+            "model.layers.{bid}.mamba.in_proj",     # jamba falcon-h1
         ),
 
         MODEL_TENSOR.SSM_CONV1D: (
-            "model.layers.{bid}.conv1d",
-            "backbone.layers.{bid}.mixer.conv1d",
-            "model.layers.{bid}.mamba.conv1d",
+            "model.layers.{bid}.conv1d",           # mamba-hf
+            "backbone.layers.{bid}.mixer.conv1d",  # mamba
+            "model.layers.{bid}.mamba.conv1d",     # jamba falcon-h1
         ),
 
         MODEL_TENSOR.SSM_X: (
-            "model.layers.{bid}.x_proj",
-            "backbone.layers.{bid}.mixer.x_proj",
+            "model.layers.{bid}.x_proj",           # mamba-hf
+            "backbone.layers.{bid}.mixer.x_proj",  # mamba
+            "model.layers.{bid}.mamba.x_proj",     # jamba
         ),
 
         MODEL_TENSOR.SSM_DT: (
-            "model.layers.{bid}.dt_proj",
-            "backbone.layers.{bid}.mixer.dt_proj",
-            "model.layers.{bid}.mamba.dt_proj",
+            "model.layers.{bid}.dt_proj",           # mamba-hf
+            "backbone.layers.{bid}.mixer.dt_proj",  # mamba
+            "model.layers.{bid}.mamba.dt_proj",     # jamba falcon-h1
+        ),
+
+        MODEL_TENSOR.SSM_DT_NORM: (
+            "model.layers.{bid}.mamba.dt_layernorm",  # jamba
         ),
 
         MODEL_TENSOR.SSM_A: (
-            "model.layers.{bid}.A_log",
-            "backbone.layers.{bid}.mixer.A_log",
-            "model.layers.{bid}.mamba.A_log",
+            "model.layers.{bid}.A_log",           # mamba-hf
+            "backbone.layers.{bid}.mixer.A_log",  # mamba
+            "model.layers.{bid}.mamba.A_log",     # jamba falcon-h1
+        ),
+
+        MODEL_TENSOR.SSM_B_NORM: (
+            "model.layers.{bid}.mamba.b_layernorm",  # jamba
+            "model.layers.{bid}.mamba.B_layernorm",  # mini-jamba
+        ),
+
+        MODEL_TENSOR.SSM_C_NORM: (
+            "model.layers.{bid}.mamba.c_layernorm",  # jamba
+            "model.layers.{bid}.mamba.C_layernorm",  # mini-jamba
         ),
 
         MODEL_TENSOR.SSM_D: (
-            "model.layers.{bid}.D",
-            "backbone.layers.{bid}.mixer.D",
-            "model.layers.{bid}.mamba.D",
+            "model.layers.{bid}.D",           # mamba-hf
+            "backbone.layers.{bid}.mixer.D",  # mamba
+            "model.layers.{bid}.mamba.D",     # jamba falcon-h1
         ),
 
         MODEL_TENSOR.SSM_NORM: (
@@ -594,9 +611,9 @@ class TensorNameMap:
         ),
 
         MODEL_TENSOR.SSM_OUT: (
-            "model.layers.{bid}.out_proj",
-            "backbone.layers.{bid}.mixer.out_proj",
-            "model.layers.{bid}.mamba.out_proj",          # falcon-h1
+            "model.layers.{bid}.out_proj",           # mamba-hf
+            "backbone.layers.{bid}.mixer.out_proj",  # mamba
+            "model.layers.{bid}.mamba.out_proj",     # jamba falcon-h1
         ),
 
         MODEL_TENSOR.TIME_MIX_W0: (
index 8f4f2df088f70b5b425d8b83246241bead3fa981..1955c03eb3d1c9170392e38217db9e8abd59f75c 100644 (file)
@@ -46,6 +46,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_STARCODER2,       "starcoder2"       },
     { LLM_ARCH_MAMBA,            "mamba"            },
     { LLM_ARCH_MAMBA2,           "mamba2"           },
+    { LLM_ARCH_JAMBA,            "jamba"            },
     { LLM_ARCH_FALCON_H1,        "falcon-h1"        },
     { LLM_ARCH_XVERSE,           "xverse"           },
     { LLM_ARCH_COMMAND_R,        "command-r"        },
@@ -1025,6 +1026,37 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_SSM_OUT,         "blk.%d.ssm_out" },
         },
     },
+    {
+        LLM_ARCH_JAMBA,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_SSM_IN,          "blk.%d.ssm_in" },
+            { LLM_TENSOR_SSM_CONV1D,      "blk.%d.ssm_conv1d" },
+            { LLM_TENSOR_SSM_X,           "blk.%d.ssm_x" },
+            { LLM_TENSOR_SSM_DT,          "blk.%d.ssm_dt" },
+            { LLM_TENSOR_SSM_DT_NORM,     "blk.%d.ssm_dt_norm" },
+            { LLM_TENSOR_SSM_A,           "blk.%d.ssm_a" },
+            { LLM_TENSOR_SSM_B_NORM,      "blk.%d.ssm_b_norm" },
+            { LLM_TENSOR_SSM_C_NORM,      "blk.%d.ssm_c_norm" },
+            { LLM_TENSOR_SSM_D,           "blk.%d.ssm_d" },
+            { LLM_TENSOR_SSM_OUT,         "blk.%d.ssm_out" },
+            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_FFN_GATE_INP,    "blk.%d.ffn_gate_inp" },
+            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+            { LLM_TENSOR_FFN_GATE_EXPS,   "blk.%d.ffn_gate_exps" },
+            { LLM_TENSOR_FFN_DOWN_EXPS,   "blk.%d.ffn_down_exps" },
+            { LLM_TENSOR_FFN_UP_EXPS,     "blk.%d.ffn_up_exps" },
+        },
+    },
     {
         LLM_ARCH_FALCON_H1,
         {
@@ -1845,6 +1877,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
     {LLM_TENSOR_FFN_ACT,                    {LLM_TENSOR_LAYER_REPEATING, GGML_OP_DIV}},
     {LLM_TENSOR_SSM_CONV1D,                 {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}},
     {LLM_TENSOR_SSM_A,                      {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}},
+    {LLM_TENSOR_SSM_DT_NORM,                {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
+    {LLM_TENSOR_SSM_B_NORM,                 {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
+    {LLM_TENSOR_SSM_C_NORM,                 {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
     {LLM_TENSOR_SSM_D,                      {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
     {LLM_TENSOR_SSM_NORM,                   {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
     {LLM_TENSOR_TIME_MIX_LERP_X,            {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
@@ -1994,6 +2029,7 @@ bool llm_arch_is_recurrent(const llm_arch & arch) {
 bool llm_arch_is_hybrid(const llm_arch & arch) {
     // List all mamba-attention hybrid models here
     switch (arch) {
+        case LLM_ARCH_JAMBA:
         case LLM_ARCH_FALCON_H1:
             return true;
         default:
index deb3bcd5bc0e3e195b8b1c072dc33d71c90d2f97..3381b8dc4a4b79796b398755a65aefcb480bef07 100644 (file)
@@ -50,6 +50,7 @@ enum llm_arch {
     LLM_ARCH_STARCODER2,
     LLM_ARCH_MAMBA,
     LLM_ARCH_MAMBA2,
+    LLM_ARCH_JAMBA,
     LLM_ARCH_FALCON_H1,
     LLM_ARCH_XVERSE,
     LLM_ARCH_COMMAND_R,
@@ -296,7 +297,10 @@ enum llm_tensor {
     LLM_TENSOR_SSM_CONV1D,
     LLM_TENSOR_SSM_X,
     LLM_TENSOR_SSM_DT,
+    LLM_TENSOR_SSM_DT_NORM,
     LLM_TENSOR_SSM_A,
+    LLM_TENSOR_SSM_B_NORM,
+    LLM_TENSOR_SSM_C_NORM,
     LLM_TENSOR_SSM_D,
     LLM_TENSOR_SSM_NORM,
     LLM_TENSOR_SSM_OUT,
index 7f0e8c67f13253c27b9154551ece542df69ecf89..55a059d0975d25e44857966ac519608239764bc6 100644 (file)
@@ -336,22 +336,8 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
 }
 
 void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
-    mctx->get_attn()->set_input_k_idxs(self_k_idxs, ubatch);
-    mctx->get_attn()->set_input_v_idxs(self_v_idxs, ubatch);
-
-    mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
-
-    const int64_t n_rs = mctx->get_recr()->get_n_rs();
-
-    if (s_copy) {
-        GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
-        int32_t * data = (int32_t *) s_copy->data;
-
-        // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
-        for (uint32_t i = 0; i < n_rs; ++i) {
-            data[i] = mctx->get_recr()->s_copy(i);
-        }
-    }
+    inp_attn->set_input(ubatch);
+    inp_rs->set_input(ubatch);
 }
 
 void llm_graph_input_one::set_input(const llama_ubatch * ubatch) {
@@ -992,35 +978,6 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
     return pos_bias;
 }
 
-llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
-    const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
-
-    auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mctx_cur);
-
-    {
-        GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
-
-        const auto n_kv = inp->mctx->get_attn()->get_n_kv();
-
-        inp->self_k_idxs = mctx_cur->get_attn()->build_input_k_idxs(ctx0, ubatch);
-        inp->self_v_idxs = mctx_cur->get_attn()->build_input_v_idxs(ctx0, ubatch);
-
-        inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
-        ggml_set_input(inp->self_kq_mask);
-
-        inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
-    }
-
-    {
-        const auto n_rs = mctx_cur->get_recr()->get_n_rs();
-
-        inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
-        ggml_set_input(inp->s_copy);
-    }
-
-    return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
-}
-
 ggml_tensor * llm_graph_context::build_attn_mha(
          ggml_cgraph * gf,
          ggml_tensor * q,
@@ -1194,8 +1151,12 @@ ggml_tensor * llm_graph_context::build_attn(
     return cur;
 }
 
-llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
-    const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
+static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unified_impl(
+           ggml_context * ctx0,
+     const llama_ubatch & ubatch,
+    const llama_hparams & hparams,
+    const llama_cparams & cparams,
+    const llama_kv_cache_unified_context * mctx_cur) {
 
     auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
 
@@ -1203,6 +1164,7 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
         GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
 
         const auto n_kv = mctx_cur->get_n_kv();
+        const auto n_tokens = ubatch.n_tokens;
 
         inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
         inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
@@ -1213,6 +1175,14 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
         inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
     }
 
+    return inp;
+}
+
+llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
+    const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
+
+    auto inp = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
+
     return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
 }
 
@@ -1234,7 +1204,7 @@ ggml_tensor * llm_graph_context::build_attn(
     ggml_build_forward_expand(gf, k_cur);
     ggml_build_forward_expand(gf, v_cur);
 
-    const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
+    const auto * mctx_cur = inp->mctx;
 
     // store to KV cache
     {
@@ -1293,7 +1263,7 @@ ggml_tensor * llm_graph_context::build_attn(
         ggml_build_forward_expand(gf, v_cur);
     }
 
-    const auto * mctx_iswa = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
+    const auto * mctx_iswa = inp->mctx;
 
     const bool is_swa = hparams.is_swa(il);
 
@@ -1391,59 +1361,9 @@ ggml_tensor * llm_graph_context::build_attn(
     return cur;
 }
 
-ggml_tensor * llm_graph_context::build_attn(
-        llm_graph_input_mem_hybrid * inp,
-        ggml_cgraph * gf,
-        ggml_tensor * wo,
-        ggml_tensor * wo_b,
-        ggml_tensor * q_cur,
-        ggml_tensor * k_cur,
-        ggml_tensor * v_cur,
-        ggml_tensor * kq_b,
-        ggml_tensor * v_mla,
-            float     kq_scale,
-            int       il) const {
-    // these nodes are added to the graph together so that they are not reordered
-    // by doing so, the number of splits in the graph is reduced
-    ggml_build_forward_expand(gf, q_cur);
-    ggml_build_forward_expand(gf, k_cur);
-    ggml_build_forward_expand(gf, v_cur);
-
-    const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_attn();
-
-    // store to KV cache
-    {
-        const auto & k_idxs = inp->get_k_idxs();
-        const auto & v_idxs = inp->get_v_idxs();
-
-        ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
-        ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
-    }
-
-    const auto & kq_mask = inp->get_kq_mask();
-
-    ggml_tensor * q = q_cur;
-    ggml_tensor * k = mctx_cur->get_k(ctx0, il);
-    ggml_tensor * v = mctx_cur->get_v(ctx0, il);
-
-    ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
-    cb(cur, "kqv_out", il);
-
-    if (wo) {
-        cur = build_lora_mm(wo, cur);
-        if (arch == LLM_ARCH_GLM4) {
-            // GLM4 seems to have numerical issues with half-precision accumulators
-            ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
-        }
-    }
-
-    if (wo_b) {
-        cur = ggml_add(ctx0, cur, wo_b);
-    }
-
-    return cur;
-}
-
+// TODO: maybe separate the inner implementation into a separate function
+//       like with the non-sliding window equivalent
+//       once sliding-window hybrid caches are a thing.
 llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
     const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
 
@@ -1513,8 +1433,9 @@ ggml_tensor * llm_graph_context::build_rs(
     return output_states;
 }
 
-llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
-    const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
+static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
+           ggml_context * ctx0,
+    const llama_memory_recurrent_context * mctx_cur) {
 
     auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
 
@@ -1523,29 +1444,25 @@ llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
     inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
     ggml_set_input(inp->s_copy);
 
-    return (llm_graph_input_rs *) res->add_input(std::move(inp));
+    return inp;
 }
 
-ggml_tensor * llm_graph_context::build_rs(
-        llm_graph_input_rs * inp,
-        ggml_cgraph * gf,
-        ggml_tensor * s,
-            int32_t   state_size,
-            int32_t   n_seqs,
-        const llm_graph_get_rows_fn & get_state_rows) const {
-    const auto * kv_state = static_cast<const llama_memory_recurrent_context *>(mctx);
+llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
+    const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
 
-    return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
+    auto inp = build_rs_inp_impl(ctx0, mctx_cur);
+
+    return (llm_graph_input_rs *) res->add_input(std::move(inp));
 }
 
 ggml_tensor * llm_graph_context::build_rs(
-        llm_graph_input_mem_hybrid * inp,
+        llm_graph_input_rs * inp,
         ggml_cgraph * gf,
         ggml_tensor * s,
             int32_t   state_size,
             int32_t   n_seqs,
         const llm_graph_get_rows_fn & get_state_rows) const {
-    const auto * kv_state = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
+    const auto * kv_state = inp->mctx;
 
     return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
 }
@@ -1592,6 +1509,17 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
     );
 }
 
+llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
+    const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
+
+    auto inp_rs   = build_rs_inp_impl(ctx0, mctx_cur->get_recr());
+    auto inp_attn = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
+
+    auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);
+
+    return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
+}
+
 void llm_graph_context::build_pooling(
         ggml_cgraph * gf,
         ggml_tensor * cls,
index 7bdf656768a0c1d5fda6e2cd4645b3ca2a5eb492..54eaaac02b99e083d070cfbe4d72ade581cbd575 100644 (file)
@@ -322,32 +322,21 @@ public:
 class llm_graph_input_mem_hybrid : public llm_graph_input_i {
 public:
     llm_graph_input_mem_hybrid(
-            const llama_hparams & hparams,
-            const llama_cparams & cparams,
-            const llama_memory_hybrid_context * mctx) :
-        hparams(hparams),
-        cparams(cparams),
-        mctx(mctx) {
-    }
+            std::unique_ptr<llm_graph_input_attn_kv_unified> inp_attn,
+            std::unique_ptr<llm_graph_input_rs>              inp_rs,
+            const llama_memory_hybrid_context *              mctx) :
+        inp_attn(std::move(inp_attn)),
+        inp_rs(std::move(inp_rs)),
+        mctx(mctx) { }
     virtual ~llm_graph_input_mem_hybrid() = default;
 
     void set_input(const llama_ubatch * ubatch) override;
 
-    ggml_tensor * s_copy; // I32 [kv_size]
-
-    ggml_tensor * get_k_idxs() const { return self_k_idxs; }
-    ggml_tensor * get_v_idxs() const { return self_v_idxs; }
-
-    ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
-
-    ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
-    ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
+    std::unique_ptr<llm_graph_input_attn_kv_unified> inp_attn;
+    std::unique_ptr<llm_graph_input_rs>              inp_rs;
 
-    ggml_tensor * self_kq_mask     = nullptr; // F32 [n_kv, n_batch, 1, 1]
-    ggml_tensor * self_kq_mask_cnv = nullptr; //     [n_kv, n_batch, 1, 1]
-
-    const llama_hparams & hparams;
-    const llama_cparams & cparams;
+    llm_graph_input_attn_kv_unified * get_attn() const { return inp_attn.get(); }
+    llm_graph_input_rs              * get_recr() const { return inp_rs.get(); }
 
     const llama_memory_hybrid_context * mctx;
 };
@@ -579,8 +568,6 @@ struct llm_graph_context {
     ggml_tensor * build_inp_pos_bucket_dec() const;
     ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
 
-    llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
-
     //
     // attention
     //
@@ -656,18 +643,6 @@ struct llm_graph_context {
                   float   kq_scale,
                     int   il) const;
 
-    ggml_tensor * build_attn(
-            llm_graph_input_mem_hybrid * inp,
-            ggml_cgraph * gf,
-            ggml_tensor * wo,
-            ggml_tensor * wo_b,
-            ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
-            ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
-            ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
-            ggml_tensor * kq_b,
-            ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
-                  float   kq_scale,
-                    int   il) const;
     //
     // recurrent
     //
@@ -700,14 +675,6 @@ struct llm_graph_context {
                 int32_t   n_seqs,
             const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
 
-    ggml_tensor * build_rs(
-            llm_graph_input_mem_hybrid * inp,
-            ggml_cgraph * gf,
-            ggml_tensor * s,
-                int32_t   state_size,
-                int32_t   n_seqs,
-            const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
-
     ggml_tensor * build_rwkv_token_shift_load(
         llm_graph_input_rs * inp,
                ggml_cgraph * gf,
@@ -718,6 +685,11 @@ struct llm_graph_context {
              ggml_tensor * token_shift,
       const llama_ubatch & ubatch,
                      int   il) const;
+    //
+    // hybrid
+    //
+
+    llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
 
     //
     // pooling
index a1b5b1a272cc09d02db4ecd7f5aa0686640f38ed..2c1ae67098ca49445ca0305e6b71cda23f09f0ce 100644 (file)
@@ -25,9 +25,6 @@ llama_memory_recurrent::llama_memory_recurrent(
                  uint32_t    n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
     const int32_t n_layer = hparams.n_layer;
 
-    LLAMA_LOG_INFO("%s: mem_size = %u, n_seq_max = %u, type_r = '%s', type_s = '%s', n_layer = %d\n",
-            __func__, mem_size, n_seq_max, ggml_type_name(type_r), ggml_type_name(type_s), n_layer);
-
     head = 0;
     size = mem_size;
     used = 0;
@@ -84,7 +81,7 @@ llama_memory_recurrent::llama_memory_recurrent(
 
         ggml_context * ctx = ctx_for_buft(buft);
         if (!ctx) {
-            throw std::runtime_error("failed to create ggml context for kv cache");
+            throw std::runtime_error("failed to create ggml context for rs cache");
         }
 
         ggml_tensor * r = ggml_new_tensor_1d(ctx, type_r, hparams.n_embd_r()*mem_size);
@@ -102,10 +99,10 @@ llama_memory_recurrent::llama_memory_recurrent(
 
         ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
         if (!buf) {
-            throw std::runtime_error("failed to allocate buffer for kv cache");
+            throw std::runtime_error("failed to allocate buffer for rs cache");
         }
         ggml_backend_buffer_clear(buf, 0);
-        LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
+        LLAMA_LOG_INFO("%s: %10s RS buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
         bufs.emplace_back(buf);
     }
 
@@ -113,8 +110,8 @@ llama_memory_recurrent::llama_memory_recurrent(
         const size_t memory_size_r = size_r_bytes();
         const size_t memory_size_s = size_s_bytes();
 
-        LLAMA_LOG_INFO("%s: KV self size  = %7.2f MiB, R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__,
-                (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f),
+        LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__,
+                (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), mem_size, n_layer, n_seq_max,
                 ggml_type_name(type_r), (float)memory_size_r / (1024.0f * 1024.0f),
                 ggml_type_name(type_s), (float)memory_size_s / (1024.0f * 1024.0f));
     }
index 4468c837f1c28e9dd018ac0b7cf3ba3fa685f243..c21cc28806c75e5a1222e8ba1400b52ad96bb813 100644 (file)
@@ -1118,6 +1118,26 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     default: type = LLM_TYPE_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_JAMBA:
+            {
+                ml.get_key(LLM_KV_SSM_CONV_KERNEL,    hparams.ssm_d_conv);
+                ml.get_key(LLM_KV_SSM_INNER_SIZE,     hparams.ssm_d_inner);
+                ml.get_key(LLM_KV_SSM_STATE_SIZE,     hparams.ssm_d_state);
+                ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
+
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+
+                for (uint32_t i = 0; i < hparams.n_layer; ++i) {
+                    hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0;
+                }
+
+                switch (hparams.n_layer) {
+                    // TODO: Jamba layers are a bit heterogenous, so naming this is hard.
+                    case 12: // 900M  8x???M
+                    case 32: // 51B  16x?B
+                    default: type = LLM_TYPE_UNKNOWN;
+                }
+            } break;
         case LLM_ARCH_XVERSE:
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -3231,10 +3251,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                     {
                         output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
 
-                        output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
                         // if output is NULL, init from the input tok embed, duplicated to allow offloading
                         if (output == NULL) {
-                            output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
+                            output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
                         }
                     }
 
@@ -3261,6 +3281,87 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0);
                     }
                 } break;
+            case LLM_ARCH_JAMBA:
+                {
+                    const int64_t d_conv  = hparams.ssm_d_conv;
+                    const int64_t d_inner = hparams.ssm_d_inner;
+                    const int64_t d_state = hparams.ssm_d_state;
+                    const int64_t dt_rank = hparams.ssm_dt_rank;
+
+                    // only an expansion factor of 2 is supported for now
+                    GGML_ASSERT(2 * n_embd == d_inner);
+
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    {
+                        output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+
+                        output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+                        // if output is NULL, init from the input tok embed, duplicated to allow offloading
+                        if (output == NULL) {
+                            output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                        }
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        const int64_t n_head_kv = hparams.n_head_kv(i);
+                        const int64_t n_embd_gqa = hparams.n_embd_v_gqa(i);
+
+                        auto & layer = layers[i];
+
+                        // norm
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        if (n_head_kv == 0) {
+                            // Mamba layer
+                            layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner}, 0);
+
+                            layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}, 0);
+                            layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner}, 0);
+
+                            layer.ssm_x = create_tensor(tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state}, 0);
+
+                            layer.ssm_dt_norm = create_tensor(tn(LLM_TENSOR_SSM_DT_NORM, "weight", i), {dt_rank}, 0);
+
+                            layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner}, 0);
+                            layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner}, 0);
+
+                            layer.ssm_b_norm = create_tensor(tn(LLM_TENSOR_SSM_B_NORM, "weight", i), {d_state}, 0);
+                            layer.ssm_c_norm = create_tensor(tn(LLM_TENSOR_SSM_C_NORM, "weight", i), {d_state}, 0);
+
+                            // no "weight" suffix for these
+                            layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}, 0);
+                            layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {d_inner}, 0);
+
+                            // out_proj
+                            layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0);
+                        } else {
+                            // Attention layers
+
+                            layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                            layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                            layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                            layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+                        }
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED);
+
+                        if (layer.ffn_gate_inp) {
+                            // MoE
+                            layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
+                            layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0);
+                            layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd, n_ff, n_expert}, 0);
+                        } else {
+                            // FFN (no MoE)
+                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
+                            layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                            layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
+                        }
+                    }
+                } break;
             case LLM_ARCH_XVERSE:
                 {
                     tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -4910,16 +5011,6 @@ void llama_model::print_info() const {
         LLAMA_LOG_INFO("%s: freq_scale_train = %g\n",     __func__, hparams.rope_freq_scale_train);
         LLAMA_LOG_INFO("%s: n_ctx_orig_yarn  = %u\n",     __func__, hparams.n_ctx_orig_yarn);
         LLAMA_LOG_INFO("%s: rope_finetuned   = %s\n",     __func__, hparams.rope_finetuned ? "yes" : "unknown");
-    }
-
-    if (arch == LLM_ARCH_MAMBA || arch == LLM_ARCH_MAMBA2) {
-        LLAMA_LOG_INFO("%s: ssm_d_conv       = %u\n",     __func__, hparams.ssm_d_conv);
-        LLAMA_LOG_INFO("%s: ssm_d_inner      = %u\n",     __func__, hparams.ssm_d_inner);
-        LLAMA_LOG_INFO("%s: ssm_d_state      = %u\n",     __func__, hparams.ssm_d_state);
-        LLAMA_LOG_INFO("%s: ssm_dt_rank      = %u\n",     __func__, hparams.ssm_dt_rank);
-        LLAMA_LOG_INFO("%s: ssm_n_group      = %u\n",     __func__, hparams.ssm_n_group);
-        LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms   = %d\n",     __func__, hparams.ssm_dt_b_c_rms);
-
         if (!classifier_labels.empty()) {
             LLAMA_LOG_INFO("%s: n_cls_out        = %u\n", __func__, hparams.n_cls_out);
 
@@ -4930,6 +5021,18 @@ void llama_model::print_info() const {
         }
     }
 
+    if (arch == LLM_ARCH_MAMBA ||
+        arch == LLM_ARCH_MAMBA2 ||
+        arch == LLM_ARCH_JAMBA ||
+        arch == LLM_ARCH_FALCON_H1) {
+        LLAMA_LOG_INFO("%s: ssm_d_conv       = %u\n",     __func__, hparams.ssm_d_conv);
+        LLAMA_LOG_INFO("%s: ssm_d_inner      = %u\n",     __func__, hparams.ssm_d_inner);
+        LLAMA_LOG_INFO("%s: ssm_d_state      = %u\n",     __func__, hparams.ssm_d_state);
+        LLAMA_LOG_INFO("%s: ssm_dt_rank      = %u\n",     __func__, hparams.ssm_dt_rank);
+        LLAMA_LOG_INFO("%s: ssm_n_group      = %u\n",     __func__, hparams.ssm_n_group);
+        LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms   = %d\n",     __func__, hparams.ssm_dt_b_c_rms);
+    }
+
     LLAMA_LOG_INFO("%s: model type       = %s\n",     __func__, type_name().c_str());
     if (pimpl->n_elements >= 1e12) {
         LLAMA_LOG_INFO("%s: model params     = %.2f T\n", __func__, pimpl->n_elements*1e-12);
@@ -9935,62 +10038,8 @@ struct llm_build_starcoder2 : public llm_graph_context {
     }
 };
 
-struct llm_build_mamba : public llm_graph_context {
-    llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
-        ggml_tensor * cur;
-        ggml_tensor * inpL;
-
-        // {n_embd, n_tokens}
-        inpL = build_inp_embd(model.tok_embd);
-
-        auto * rs_inp = build_rs_inp();
-
-        ggml_tensor * inp_out_ids = build_inp_out_ids();
-
-        for (int il = 0; il < n_layer; ++il) {
-            // norm
-            cur = build_norm(inpL,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, il);
-            cb(cur, "attn_norm", il);
-
-            if (model.arch == LLM_ARCH_MAMBA2) {
-                cur = build_mamba2_layer(rs_inp, gf, cur, model, ubatch, il);
-            } else {
-                cur = build_mamba_layer(rs_inp, gf, cur, model, ubatch, il);
-            }
-
-            if (il == n_layer - 1 && inp_out_ids) {
-                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
-                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
-            }
-
-            // residual
-            cur = ggml_add(ctx0, cur, inpL);
-
-            cur = build_cvec(cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        // final rmsnorm
-        cur = build_norm(inpL,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, -1);
-
-        cb(cur, "result_norm", -1);
-        res->t_embd = cur;
-
-        // lm_head
-        cur = build_lora_mm(model.output, cur);
-
-        cb(cur, "result_output", -1);
-        res->t_logits = cur;
-
-        ggml_build_forward_expand(gf, cur);
-    }
+struct llm_graph_context_mamba : public llm_graph_context {
+    llm_graph_context_mamba(const llm_graph_params & params) : llm_graph_context(params) {}
 
     ggml_tensor * build_mamba_layer(
         llm_graph_input_rs * inp,
@@ -9998,11 +10047,14 @@ struct llm_build_mamba : public llm_graph_context {
                ggml_tensor * cur,
          const llama_model & model,
         const llama_ubatch & ubatch,
-                       int   il) const {
-        const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
+                       int   il) {
+
+        const auto * mctx_cur = inp->mctx;
 
         const auto kv_head = mctx_cur->get_head();
 
+        const auto & layer = model.layers[il];
+
         const int64_t d_conv  = hparams.ssm_d_conv;
         const int64_t d_inner = hparams.ssm_d_inner;
         const int64_t d_state = hparams.ssm_d_state;
@@ -10012,8 +10064,6 @@ struct llm_build_mamba : public llm_graph_context {
         const int64_t n_seqs  = ubatch.n_seqs;
         // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
         const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
-        // Use the same RMS norm as the final layer norm
-        const float norm_rms_eps = hparams.f_norm_rms_eps;
 
         const int64_t n_seq_tokens = ubatch.n_seq_tokens;
 
@@ -10031,7 +10081,7 @@ struct llm_build_mamba : public llm_graph_context {
         cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
 
         // {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs}
-        ggml_tensor * xz = build_lora_mm(model.layers[il].ssm_in, cur);
+        ggml_tensor * xz = build_lora_mm(layer.ssm_in, cur);
         // split the above in two
         // => {d_inner, n_seq_tokens, n_seqs}
         ggml_tensor * x = ggml_view_3d(ctx0, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], 0);
@@ -10060,10 +10110,10 @@ struct llm_build_mamba : public llm_graph_context {
             // then permute away the ne[0] dimension,
             // and then you're left with the resulting x tensor.
             // For simultaneous sequences, all sequences need to have the same length.
-            x = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d);
+            x = ggml_ssm_conv(ctx0, conv_x, layer.ssm_conv1d);
 
             // bias
-            x = ggml_add(ctx0, x, model.layers[il].ssm_conv1d_b);
+            x = ggml_add(ctx0, x, layer.ssm_conv1d_b);
 
             x = ggml_silu(ctx0, x);
         }
@@ -10071,27 +10121,27 @@ struct llm_build_mamba : public llm_graph_context {
         // ssm
         {
             // {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs}
-            ggml_tensor * x_db = build_lora_mm(model.layers[il].ssm_x, x);
+            ggml_tensor * x_db = build_lora_mm(layer.ssm_x, x);
             // split
             ggml_tensor * dt = ggml_view_3d(ctx0, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0);
             ggml_tensor * B  = ggml_view_4d(ctx0, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state*x_db->nb[0], x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank);
             ggml_tensor * C  = ggml_view_4d(ctx0, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state*x_db->nb[0], x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state));
 
-            // Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers
-            if (ssm_dt_b_c_rms) {
-                dt = ggml_rms_norm(ctx0, dt, norm_rms_eps);
-                B = ggml_rms_norm(ctx0, B, norm_rms_eps);
-                C = ggml_rms_norm(ctx0, C, norm_rms_eps);
+            // Some Mamba variants (e.g. FalconMamba, Jamba) apply RMS norm in B, C & Dt layers
+            if (ssm_dt_b_c_rms || (layer.ssm_dt_norm && layer.ssm_b_norm && layer.ssm_c_norm)) {
+                dt = build_norm(dt, layer.ssm_dt_norm, NULL, LLM_NORM_RMS, il);
+                B  = build_norm(B,  layer.ssm_b_norm,  NULL, LLM_NORM_RMS, il);
+                C  = build_norm(C,  layer.ssm_c_norm,  NULL, LLM_NORM_RMS, il);
             }
 
             // {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs}
-            dt = build_lora_mm(model.layers[il].ssm_dt, dt);
-            dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
+            dt = build_lora_mm(layer.ssm_dt, dt);
+            dt = ggml_add(ctx0, dt, layer.ssm_dt_b);
 
             cur = x;
             x = ggml_reshape_4d(ctx0, x, head_dim, n_head, n_seq_tokens, n_seqs);
 
-            ggml_tensor * A = model.layers[il].ssm_a;
+            ggml_tensor * A = layer.ssm_a;
 
             // use the states and the indices provided by build_recurrent_state
             // (this is necessary in order to properly use the states before they are overwritten,
@@ -10117,16 +10167,15 @@ struct llm_build_mamba : public llm_graph_context {
 
             // TODO: skip computing output earlier for unused tokens
 
-            y = ggml_add(ctx0, y, ggml_mul(ctx0, cur, model.layers[il].ssm_d));
-            y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z)));
+            y = ggml_add(ctx0, y, ggml_mul(ctx0, cur, layer.ssm_d));
+            y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y);
 
             // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
-            cur = build_lora_mm(model.layers[il].ssm_out, y);
+            cur = build_lora_mm(layer.ssm_out, y);
         }
 
         // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
         cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
-        // cb(cur, "mamba_out", il);
 
         return cur;
     }
@@ -10138,7 +10187,8 @@ struct llm_build_mamba : public llm_graph_context {
        const llama_model & model,
       const llama_ubatch & ubatch,
                      int   il) const {
-        const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
+
+        const auto * mctx_cur = inp->mctx;
 
         const auto kv_head = mctx_cur->get_head();
 
@@ -10242,11 +10292,14 @@ struct llm_build_mamba : public llm_graph_context {
             // TODO: skip computing output earlier for unused tokens
 
             y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
-            y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z)));
+            y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y);
 
             // grouped RMS norm
-            y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs);
-            y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
+            if (model.layers[il].ssm_norm) {
+                y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs);
+                y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
+            }
+
             y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs);
 
             // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
@@ -10261,6 +10314,172 @@ struct llm_build_mamba : public llm_graph_context {
     }
 };
 
+struct llm_build_mamba : public llm_graph_context_mamba {
+    llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context_mamba(params) {
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        // {n_embd, n_tokens}
+        inpL = build_inp_embd(model.tok_embd);
+
+        auto * rs_inp = build_rs_inp();
+
+        ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+        for (int il = 0; il < n_layer; ++il) {
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            if (model.arch == LLM_ARCH_MAMBA2) {
+                cur = build_mamba2_layer(rs_inp, gf, cur, model, ubatch, il);
+            } else {
+                cur = build_mamba_layer(rs_inp, gf, cur, model, ubatch, il);
+            }
+
+            if (il == n_layer - 1 && inp_out_ids) {
+                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+            }
+
+            // residual
+            cur = ggml_add(ctx0, cur, inpL);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        // final rmsnorm
+        cur = build_norm(inpL, model.output_norm, NULL, LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+
+};
+
+struct llm_build_jamba : public llm_graph_context_mamba {
+    llm_build_jamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context_mamba(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        // {n_embd, n_tokens}
+        inpL = build_inp_embd(model.tok_embd);
+
+        auto * inp_hybrid = build_inp_mem_hybrid();
+
+        ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+        for (int il = 0; il < n_layer; ++il) {
+            const int64_t n_head_kv = hparams.n_head_kv(il);
+
+            cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            if (n_head_kv == 0) {
+                cur = build_mamba_layer(inp_hybrid->get_recr(), gf, cur, model, ubatch, il);
+            } else {
+                // Attention
+
+                struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                struct ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                struct ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                // No RoPE :)
+                cur = build_attn(inp_hybrid->get_attn(), gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1 && inp_out_ids) {
+                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+            }
+
+            // residual
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, inpL, cur);
+            cb(cur, "ffn_inp", il);
+
+            cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            // feed-forward network
+            if (model.layers[il].ffn_gate_inp == nullptr) {
+                // FFN
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            } else {
+                // MoE branch
+                cur = build_moe_ffn(cur,
+                        model.layers[il].ffn_gate_inp,
+                        model.layers[il].ffn_up_exps,
+                        model.layers[il].ffn_gate_exps,
+                        model.layers[il].ffn_down_exps,
+                        nullptr,
+                        n_expert, n_expert_used,
+                        LLM_FFN_SILU, false,
+                        false, 0.0,
+                        LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
+                        il);
+                cb(cur, "ffn_moe_out", il);
+            }
+
+            // residual
+            cur = ggml_add(ctx0, ffn_inp, cur);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        // final rmsnorm
+        cur = build_norm(inpL, model.output_norm, NULL, LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
 struct llm_build_command_r : public llm_graph_context {
     llm_build_command_r(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -14706,10 +14925,8 @@ struct llm_build_ernie4_5 : public llm_graph_context {
     }
 };
 
-struct llm_build_falcon_h1 : public llm_graph_context {
-    const llama_model & model;
-
-    llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params), model(model)  {
+struct llm_build_falcon_h1 : public llm_graph_context_mamba {
+    llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context_mamba(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
 
         ggml_tensor * cur;
@@ -14765,7 +14982,7 @@ struct llm_build_falcon_h1 : public llm_graph_context {
             cb(Kcur, "Kcur-post-rope", il);
             cb(Vcur, "Vcur-post-rope", il);
 
-            ggml_tensor * attn_out = build_attn(inp, gf,
+            ggml_tensor * attn_out = build_attn(inp->get_attn(), gf,
                     model.layers[il].wo, NULL,
                     Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
             cb(attn_out, "attn_out", il);
@@ -14776,7 +14993,7 @@ struct llm_build_falcon_h1 : public llm_graph_context {
             // Mamba2 layer
             cb(cur, "ssm_in", il);
 
-            ggml_tensor * ssm_out = build_mamba2_layer(inp, gf, cur, ubatch, il);
+            ggml_tensor * ssm_out = build_mamba2_layer(inp->get_recr(), gf, cur, model, ubatch, il);
             cb(ssm_out, "ssm_out", il);
 
             // // Aggregation
@@ -14832,139 +15049,6 @@ struct llm_build_falcon_h1 : public llm_graph_context {
 
         ggml_build_forward_expand(gf, cur);
     }
-
-    ggml_tensor * build_mamba2_layer(
-        llm_graph_input_mem_hybrid * inp,
-                                ggml_cgraph * gf,
-                                ggml_tensor * cur,
-                         const llama_ubatch & ubatch,
-                                      int   il) const {
-        const auto * kv_state = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
-
-        const auto kv_head = kv_state->get_head();
-
-        const int64_t d_conv  = hparams.ssm_d_conv;
-        const int64_t d_inner = hparams.ssm_d_inner;
-        const int64_t d_state = hparams.ssm_d_state;
-        const int64_t n_head  = hparams.ssm_dt_rank;
-        const int64_t head_dim = d_inner / n_head;
-        const int64_t n_group = hparams.ssm_n_group;
-        const int64_t n_seqs  = ubatch.n_seqs;
-
-        const int64_t n_seq_tokens = ubatch.n_seq_tokens;
-
-        GGML_ASSERT(n_seqs != 0);
-        GGML_ASSERT(ubatch.equal_seqs);
-        GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
-
-        ggml_tensor * conv_states_all = kv_state->get_r_l(il);
-        ggml_tensor * ssm_states_all  = kv_state->get_s_l(il);
-
-        ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs);
-        conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs);
-
-        // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
-        cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
-
-        // d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
-
-        // {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs}
-        ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur);
-        cb(zxBCdt, "zxBCdt", il);
-
-        // split the above in three
-        ggml_tensor * z = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*zxBCdt->nb[0], zxBCdt->nb[1], zxBCdt->nb[2], 0);
-        ggml_tensor * xBC = ggml_view_3d(ctx0, zxBCdt, d_inner + 2*n_group*d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], d_inner*ggml_element_size(zxBCdt));
-        ggml_tensor * dt = ggml_view_3d(ctx0, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], (2*d_inner + 2*n_group*d_state)*ggml_element_size(zxBCdt));
-
-        // conv
-        {
-            // => {d_conv - 1 + n_seq_tokens, d_inner + 2*n_group*d_state, n_seqs}
-            ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, xBC), 0);
-
-            // copy last (d_conv - 1) columns back into the state cache
-            ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0]));
-
-            ggml_build_forward_expand(gf,
-                ggml_cpy(ctx0, last_conv,
-                    ggml_view_1d(ctx0, conv_states_all,
-                        (d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs),
-                        kv_head*(d_conv - 1)*(d_inner + 2*n_group*d_state)*ggml_element_size(conv_states_all))));
-
-            // 1D convolution
-            // The equivalent is to make a self-overlapping view of conv_x
-            // over d_conv columns at each stride in the 3rd dimension,
-            // then element-wise multiply that with the conv1d weight,
-            // then sum the elements of each row,
-            // (the last two steps are a dot product over rows (also doable with mul_mat))
-            // then permute away the ne[0] dimension,
-            // and then you're left with the resulting x tensor.
-            // For simultaneous sequences, all sequences need to have the same length.
-            xBC = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d);
-
-            // bias
-            xBC = ggml_add(ctx0, xBC, model.layers[il].ssm_conv1d_b);
-
-            xBC = ggml_silu(ctx0, xBC);
-        }
-
-        // ssm
-        {
-            // These correspond to V K Q in SSM/attention duality
-            ggml_tensor * x = ggml_view_4d(ctx0, xBC, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*xBC->nb[0], xBC->nb[1], xBC->nb[2], 0);
-
-            ggml_tensor * B = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], d_inner*ggml_element_size(xBC));
-
-            ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], (d_inner + n_group*d_state)*ggml_element_size(xBC));
-
-            // {n_head, n_seq_tokens, n_seqs}
-            dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b);
-
-            ggml_tensor * A = model.layers[il].ssm_a;
-
-            // use the states and the indices provided by build_rs
-            // (this is necessary in order to properly use the states before they are overwritten,
-            //  while avoiding to make unnecessary copies of the states)
-            auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
-                ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, kv_state->get_size());
-
-                // TODO: use semistructured matrices to implement state-space duality
-                // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
-                return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
-            };
-
-            ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);
-
-            // store last states
-            ggml_build_forward_expand(gf,
-                ggml_cpy(ctx0,
-                    ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, ggml_nelements(x)*x->nb[0]),
-                    ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all))));
-
-            ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_head, n_seq_tokens, n_seqs, x->nb[1], n_head*x->nb[1], n_seq_tokens*n_head*x->nb[1], 0);
-
-            // TODO: skip computing output earlier for unused tokens
-
-            y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
-            y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y);
-
-            // grouped RMS norm
-            if (model.layers[il].ssm_norm) {
-                y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs);
-                y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
-            }
-
-            y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs);
-
-            // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
-            cur = build_lora_mm(model.layers[il].ssm_out, y);
-        }
-
-        // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
-        cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
-        cb(cur, "mamba_out", il);
-        return cur;
-    }
 };
 
 struct llm_build_arcee : public llm_graph_context {
@@ -15641,6 +15725,10 @@ llm_graph_result_ptr llama_model::build_graph(
             {
                 llm = std::make_unique<llm_build_mamba>(*this, params, gf);
             } break;
+        case LLM_ARCH_JAMBA:
+            {
+                llm = std::make_unique<llm_build_jamba>(*this, params, gf);
+            } break;
         case LLM_ARCH_XVERSE:
             {
                 llm = std::make_unique<llm_build_xverse>(*this, params, gf);
@@ -15911,6 +15999,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_BLOOM:
         case LLM_ARCH_MAMBA:
         case LLM_ARCH_MAMBA2:
+        case LLM_ARCH_JAMBA:
         case LLM_ARCH_JINA_BERT_V2:
         case LLM_ARCH_T5:
         case LLM_ARCH_T5ENCODER:
index 70a6dc89e1b06ea24a3c253bdb716f22123e313a..453f5af62fbc7e01f18e368a2f894e2d4eb9234f 100644 (file)
@@ -174,6 +174,9 @@ struct llama_layer {
     struct ggml_tensor * attn_norm_cross = nullptr;
     struct ggml_tensor * attn_norm_enc   = nullptr;
     struct ggml_tensor * ssm_norm        = nullptr;
+    struct ggml_tensor * ssm_dt_norm     = nullptr;
+    struct ggml_tensor * ssm_b_norm      = nullptr;
+    struct ggml_tensor * ssm_c_norm      = nullptr;
 
     // attention
     struct ggml_tensor * wq        = nullptr;