norm_shift = 1.0 # Gemma3RMSNorm adds 1.0 to the norm value
def set_vocab(self):
- self._set_vocab_sentencepiece()
-
- self.gguf_writer.add_add_space_prefix(False)
+ if (self.dir_model / "tokenizer.model").is_file():
+ self._set_vocab_sentencepiece()
+ self.gguf_writer.add_add_space_prefix(False)
+ else:
+ self._set_vocab_gpt2()
def set_gguf_parameters(self):
hparams = self.hparams
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1_000_000.0)) # for global layers
# attn_logit_softcapping is removed in Gemma3
assert hparams.get("attn_logit_softcapping") is None
- self.gguf_writer.add_sliding_window(hparams["sliding_window"])
+ if (final_logit_softcap := hparams.get("final_logit_softcapping")):
+ self.gguf_writer.add_final_logit_softcapping(final_logit_softcap)
+ if hparams.get("sliding_window_pattern") != 1:
+ self.gguf_writer.add_sliding_window(hparams["sliding_window"])
self.gguf_writer.add_head_count_kv(hparams.get("num_key_value_heads", 4))
if hparams.get("rope_scaling") is not None:
- assert hparams["rope_scaling"]["rope_type"] == "linear"
- # important: this rope_scaling is only applied for global layers, and not used by 1B model
- self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
- self.gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"])
+ rope_scaling = hparams["rope_scaling"]
+ if rope_scaling["rope_type"] == "linear":
+ # important: this rope_scaling is only applied for global layers, and not used by 1B model
+ self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
+ self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
+ elif rope_scaling["rope_type"] == "yarn":
+ self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
+ self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
+ self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
+ self.gguf_writer.add_rope_scaling_yarn_ext_factor(rope_scaling["extrapolation_factor"])
+ self.gguf_writer.add_rope_scaling_yarn_beta_fast(rope_scaling["beta_fast"])
+ self.gguf_writer.add_rope_scaling_yarn_beta_slow(rope_scaling["beta_slow"])
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
# remove OOV (out-of-vocabulary) rows in token_embd
if "embed_tokens.weight" in name:
- vocab = self._create_vocab_sentencepiece()
- tokens = vocab[0]
+ if (self.dir_model / "tokenizer.model").is_file():
+ tokens = self._create_vocab_sentencepiece()[0]
+ else:
+ tokens = self.get_vocab_base()[0]
data_torch = data_torch[:len(tokens)]
# ref code in Gemma3RMSNorm
models/gemma-embedding.cpp
models/gemma.cpp
models/gemma2-iswa.cpp
- models/gemma3-iswa.cpp
+ models/gemma3.cpp
models/gemma3n-iswa.cpp
models/glm4-moe.cpp
models/glm4.cpp
} break;
case LLM_ARCH_GEMMA3:
{
- hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
- hparams.set_swa_pattern(6);
+ const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
+ if (found_swa && hparams.n_swa > 0) {
+ hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
+ hparams.set_swa_pattern(6);
- hparams.rope_freq_base_train_swa = 10000.0f;
- hparams.rope_freq_scale_train_swa = 1.0f;
+ hparams.rope_freq_base_train_swa = 10000.0f;
+ hparams.rope_freq_scale_train_swa = 1.0f;
+ } else {
+ hparams.swa_type = LLAMA_SWA_TYPE_NONE;
+ }
- ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
+ hparams.f_final_logit_softcapping = 0.0f;
+ ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) {
case 18: type = LLM_TYPE_270M; break;
case 26: type = LLM_TYPE_1B; break;
+ case 32: type = LLM_TYPE_8B; break; // Rnj-1
case 34: type = LLM_TYPE_4B; break;
case 48: type = LLM_TYPE_12B; break;
case 62: type = LLM_TYPE_27B; break;
} break;
case LLM_ARCH_GEMMA3:
{
- llm = std::make_unique<llm_build_gemma3_iswa>(*this, params);
+ if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) {
+ llm = std::make_unique<llm_build_gemma3<true>>(*this, params);
+ } else {
+ llm = std::make_unique<llm_build_gemma3<false>>(*this, params);
+ }
} break;
case LLM_ARCH_GEMMA3N:
{
+++ /dev/null
-#include "models.h"
-
-llm_build_gemma3_iswa::llm_build_gemma3_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
- const int64_t n_embd_head = hparams.n_embd_head_k;
-
- ggml_tensor * cur;
- ggml_tensor * inpL;
-
- inpL = build_inp_embd(model.tok_embd);
-
- // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
- if (ubatch.token) {
- inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
- cb(inpL, "inp_scaled", -1);
- }
- // inp_pos - contains the positions
- ggml_tensor * inp_pos = build_inp_pos();
-
- // TODO: is causal == true correct? might need some changes
- auto * inp_attn = build_attn_inp_kv_iswa();
-
- ggml_tensor * inp_out_ids = build_inp_out_ids();
-
- for (int il = 0; il < n_layer; ++il) {
- const float freq_base_l = model.get_rope_freq_base (cparams, il);
- const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
-
- // norm
- cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
- cb(cur, "attn_norm", il);
-
- // self-attention
- {
- // compute Q and K and RoPE them
- ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
- cb(Qcur, "Qcur", il);
-
- ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
- cb(Kcur, "Kcur", il);
-
- ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
- cb(Vcur, "Vcur", il);
-
- Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
- Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
- Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
-
- Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
- cb(Qcur, "Qcur_normed", il);
-
- Qcur = ggml_rope_ext(
- ctx0, Qcur, inp_pos, nullptr,
- n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
- ext_factor, attn_factor, beta_fast, beta_slow);
-
- Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
- cb(Kcur, "Kcur_normed", il);
-
- Kcur = ggml_rope_ext(
- ctx0, Kcur, inp_pos, nullptr,
- n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
- ext_factor, attn_factor, beta_fast, beta_slow);
-
- cb(Qcur, "Qcur", il);
- cb(Kcur, "Kcur", il);
- cb(Vcur, "Vcur", il);
-
- // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/model.py#L315
- Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale);
-
- cur = build_attn(inp_attn,
- model.layers[il].wo, NULL,
- Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il);
- }
- if (il == n_layer - 1 && inp_out_ids) {
- cur = ggml_get_rows(ctx0, cur, inp_out_ids);
- inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
- }
- cur = build_norm(cur,
- model.layers[il].attn_post_norm, NULL,
- LLM_NORM_RMS, il);
- cb(cur, "attn_post_norm", il);
-
- ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
- cb(sa_out, "sa_out", il);
-
- cur = build_norm(sa_out,
- model.layers[il].ffn_norm, NULL,
- LLM_NORM_RMS, il);
- cb(cur, "ffn_norm", il);
-
- // feed-forward network
- {
- cur = build_ffn(cur,
- model.layers[il].ffn_up, NULL, NULL,
- model.layers[il].ffn_gate, NULL, NULL,
- model.layers[il].ffn_down, NULL, NULL,
- NULL,
- LLM_FFN_GELU, LLM_FFN_PAR, il);
- cb(cur, "ffn_out", il);
- }
- cur = build_norm(cur,
- model.layers[il].ffn_post_norm, NULL,
- LLM_NORM_RMS, -1);
- cb(cur, "ffn_post_norm", -1);
-
- cur = ggml_add(ctx0, cur, sa_out);
-
- cur = build_cvec(cur, il);
- cb(cur, "l_out", il);
-
- // input for next layer
- inpL = cur;
- }
- cur = inpL;
-
- cur = build_norm(cur,
- model.output_norm, NULL,
- LLM_NORM_RMS, -1);
-
- cb(cur, "result_norm", -1);
- res->t_embd = cur;
-
- // lm_head
- cur = build_lora_mm(model.output, cur);
-
- cb(cur, "result_output", -1);
- res->t_logits = cur;
-
- ggml_build_forward_expand(gf, cur);
-}
--- /dev/null
+#include "models.h"
+
+template <bool iswa>
+llm_build_gemma3<iswa>::llm_build_gemma3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
+ const int64_t n_embd_head = hparams.n_embd_head_k;
+
+ ggml_tensor * cur;
+ ggml_tensor * inpL;
+
+ inpL = build_inp_embd(model.tok_embd);
+
+ // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
+ if (ubatch.token) {
+ inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
+ cb(inpL, "inp_scaled", -1);
+ }
+ // inp_pos - contains the positions
+ ggml_tensor * inp_pos = build_inp_pos();
+
+ // TODO: is causal == true correct? might need some changes
+ using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_iswa, llm_graph_input_attn_kv>;
+ inp_attn_type * inp_attn = nullptr;
+
+ if constexpr (iswa) {
+ inp_attn = build_attn_inp_kv_iswa();
+ } else {
+ inp_attn = build_attn_inp_kv();
+ }
+
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+ for (int il = 0; il < n_layer; ++il) {
+ float freq_base_l = 0.0f;
+ float freq_scale_l = 0.0f;
+
+ if constexpr (iswa) {
+ freq_base_l = model.get_rope_freq_base (cparams, il);
+ freq_scale_l = model.get_rope_freq_scale(cparams, il);
+ } else {
+ freq_base_l = freq_base;
+ freq_scale_l = freq_scale;
+ }
+
+ // norm
+ cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
+ cb(cur, "attn_norm", il);
+
+ // self-attention
+ {
+ // compute Q and K and RoPE them
+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+ cb(Qcur, "Qcur", il);
+
+ ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+ cb(Kcur, "Kcur", il);
+
+ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+ cb(Vcur, "Vcur", il);
+
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+ Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
+ cb(Qcur, "Qcur_normed", il);
+
+ Qcur = ggml_rope_ext(
+ ctx0, Qcur, inp_pos, nullptr,
+ n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+ ext_factor, attn_factor, beta_fast, beta_slow);
+
+ Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
+ cb(Kcur, "Kcur_normed", il);
+
+ Kcur = ggml_rope_ext(
+ ctx0, Kcur, inp_pos, nullptr,
+ n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+ ext_factor, attn_factor, beta_fast, beta_slow);
+
+ cb(Qcur, "Qcur", il);
+ cb(Kcur, "Kcur", il);
+ cb(Vcur, "Vcur", il);
+
+ // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/model.py#L315
+ Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale);
+
+ cur = build_attn(inp_attn,
+ model.layers[il].wo, NULL,
+ Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il);
+ }
+ if (il == n_layer - 1 && inp_out_ids) {
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ }
+ cur = build_norm(cur,
+ model.layers[il].attn_post_norm, NULL,
+ LLM_NORM_RMS, il);
+ cb(cur, "attn_post_norm", il);
+
+ ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
+ cb(sa_out, "sa_out", il);
+
+ cur = build_norm(sa_out,
+ model.layers[il].ffn_norm, NULL,
+ LLM_NORM_RMS, il);
+ cb(cur, "ffn_norm", il);
+
+ // feed-forward network
+ {
+ cur = build_ffn(cur,
+ model.layers[il].ffn_up, NULL, NULL,
+ model.layers[il].ffn_gate, NULL, NULL,
+ model.layers[il].ffn_down, NULL, NULL,
+ NULL,
+ LLM_FFN_GELU, LLM_FFN_PAR, il);
+ cb(cur, "ffn_out", il);
+ }
+ cur = build_norm(cur,
+ model.layers[il].ffn_post_norm, NULL,
+ LLM_NORM_RMS, -1);
+ cb(cur, "ffn_post_norm", il);
+
+ cur = ggml_add(ctx0, cur, sa_out);
+
+ cur = build_cvec(cur, il);
+ cb(cur, "l_out", il);
+
+ // input for next layer
+ inpL = cur;
+ }
+ cur = inpL;
+
+ cur = build_norm(cur,
+ model.output_norm, NULL,
+ LLM_NORM_RMS, -1);
+
+ cb(cur, "result_norm", -1);
+ res->t_embd = cur;
+
+ // lm_head
+ cur = build_lora_mm(model.output, cur);
+
+ if (hparams.f_final_logit_softcapping) {
+ cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
+ cur = ggml_tanh(ctx0, cur);
+ cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
+ }
+
+ cb(cur, "result_output", -1);
+ res->t_logits = cur;
+
+ ggml_build_forward_expand(gf, cur);
+}
+
+template struct llm_build_gemma3<false>;
+template struct llm_build_gemma3<true>;
llm_build_gemma2_iswa(const llama_model & model, const llm_graph_params & params);
};
-struct llm_build_gemma3_iswa : public llm_graph_context {
- llm_build_gemma3_iswa(const llama_model & model, const llm_graph_params & params);
+template <bool iswa>
+struct llm_build_gemma3 : public llm_graph_context {
+ llm_build_gemma3(const llama_model & model, const llm_graph_params & params);
};
struct llm_build_gemma3n_iswa : public llm_graph_context {