gguf.MODEL_TENSOR.POSNET_NORM2,
gguf.MODEL_TENSOR.V_ENC_EMBD_POS,
gguf.MODEL_TENSOR.A_ENC_EMBD_POS,
+ gguf.MODEL_TENSOR.ALTUP_CORRECT_COEF,
+ gguf.MODEL_TENSOR.ALTUP_PREDICT_COEF,
)
)
or not new_name.endswith(".weight")
self.match_model_tensor_name(new_name, key, bid)
for key in (
gguf.MODEL_TENSOR.TOKEN_EMBD,
+ gguf.MODEL_TENSOR.PER_LAYER_TOKEN_EMBD,
gguf.MODEL_TENSOR.OUTPUT,
+ gguf.MODEL_TENSOR.ALTUP_ROUTER,
+ gguf.MODEL_TENSOR.LAUREL_L,
+ gguf.MODEL_TENSOR.LAUREL_R,
)
):
if self.ftype in (
tokenizer = SentencePieceProcessor()
tokenizer.LoadFromFile(str(tokenizer_path))
- vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
+ vocab_size = self.find_hparam([
+ "vocab_size_per_layer_input", # gemma3n
+ "vocab_size",
+ ], optional=True) or tokenizer.vocab_size()
tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
scores: list[float] = [-10000.0] * vocab_size
toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size
- for token_id in range(tokenizer.vocab_size()):
+ for token_id in range(vocab_size):
piece = tokenizer.IdToPiece(token_id)
text = piece.encode("utf-8")
score = tokenizer.GetScore(token_id)
elif tokenizer.IsByte(token_id):
toktype = SentencePieceTokenTypes.BYTE
+ if token_id >= vocab_size:
+ logger.warning(f'ignore tokens from {token_id}: id is out of range, max={vocab_size - 1}')
+ break
+
tokens[token_id] = text
scores[token_id] = score
toktypes[token_id] = toktype
@ModelBase.register("Gemma3ForCausalLM", "Gemma3ForConditionalGeneration")
class Gemma3Model(TextModel):
model_arch = gguf.MODEL_ARCH.GEMMA3
+ norm_shift = 1.0 # Gemma3RMSNorm adds 1.0 to the norm value
def set_vocab(self):
self._set_vocab_sentencepiece()
self.gguf_writer.add_value_length(hparams.get("head_dim", 256))
self.gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1_000_000.0)) # for global layers
- # both attn_logit_softcapping and final_logit_softcapping are removed in Gemma3
+ # attn_logit_softcapping is removed in Gemma3
assert hparams.get("attn_logit_softcapping") is None
- assert hparams.get("final_logit_softcapping") is None
self.gguf_writer.add_sliding_window(hparams["sliding_window"])
self.gguf_writer.add_head_count_kv(hparams.get("num_key_value_heads", 4))
if hparams.get("rope_scaling") is not None:
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
- if name.startswith("language_model."):
+ if "language_model." in name:
name = name.replace("language_model.", "")
elif name.startswith("multi_modal_projector.") or name.startswith("vision_tower.") \
# ref code in Gemma3RMSNorm
# output = output * (1.0 + self.weight.float())
+ # note: this is not the case on gemma3n
if name.endswith("norm.weight"):
- data_torch = data_torch + 1
+ data_torch = data_torch + self.norm_shift
return [(self.map_tensor_name(name), data_torch)]
return [] # skip other tensors
+@ModelBase.register("Gemma3nForConditionalGeneration")
+class Gemma3NModel(Gemma3Model):
+ model_arch = gguf.MODEL_ARCH.GEMMA3N
+ norm_shift = 0.0 # same value with Gemma3p5RMSNorm scale_shift on python code
+
+ _altup_proj: list[Tensor] = []
+ _altup_unembd: list[Tensor] = []
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ assert self.hparams["altup_num_inputs"] == 4, "Current conversion only supports 4 altup inputs"
+ self._altup_proj = [
+ torch.Tensor(), # to be replaced
+ torch.Tensor(), # to be replaced
+ torch.Tensor(), # to be replaced
+ ]
+ self._altup_unembd = [
+ torch.Tensor(), # to be replaced
+ torch.Tensor(), # to be replaced
+ torch.Tensor(), # to be replaced
+ ]
+
+ def set_vocab(self):
+ with open(self.dir_model / "chat_template.jinja") as f:
+ # quick hack to make sure chat template is added
+ self.gguf_writer.add_chat_template(f.read())
+ super().set_vocab()
+
+ def set_gguf_parameters(self):
+ super().set_gguf_parameters()
+ self.gguf_writer.add_altup_active_idx(self.hparams["altup_active_idx"])
+ self.gguf_writer.add_altup_num_inputs(self.hparams["altup_num_inputs"])
+ self.gguf_writer.add_embedding_length_per_layer_input(self.hparams["hidden_size_per_layer_input"])
+ self.gguf_writer.add_shared_kv_layers(self.hparams["num_kv_shared_layers"])
+
+ activation_sparsity_scale = []
+ for s in self.hparams["activation_sparsity_pattern"]:
+ normal_dist = torch.distributions.normal.Normal(0, 1)
+ std_multiplier = normal_dist.icdf(torch.tensor(s, dtype=torch.float32))
+ activation_sparsity_scale.append(std_multiplier.item())
+ self.gguf_writer.add_activation_sparsity_scale(activation_sparsity_scale)
+
+ sliding_window_pattern = []
+ for t in self.hparams["layer_types"]:
+ sliding_window_pattern.append(t == "sliding_attention")
+ self.gguf_writer.add_sliding_window_pattern(sliding_window_pattern)
+
+ def _stack_matrices(self, matrices: list[Tensor]) -> Tensor | None:
+ has_all = all(m.numel() > 0 for m in matrices)
+ if not has_all:
+ return None
+ else:
+ return torch.stack(matrices, dim=0)
+
+ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+ if name.endswith("_scale"):
+ name = name + ".weight"
+
+ # TODO: implement self.prediction_coefs.weight.clamp_(...)
+
+ if "language_model." not in name:
+ return [] # skip non-language model tensors
+
+ if "altup_unembed_projections" in name:
+ data_torch = data_torch.to(device="cpu")
+ if ".0." in name:
+ self._altup_unembd[0] = data_torch
+ elif ".1." in name:
+ self._altup_unembd[1] = data_torch
+ elif ".2." in name:
+ self._altup_unembd[2] = data_torch
+ else:
+ raise ValueError(f"Unknown name: {name}")
+ out = self._stack_matrices(self._altup_unembd)
+ if out is not None:
+ return [(self.map_tensor_name("model.altup_unembed_projections.weight"), out)]
+ else:
+ return []
+
+ if "altup_projections" in name:
+ data_torch = data_torch.to(device="cpu")
+ if ".0." in name:
+ self._altup_proj[0] = data_torch
+ elif ".1." in name:
+ self._altup_proj[1] = data_torch
+ elif ".2." in name:
+ self._altup_proj[2] = data_torch
+ else:
+ raise ValueError(f"Unknown name: {name}")
+ out = self._stack_matrices(self._altup_proj)
+ if out is not None:
+ return [(self.map_tensor_name("model.altup_projections.weight"), out)]
+ else:
+ return []
+
+ return super().modify_tensors(data_torch, name, bid)
+
+
@ModelBase.register("Starcoder2ForCausalLM")
class StarCoder2Model(TextModel):
model_arch = gguf.MODEL_ARCH.STARCODER2
case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)";
case LLM_TYPE_30B_A3B: return "30B.A3B";
case LLM_TYPE_235B_A22B: return "235B.A22B";
+ case LLM_TYPE_E2B: return "E2B";
+ case LLM_TYPE_E4B: return "E4B";
default: return "?B";
}
}
? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
: 1.0f / std::sqrt(float(hparams.n_embd_head_k));
} break;
+ case LLM_ARCH_GEMMA3N:
+ {
+ hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
+ hparams.set_swa_pattern(5);
+
+ hparams.rope_freq_base_train_swa = 10000.0f;
+ hparams.rope_freq_scale_train_swa = 1.0f;
+ hparams.f_attention_scale = 1.0f;
+
+ ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+
+ switch (hparams.n_layer) {
+ case 30: type = LLM_TYPE_E2B; break;
+ case 35: type = LLM_TYPE_E4B; break;
+ default: type = LLM_TYPE_UNKNOWN;
+ }
+ } break;
case LLM_ARCH_STARCODER2:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
}
} break;
+ case LLM_ARCH_GEMMA3N:
+ {
+ const int64_t n_altup = hparams.n_altup;
+ const int64_t laurel_rank = hparams.laurel_rank;
+ const int64_t n_embd_altup = hparams.n_embd_altup;
+
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+ // if output is NULL, init from the input tok embed
+ if (output == NULL) {
+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+ }
+
+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+ tok_embd_per_layer = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_altup * n_layer, n_vocab}, 0);
+
+ altup_proj = create_tensor(tn(LLM_TENSOR_ALTUP_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0);
+ altup_unembd_proj = create_tensor(tn(LLM_TENSOR_ALTUP_UNEMBD_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0);
+ per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight"), {n_embd, n_embd_altup * n_layer}, 0);
+ per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight"), {n_embd_altup}, 0);
+
+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+
+ for (int i = 0; i < n_layer; ++i) {
+ auto & layer = layers[i];
+
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
+
+ layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
+ layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
+ layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
+
+ layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+ layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
+ layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
+
+ // altup & laurel
+ layer.per_layer_inp_gate = create_tensor(tn(LLM_TENSOR_PER_LAYER_INP_GATE, "weight", i), {n_embd, n_embd_altup}, 0);
+ layer.per_layer_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ, "weight", i), {n_embd_altup, n_embd}, 0);
+ layer.per_layer_post_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_POST_NORM, "weight", i), {n_embd}, 0);
+ layer.altup_correct_coef = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_COEF, "weight", i), {n_altup, n_altup}, 0);
+ layer.altup_correct_scale = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_SCALE, "weight", i), {n_embd}, 0);
+ layer.altup_predict_coef = create_tensor(tn(LLM_TENSOR_ALTUP_PREDICT_COEF, "weight", i), {n_altup, n_altup * n_altup}, 0);
+ layer.altup_router = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER, "weight", i), {n_embd, n_altup}, 0);
+ layer.altup_router_norm = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER_NORM, "weight", i), {n_embd}, 0);
+ layer.laurel_l = create_tensor(tn(LLM_TENSOR_LAUREL_L, "weight", i), {n_embd, laurel_rank}, 0);
+ layer.laurel_r = create_tensor(tn(LLM_TENSOR_LAUREL_R, "weight", i), {laurel_rank, n_embd}, 0);
+ layer.laurel_post_norm = create_tensor(tn(LLM_TENSOR_LAUREL_POST_NORM, "weight", i), {n_embd}, 0);
+ }
+ } break;
case LLM_ARCH_STARCODER2:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
}
};
+struct llm_build_gemma3n_iswa : public llm_graph_context {
+ const llama_model & model;
+ ggml_cgraph * gf;
+
+ const int64_t n_embd_head;
+ const int64_t n_embd_altup;
+ const int64_t n_altup;
+ const int i_altup_act;
+ const int n_layer_kv = 20; // number of layers having KV [KV_REUSE]
+ const int n_layer_sparsity = 10; // number of layers using activation sparsity
+ const float f_sparsity_std_mul = 1.6448533535003662f; // std_multiplier = normal_dist.icdf(0.95)
+
+ ggml_tensor * one; // containing single element 1.0f
+
+ llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf)
+ : llm_graph_context(params),
+ model(model),
+ gf(gf),
+ n_embd_head(model.hparams.n_embd_head_k),
+ n_embd_altup(model.hparams.n_embd_altup),
+ n_altup(model.hparams.n_altup),
+ i_altup_act(model.hparams.i_altup_act) {
+ ggml_tensor * cur;
+ ggml_tensor * inpL;
+
+ // TODO: remove this when ggml_scale_add is implemented
+ one = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+ {
+ auto inp = std::make_unique<llm_graph_input_one>();
+ inp->one = one;
+ res->add_input(std::move(inp));
+ }
+
+ inpL = build_inp_embd(model.tok_embd);
+
+ // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
+ if (ubatch.token) {
+ inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
+ cb(inpL, "inp_scaled", -1);
+ }
+
+ // inp_pos - contains the positions
+ ggml_tensor * inp_pos = build_inp_pos();
+
+ // TODO: is causal == true correct? might need some changes
+ auto * inp_attn = build_attn_inp_kv_unified_iswa();
+
+ // inp_per_layer shape: [n_embd_altup, n_tokens, n_layer]
+ ggml_tensor * inp_per_layer = project_per_layer_inputs(inpL, get_per_layer_inputs());
+
+ // inpL now has only 1 altup, project it to the rest of the altups
+ // these "added" altups will be concat to the last dim of inpL
+ {
+ ggml_tensor * target_magnitude = calc_magnitude(inpL);
+ ggml_tensor * inp_repeated = ggml_repeat_4d(ctx0, inpL, n_embd, n_tokens, n_altup - 1, 1);
+ ggml_tensor * altup_added = ggml_mul_mat(ctx0, model.altup_proj, inp_repeated); // shape: [n_embd, n_tokens, n_altup - 1]
+ ggml_tensor * new_magnitude = calc_magnitude(altup_added);
+ altup_added = ggml_div(ctx0,
+ ggml_mul(ctx0, altup_added, target_magnitude),
+ new_magnitude);
+ inpL = ggml_concat(ctx0, inpL, altup_added, 2); // shape: [n_embd, n_tokens, n_altup]
+ cb(inpL, "inp_stacked", -1);
+ }
+
+ // inpL now has shape: [n_embd, n_tokens, n_altup]
+ // inp_per_layer now has shape: [n_embd_altup, n_tokens, n_layer]
+
+ for (int il = 0; il < n_layer; ++il) {
+ // this block is made to be closely resemble Gemma3p5DecoderLayer on python code
+ const bool has_kv = (il < n_layer_kv);
+
+ const float freq_base_l = model.get_rope_freq_base (cparams, il);
+ const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
+
+ ggml_tensor * cur = inpL; // [n_embd, n_tokens, n_altup]
+ ggml_tensor * predictions = altup_predict(cur, il); // [n_embd, n_tokens, n_altup]
+
+ // predicted value will go through self-attention and laurel
+ ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act); // [n_embd, n_tokens]
+ cur = active_prediction;
+ cb(cur, "active_prediction", il);
+
+ // norm
+ cur = build_norm(cur, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
+ cb(cur, "attn_norm", il);
+
+ // laurel
+ ggml_tensor * laurel_out = laurel(cur, il); // [n_embd, n_tokens]
+
+ // self-attention
+ if (has_kv) {
+ // compute Q and K and RoPE them
+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+ cb(Qcur, "Qcur", il);
+
+ ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+ cb(Kcur, "Kcur", il);
+
+ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+ cb(Vcur, "Vcur", il);
+
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+ Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
+ Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
+ Vcur = ggml_rms_norm(ctx0, Vcur, hparams.f_norm_rms_eps);
+
+ cb(Qcur, "Qcur_normed", il);
+ cb(Kcur, "Kcur_normed", il);
+ cb(Vcur, "Vcur_normed", il);
+
+ Qcur = ggml_rope_ext(
+ ctx0, Qcur, inp_pos, nullptr,
+ n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+ ext_factor, attn_factor, beta_fast, beta_slow);
+
+ Kcur = ggml_rope_ext(
+ ctx0, Kcur, inp_pos, nullptr,
+ n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+ ext_factor, attn_factor, beta_fast, beta_slow);
+
+ cb(Qcur, "Qcur_pos", il);
+ cb(Kcur, "Kcur_pos", il);
+
+ cur = build_attn(inp_attn, gf,
+ model.layers[il].wo, NULL,
+ Qcur, Kcur, Vcur, nullptr, nullptr, hparams.f_attention_scale, il);
+ } else {
+ // no KV layers
+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+ cb(Qcur, "Qcur", il);
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+
+ Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
+ cb(Qcur, "Qcur_normed", il);
+
+ Qcur = ggml_rope_ext(
+ ctx0, Qcur, inp_pos, nullptr,
+ n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+ ext_factor, attn_factor, beta_fast, beta_slow);
+ cb(Qcur, "Qcur_pos", il);
+
+ cur = build_attn(inp_attn, gf,
+ model.layers[il].wo, NULL,
+ Qcur, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il);
+ }
+
+ cur = build_norm(cur,
+ model.layers[il].attn_post_norm, NULL,
+ LLM_NORM_RMS, il);
+ cb(cur, "attn_post_norm", il);
+
+ cur = ggml_add(ctx0, cur, active_prediction); // [n_embd, n_tokens]
+ cb(cur, "attn_gated", il);
+
+ ggml_tensor * attn_laurel = ggml_scale(ctx0,
+ ggml_add(ctx0, cur, laurel_out),
+ 1.0f / sqrtf(2.0f)); // [n_embd, n_tokens]
+ cb(attn_laurel, "attn_laurel", il);
+
+ cur = build_norm(attn_laurel,
+ model.layers[il].ffn_norm, NULL,
+ LLM_NORM_RMS, il);
+ cb(cur, "ffn_norm", il);
+
+ // feed-forward network
+ {
+ ggml_tensor * up_proj = build_lora_mm(model.layers[il].ffn_up, cur);
+ ggml_tensor * gate_proj = build_lora_mm(model.layers[il].ffn_gate, cur);
+
+ if (il < n_layer_sparsity) {
+ // apply activation sparsity
+ gate_proj = gaussian_topk(gate_proj);
+ }
+ gate_proj = ggml_gelu(ctx0, gate_proj);
+
+ cur = ggml_mul(ctx0, up_proj, gate_proj);
+ cur = build_lora_mm(model.layers[il].ffn_down, cur);
+ cb(cur, "ffn_out", il);
+ }
+
+ cur = build_norm(cur,
+ model.layers[il].ffn_post_norm, NULL,
+ LLM_NORM_RMS, -1);
+ cb(cur, "ffn_post_norm", il);
+
+ ggml_tensor * attn_ffw_laurel_gated = ggml_add(ctx0, cur, attn_laurel); // [n_embd, n_tokens]
+ cb(attn_ffw_laurel_gated, "attn_ffw_laurel_gated", il);
+
+ ggml_tensor * corrected = altup_correct(predictions, attn_ffw_laurel_gated, il); // [n_embd, n_tokens, n_altup]
+
+ ggml_tensor * first_prediction; // [n_embd, n_tokens]
+ {
+ first_prediction = view_2d_slice(corrected, i_altup_act); // [n_embd, n_tokens]
+ first_prediction = ggml_mul(ctx0, first_prediction, model.layers[il].altup_correct_scale);
+ first_prediction = build_lora_mm(model.layers[il].per_layer_inp_gate, first_prediction);
+ first_prediction = ggml_gelu(ctx0, first_prediction); // [n_embd_altup, n_tokens]
+ cb(first_prediction, "first_prediction_gated", il);
+ ggml_tensor * inp_this_layer = view_2d_slice(inp_per_layer, il); // [n_embd_altup, n_tokens]
+ first_prediction = ggml_mul(ctx0, first_prediction, inp_this_layer); // [n_embd_altup, n_tokens]
+ cb(first_prediction, "first_prediction_scaled", il);
+
+ first_prediction = build_lora_mm(model.layers[il].per_layer_proj, first_prediction); // [n_embd, n_tokens]
+ first_prediction = build_norm(first_prediction,
+ model.layers[il].per_layer_post_norm, NULL,
+ LLM_NORM_RMS, il);
+ cb(first_prediction, "first_prediction_out", il);
+ }
+
+ // equivalent to python code: corrected_predictions[1:] += first_prediction
+ {
+ ggml_tensor * slice_first = view_2d_slice(corrected, 0);
+ ggml_tensor * slice_rest = ggml_view_3d(ctx0, corrected, n_embd, n_tokens, n_altup - 1,
+ ggml_row_size(corrected->type, n_embd),
+ ggml_row_size(corrected->type, n_embd*n_tokens),
+ n_embd*n_tokens*ggml_element_size(corrected));
+ ggml_tensor * tmp = ggml_add(ctx0, slice_rest, first_prediction); // [n_embd, n_tokens, n_altup - 1]
+ corrected = ggml_concat(ctx0, slice_first, tmp, 2); // [n_embd, n_tokens, n_altup]
+ }
+
+ cur = corrected; // [n_embd, n_tokens, n_altup]
+ cur = build_cvec(cur, il);
+ cb(cur, "l_out", il);
+
+ // input for next layer
+ inpL = cur;
+ }
+
+ cur = inpL; // [n_embd, n_tokens, n_altup]
+
+ // cur now has multiple altup(s), we want to merge them back to 1 altup
+ {
+ ggml_tensor * target_magnitude = calc_magnitude(view_2d_slice(cur, i_altup_act)); // [n_embd, n_tokens]
+ // do a view to skip the first slice (active altup)
+ ggml_tensor * alt_slice = ggml_view_3d(ctx0, cur, n_embd, n_tokens, n_altup - 1,
+ ggml_row_size(cur->type, n_embd),
+ ggml_row_size(cur->type, n_embd*n_tokens),
+ n_embd*n_tokens*ggml_element_size(cur));
+ ggml_tensor * altup_unembd = ggml_mul_mat(ctx0, model.altup_unembd_proj, alt_slice); // shape: [n_embd, n_tokens, n_altup - 1]
+ ggml_tensor * new_magnitude = calc_magnitude(altup_unembd);
+ altup_unembd = ggml_div(ctx0,
+ ggml_mul(ctx0, altup_unembd, target_magnitude),
+ new_magnitude);
+ cb(altup_unembd, "altup_unembd", -1);
+
+ // equivalent to torch.mean(hidden_states, dim=0)
+ cur = view_2d_slice(cur, 0); // [n_embd, n_tokens]
+ for (int i = 0; i < n_altup - 1; ++i) {
+ cur = ggml_add(ctx0, cur, view_2d_slice(altup_unembd, i));
+ }
+ cur = ggml_scale(ctx0, cur, 1.0f / float(n_altup)); // [n_embd, n_tokens]
+ cb(cur, "unembd_merged", -1);
+ }
+
+ // cur now has shape: [n_embd, n_tokens]
+
+ // TODO: move this to right after the last KV layer
+ {
+ // skip computing output for unused tokens
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ }
+
+ cur = build_norm(cur,
+ model.output_norm, NULL,
+ LLM_NORM_RMS, -1);
+
+ cb(cur, "result_norm", -1);
+ res->t_embd = cur;
+
+ cur = build_lora_mm(model.output, cur);
+
+ {
+ // final logit soft-capping
+ cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
+ cur = ggml_tanh(ctx0, cur);
+ cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
+ }
+
+ cb(cur, "result_output", -1);
+ res->t_logits = cur;
+
+ ggml_build_forward_expand(gf, cur);
+ }
+
+ ggml_tensor * calc_magnitude(ggml_tensor * x) {
+ return ggml_sqrt(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, x)));
+ }
+
+ // get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim
+ ggml_tensor * view_2d_slice(ggml_tensor * x, int idx) {
+ GGML_ASSERT(idx < (int)x->ne[2]);
+ return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1],
+ ggml_row_size(x->type, x->ne[0]),
+ idx * x->ne[0] * x->ne[1] * ggml_element_size(x));
+ }
+
+ // equivalent to get_per_layer_inputs() in python code
+ // output shape: [n_embd_altup, n_layer, n_tokens]
+ ggml_tensor * get_per_layer_inputs() {
+ auto inp = std::make_unique<llm_graph_input_embd>();
+ ggml_tensor * inp_per_layer;
+ if (ubatch.token) {
+ inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
+ ggml_set_input(inp->tokens);
+ res->t_tokens = inp->tokens;
+ inp_per_layer = ggml_get_rows(ctx0, model.tok_embd_per_layer, inp->tokens);
+ inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens);
+ inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float)n_embd_altup));
+ cb(inp_per_layer, "inp_per_layer_selected", -1);
+ } else {
+ GGML_ABORT("TODO: support embd input");
+ }
+ res->add_input(std::move(inp));
+ return inp_per_layer;
+ }
+
+ // equivalent to project_per_layer_inputs() in python code
+ // this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim
+ // output shape: [n_embd_altup, n_tokens, n_layer]
+ ggml_tensor * project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer) {
+ const float per_layer_projection_scale = 1.0f / sqrtf((float)n_embd);
+ const float per_layer_input_scale = 1.0f / sqrtf(2.0f);
+
+ ggml_tensor * per_layer_proj = ggml_mul_mat(ctx0, model.per_layer_model_proj, inputs_embeds);
+ per_layer_proj = ggml_scale(ctx0, per_layer_proj, per_layer_projection_scale);
+ per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_altup, n_layer, n_tokens);
+ per_layer_proj = build_norm(per_layer_proj,
+ model.per_layer_proj_norm, NULL,
+ LLM_NORM_RMS, -1); // [n_embd_altup, n_layer, n_tokens]
+ cb(per_layer_proj, "per_layer_proj", -1);
+
+ inp_per_layer = ggml_add(ctx0, inp_per_layer, per_layer_proj);
+ inp_per_layer = ggml_scale(ctx0, inp_per_layer, per_layer_input_scale);
+ cb(inp_per_layer, "inp_per_layer", -1);
+
+ // permute to shape: [n_embd_altup, n_tokens, n_layer]
+ inp_per_layer = ggml_cont(ctx0, ggml_permute(ctx0, inp_per_layer, 0, 2, 1, 3));
+ return inp_per_layer;
+ }
+
+ // input cur shape: [n_altup, n_tokens]
+ // output shape: [n_altup, n_tokens]
+ ggml_tensor * laurel(ggml_tensor * cur, int il) {
+ ggml_tensor * tmp = cur;
+ tmp = build_lora_mm(model.layers[il].laurel_l, tmp);
+ tmp = build_lora_mm(model.layers[il].laurel_r, tmp);
+ tmp = build_norm(tmp, model.layers[il].laurel_post_norm, NULL, LLM_NORM_RMS, il);
+ tmp = ggml_add(ctx0, tmp, cur);
+ cb(tmp, "laurel_out", il);
+ return tmp;
+ }
+
+ // input x shape: [n_embd, n_tokens]
+ // output shape: [n_embd, n_tokens]
+ ggml_tensor * gaussian_topk(ggml_tensor * x) {
+ ggml_tensor * mean = ggml_mean(ctx0, x);
+ ggml_tensor * std = ggml_sqrt(ctx0, ggml_scale(ctx0,
+ ggml_sum_rows(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x, mean))),
+ 1.0f / (float)(x->ne[0] - 1)
+ ));
+ ggml_tensor * cutoff_x = ggml_add(ctx0, mean, ggml_scale(ctx0, std, f_sparsity_std_mul));
+ return ggml_relu(ctx0, ggml_sub(ctx0, x, cutoff_x));
+ }
+
+ //
+ // altup functions
+ //
+
+ // equivalent to compute_router_modalities() in python code
+ // input x shape: [n_embd, n_tokens]
+ // output shape: [n_altup, n_tokens]
+ ggml_tensor * altup_compute_router_modalities(ggml_tensor * x, int il) {
+ ggml_tensor * router_inputs = build_norm(x,
+ model.layers[il].altup_router_norm, NULL,
+ LLM_NORM_RMS, il);
+
+ // router_input_scale
+ router_inputs = ggml_scale(ctx0, router_inputs, 1.0f / (float)n_embd);
+
+ ggml_tensor * output = ggml_mul_mat(ctx0, model.layers[il].altup_router, router_inputs);
+ return ggml_tanh(ctx0, output); // [n_altup, n_tokens]
+ }
+
+ // input cur shape: [n_embd, n_tokens, n_altup]
+ // output shape: [n_embd, n_tokens, n_altup]
+ ggml_tensor * altup_predict(ggml_tensor * cur, int il) {
+ ggml_tensor * activated = view_2d_slice(cur, i_altup_act); // [n_embd, n_tokens]
+ ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens]
+ cb(modalities, "modalities", il);
+
+ ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_predict_coef, modalities);
+ cb(all_coefs, "all_coefs", il);
+ // first dim now having n_altup^2 elements, we reshape it to 2D (so we end up with 3D tensor)
+ all_coefs = ggml_reshape_3d(ctx0, all_coefs, n_altup, n_altup, n_tokens);
+
+ // permute to [n_altup, n_embd, n_tokens]
+ ggml_tensor * cur_permuted = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3));
+ ggml_tensor * predictions = ggml_mul_mat(ctx0, cur_permuted, all_coefs); // [n_altup, n_embd, n_tokens]
+
+ // final shape must be the same as cur: [n_embd, n_tokens, n_altup]
+ predictions = ggml_cont(ctx0, ggml_permute(ctx0, predictions, 0, 2, 1, 3));
+ predictions = ggml_add(ctx0, predictions, cur);
+ cb(predictions, "predictions", il);
+
+ return predictions;
+ }
+
+ // input predictions shape: [n_embd, n_tokens, n_altup]
+ // input activated shape: [n_embd, n_tokens]
+ // output shape: [n_embd, n_tokens, n_altup]
+ ggml_tensor * altup_correct(ggml_tensor * predictions, ggml_tensor * activated, int il) {
+ ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens]
+ cb(modalities, "modalities", il);
+
+ ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act);
+ ggml_tensor * innovation = ggml_sub(ctx0, activated, active_prediction); // [n_embd, n_tokens]
+ cb(innovation, "innovation", il);
+
+ ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_correct_coef, modalities); // [n_altup, n_tokens]
+ all_coefs = ggml_add(ctx0, all_coefs, one);
+ cb(all_coefs, "all_coefs", il);
+ all_coefs = ggml_cont(ctx0, ggml_transpose(ctx0, all_coefs)); // [n_tokens, n_altup]
+ all_coefs = ggml_reshape_3d(ctx0, all_coefs, 1, n_tokens, n_altup); // [1, n_tokens, n_altup]
+
+ innovation = ggml_repeat_4d(ctx0, innovation, n_embd, n_tokens, n_altup, 1);
+ ggml_tensor * corrected = ggml_mul(ctx0, innovation, all_coefs); // [n_embd, n_tokens, n_altup]
+ corrected = ggml_add(ctx0, corrected, predictions); // [n_embd, n_tokens, n_altup]
+ cb(corrected, "corrected", il);
+
+ return corrected;
+ }
+};
+
// TODO: move up next to build_starcoder
struct llm_build_starcoder2 : public llm_graph_context {
llm_build_starcoder2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
{
llm = std::make_unique<llm_build_gemma3_iswa>(*this, params, gf);
} break;
+ case LLM_ARCH_GEMMA3N:
+ {
+ llm = std::make_unique<llm_build_gemma3n_iswa>(*this, params, gf);
+ } break;
case LLM_ARCH_STARCODER2:
{
llm = std::make_unique<llm_build_starcoder2>(*this, params, gf);
case LLM_ARCH_GEMMA:
case LLM_ARCH_GEMMA2:
case LLM_ARCH_GEMMA3:
+ case LLM_ARCH_GEMMA3N:
case LLM_ARCH_STARCODER2:
case LLM_ARCH_OPENELM:
case LLM_ARCH_GPTNEOX: