]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
model : gemma3n text-only (#14400)
authorXuan-Son Nguyen <redacted>
Thu, 26 Jun 2025 17:34:02 +0000 (19:34 +0200)
committerGitHub <redacted>
Thu, 26 Jun 2025 17:34:02 +0000 (20:34 +0300)
* gemma3n

* add llm_graph_input_one

13 files changed:
convert_hf_to_gguf.py
gguf-py/gguf/constants.py
gguf-py/gguf/gguf_writer.py
gguf-py/gguf/tensor_mapping.py
src/llama-arch.cpp
src/llama-arch.h
src/llama-graph.cpp
src/llama-graph.h
src/llama-hparams.h
src/llama-kv-cache-unified.cpp
src/llama-model.cpp
src/llama-model.h
src/llama-quant.cpp

index bbf8b30ff5324f0bf8fff3b613ad11def1dc9aad..4f2339a02a13c4b0d28b2631adf07d4506a4ee13 100755 (executable)
@@ -310,6 +310,8 @@ class ModelBase:
                             gguf.MODEL_TENSOR.POSNET_NORM2,
                             gguf.MODEL_TENSOR.V_ENC_EMBD_POS,
                             gguf.MODEL_TENSOR.A_ENC_EMBD_POS,
+                            gguf.MODEL_TENSOR.ALTUP_CORRECT_COEF,
+                            gguf.MODEL_TENSOR.ALTUP_PREDICT_COEF,
                         )
                     )
                     or not new_name.endswith(".weight")
@@ -320,7 +322,11 @@ class ModelBase:
                     self.match_model_tensor_name(new_name, key, bid)
                     for key in (
                         gguf.MODEL_TENSOR.TOKEN_EMBD,
+                        gguf.MODEL_TENSOR.PER_LAYER_TOKEN_EMBD,
                         gguf.MODEL_TENSOR.OUTPUT,
+                        gguf.MODEL_TENSOR.ALTUP_ROUTER,
+                        gguf.MODEL_TENSOR.LAUREL_L,
+                        gguf.MODEL_TENSOR.LAUREL_R,
                     )
                 ):
                     if self.ftype in (
@@ -921,13 +927,16 @@ class TextModel(ModelBase):
         tokenizer = SentencePieceProcessor()
         tokenizer.LoadFromFile(str(tokenizer_path))
 
-        vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
+        vocab_size = self.find_hparam([
+            "vocab_size_per_layer_input", # gemma3n
+            "vocab_size",
+        ], optional=True) or tokenizer.vocab_size()
 
         tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
         scores: list[float] = [-10000.0] * vocab_size
         toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size
 
-        for token_id in range(tokenizer.vocab_size()):
+        for token_id in range(vocab_size):
             piece = tokenizer.IdToPiece(token_id)
             text = piece.encode("utf-8")
             score = tokenizer.GetScore(token_id)
@@ -942,6 +951,10 @@ class TextModel(ModelBase):
             elif tokenizer.IsByte(token_id):
                 toktype = SentencePieceTokenTypes.BYTE
 
+            if token_id >= vocab_size:
+                logger.warning(f'ignore tokens from {token_id}: id is out of range, max={vocab_size - 1}')
+                break
+
             tokens[token_id] = text
             scores[token_id] = score
             toktypes[token_id] = toktype
@@ -4217,6 +4230,7 @@ class Gemma2Model(TextModel):
 @ModelBase.register("Gemma3ForCausalLM", "Gemma3ForConditionalGeneration")
 class Gemma3Model(TextModel):
     model_arch = gguf.MODEL_ARCH.GEMMA3
+    norm_shift = 1.0  # Gemma3RMSNorm adds 1.0 to the norm value
 
     def set_vocab(self):
         self._set_vocab_sentencepiece()
@@ -4238,9 +4252,8 @@ class Gemma3Model(TextModel):
         self.gguf_writer.add_value_length(hparams.get("head_dim", 256))
         self.gguf_writer.add_file_type(self.ftype)
         self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1_000_000.0)) # for global layers
-        # both attn_logit_softcapping and final_logit_softcapping are removed in Gemma3
+        # attn_logit_softcapping is removed in Gemma3
         assert hparams.get("attn_logit_softcapping") is None
-        assert hparams.get("final_logit_softcapping") is None
         self.gguf_writer.add_sliding_window(hparams["sliding_window"])
         self.gguf_writer.add_head_count_kv(hparams.get("num_key_value_heads", 4))
         if hparams.get("rope_scaling") is not None:
@@ -4252,7 +4265,7 @@ class Gemma3Model(TextModel):
     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
         del bid  # unused
 
-        if name.startswith("language_model."):
+        if "language_model." in name:
             name = name.replace("language_model.", "")
 
         elif name.startswith("multi_modal_projector.") or name.startswith("vision_tower.") \
@@ -4267,8 +4280,9 @@ class Gemma3Model(TextModel):
 
         # ref code in Gemma3RMSNorm
         # output = output * (1.0 + self.weight.float())
+        # note: this is not the case on gemma3n
         if name.endswith("norm.weight"):
-            data_torch = data_torch + 1
+            data_torch = data_torch + self.norm_shift
 
         return [(self.map_tensor_name(name), data_torch)]
 
@@ -4325,6 +4339,104 @@ class Gemma3VisionModel(MmprojModel):
         return [] # skip other tensors
 
 
+@ModelBase.register("Gemma3nForConditionalGeneration")
+class Gemma3NModel(Gemma3Model):
+    model_arch = gguf.MODEL_ARCH.GEMMA3N
+    norm_shift = 0.0 # same value with Gemma3p5RMSNorm scale_shift on python code
+
+    _altup_proj: list[Tensor] = []
+    _altup_unembd: list[Tensor] = []
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        assert self.hparams["altup_num_inputs"] == 4, "Current conversion only supports 4 altup inputs"
+        self._altup_proj = [
+            torch.Tensor(), # to be replaced
+            torch.Tensor(), # to be replaced
+            torch.Tensor(), # to be replaced
+        ]
+        self._altup_unembd = [
+            torch.Tensor(), # to be replaced
+            torch.Tensor(), # to be replaced
+            torch.Tensor(), # to be replaced
+        ]
+
+    def set_vocab(self):
+        with open(self.dir_model / "chat_template.jinja") as f:
+            # quick hack to make sure chat template is added
+            self.gguf_writer.add_chat_template(f.read())
+        super().set_vocab()
+
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+        self.gguf_writer.add_altup_active_idx(self.hparams["altup_active_idx"])
+        self.gguf_writer.add_altup_num_inputs(self.hparams["altup_num_inputs"])
+        self.gguf_writer.add_embedding_length_per_layer_input(self.hparams["hidden_size_per_layer_input"])
+        self.gguf_writer.add_shared_kv_layers(self.hparams["num_kv_shared_layers"])
+
+        activation_sparsity_scale = []
+        for s in self.hparams["activation_sparsity_pattern"]:
+            normal_dist = torch.distributions.normal.Normal(0, 1)
+            std_multiplier = normal_dist.icdf(torch.tensor(s, dtype=torch.float32))
+            activation_sparsity_scale.append(std_multiplier.item())
+        self.gguf_writer.add_activation_sparsity_scale(activation_sparsity_scale)
+
+        sliding_window_pattern = []
+        for t in self.hparams["layer_types"]:
+            sliding_window_pattern.append(t == "sliding_attention")
+        self.gguf_writer.add_sliding_window_pattern(sliding_window_pattern)
+
+    def _stack_matrices(self, matrices: list[Tensor]) -> Tensor | None:
+        has_all = all(m.numel() > 0 for m in matrices)
+        if not has_all:
+            return None
+        else:
+            return torch.stack(matrices, dim=0)
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        if name.endswith("_scale"):
+            name = name + ".weight"
+
+        # TODO: implement self.prediction_coefs.weight.clamp_(...)
+
+        if "language_model." not in name:
+            return [] # skip non-language model tensors
+
+        if "altup_unembed_projections" in name:
+            data_torch = data_torch.to(device="cpu")
+            if ".0." in name:
+                self._altup_unembd[0] = data_torch
+            elif ".1." in name:
+                self._altup_unembd[1] = data_torch
+            elif ".2." in name:
+                self._altup_unembd[2] = data_torch
+            else:
+                raise ValueError(f"Unknown name: {name}")
+            out = self._stack_matrices(self._altup_unembd)
+            if out is not None:
+                return [(self.map_tensor_name("model.altup_unembed_projections.weight"), out)]
+            else:
+                return []
+
+        if "altup_projections" in name:
+            data_torch = data_torch.to(device="cpu")
+            if ".0." in name:
+                self._altup_proj[0] = data_torch
+            elif ".1." in name:
+                self._altup_proj[1] = data_torch
+            elif ".2." in name:
+                self._altup_proj[2] = data_torch
+            else:
+                raise ValueError(f"Unknown name: {name}")
+            out = self._stack_matrices(self._altup_proj)
+            if out is not None:
+                return [(self.map_tensor_name("model.altup_projections.weight"), out)]
+            else:
+                return []
+
+        return super().modify_tensors(data_torch, name, bid)
+
+
 @ModelBase.register("Starcoder2ForCausalLM")
 class StarCoder2Model(TextModel):
     model_arch = gguf.MODEL_ARCH.STARCODER2
index 0429b0aaf135dddf0f13f5a9e43f9889e1f7c641..fb75143b0b54586efb100af1add9568986a5dd49 100644 (file)
@@ -118,6 +118,10 @@ class Keys:
         EMBEDDING_SCALE                   = "{arch}.embedding_scale"
         TOKEN_SHIFT_COUNT                 = "{arch}.token_shift_count"
         INTERLEAVE_MOE_LAYER_STEP         = "{arch}.interleave_moe_layer_step"
+        ACTIVATION_SPARSITY_SCALE         = "{arch}.activation_sparsity_scale"
+        ALTUP_ACTIVE_IDX                  = "{arch}.altup.active_idx"
+        ALTUP_NUM_INPUTS                  = "{arch}.altup.num_inputs"
+        EMBD_LENGTH_PER_LAYER_INP         = "{arch}.embedding_length_per_layer_input"
 
     class Attention:
         HEAD_COUNT                   = "{arch}.attention.head_count"
@@ -142,6 +146,8 @@ class Keys:
         SCALE                        = "{arch}.attention.scale"
         KEY_LENGTH_MLA               = "{arch}.attention.key_length_mla"
         VALUE_LENGTH_MLA             = "{arch}.attention.value_length_mla"
+        SHARED_KV_LAYERS             = "{arch}.attention.shared_kv_layers"
+        SLIDING_WINDOW_PATTERN       = "{arch}.attention.sliding_window_pattern"
 
     class Rope:
         DIMENSION_COUNT         = "{arch}.rope.dimension_count"
@@ -314,6 +320,7 @@ class MODEL_ARCH(IntEnum):
     GEMMA            = auto()
     GEMMA2           = auto()
     GEMMA3           = auto()
+    GEMMA3N          = auto()
     STARCODER2       = auto()
     RWKV6            = auto()
     RWKV6QWEN2       = auto()
@@ -399,6 +406,22 @@ class MODEL_TENSOR(IntEnum):
     ATTN_Q_NORM          = auto()
     ATTN_K_NORM          = auto()
     LAYER_OUT_NORM       = auto()
+    PER_LAYER_TOKEN_EMBD = auto() # gemma3n
+    PER_LAYER_MODEL_PROJ = auto() # gemma3n
+    PER_LAYER_INP_GATE   = auto() # gemma3n
+    PER_LAYER_PROJ       = auto() # gemma3n
+    PER_LAYER_PROJ_NORM  = auto() # gemma3n
+    PER_LAYER_POST_NORM  = auto() # gemma3n
+    ALTUP_PROJ           = auto() # gemma3n
+    ALTUP_UNEMBD_PROJ    = auto() # gemma3n
+    ALTUP_CORRECT_COEF   = auto() # gemma3n
+    ALTUP_CORRECT_SCALE  = auto() # gemma3n
+    ALTUP_PREDICT_COEF   = auto() # gemma3n
+    ALTUP_ROUTER         = auto() # gemma3n
+    ALTUP_ROUTER_NORM    = auto() # gemma3n
+    LAUREL_L             = auto() # gemma3n
+    LAUREL_R             = auto() # gemma3n
+    LAUREL_POST_NORM     = auto() # gemma3n
     SSM_IN               = auto()
     SSM_CONV1D           = auto()
     SSM_X                = auto()
@@ -597,6 +620,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.GEMMA:            "gemma",
     MODEL_ARCH.GEMMA2:           "gemma2",
     MODEL_ARCH.GEMMA3:           "gemma3",
+    MODEL_ARCH.GEMMA3N:          "gemma3n",
     MODEL_ARCH.STARCODER2:       "starcoder2",
     MODEL_ARCH.RWKV6:            "rwkv6",
     MODEL_ARCH.RWKV6QWEN2:       "rwkv6qwen2",
@@ -682,6 +706,22 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
     MODEL_TENSOR.FFN_UP_EXP:                "blk.{bid}.ffn_up_exps",
     MODEL_TENSOR.FFN_EXP_PROBS_B:           "blk.{bid}.exp_probs_b",
     MODEL_TENSOR.LAYER_OUT_NORM:            "blk.{bid}.layer_output_norm",
+    MODEL_TENSOR.PER_LAYER_TOKEN_EMBD:      "per_layer_token_embd",           # gemma3n
+    MODEL_TENSOR.PER_LAYER_MODEL_PROJ:      "per_layer_model_proj",           # gemma3n
+    MODEL_TENSOR.PER_LAYER_PROJ_NORM:       "per_layer_proj_norm",            # gemma3n
+    MODEL_TENSOR.ALTUP_UNEMBD_PROJ:         "altup_unembd_proj",              # gemma3n
+    MODEL_TENSOR.ALTUP_PROJ:                "altup_proj",                     # gemma3n
+    MODEL_TENSOR.PER_LAYER_INP_GATE:        "blk.{bid}.inp_gate",             # gemma3n
+    MODEL_TENSOR.PER_LAYER_PROJ:            "blk.{bid}.proj",                 # gemma3n
+    MODEL_TENSOR.PER_LAYER_POST_NORM:       "blk.{bid}.post_norm",            # gemma3n
+    MODEL_TENSOR.ALTUP_CORRECT_COEF:        "blk.{bid}.altup_correct_coef",   # gemma3n
+    MODEL_TENSOR.ALTUP_CORRECT_SCALE:       "blk.{bid}.altup_correct_scale",  # gemma3n
+    MODEL_TENSOR.ALTUP_PREDICT_COEF:        "blk.{bid}.altup_predict_coef",   # gemma3n
+    MODEL_TENSOR.ALTUP_ROUTER:              "blk.{bid}.altup_router",         # gemma3n
+    MODEL_TENSOR.ALTUP_ROUTER_NORM:         "blk.{bid}.altup_router_norm",    # gemma3n
+    MODEL_TENSOR.LAUREL_L:                  "blk.{bid}.laurel_l",             # gemma3n
+    MODEL_TENSOR.LAUREL_R:                  "blk.{bid}.laurel_r",             # gemma3n
+    MODEL_TENSOR.LAUREL_POST_NORM:          "blk.{bid}.laurel_post_norm",     # gemma3n
     MODEL_TENSOR.SSM_IN:                    "blk.{bid}.ssm_in",
     MODEL_TENSOR.SSM_CONV1D:                "blk.{bid}.ssm_conv1d",
     MODEL_TENSOR.SSM_X:                     "blk.{bid}.ssm_x",
@@ -1486,6 +1526,41 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_PRE_NORM,
         MODEL_TENSOR.FFN_POST_NORM,
     ],
+    MODEL_ARCH.GEMMA3N: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.ATTN_Q,
+        MODEL_TENSOR.ATTN_Q_NORM,
+        MODEL_TENSOR.ATTN_K,
+        MODEL_TENSOR.ATTN_K_NORM,
+        MODEL_TENSOR.ATTN_V,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.FFN_GATE,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_POST_NORM,
+        MODEL_TENSOR.FFN_PRE_NORM,
+        MODEL_TENSOR.FFN_POST_NORM,
+        # altup / laurel
+        MODEL_TENSOR.PER_LAYER_TOKEN_EMBD,
+        MODEL_TENSOR.PER_LAYER_MODEL_PROJ,
+        MODEL_TENSOR.PER_LAYER_INP_GATE,
+        MODEL_TENSOR.PER_LAYER_PROJ,
+        MODEL_TENSOR.PER_LAYER_PROJ_NORM,
+        MODEL_TENSOR.PER_LAYER_POST_NORM,
+        MODEL_TENSOR.ALTUP_PROJ,
+        MODEL_TENSOR.ALTUP_UNEMBD_PROJ,
+        MODEL_TENSOR.ALTUP_CORRECT_COEF,
+        MODEL_TENSOR.ALTUP_CORRECT_SCALE,
+        MODEL_TENSOR.ALTUP_PREDICT_COEF,
+        MODEL_TENSOR.ALTUP_ROUTER,
+        MODEL_TENSOR.ALTUP_ROUTER_NORM,
+        MODEL_TENSOR.LAUREL_L,
+        MODEL_TENSOR.LAUREL_R,
+        MODEL_TENSOR.LAUREL_POST_NORM,
+    ],
     MODEL_ARCH.STARCODER2: [
         MODEL_TENSOR.TOKEN_EMBD,
         MODEL_TENSOR.OUTPUT_NORM,
index b9b63d052624d81c53852381a86e26129efb5261..d32cd479adb17aa5f7cac4d5b75ba9e7f1a3deac 100644 (file)
@@ -672,6 +672,18 @@ class GGUFWriter:
     def add_decoder_start_token_id(self, id: int) -> None:
         self.add_uint32(Keys.LLM.DECODER_START_TOKEN_ID.format(arch=self.arch), id)
 
+    def add_embedding_length_per_layer_input(self, value: int) -> None:
+        self.add_uint32(Keys.LLM.EMBD_LENGTH_PER_LAYER_INP.format(arch=self.arch), value)
+
+    def add_altup_active_idx(self, val: int) -> None:
+        self.add_uint32(Keys.LLM.ALTUP_ACTIVE_IDX.format(arch=self.arch), val)
+
+    def add_altup_num_inputs(self, val: int) -> None:
+        self.add_uint32(Keys.LLM.ALTUP_NUM_INPUTS.format(arch=self.arch), val)
+
+    def add_activation_sparsity_scale(self, values: Sequence[float]) -> None:
+        self.add_array(Keys.LLM.ACTIVATION_SPARSITY_SCALE.format(arch=self.arch), values)
+
     def add_head_count(self, count: int | Sequence[int]) -> None:
         if isinstance(count, int):
             self.add_uint32(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count)
@@ -702,6 +714,12 @@ class GGUFWriter:
     def add_clamp_kqv(self, value: float) -> None:
         self.add_float32(Keys.Attention.CLAMP_KQV.format(arch=self.arch), value)
 
+    def add_shared_kv_layers(self, value: float) -> None:
+        self.add_float32(Keys.Attention.SHARED_KV_LAYERS.format(arch=self.arch), value)
+
+    def add_sliding_window_pattern(self, value: Sequence[bool]) -> None:
+        self.add_array(Keys.Attention.SLIDING_WINDOW_PATTERN.format(arch=self.arch), value)
+
     def add_logit_scale(self, value: float) -> None:
         self.add_float32(Keys.LLM.LOGIT_SCALE.format(arch=self.arch), value)
 
index 79f044d2a5945236b613ac3733ccc2b60ebd9ba4..b30f77dbe3be74c75a7e2049d97f587543216235 100644 (file)
@@ -480,6 +480,70 @@ class TensorNameMap:
             "encoder.layer.{bid}.layer_norm_2"              # jina-v2-code
         ),
 
+        MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: (
+            "model.embed_tokens_per_layer",  # gemma3n
+        ),
+
+        MODEL_TENSOR.PER_LAYER_MODEL_PROJ: (
+            "model.per_layer_model_projection",  # gemma3n
+        ),
+
+        MODEL_TENSOR.PER_LAYER_PROJ_NORM: (
+            "model.per_layer_projection_norm",  # gemma3n
+        ),
+
+        MODEL_TENSOR.ALTUP_PROJ: (
+            "model.altup_projections",  # gemma3n
+        ),
+
+        MODEL_TENSOR.ALTUP_UNEMBD_PROJ: (
+            "model.altup_unembed_projections",  # gemma3n
+        ),
+
+        MODEL_TENSOR.PER_LAYER_INP_GATE: (
+            "model.layers.{bid}.per_layer_input_gate",  # gemma3n
+        ),
+
+        MODEL_TENSOR.PER_LAYER_PROJ: (
+            "model.layers.{bid}.per_layer_projection",  # gemma3n
+        ),
+
+        MODEL_TENSOR.PER_LAYER_POST_NORM: (
+            "model.layers.{bid}.post_per_layer_input_norm",  # gemma3n
+        ),
+
+        MODEL_TENSOR.ALTUP_CORRECT_COEF: (
+            "model.layers.{bid}.altup.correction_coefs",  # gemma3n
+        ),
+
+        MODEL_TENSOR.ALTUP_CORRECT_SCALE: (
+            "model.layers.{bid}.altup.correct_output_scale",  # gemma3n
+        ),
+
+        MODEL_TENSOR.ALTUP_PREDICT_COEF: (
+            "model.layers.{bid}.altup.prediction_coefs",  # gemma3n
+        ),
+
+        MODEL_TENSOR.ALTUP_ROUTER: (
+            "model.layers.{bid}.altup.modality_router",  # gemma3n
+        ),
+
+        MODEL_TENSOR.ALTUP_ROUTER_NORM: (
+            "model.layers.{bid}.altup.router_norm",  # gemma3n
+        ),
+
+        MODEL_TENSOR.LAUREL_L: (
+            "model.layers.{bid}.laurel.linear_left",  # gemma3n
+        ),
+
+        MODEL_TENSOR.LAUREL_R: (
+            "model.layers.{bid}.laurel.linear_right",  # gemma3n
+        ),
+
+        MODEL_TENSOR.LAUREL_POST_NORM: (
+            "model.layers.{bid}.laurel.post_laurel_norm",  # gemma3n
+        ),
+
         MODEL_TENSOR.SSM_IN: (
             "model.layers.{bid}.in_proj",
             "backbone.layers.{bid}.mixer.in_proj",
index 8dadef204f9d71f45039c4402aedcb5e923683ad..435e3b9ba3db8b50521cea4a384b137b014d3e7c 100644 (file)
@@ -42,6 +42,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_GEMMA,            "gemma"            },
     { LLM_ARCH_GEMMA2,           "gemma2"           },
     { LLM_ARCH_GEMMA3,           "gemma3"           },
+    { LLM_ARCH_GEMMA3N,          "gemma3n"          },
     { LLM_ARCH_STARCODER2,       "starcoder2"       },
     { LLM_ARCH_MAMBA,            "mamba"            },
     { LLM_ARCH_XVERSE,           "xverse"           },
@@ -932,6 +933,42 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_FFN_POST_NORM,   "blk.%d.post_ffw_norm" },
         },
     },
+    {
+        LLM_ARCH_GEMMA3N,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,           "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,          "output_norm" },
+            { LLM_TENSOR_ATTN_NORM,            "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,               "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_Q_NORM,          "blk.%d.attn_q_norm" },
+            { LLM_TENSOR_ATTN_K,               "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_K_NORM,          "blk.%d.attn_k_norm" },
+            { LLM_TENSOR_ATTN_V,               "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,             "blk.%d.attn_output" },
+            { LLM_TENSOR_ATTN_POST_NORM,       "blk.%d.post_attention_norm" },
+            { LLM_TENSOR_FFN_NORM,             "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE,             "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,             "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,               "blk.%d.ffn_up" },
+            { LLM_TENSOR_FFN_POST_NORM,        "blk.%d.post_ffw_norm" },
+            { LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "per_layer_token_embd" },
+            { LLM_TENSOR_PER_LAYER_MODEL_PROJ, "per_layer_model_proj" },
+            { LLM_TENSOR_PER_LAYER_PROJ_NORM,  "per_layer_proj_norm" },
+            { LLM_TENSOR_ALTUP_UNEMBD_PROJ,    "altup_unembd_proj" },
+            { LLM_TENSOR_ALTUP_PROJ,           "altup_proj" },
+            { LLM_TENSOR_PER_LAYER_INP_GATE,   "blk.%d.inp_gate" },
+            { LLM_TENSOR_PER_LAYER_PROJ,       "blk.%d.proj" },
+            { LLM_TENSOR_PER_LAYER_POST_NORM,  "blk.%d.post_norm" },
+            { LLM_TENSOR_ALTUP_CORRECT_COEF,   "blk.%d.altup_correct_coef" },
+            { LLM_TENSOR_ALTUP_CORRECT_SCALE,  "blk.%d.altup_correct_scale" },
+            { LLM_TENSOR_ALTUP_PREDICT_COEF,   "blk.%d.altup_predict_coef" },
+            { LLM_TENSOR_ALTUP_ROUTER,         "blk.%d.altup_router" },
+            { LLM_TENSOR_ALTUP_ROUTER_NORM,    "blk.%d.altup_router_norm" },
+            { LLM_TENSOR_LAUREL_L,             "blk.%d.laurel_l" },
+            { LLM_TENSOR_LAUREL_R,             "blk.%d.laurel_r" },
+            { LLM_TENSOR_LAUREL_POST_NORM,     "blk.%d.laurel_post_norm" },
+        },
+    },
     {
         LLM_ARCH_STARCODER2,
         {
@@ -1749,6 +1786,23 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
     {LLM_TENSOR_FFN_GATE_EXPS,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
     {LLM_TENSOR_FFN_UP_EXPS,                {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
     {LLM_TENSOR_FFN_EXP_PROBS_B,            {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
+    // altup / laurel (gemma 3n)
+    {LLM_TENSOR_PER_LAYER_TOKEN_EMBD,       {LLM_TENSOR_LAYER_OUTPUT,    GGML_OP_GET_ROWS}},
+    {LLM_TENSOR_PER_LAYER_MODEL_PROJ,       {LLM_TENSOR_LAYER_OUTPUT,    GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_PER_LAYER_PROJ_NORM,        {LLM_TENSOR_LAYER_OUTPUT,    GGML_OP_MUL}},
+    {LLM_TENSOR_ALTUP_PROJ,                 {LLM_TENSOR_LAYER_OUTPUT,    GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_ALTUP_UNEMBD_PROJ,          {LLM_TENSOR_LAYER_OUTPUT,    GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_PER_LAYER_INP_GATE,         {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_PER_LAYER_PROJ,             {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_PER_LAYER_POST_NORM,        {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
+    {LLM_TENSOR_ALTUP_CORRECT_COEF,         {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_ALTUP_CORRECT_SCALE,        {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
+    {LLM_TENSOR_ALTUP_PREDICT_COEF,         {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_ALTUP_ROUTER,               {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_ALTUP_ROUTER_NORM,          {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
+    {LLM_TENSOR_LAUREL_L,                   {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_LAUREL_R,                   {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_LAUREL_POST_NORM,           {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
     // this tensor is loaded for T5, but never used
     {LLM_TENSOR_DEC_CROSS_ATTN_REL_B,       {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
     {LLM_TENSOR_CONV1D,                     {LLM_TENSOR_LAYER_INPUT,     GGML_OP_IM2COL}},
index 5b0230c15067817a4500c903259f7e97a1213db3..9181ad053f6b3e61307a096aace4990b870f1699 100644 (file)
@@ -46,6 +46,7 @@ enum llm_arch {
     LLM_ARCH_GEMMA,
     LLM_ARCH_GEMMA2,
     LLM_ARCH_GEMMA3,
+    LLM_ARCH_GEMMA3N,
     LLM_ARCH_STARCODER2,
     LLM_ARCH_MAMBA,
     LLM_ARCH_XVERSE,
@@ -269,6 +270,22 @@ enum llm_tensor {
     LLM_TENSOR_LAYER_OUT_NORM,
     LLM_TENSOR_POST_ATTN_NORM,
     LLM_TENSOR_POST_MLP_NORM,
+    LLM_TENSOR_PER_LAYER_TOKEN_EMBD, // gemma3n
+    LLM_TENSOR_PER_LAYER_MODEL_PROJ, // gemma3n
+    LLM_TENSOR_PER_LAYER_INP_GATE,   // gemma3n
+    LLM_TENSOR_PER_LAYER_PROJ,       // gemma3n
+    LLM_TENSOR_PER_LAYER_PROJ_NORM,  // gemma3n
+    LLM_TENSOR_PER_LAYER_POST_NORM,  // gemma3n
+    LLM_TENSOR_ALTUP_PROJ,           // gemma3n
+    LLM_TENSOR_ALTUP_UNEMBD_PROJ,    // gemma3n
+    LLM_TENSOR_ALTUP_CORRECT_COEF,   // gemma3n
+    LLM_TENSOR_ALTUP_CORRECT_SCALE,  // gemma3n
+    LLM_TENSOR_ALTUP_PREDICT_COEF,   // gemma3n
+    LLM_TENSOR_ALTUP_ROUTER,         // gemma3n
+    LLM_TENSOR_ALTUP_ROUTER_NORM,    // gemma3n
+    LLM_TENSOR_LAUREL_L,             // gemma3n
+    LLM_TENSOR_LAUREL_R,             // gemma3n
+    LLM_TENSOR_LAUREL_POST_NORM,     // gemma3n
     LLM_TENSOR_SSM_IN,
     LLM_TENSOR_SSM_CONV1D,
     LLM_TENSOR_SSM_X,
index 48589a50ab24d4f535b6c76add9061ad393c047d..71ee431a977ba5dba2a55791da9bba152f87f2a7 100644 (file)
@@ -350,6 +350,12 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
     }
 }
 
+void llm_graph_input_one::set_input(const llama_ubatch *) {
+    GGML_ASSERT(one && ggml_nelements(one) == 1);
+    float f_one = 1.0f;
+    ggml_backend_tensor_set(one, &f_one, 0, sizeof(float));
+}
+
 //
 // llm_graph_context
 //
@@ -1267,8 +1273,14 @@ ggml_tensor * llm_graph_context::build_attn(
     // these nodes are added to the graph together so that they are not reordered
     // by doing so, the number of splits in the graph is reduced
     ggml_build_forward_expand(gf, q_cur);
-    ggml_build_forward_expand(gf, k_cur);
-    ggml_build_forward_expand(gf, v_cur);
+
+    if (k_cur) {
+        ggml_build_forward_expand(gf, k_cur);
+    }
+
+    if (v_cur) {
+        ggml_build_forward_expand(gf, v_cur);
+    }
 
     const auto * mctx_iswa = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
 
@@ -1276,9 +1288,12 @@ ggml_tensor * llm_graph_context::build_attn(
 
     const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
 
-    // store to KV cache
-    {
+    // optionally store to KV cache
+    if (k_cur) {
         ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
+    }
+
+    if (v_cur) {
         ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
     }
 
index b433f266d1b295e94899780b0f87cf57a52781d1..4b1ec354dfc30f06d00b4dcb1d73857e8ca10cb8 100644 (file)
@@ -329,6 +329,17 @@ public:
     const llama_memory_hybrid_context * mctx;
 };
 
+// TODO: remove this when ggml_scale_add is implemented
+class llm_graph_input_one : public llm_graph_input_i {
+public:
+    llm_graph_input_one() {}
+    virtual ~llm_graph_input_one() = default;
+
+    void set_input(const llama_ubatch *) override;
+
+    ggml_tensor * one = nullptr; // F32
+};
+
 //
 // llm_graph_result
 //
@@ -589,14 +600,15 @@ struct llm_graph_context {
 
     llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
 
+    // note: if k_cur or v_cur are not provided, they will not be stored in the memory
     ggml_tensor * build_attn(
             llm_graph_input_attn_kv_unified_iswa * inp,
             ggml_cgraph * gf,
             ggml_tensor * wo,
             ggml_tensor * wo_b,
             ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
-            ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
-            ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
+            ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
+            ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
             ggml_tensor * kq_b,
             ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
                   float   kq_scale,
index 7b315a9a74b1da6669ba64eb2c6efe3ab88c8d81..e85afe145a922a57d74b887d343857c35cff5ced 100644 (file)
@@ -143,6 +143,12 @@ struct llama_hparams {
     uint32_t n_attn_temp_floor_scale = 8192;
     float    f_attn_temp_scale       = 0.1;
 
+    // gemma3n altup
+    uint32_t n_altup      = 4; // altup_num_inputs
+    uint32_t i_altup_act  = 0; // altup_active_idx
+    uint32_t laurel_rank  = 64;
+    uint32_t n_embd_altup = 256;
+
     // needed by encoder-decoder models (e.g. T5, FLAN-T5)
     // ref: https://github.com/ggerganov/llama.cpp/pull/8141
     llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
index b506d32ed4d06af3096656cb6b3368dad3f800da..8517b722a9f80715ce011c7b8b158fbbaeed1ae5 100644 (file)
@@ -33,13 +33,19 @@ llama_kv_cache_unified::llama_kv_cache_unified(
 
     GGML_ASSERT(kv_size % n_pad == 0);
 
+    // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
+    auto n_layer_cache = hparams.n_layer;
+    if (model.arch == LLM_ARCH_GEMMA3N) {
+        n_layer_cache = 20;
+    }
+
     // create a context for each buffer type
     std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
     auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
         auto it = ctx_map.find(buft);
         if (it == ctx_map.end()) {
             ggml_init_params params = {
-                /*.mem_size   =*/ size_t(2u*hparams.n_layer*ggml_tensor_overhead()),
+                /*.mem_size   =*/ size_t(2u*n_layer_cache*ggml_tensor_overhead()),
                 /*.mem_buffer =*/ NULL,
                 /*.no_alloc   =*/ true,
             };
@@ -62,7 +68,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
 
     cells.resize(kv_size);
 
-    for (uint32_t il = 0; il < hparams.n_layer; il++) {
+    for (uint32_t il = 0; il < n_layer_cache; il++) {
         if (filter && !filter(il)) {
             LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
             continue;
@@ -102,6 +108,26 @@ llama_kv_cache_unified::llama_kv_cache_unified(
         layers.push_back({ il, k, v });
     }
 
+    // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
+    if (model.arch == LLM_ARCH_GEMMA3N) {
+        LLAMA_LOG_DEBUG("%s: GEMMA3N: reuse layers [%d, %d]\n", __func__, n_layer_cache, hparams.n_layer - 1);
+
+        for (uint32_t il = n_layer_cache; il < hparams.n_layer; il++) {
+            if (filter && !filter(il)) {
+                LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
+                continue;
+            }
+
+            const bool     is_swa   = hparams.is_swa(il);
+            const uint32_t il_reuse = n_layer_cache - (is_swa ? 2 : 1);
+
+            GGML_ASSERT(map_layer_ids.find(il_reuse) != map_layer_ids.end());
+            map_layer_ids[il] = map_layer_ids[il_reuse];
+
+            LLAMA_LOG_DEBUG("%s: layer %3d: reuse layer %d, isw = %d\n", __func__, il, il_reuse, is_swa);
+        }
+    }
+
     // allocate tensors and initialize the buffers to avoid NaNs in the padding
     for (auto it : ctx_map) {
         auto * buft = it.first;
index c2835ce67a88d6fc5a2577495a0a4f20714ddf1e..fc39195ed5177a14e7d7a67fedaa1f0c110e712e 100644 (file)
@@ -103,6 +103,8 @@ const char * llm_type_name(llm_type type) {
         case LLM_TYPE_17B_128E:      return "17Bx128E (Maverick)";
         case LLM_TYPE_30B_A3B:       return "30B.A3B";
         case LLM_TYPE_235B_A22B:     return "235B.A22B";
+        case LLM_TYPE_E2B:           return "E2B";
+        case LLM_TYPE_E4B:           return "E4B";
         default:                     return "?B";
     }
 }
@@ -1017,6 +1019,24 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
                     : 1.0f / std::sqrt(float(hparams.n_embd_head_k));
             } break;
+        case LLM_ARCH_GEMMA3N:
+            {
+                hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
+                hparams.set_swa_pattern(5);
+
+                hparams.rope_freq_base_train_swa  = 10000.0f;
+                hparams.rope_freq_scale_train_swa = 1.0f;
+                hparams.f_attention_scale         = 1.0f;
+
+                ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW,    hparams.n_swa);
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+
+                switch (hparams.n_layer) {
+                    case 30: type = LLM_TYPE_E2B; break;
+                    case 35: type = LLM_TYPE_E4B; break;
+                    default: type = LLM_TYPE_UNKNOWN;
+                }
+            } break;
         case LLM_ARCH_STARCODER2:
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -2950,6 +2970,62 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
                     }
                 } break;
+            case LLM_ARCH_GEMMA3N:
+                {
+                    const int64_t n_altup      = hparams.n_altup;
+                    const int64_t laurel_rank  = hparams.laurel_rank;
+                    const int64_t n_embd_altup = hparams.n_embd_altup;
+
+                    output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+                    // if output is NULL, init from the input tok embed
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
+
+                    tok_embd           = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,           "weight"), {n_embd, n_vocab}, 0);
+                    tok_embd_per_layer = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_altup * n_layer, n_vocab}, 0);
+
+                    altup_proj           = create_tensor(tn(LLM_TENSOR_ALTUP_PROJ,           "weight"), {n_embd, n_embd, n_altup - 1}, 0);
+                    altup_unembd_proj    = create_tensor(tn(LLM_TENSOR_ALTUP_UNEMBD_PROJ,    "weight"), {n_embd, n_embd, n_altup - 1}, 0);
+                    per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight"), {n_embd, n_embd_altup * n_layer}, 0);
+                    per_layer_proj_norm  = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM,  "weight"), {n_embd_altup}, 0);
+
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
+
+                        layer.attn_q_norm    = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM,    "weight", i), {n_embd_head_k}, 0);
+                        layer.attn_k_norm    = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM,    "weight", i), {n_embd_head_k}, 0);
+                        layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
+
+                        // altup & laurel
+                        layer.per_layer_inp_gate   = create_tensor(tn(LLM_TENSOR_PER_LAYER_INP_GATE,  "weight", i), {n_embd, n_embd_altup}, 0);
+                        layer.per_layer_proj       = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ,      "weight", i), {n_embd_altup, n_embd}, 0);
+                        layer.per_layer_post_norm  = create_tensor(tn(LLM_TENSOR_PER_LAYER_POST_NORM, "weight", i), {n_embd}, 0);
+                        layer.altup_correct_coef   = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_COEF,  "weight", i), {n_altup, n_altup}, 0);
+                        layer.altup_correct_scale  = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_SCALE, "weight", i), {n_embd}, 0);
+                        layer.altup_predict_coef   = create_tensor(tn(LLM_TENSOR_ALTUP_PREDICT_COEF,  "weight", i), {n_altup, n_altup * n_altup}, 0);
+                        layer.altup_router         = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER,        "weight", i), {n_embd, n_altup}, 0);
+                        layer.altup_router_norm    = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER_NORM,   "weight", i), {n_embd}, 0);
+                        layer.laurel_l             = create_tensor(tn(LLM_TENSOR_LAUREL_L,            "weight", i), {n_embd, laurel_rank}, 0);
+                        layer.laurel_r             = create_tensor(tn(LLM_TENSOR_LAUREL_R,            "weight", i), {laurel_rank, n_embd}, 0);
+                        layer.laurel_post_norm     = create_tensor(tn(LLM_TENSOR_LAUREL_POST_NORM,    "weight", i), {n_embd}, 0);
+                    }
+                } break;
             case LLM_ARCH_STARCODER2:
                 {
                     tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -8980,6 +9056,442 @@ struct llm_build_gemma3_iswa : public llm_graph_context {
     }
 };
 
+struct llm_build_gemma3n_iswa : public llm_graph_context {
+    const llama_model & model;
+    ggml_cgraph * gf;
+
+    const int64_t n_embd_head;
+    const int64_t n_embd_altup;
+    const int64_t n_altup;
+    const int     i_altup_act;
+    const int     n_layer_kv = 20; // number of layers having KV [KV_REUSE]
+    const int     n_layer_sparsity = 10; // number of layers using activation sparsity
+    const float   f_sparsity_std_mul = 1.6448533535003662f; // std_multiplier = normal_dist.icdf(0.95)
+
+    ggml_tensor * one; // containing single element 1.0f
+
+    llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf)
+            : llm_graph_context(params),
+              model(model),
+              gf(gf),
+              n_embd_head(model.hparams.n_embd_head_k),
+              n_embd_altup(model.hparams.n_embd_altup),
+              n_altup(model.hparams.n_altup),
+              i_altup_act(model.hparams.i_altup_act) {
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        // TODO: remove this when ggml_scale_add is implemented
+        one = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+        {
+            auto inp = std::make_unique<llm_graph_input_one>();
+            inp->one = one;
+            res->add_input(std::move(inp));
+        }
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
+        if (ubatch.token) {
+            inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
+            cb(inpL, "inp_scaled", -1);
+        }
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        // TODO: is causal == true correct? might need some changes
+        auto * inp_attn = build_attn_inp_kv_unified_iswa();
+
+        // inp_per_layer shape: [n_embd_altup, n_tokens, n_layer]
+        ggml_tensor * inp_per_layer = project_per_layer_inputs(inpL, get_per_layer_inputs());
+
+        // inpL now has only 1 altup, project it to the rest of the altups
+        // these "added" altups will be concat to the last dim of inpL
+        {
+            ggml_tensor * target_magnitude = calc_magnitude(inpL);
+            ggml_tensor * inp_repeated = ggml_repeat_4d(ctx0, inpL, n_embd, n_tokens, n_altup - 1, 1);
+            ggml_tensor * altup_added = ggml_mul_mat(ctx0, model.altup_proj, inp_repeated); // shape: [n_embd, n_tokens, n_altup - 1]
+            ggml_tensor * new_magnitude = calc_magnitude(altup_added);
+            altup_added = ggml_div(ctx0,
+                                ggml_mul(ctx0, altup_added, target_magnitude),
+                                new_magnitude);
+            inpL = ggml_concat(ctx0, inpL, altup_added, 2); // shape: [n_embd, n_tokens, n_altup]
+            cb(inpL, "inp_stacked", -1);
+        }
+
+        // inpL now has shape:          [n_embd,       n_tokens, n_altup]
+        // inp_per_layer now has shape: [n_embd_altup, n_tokens, n_layer]
+
+        for (int il = 0; il < n_layer; ++il) {
+            // this block is made to be closely resemble Gemma3p5DecoderLayer on python code
+            const bool has_kv = (il < n_layer_kv);
+
+            const float freq_base_l  = model.get_rope_freq_base (cparams, il);
+            const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
+
+            ggml_tensor * cur = inpL; // [n_embd, n_tokens, n_altup]
+            ggml_tensor * predictions = altup_predict(cur, il); // [n_embd, n_tokens, n_altup]
+
+            // predicted value will go through self-attention and laurel
+            ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act); // [n_embd, n_tokens]
+            cur = active_prediction;
+            cb(cur, "active_prediction", il);
+
+            // norm
+            cur = build_norm(cur, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // laurel
+            ggml_tensor * laurel_out = laurel(cur, il); // [n_embd, n_tokens]
+
+            // self-attention
+            if (has_kv) {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+                Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
+                Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
+                Vcur = ggml_rms_norm(ctx0, Vcur, hparams.f_norm_rms_eps);
+
+                cb(Qcur, "Qcur_normed", il);
+                cb(Kcur, "Kcur_normed", il);
+                cb(Vcur, "Vcur_normed", il);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, Qcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+                        ext_factor, attn_factor, beta_fast, beta_slow);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, Kcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+                        ext_factor, attn_factor, beta_fast, beta_slow);
+
+                cb(Qcur, "Qcur_pos", il);
+                cb(Kcur, "Kcur_pos", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, nullptr, nullptr, hparams.f_attention_scale, il);
+            } else {
+                // no KV layers
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+
+                Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
+                cb(Qcur, "Qcur_normed", il);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, Qcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+                        ext_factor, attn_factor, beta_fast, beta_slow);
+                cb(Qcur, "Qcur_pos", il);
+
+                cur = build_attn(inp_attn, gf,
+                    model.layers[il].wo, NULL,
+                    Qcur, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il);
+            }
+
+            cur = build_norm(cur,
+                    model.layers[il].attn_post_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_post_norm", il);
+
+            cur = ggml_add(ctx0, cur, active_prediction); // [n_embd, n_tokens]
+            cb(cur, "attn_gated", il);
+
+            ggml_tensor * attn_laurel = ggml_scale(ctx0,
+                                            ggml_add(ctx0, cur, laurel_out),
+                                            1.0f / sqrtf(2.0f)); // [n_embd, n_tokens]
+            cb(attn_laurel, "attn_laurel", il);
+
+            cur = build_norm(attn_laurel,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            // feed-forward network
+            {
+                ggml_tensor * up_proj   = build_lora_mm(model.layers[il].ffn_up,   cur);
+                ggml_tensor * gate_proj = build_lora_mm(model.layers[il].ffn_gate, cur);
+
+                if (il < n_layer_sparsity) {
+                    // apply activation sparsity
+                    gate_proj = gaussian_topk(gate_proj);
+                }
+                gate_proj = ggml_gelu(ctx0, gate_proj);
+
+                cur = ggml_mul(ctx0, up_proj, gate_proj);
+                cur = build_lora_mm(model.layers[il].ffn_down, cur);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = build_norm(cur,
+                    model.layers[il].ffn_post_norm, NULL,
+                    LLM_NORM_RMS, -1);
+            cb(cur, "ffn_post_norm", il);
+
+            ggml_tensor * attn_ffw_laurel_gated = ggml_add(ctx0, cur, attn_laurel); // [n_embd, n_tokens]
+            cb(attn_ffw_laurel_gated, "attn_ffw_laurel_gated", il);
+
+            ggml_tensor * corrected = altup_correct(predictions, attn_ffw_laurel_gated, il); // [n_embd, n_tokens, n_altup]
+
+            ggml_tensor * first_prediction; // [n_embd, n_tokens]
+            {
+                first_prediction = view_2d_slice(corrected, i_altup_act); // [n_embd, n_tokens]
+                first_prediction = ggml_mul(ctx0, first_prediction, model.layers[il].altup_correct_scale);
+                first_prediction = build_lora_mm(model.layers[il].per_layer_inp_gate, first_prediction);
+                first_prediction = ggml_gelu(ctx0, first_prediction); // [n_embd_altup, n_tokens]
+                cb(first_prediction, "first_prediction_gated", il);
+                ggml_tensor * inp_this_layer = view_2d_slice(inp_per_layer, il); // [n_embd_altup, n_tokens]
+                first_prediction = ggml_mul(ctx0, first_prediction, inp_this_layer); // [n_embd_altup, n_tokens]
+                cb(first_prediction, "first_prediction_scaled", il);
+
+                first_prediction = build_lora_mm(model.layers[il].per_layer_proj, first_prediction); // [n_embd, n_tokens]
+                first_prediction = build_norm(first_prediction,
+                        model.layers[il].per_layer_post_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(first_prediction, "first_prediction_out", il);
+            }
+
+            // equivalent to python code: corrected_predictions[1:] += first_prediction
+            {
+                ggml_tensor * slice_first = view_2d_slice(corrected, 0);
+                ggml_tensor * slice_rest  = ggml_view_3d(ctx0, corrected, n_embd, n_tokens, n_altup - 1,
+                                                    ggml_row_size(corrected->type, n_embd),
+                                                    ggml_row_size(corrected->type, n_embd*n_tokens),
+                                                    n_embd*n_tokens*ggml_element_size(corrected));
+                ggml_tensor * tmp = ggml_add(ctx0, slice_rest, first_prediction); // [n_embd, n_tokens, n_altup - 1]
+                corrected = ggml_concat(ctx0, slice_first, tmp, 2); // [n_embd, n_tokens, n_altup]
+            }
+
+            cur = corrected; // [n_embd, n_tokens, n_altup]
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL; // [n_embd, n_tokens, n_altup]
+
+        // cur now has multiple altup(s), we want to merge them back to 1 altup
+        {
+            ggml_tensor * target_magnitude = calc_magnitude(view_2d_slice(cur, i_altup_act)); // [n_embd, n_tokens]
+            // do a view to skip the first slice (active altup)
+            ggml_tensor * alt_slice = ggml_view_3d(ctx0, cur, n_embd, n_tokens, n_altup - 1,
+                                                    ggml_row_size(cur->type, n_embd),
+                                                    ggml_row_size(cur->type, n_embd*n_tokens),
+                                                    n_embd*n_tokens*ggml_element_size(cur));
+            ggml_tensor * altup_unembd = ggml_mul_mat(ctx0, model.altup_unembd_proj, alt_slice); // shape: [n_embd, n_tokens, n_altup - 1]
+            ggml_tensor * new_magnitude = calc_magnitude(altup_unembd);
+            altup_unembd = ggml_div(ctx0,
+                                ggml_mul(ctx0, altup_unembd, target_magnitude),
+                                new_magnitude);
+            cb(altup_unembd, "altup_unembd", -1);
+
+            // equivalent to torch.mean(hidden_states, dim=0)
+            cur = view_2d_slice(cur, 0); // [n_embd, n_tokens]
+            for (int i = 0; i < n_altup - 1; ++i) {
+                cur = ggml_add(ctx0, cur, view_2d_slice(altup_unembd, i));
+            }
+            cur = ggml_scale(ctx0, cur, 1.0f / float(n_altup)); // [n_embd, n_tokens]
+            cb(cur, "unembd_merged", -1);
+        }
+
+        // cur now has shape: [n_embd, n_tokens]
+
+        // TODO: move this to right after the last KV layer
+        {
+            // skip computing output for unused tokens
+            ggml_tensor * inp_out_ids = build_inp_out_ids();
+            cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+        }
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        cur = build_lora_mm(model.output, cur);
+
+        {
+            // final logit soft-capping
+            cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
+            cur = ggml_tanh(ctx0, cur);
+            cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
+        }
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+
+    ggml_tensor * calc_magnitude(ggml_tensor * x) {
+        return ggml_sqrt(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, x)));
+    }
+
+    // get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim
+    ggml_tensor * view_2d_slice(ggml_tensor * x, int idx) {
+        GGML_ASSERT(idx < (int)x->ne[2]);
+        return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1],
+                            ggml_row_size(x->type, x->ne[0]),
+                            idx * x->ne[0] * x->ne[1] * ggml_element_size(x));
+    }
+
+    // equivalent to get_per_layer_inputs() in python code
+    // output shape: [n_embd_altup, n_layer, n_tokens]
+    ggml_tensor * get_per_layer_inputs() {
+        auto inp = std::make_unique<llm_graph_input_embd>();
+        ggml_tensor * inp_per_layer;
+        if (ubatch.token) {
+            inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
+            ggml_set_input(inp->tokens);
+            res->t_tokens = inp->tokens;
+            inp_per_layer = ggml_get_rows(ctx0, model.tok_embd_per_layer, inp->tokens);
+            inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens);
+            inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float)n_embd_altup));
+            cb(inp_per_layer, "inp_per_layer_selected", -1);
+        } else {
+            GGML_ABORT("TODO: support embd input");
+        }
+        res->add_input(std::move(inp));
+        return inp_per_layer;
+    }
+
+    // equivalent to project_per_layer_inputs() in python code
+    // this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim
+    // output shape: [n_embd_altup, n_tokens, n_layer]
+    ggml_tensor * project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer) {
+        const float per_layer_projection_scale = 1.0f / sqrtf((float)n_embd);
+        const float per_layer_input_scale      = 1.0f / sqrtf(2.0f);
+
+        ggml_tensor * per_layer_proj = ggml_mul_mat(ctx0, model.per_layer_model_proj, inputs_embeds);
+        per_layer_proj = ggml_scale(ctx0, per_layer_proj, per_layer_projection_scale);
+        per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_altup, n_layer, n_tokens);
+        per_layer_proj = build_norm(per_layer_proj,
+                                    model.per_layer_proj_norm, NULL,
+                                    LLM_NORM_RMS, -1); // [n_embd_altup, n_layer, n_tokens]
+        cb(per_layer_proj, "per_layer_proj", -1);
+
+        inp_per_layer = ggml_add(ctx0, inp_per_layer, per_layer_proj);
+        inp_per_layer = ggml_scale(ctx0, inp_per_layer, per_layer_input_scale);
+        cb(inp_per_layer, "inp_per_layer", -1);
+
+        // permute to shape: [n_embd_altup, n_tokens, n_layer]
+        inp_per_layer = ggml_cont(ctx0, ggml_permute(ctx0, inp_per_layer, 0, 2, 1, 3));
+        return inp_per_layer;
+    }
+
+    // input cur shape: [n_altup, n_tokens]
+    // output    shape: [n_altup, n_tokens]
+    ggml_tensor * laurel(ggml_tensor * cur, int il) {
+        ggml_tensor * tmp = cur;
+        tmp = build_lora_mm(model.layers[il].laurel_l, tmp);
+        tmp = build_lora_mm(model.layers[il].laurel_r, tmp);
+        tmp = build_norm(tmp, model.layers[il].laurel_post_norm, NULL, LLM_NORM_RMS, il);
+        tmp = ggml_add(ctx0, tmp, cur);
+        cb(tmp, "laurel_out", il);
+        return tmp;
+    }
+
+    // input x shape: [n_embd, n_tokens]
+    // output  shape: [n_embd, n_tokens]
+    ggml_tensor * gaussian_topk(ggml_tensor * x) {
+        ggml_tensor * mean = ggml_mean(ctx0, x);
+        ggml_tensor * std  = ggml_sqrt(ctx0, ggml_scale(ctx0,
+            ggml_sum_rows(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x, mean))),
+            1.0f / (float)(x->ne[0] - 1)
+        ));
+        ggml_tensor * cutoff_x = ggml_add(ctx0, mean, ggml_scale(ctx0, std, f_sparsity_std_mul));
+        return ggml_relu(ctx0, ggml_sub(ctx0, x, cutoff_x));
+    }
+
+    //
+    // altup functions
+    //
+
+    // equivalent to compute_router_modalities() in python code
+    // input x shape: [n_embd,  n_tokens]
+    // output  shape: [n_altup, n_tokens]
+    ggml_tensor * altup_compute_router_modalities(ggml_tensor * x, int il) {
+        ggml_tensor * router_inputs = build_norm(x,
+            model.layers[il].altup_router_norm, NULL,
+            LLM_NORM_RMS, il);
+
+        // router_input_scale
+        router_inputs = ggml_scale(ctx0, router_inputs, 1.0f / (float)n_embd);
+
+        ggml_tensor * output = ggml_mul_mat(ctx0, model.layers[il].altup_router, router_inputs);
+        return ggml_tanh(ctx0, output); // [n_altup, n_tokens]
+    }
+
+    // input cur shape: [n_embd, n_tokens, n_altup]
+    // output    shape: [n_embd, n_tokens, n_altup]
+    ggml_tensor * altup_predict(ggml_tensor * cur, int il) {
+        ggml_tensor * activated = view_2d_slice(cur, i_altup_act); // [n_embd, n_tokens]
+        ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens]
+        cb(modalities, "modalities", il);
+
+        ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_predict_coef, modalities);
+        cb(all_coefs, "all_coefs", il);
+        // first dim now having n_altup^2 elements, we reshape it to 2D (so we end up with 3D tensor)
+        all_coefs = ggml_reshape_3d(ctx0, all_coefs, n_altup, n_altup, n_tokens);
+
+        // permute to [n_altup, n_embd, n_tokens]
+        ggml_tensor * cur_permuted = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3));
+        ggml_tensor * predictions = ggml_mul_mat(ctx0, cur_permuted, all_coefs); // [n_altup, n_embd, n_tokens]
+
+        // final shape must be the same as cur: [n_embd, n_tokens, n_altup]
+        predictions = ggml_cont(ctx0, ggml_permute(ctx0, predictions, 0, 2, 1, 3));
+        predictions = ggml_add(ctx0, predictions, cur);
+        cb(predictions, "predictions", il);
+
+        return predictions;
+    }
+
+    // input predictions       shape: [n_embd, n_tokens, n_altup]
+    // input activated         shape: [n_embd, n_tokens]
+    // output                  shape: [n_embd, n_tokens, n_altup]
+    ggml_tensor * altup_correct(ggml_tensor * predictions, ggml_tensor * activated, int il) {
+        ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens]
+        cb(modalities, "modalities", il);
+
+        ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act);
+        ggml_tensor * innovation = ggml_sub(ctx0, activated, active_prediction); // [n_embd, n_tokens]
+        cb(innovation, "innovation", il);
+
+        ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_correct_coef, modalities); // [n_altup, n_tokens]
+        all_coefs = ggml_add(ctx0, all_coefs, one);
+        cb(all_coefs, "all_coefs", il);
+        all_coefs = ggml_cont(ctx0, ggml_transpose(ctx0, all_coefs)); // [n_tokens, n_altup]
+        all_coefs = ggml_reshape_3d(ctx0, all_coefs, 1, n_tokens, n_altup); // [1, n_tokens, n_altup]
+
+        innovation = ggml_repeat_4d(ctx0, innovation, n_embd, n_tokens, n_altup, 1);
+        ggml_tensor * corrected = ggml_mul(ctx0, innovation, all_coefs); // [n_embd, n_tokens, n_altup]
+        corrected = ggml_add(ctx0, corrected, predictions); // [n_embd, n_tokens, n_altup]
+        cb(corrected, "corrected", il);
+
+        return corrected;
+    }
+};
+
 // TODO: move up next to build_starcoder
 struct llm_build_starcoder2 : public llm_graph_context {
     llm_build_starcoder2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
@@ -13974,6 +14486,10 @@ llm_graph_result_ptr llama_model::build_graph(
             {
                 llm = std::make_unique<llm_build_gemma3_iswa>(*this, params, gf);
             } break;
+        case LLM_ARCH_GEMMA3N:
+            {
+                llm = std::make_unique<llm_build_gemma3n_iswa>(*this, params, gf);
+            } break;
         case LLM_ARCH_STARCODER2:
             {
                 llm = std::make_unique<llm_build_starcoder2>(*this, params, gf);
@@ -14295,6 +14811,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_GEMMA:
         case LLM_ARCH_GEMMA2:
         case LLM_ARCH_GEMMA3:
+        case LLM_ARCH_GEMMA3N:
         case LLM_ARCH_STARCODER2:
         case LLM_ARCH_OPENELM:
         case LLM_ARCH_GPTNEOX:
index 06e6c687943cc23e615bd1f49f773347d2b4247b..40063b790d434d19425a207ede1b9bde129bd214 100644 (file)
@@ -95,6 +95,8 @@ enum llm_type {
     LLM_TYPE_17B_128E, // llama4 Maverick
     LLM_TYPE_30B_A3B,
     LLM_TYPE_235B_A22B,
+    LLM_TYPE_E2B,
+    LLM_TYPE_E4B,
 };
 
 std::string llama_rope_scaling_type_name(llama_rope_scaling_type rope_scaling_type);
@@ -316,6 +318,19 @@ struct llama_layer {
     struct ggml_tensor * ffn_up_scale   = nullptr;
     struct ggml_tensor * ffn_down_scale = nullptr;
 
+    // altup & laurel
+    struct ggml_tensor * per_layer_inp_gate   = nullptr;
+    struct ggml_tensor * per_layer_proj       = nullptr;
+    struct ggml_tensor * per_layer_post_norm  = nullptr;
+    struct ggml_tensor * altup_correct_coef   = nullptr;
+    struct ggml_tensor * altup_correct_scale  = nullptr;
+    struct ggml_tensor * altup_predict_coef   = nullptr;
+    struct ggml_tensor * altup_router         = nullptr;
+    struct ggml_tensor * altup_router_norm    = nullptr;
+    struct ggml_tensor * laurel_l             = nullptr;
+    struct ggml_tensor * laurel_r             = nullptr;
+    struct ggml_tensor * laurel_post_norm     = nullptr;
+
     struct llama_layer_posnet posnet;
 
     struct llama_layer_convnext convnext;
@@ -354,6 +369,13 @@ struct llama_model {
     struct ggml_tensor * conv1d   = nullptr;
     struct ggml_tensor * conv1d_b = nullptr;
 
+    // gemma3n altup
+    struct ggml_tensor * tok_embd_per_layer   = nullptr;
+    struct ggml_tensor * altup_proj           = nullptr;
+    struct ggml_tensor * altup_unembd_proj    = nullptr;
+    struct ggml_tensor * per_layer_model_proj = nullptr;
+    struct ggml_tensor * per_layer_proj_norm  = nullptr;
+
     std::vector<llama_layer> layers;
 
     llama_model_params params;
index 43229e1938597a359f9f27ecccdc8a890982338d..f4b5713d7dd9aefe1f26759bd15d33a6e9487fe8 100644 (file)
@@ -223,7 +223,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
                 new_type = GGML_TYPE_Q6_K;
             }
         }
-    } else if (name == "token_embd.weight") {
+    } else if (name == "token_embd.weight" || name == "per_layer_token_embd.weight") {
         if (qs.params->token_embedding_type < GGML_TYPE_COUNT) {
             new_type = qs.params->token_embedding_type;
         } else {
@@ -830,6 +830,13 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
         // NOTE: can't use LLM_TN here because the layer number is not known
         quantize &= name.find("ffn_gate_inp.weight") == std::string::npos;
 
+        // these are very small (e.g. 4x4)
+        quantize &= name.find("altup")  == std::string::npos;
+        quantize &= name.find("laurel") == std::string::npos;
+
+        // these are not too big so keep them as it is
+        quantize &= name.find("per_layer_model_proj") == std::string::npos;
+
         // do not quantize positional embeddings and token types (BERT)
         quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_POS_EMBD,    "weight");
         quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_TOKEN_TYPES, "weight");