for token_id, token_data in added_tokens_decoder.items():
token_id = int(token_id)
token: str = token_data["content"]
+ if token_id >= vocab_size:
+ logger.warning(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}')
+ continue
if toktypes[token_id] != SentencePieceTokenTypes.UNUSED:
if tokens[token_id] != token.encode("utf-8"):
logger.warning(f'replacing token {token_id}: {tokens[token_id].decode("utf-8")!r} -> {token!r}')
return [(self.map_tensor_name(name), data_torch)]
+@Model.register("Gemma3ForCausalLM", "Gemma3ForConditionalGeneration")
+class Gemma3Model(Model):
+ model_arch = gguf.MODEL_ARCH.GEMMA3
+ has_vision: bool = False
+
+ # we need to merge the text_config into the root level of hparams
+ def __init__(self, *args, **kwargs):
+ hparams = Model.load_hparams(kwargs["dir_model"])
+ if "text_config" in hparams:
+ hparams = {**hparams, **hparams["text_config"]}
+ kwargs["hparams"] = hparams
+ super().__init__(*args, **kwargs)
+ if "vision_config" in hparams:
+ logger.info("Has vision encoder, but it will be ignored")
+ self.has_vision = True
+
+ def write(self):
+ super().write()
+ if self.has_vision:
+ logger.info("NOTE: this script only convert the language model to GGUF")
+ logger.info(" for the vision model, please use gemma3_convert_encoder_to_gguf.py")
+
+ def set_vocab(self):
+ self._set_vocab_sentencepiece()
+
+ self.gguf_writer.add_add_space_prefix(False)
+
+ def set_gguf_parameters(self):
+ hparams = self.hparams
+ block_count = hparams["num_hidden_layers"]
+
+ # some default values are not specified in the hparams
+ self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 131072))
+ self.gguf_writer.add_embedding_length(hparams["hidden_size"])
+ self.gguf_writer.add_block_count(block_count)
+ self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
+ self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 8))
+ self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-6))
+ self.gguf_writer.add_key_length(hparams.get("head_dim", 256))
+ self.gguf_writer.add_value_length(hparams.get("head_dim", 256))
+ self.gguf_writer.add_file_type(self.ftype)
+ self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1_000_000.0)) # for global layers
+ # both attn_logit_softcapping and final_logit_softcapping are removed in Gemma3
+ assert hparams.get("attn_logit_softcapping") is None
+ assert hparams.get("final_logit_softcapping") is None
+ self.gguf_writer.add_sliding_window(hparams["sliding_window"])
+ self.gguf_writer.add_head_count_kv(hparams.get("num_key_value_heads", 4))
+ if hparams.get("rope_scaling") is not None:
+ assert hparams["rope_scaling"]["rope_type"] == "linear"
+ # important: this rope_scaling is only applied for global layers, and not used by 1B model
+ self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
+ self.gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"])
+
+ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+ del bid # unused
+
+ if name.startswith("language_model."):
+ name = name.replace("language_model.", "")
+ elif name.startswith("multi_modal_projector.") or name.startswith("vision_tower.") \
+ or name.startswith("multimodal_projector.") or name.startswith("vision_model."): # this is for old HF model, should be removed later
+ # ignore vision tensors
+ return []
+
+ # remove OOV (out-of-vocabulary) rows in token_embd
+ if "embed_tokens.weight" in name:
+ vocab = self._create_vocab_sentencepiece()
+ tokens = vocab[0]
+ data_torch = data_torch[:len(tokens)]
+
+ # ref code in Gemma3RMSNorm
+ # output = output * (1.0 + self.weight.float())
+ if name.endswith("norm.weight"):
+ data_torch = data_torch + 1
+
+ return [(self.map_tensor_name(name), data_torch)]
+
+
@Model.register("Starcoder2ForCausalLM")
class StarCoder2Model(Model):
model_arch = gguf.MODEL_ARCH.STARCODER2
target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_17)
+set(TARGET llama-gemma3-cli)
+add_executable(${TARGET} gemma3-cli.cpp)
+set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-gemma3-cli)
+install(TARGETS ${TARGET} RUNTIME)
+target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
+target_compile_features(${TARGET} PRIVATE cxx_std_17)
+
set(TARGET llama-llava-clip-quantize-cli)
add_executable(${TARGET} clip-quantize-cli.cpp)
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-llava-clip-quantize-cli)
--- /dev/null
+# Gemma 3 vision
+
+> [!IMPORTANT]
+>
+> This is very experimental, only used for demo purpose.
+
+## How to get mmproj.gguf?
+
+```bash
+cd gemma-3-4b-it
+python ../llama.cpp/examples/llava/gemma3_convert_encoder_to_gguf.py .
+
+# output file is mmproj.gguf
+```
+
+## How to run it?
+
+What you need:
+- The text model GGUF, can be converted using `convert_hf_to_gguf.py`
+- The mmproj file from step above
+- An image file
+
+```bash
+# build
+cmake -B build
+cmake --build build --target llama-gemma3-cli
+
+# run it
+./build/bin/llama-gemma3-cli -m {text_model}.gguf --mmproj mmproj.gguf --image your_image.jpg
+```
#define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s"
#define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s"
#define TN_IMAGE_NEWLINE "model.image_newline"
+#define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3
+#define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3
#define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k"
#define TN_MINICPMV_QUERY "resampler.query"
PROJECTOR_TYPE_RESAMPLER,
PROJECTOR_TYPE_GLM_EDGE,
PROJECTOR_TYPE_MERGER,
+ PROJECTOR_TYPE_GEMMA3,
PROJECTOR_TYPE_UNKNOWN,
};
{ PROJECTOR_TYPE_RESAMPLER, "resampler"},
{ PROJECTOR_TYPE_GLM_EDGE, "adapter"},
{ PROJECTOR_TYPE_MERGER, "qwen2vl_merger"},
+ { PROJECTOR_TYPE_GEMMA3, "gemma3"},
};
return kv.first;
}
}
- return PROJECTOR_TYPE_UNKNOWN;
+ throw std::runtime_error(format("Unknown projector type: %s", name.c_str()));
}
#ifdef CLIP_DEBUG_FUNCTIONS
struct ggml_tensor * mm_model_ln_kv_b;
struct ggml_tensor * mm_model_ln_post_w;
struct ggml_tensor * mm_model_ln_post_b;
+
+ // gemma3
+ struct ggml_tensor * mm_input_proj_w;
+ struct ggml_tensor * mm_soft_emb_norm_w;
};
struct clip_ctx {
struct clip_vision_model vision_model;
projector_type proj_type = PROJECTOR_TYPE_MLP;
- int32_t max_feature_layer;
+ int32_t max_feature_layer; // unused in newer models like gemma3
float image_mean[3];
float image_std[3];
bool use_gelu = false;
ggml_backend_sched_ptr sched;
- struct clip_image_size * load_image_size;
+ struct clip_image_size * load_image_size = nullptr;
clip_ctx(clip_context_params & ctx_params) {
backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
}
};
-static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false) {
+static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_image_f32_batch * imgs) {
+ const auto & model = ctx->vision_model;
+ const auto & hparams = model.hparams;
+
+ const int image_size = hparams.image_size;
+ int image_size_width = image_size;
+ int image_size_height = image_size;
+
+ const int patch_size = hparams.patch_size;
+ const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
+ const int hidden_size = hparams.hidden_size;
+ const int n_head = hparams.n_head;
+ const int d_head = hidden_size / n_head;
+ const int n_layer = hparams.n_layer;
+ const float eps = hparams.eps;
+
+ GGML_ASSERT(imgs->size == 1); // batch_size == 1
+
+ struct ggml_init_params params = {
+ /*.mem_size =*/ ctx->buf_compute_meta.size(),
+ /*.mem_buffer =*/ ctx->buf_compute_meta.data(),
+ /*.no_alloc =*/ true,
+ };
+
+ struct ggml_context * ctx0 = ggml_init(params);
+ struct ggml_cgraph * gf = ggml_new_graph(ctx0);
+
+ // input raw
+ struct ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3);
+ ggml_set_name(inp_raw, "inp_raw");
+ ggml_set_input(inp_raw);
+
+ struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
+ inp = ggml_reshape_2d(ctx0, inp, num_patches, hidden_size);
+ inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
+ inp = ggml_add(ctx0, inp, model.patch_bias);
+
+ // position embeddings
+ struct ggml_tensor * embeddings = ggml_add(ctx0, inp, model.position_embeddings);
+
+ // loop over layers
+ for (int il = 0; il < n_layer; il++) {
+ struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
+
+ // layernorm1
+ {
+ cur = ggml_norm(ctx0, cur, eps);
+ cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_1_w), model.layers[il].ln_1_b);
+ }
+
+ // self-attention
+ {
+
+ struct ggml_tensor * Q =
+ ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b);
+
+ Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_patches);
+ Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
+
+ struct ggml_tensor * K =
+ ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b);
+
+ K = ggml_reshape_3d(ctx0, K, d_head, n_head, num_patches);
+ K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
+
+ struct ggml_tensor * V =
+ ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].v_w, cur), model.layers[il].v_b);
+
+ V = ggml_reshape_3d(ctx0, V, d_head, n_head, num_patches);
+ V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
+
+ struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+ KQ = ggml_scale_inplace(ctx0, KQ, 1.0f / sqrtf((float)d_head));
+ KQ = ggml_soft_max_inplace(ctx0, KQ);
+
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
+ KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_patches, n_head);
+ KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+
+ cur = ggml_cont_2d(ctx0, KQV, hidden_size, num_patches);
+ }
+
+ // attention output
+ cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].o_w, cur), model.layers[il].o_b);
+
+ // re-add the layer input, e.g., residual
+ cur = ggml_add(ctx0, cur, embeddings);
+
+ embeddings = cur; // embeddings = residual, cur = hidden_states
+
+ // layernorm2
+ {
+ cur = ggml_norm(ctx0, cur, eps);
+ cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b);
+ }
+
+ cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
+ cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
+
+ // siglip uses gelu
+ cur = ggml_gelu(ctx0, cur);
+
+ cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
+ cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b);
+
+ // residual 2
+ cur = ggml_add(ctx0, embeddings, cur);
+
+ embeddings = cur;
+ }
+
+ // post-layernorm
+ if (ctx->has_post_norm) {
+ embeddings = ggml_norm(ctx0, embeddings, eps);
+ ggml_set_name(embeddings, "post_ln");
+
+ embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b);
+ }
+
+ if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
+ const int batch_size = 1;
+ const int mm_tokens_per_image = 256; // default value for gemma3
+ const int tokens_per_side = sqrt(mm_tokens_per_image);
+ const int patches_per_image = sqrt(num_patches);
+ const int kernel_size = patches_per_image / tokens_per_side;
+
+ embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings));
+ embeddings = ggml_reshape_4d(ctx0, embeddings, patches_per_image, patches_per_image, hidden_size, batch_size);
+
+ // doing a pool2d to reduce the number of output tokens to 256
+ embeddings = ggml_pool_2d(ctx0, embeddings, GGML_OP_POOL_AVG, kernel_size, kernel_size, kernel_size, kernel_size, 0, 0);
+ embeddings = ggml_reshape_3d(ctx0, embeddings, embeddings->ne[0] * embeddings->ne[0], hidden_size, batch_size);
+ embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings));
+
+ // apply norm before projection
+ embeddings = ggml_rms_norm(ctx0, embeddings, eps);
+ embeddings = ggml_mul(ctx0, embeddings, model.mm_soft_emb_norm_w);
+
+ // apply projection
+ embeddings = ggml_mul_mat(ctx0,
+ ggml_cont(ctx0, ggml_transpose(ctx0, model.mm_input_proj_w)),
+ embeddings);
+ }
+
+ // build the graph
+ ggml_build_forward_expand(gf, embeddings);
+
+ ggml_free(ctx0);
+
+ return gf;
+}
+
+static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false) {
if (!ctx->has_vision_encoder) {
LOG_ERR("This gguf file seems to have no vision encoder\n");
return nullptr;
} else {
GGML_ABORT("fatel error");
}
- } else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
+ }
+ else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size);
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
return gf;
}
+static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false) {
+ if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
+ return clip_image_build_graph_siglip(ctx, imgs);
+ } else {
+ // TODO: we should have one build_* function per model
+ return clip_image_build_graph_legacy(ctx, imgs, load_image_size, is_inf);
+ }
+}
+
// read and create ggml_context containing the tensors and their data
struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
return clip_init(fname, clip_context_params{
GGML_ASSERT(new_clip->has_vision_encoder);
GGML_ASSERT(!new_clip->has_text_encoder);
- idx = get_key_idx(ctx, KEY_USE_GELU);
- new_clip->use_gelu = gguf_get_val_bool(ctx, idx);
+ try {
+ idx = get_key_idx(ctx, KEY_USE_GELU);
+ new_clip->use_gelu = gguf_get_val_bool(ctx, idx);
+ } catch (std::runtime_error & /*e*/) {
+ new_clip->use_gelu = false;
+ }
try {
idx = get_key_idx(ctx, KEY_USE_SILU);
}
try {
- vision_model.patch_embeddings_0 = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD);
+ vision_model.patch_embeddings_0 = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD);
+ } catch(const std::exception& /*e*/) {
+ vision_model.patch_embeddings_0 = nullptr;
+ }
+
+ try {
vision_model.position_embeddings = get_tensor(new_clip->ctx_data, format(TN_POS_EMBD, "v"));
} catch(const std::exception& /*e*/) {
- LOG_ERR("%s: failed to load vision model tensors\n", __func__);
+ vision_model.position_embeddings = nullptr;
}
+
try {
vision_model.patch_embeddings_1 = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD_1);
} catch(const std::exception& /*e*/) {
vision_model.mm_1_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "weight"));
vision_model.mm_1_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "bias"));
}
+ else if (new_clip->proj_type == PROJECTOR_TYPE_GEMMA3) {
+ vision_model.mm_input_proj_w = get_tensor(new_clip->ctx_data, TN_MM_INP_PROJ);
+ vision_model.mm_soft_emb_norm_w = get_tensor(new_clip->ctx_data, TN_MM_SOFT_EMB_N);
+ }
else {
std::string proj_type = PROJECTOR_TYPE_NAMES[new_clip->proj_type];
throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));
return true;
}
- if (ctx->has_glm_projector) {
+ if (ctx->has_glm_projector || ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
res_imgs->size = 1;
res_imgs->data = new clip_image_f32[res_imgs->size];
clip_image_u8 resized_image;
ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
free(positions_data);
}
+ else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
+ // do nothing
+ }
else {
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
return ctx->vision_model.mm_1_b->ne[0];
}
+ if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
+ return ctx->vision_model.mm_input_proj_w->ne[0];
+ }
std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type];
throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));
--- /dev/null
+#include "arg.h"
+#include "log.h"
+#include "common.h"
+#include "sampling.h"
+#include "clip.h"
+#include "stb_image.h"
+#include "llama.h"
+#include "ggml.h"
+#include "console.h"
+
+#include <vector>
+#include <limits.h>
+#include <inttypes.h>
+
+#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
+#include <signal.h>
+#include <unistd.h>
+#elif defined (_WIN32)
+#define WIN32_LEAN_AND_MEAN
+#ifndef NOMINMAX
+#define NOMINMAX
+#endif
+#include <windows.h>
+#include <signal.h>
+#endif
+
+static bool g_is_generating = false;
+
+/**
+ * Please note that this is NOT a production-ready stuff.
+ * It is a playground for trying Gemma 3 vision capabilities.
+ * For contributors: please keep this code simple and easy to understand.
+ */
+
+static void show_additional_info(int /*argc*/, char ** argv) {
+ LOG(
+ "Experimental CLI for using Gemma 3 vision model\n\n"
+ "Usage: %s [options] -m <model> --mmproj <mmproj> --image <image> -p <prompt>\n\n"
+ " -m and --mmproj are required\n"
+ " --image and -p are optional, if NOT provided, the CLI will run in chat mode\n",
+ argv[0]
+ );
+}
+
+#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
+static void sigint_handler(int signo) {
+ if (signo == SIGINT) {
+ if (g_is_generating) {
+ g_is_generating = false;
+ } else {
+ console::cleanup();
+ LOG("\nInterrupted by user\n");
+ _exit(130);
+ }
+ }
+}
+#endif
+
+struct gemma3_context {
+ struct clip_ctx * ctx_clip = NULL;
+ common_init_result llama_init;
+
+ llama_model * model;
+ llama_context * lctx;
+ const llama_vocab * vocab;
+ llama_batch batch;
+
+ int n_threads = 1;
+ llama_pos n_past = 0;
+
+ gemma3_context(common_params & params) : llama_init(common_init_from_params(params)) {
+ model = llama_init.model.get();
+ lctx = llama_init.context.get();
+ vocab = llama_model_get_vocab(model);
+ n_threads = params.cpuparams.n_threads;
+ batch = llama_batch_init(params.n_batch, 0, 1);
+ init_clip_model(params);
+ }
+
+ void init_clip_model(common_params & params) {
+ const char * clip_path = params.mmproj.c_str();
+ ctx_clip = clip_model_load(clip_path, params.verbosity > 1);
+ }
+
+ ~gemma3_context() {
+ clip_free(ctx_clip);
+ }
+};
+
+struct decode_embd_batch {
+ std::vector<llama_pos> pos;
+ std::vector<int32_t> n_seq_id;
+ std::vector<llama_seq_id> seq_id_0;
+ std::vector<llama_seq_id *> seq_ids;
+ std::vector<int8_t> logits;
+ llama_batch batch;
+ decode_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
+ pos .resize(n_tokens);
+ n_seq_id.resize(n_tokens);
+ seq_ids .resize(n_tokens + 1);
+ logits .resize(n_tokens);
+ seq_id_0.resize(1);
+ seq_id_0[0] = seq_id;
+ seq_ids [n_tokens] = nullptr;
+ batch = {
+ /*n_tokens =*/ n_tokens,
+ /*tokens =*/ nullptr,
+ /*embd =*/ embd,
+ /*pos =*/ pos.data(),
+ /*n_seq_id =*/ n_seq_id.data(),
+ /*seq_id =*/ seq_ids.data(),
+ /*logits =*/ logits.data(),
+ };
+ for (int i = 0; i < n_tokens; i++) {
+ batch.pos [i] = pos_0 + i;
+ batch.n_seq_id[i] = 1;
+ batch.seq_id [i] = seq_id_0.data();
+ batch.logits [i] = false;
+ }
+ }
+};
+
+static int eval_text(gemma3_context & ctx, std::string input, bool logits_last = false) {
+ llama_tokens tokens = common_tokenize(ctx.lctx, input, false, true);
+ common_batch_clear(ctx.batch);
+ for (llama_token & t : tokens) {
+ common_batch_add(ctx.batch, t, ctx.n_past++, {0}, false);
+ }
+ if (logits_last) {
+ ctx.batch.logits[ctx.batch.n_tokens - 1] = true;
+ }
+ // LOG("eval_text (n_tokens = %d): %s\n", (int)tokens.size(), input.c_str());
+ if (llama_decode(ctx.lctx, ctx.batch)) {
+ LOG_ERR("Failed to decode text\n");
+ return 1;
+ }
+ return 0;
+}
+
+static int eval_image(gemma3_context & ctx, std::string & fname) {
+ std::vector<float> image_embd_v;
+ int n_embd = llama_model_n_embd(ctx.model);
+ int n_tokens = 256;
+ image_embd_v.resize(n_tokens * n_embd);
+
+ bool ok;
+ struct clip_image_u8 * img_u8 = clip_image_u8_init();
+ ok = clip_image_load_from_file(fname.c_str(), img_u8);
+ if (!ok) {
+ LOG_ERR("Unable to load image %s\n", fname.c_str());
+ clip_image_u8_free(img_u8);
+ return 2; // non-fatal error
+ }
+
+ clip_image_f32_batch batch_f32;
+ ok = clip_image_preprocess(ctx.ctx_clip, img_u8, &batch_f32);
+ if (!ok) {
+ LOG_ERR("Unable to preprocess image\n");
+ clip_image_f32_batch_free(&batch_f32);
+ clip_image_u8_free(img_u8);
+ return 1;
+ }
+
+ int64_t t0 = ggml_time_ms();
+ LOG("Encoding image %s\n", fname.c_str());
+ ok = clip_image_batch_encode(ctx.ctx_clip, ctx.n_threads, &batch_f32, image_embd_v.data());
+ if (!ok) {
+ LOG_ERR("Unable to encode image\n");
+ clip_image_f32_batch_free(&batch_f32);
+ clip_image_u8_free(img_u8);
+ return 1;
+ }
+ LOG("Image encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
+
+ clip_image_f32_batch_free(&batch_f32);
+ clip_image_u8_free(img_u8);
+
+ // decode image embeddings
+ int64_t t1 = ggml_time_ms();
+ eval_text(ctx, "<start_of_image>");
+ llama_set_causal_attn(ctx.lctx, false);
+ decode_embd_batch batch_img(image_embd_v.data(), n_tokens, ctx.n_past, 0);
+ if (llama_decode(ctx.lctx, batch_img.batch)) {
+ LOG_ERR("failed to decode image\n");
+ return 1;
+ }
+ ctx.n_past += n_tokens;
+ llama_set_causal_attn(ctx.lctx, true);
+ eval_text(ctx, "<end_of_image>");
+ LOG("Image decoded in %" PRId64 " ms\n", ggml_time_ms() - t1);
+ return 0;
+}
+
+static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_predict) {
+ for (int i = 0; i < n_predict; i++) {
+ if (i > n_predict || !g_is_generating) {
+ printf("\n");
+ break;
+ }
+
+ llama_token token_id = common_sampler_sample(smpl, ctx.lctx, -1);
+ common_sampler_accept(smpl, token_id, true);
+
+ if (llama_vocab_is_eog(ctx.vocab, token_id)) {
+ printf("\n");
+ break; // end of generation
+ }
+
+ printf("%s", common_token_to_piece(ctx.lctx, token_id).c_str());
+ fflush(stdout);
+
+ // eval the token
+ common_batch_clear(ctx.batch);
+ common_batch_add(ctx.batch, token_id, ctx.n_past++, {0}, true);
+ if (llama_decode(ctx.lctx, ctx.batch)) {
+ LOG_ERR("failed to decode token\n");
+ return 1;
+ }
+ }
+ return 0;
+}
+
+int main(int argc, char ** argv) {
+ ggml_time_init();
+
+ common_params params;
+ params.sampling.temp = 0.2; // lower temp by default for better quality
+
+ if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LLAVA, show_additional_info)) {
+ return 1;
+ }
+
+ common_init();
+
+ if (params.mmproj.empty()) {
+ show_additional_info(argc, argv);
+ return 1;
+ }
+
+ gemma3_context ctx(params);
+ printf("%s: %s\n", __func__, params.model.c_str());
+
+ bool is_single_turn = !params.prompt.empty() && !params.image.empty();
+
+ struct common_sampler * smpl = common_sampler_init(ctx.model, params.sampling);
+ int n_predict = params.n_predict < 0 ? INT_MAX : params.n_predict;
+
+ // ctrl+C handling
+ {
+#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
+ struct sigaction sigint_action;
+ sigint_action.sa_handler = sigint_handler;
+ sigemptyset (&sigint_action.sa_mask);
+ sigint_action.sa_flags = 0;
+ sigaction(SIGINT, &sigint_action, NULL);
+#elif defined (_WIN32)
+ auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
+ return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false;
+ };
+ SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
+#endif
+ }
+
+ if (eval_text(ctx, "<bos>")) {
+ return 1;
+ }
+
+ if (is_single_turn) {
+ g_is_generating = true;
+ if (eval_text(ctx, "<start_of_turn>user\n")) {
+ return 1;
+ }
+ for (auto & fname : params.image) {
+ if (eval_image(ctx, fname)) {
+ return 1;
+ }
+ }
+ if (eval_text(ctx, params.prompt + "<end_of_turn><start_of_turn>model\n", true)) {
+ return 1;
+ }
+ if (generate_response(ctx, smpl, n_predict)) {
+ return 1;
+ }
+
+ } else {
+ LOG("\n Running in chat mode, available commands:");
+ LOG("\n /image <path> load an image");
+ LOG("\n /clear clear the chat history");
+ LOG("\n /quit or /exit exit the program");
+ LOG("\n");
+
+ if (eval_text(ctx, "<start_of_turn>user\n")) {
+ return 1;
+ }
+
+ while (true) {
+ g_is_generating = false;
+ LOG("\n> ");
+ console::set_display(console::user_input);
+ std::string line;
+ console::readline(line, false);
+ console::set_display(console::reset);
+ line = string_strip(line);
+ if (line.empty()) {
+ continue;
+ }
+ if (line == "/quit" || line == "/exit") {
+ break;
+ }
+ if (line == "/clear") {
+ ctx.n_past = 0;
+ llama_kv_cache_seq_rm(ctx.lctx, 0, 1, -1); // keep BOS
+ LOG("Chat history cleared\n\n");
+ continue;
+ }
+ g_is_generating = true;
+ if (line.find("/image") == 0) {
+ std::string image = line.substr(7);
+ int res = eval_image(ctx, image);
+ if (res == 2) {
+ continue; // image not found
+ }
+ if (res) {
+ return 1;
+ }
+ continue;
+ }
+ if (eval_text(ctx, line + "<end_of_turn><start_of_turn>model\n", true)) {
+ return 1;
+ }
+ if (generate_response(ctx, smpl, n_predict)) {
+ return 1;
+ }
+ if (eval_text(ctx, "<end_of_turn><start_of_turn>user\n")) {
+ return 1;
+ }
+ }
+ }
+
+ return 0;
+}
--- /dev/null
+import gguf
+import argparse
+import logging
+import sys
+import torch
+import json
+import os
+import numpy as np
+from typing import cast, ContextManager, Any, Iterator
+from pathlib import Path
+from torch import Tensor
+
+logger = logging.getLogger("gemma3-mmproj")
+
+
+# (copied from convert_hf_to_gguf.py)
+# tree of lazy tensors
+class LazyTorchTensor(gguf.LazyBase):
+ _tensor_type = torch.Tensor
+ # to keep the type-checker happy
+ dtype: torch.dtype
+ shape: torch.Size
+
+ # only used when converting a torch.Tensor to a np.ndarray
+ _dtype_map: dict[torch.dtype, type] = {
+ torch.float16: np.float16,
+ torch.float32: np.float32,
+ }
+
+ # used for safetensors slices
+ # ref: https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L1046
+ # TODO: uncomment U64, U32, and U16, ref: https://github.com/pytorch/pytorch/issues/58734
+ _dtype_str_map: dict[str, torch.dtype] = {
+ "F64": torch.float64,
+ "F32": torch.float32,
+ "BF16": torch.bfloat16,
+ "F16": torch.float16,
+ # "U64": torch.uint64,
+ "I64": torch.int64,
+ # "U32": torch.uint32,
+ "I32": torch.int32,
+ # "U16": torch.uint16,
+ "I16": torch.int16,
+ "U8": torch.uint8,
+ "I8": torch.int8,
+ "BOOL": torch.bool,
+ "F8_E4M3": torch.float8_e4m3fn,
+ "F8_E5M2": torch.float8_e5m2,
+ }
+
+ def numpy(self) -> gguf.LazyNumpyTensor:
+ dtype = self._dtype_map[self.dtype]
+ return gguf.LazyNumpyTensor(
+ meta=gguf.LazyNumpyTensor.meta_with_dtype_and_shape(dtype, self.shape),
+ args=(self,),
+ func=(lambda s: s.numpy())
+ )
+
+ @classmethod
+ def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: tuple[int, ...]) -> Tensor:
+ return torch.empty(size=shape, dtype=dtype, device="meta")
+
+ @classmethod
+ def from_safetensors_slice(cls, st_slice: Any) -> Tensor:
+ dtype = cls._dtype_str_map[st_slice.get_dtype()]
+ shape: tuple[int, ...] = tuple(st_slice.get_shape())
+ lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:])
+ return cast(torch.Tensor, lazy)
+
+ @classmethod
+ def __torch_function__(cls, func, types, args=(), kwargs=None):
+ del types # unused
+
+ if kwargs is None:
+ kwargs = {}
+
+ if func is torch.Tensor.numpy:
+ return args[0].numpy()
+
+ return cls._wrap_fn(func)(*args, **kwargs)
+
+
+class Gemma3VisionTower:
+ hparams: dict
+ gguf_writer: gguf.GGUFWriter
+ fname_out: Path
+ ftype: gguf.LlamaFileType
+
+ @staticmethod
+ def load_hparams(dir_model: Path):
+ with open(dir_model / "config.json", "r", encoding="utf-8") as f:
+ return json.load(f)
+
+ @staticmethod
+ def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]:
+ part_names: list[str] = []
+ for filename in os.listdir(dir_model):
+ if filename.startswith(prefix) and filename.endswith(suffix):
+ part_names.append(filename)
+ part_names.sort()
+ return part_names
+
+ def __init__(self,
+ dir_model: Path,
+ fname_out: Path,
+ ftype: gguf.LlamaFileType,
+ is_big_endian: bool,):
+ hparams = Gemma3VisionTower.load_hparams(dir_model)
+ self.hparams = hparams
+ self.fname_out = fname_out
+ self.ftype = ftype
+ endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
+ self.gguf_writer = gguf.GGUFWriter(path=None, arch="clip", endianess=endianess)
+
+ text_config = hparams["text_config"]
+ vision_config = hparams["vision_config"]
+
+ assert hparams["architectures"][0] == "Gemma3ForConditionalGeneration"
+ assert text_config is not None
+ assert vision_config is not None
+
+ self.gguf_writer.add_string ("clip.projector_type", "gemma3")
+ self.gguf_writer.add_bool ("clip.has_text_encoder", False)
+ self.gguf_writer.add_bool ("clip.has_vision_encoder", True)
+ self.gguf_writer.add_bool ("clip.has_llava_projector", False) # legacy
+ self.gguf_writer.add_uint32 ("clip.vision.image_size", vision_config["image_size"])
+ self.gguf_writer.add_uint32 ("clip.vision.patch_size", vision_config["patch_size"])
+ self.gguf_writer.add_uint32 ("clip.vision.embedding_length", vision_config["hidden_size"])
+ self.gguf_writer.add_uint32 ("clip.vision.feed_forward_length", vision_config["intermediate_size"])
+ self.gguf_writer.add_uint32 ("clip.vision.projection_dim", text_config["hidden_size"])
+ self.gguf_writer.add_uint32 ("clip.vision.block_count", vision_config["num_hidden_layers"])
+ self.gguf_writer.add_uint32 ("clip.vision.attention.head_count", vision_config["num_attention_heads"])
+ self.gguf_writer.add_float32("clip.vision.attention.layer_norm_epsilon", vision_config.get("layer_norm_eps", 1e-6))
+ # default values taken from HF tranformers code
+ self.gguf_writer.add_array ("clip.vision.image_mean", [0.5, 0.5, 0.5])
+ self.gguf_writer.add_array ("clip.vision.image_std", [0.5, 0.5, 0.5])
+ self.gguf_writer.add_bool ("clip.use_gelu", True)
+
+ # load tensors
+ for name, data_torch in self.get_tensors(dir_model):
+ # convert any unsupported data types to float32
+ if data_torch.dtype not in (torch.float16, torch.float32):
+ data_torch = data_torch.to(torch.float32)
+ self.add_tensor(name, data_torch)
+
+ def get_tensors(self, dir_model: Path) -> Iterator[tuple[str, Tensor]]:
+ part_names = Gemma3VisionTower.get_model_part_names(dir_model, "model", ".safetensors")
+ tensor_names_from_parts: set[str] = set()
+ for part_name in part_names:
+ logger.info(f"gguf: loading model part '{part_name}'")
+ from safetensors import safe_open
+ ctx = cast(ContextManager[Any], safe_open(dir_model / part_name, framework="pt", device="cpu"))
+ with ctx as model_part:
+ tensor_names_from_parts.update(model_part.keys())
+
+ for name in model_part.keys():
+ data = model_part.get_slice(name)
+ data = LazyTorchTensor.from_safetensors_slice(data)
+ yield name, data
+
+ def add_tensor(self, name: str, data_torch: Tensor):
+ is_1d = len(data_torch.shape) == 1
+ is_embd = ".embeddings." in name
+ old_dtype = data_torch.dtype
+ can_quantize = not is_1d and not is_embd
+ data_qtype = gguf.GGMLQuantizationType.F32
+
+ # this is to support old checkpoint
+ # TODO: remove this when we have the final model
+ name = name.replace("vision_model.vision_model.", "vision_tower.vision_model.")
+ name = name.replace("multimodal_projector.", "multi_modal_projector.")
+
+ # filter only vision tensors
+ if not name.startswith("vision_tower.vision_model.") and not name.startswith("multi_modal_projector."):
+ return
+ # prefix
+ name = name.replace("vision_tower.vision_model.encoder.layers.", "v.blk.")
+ name = name.replace("vision_tower.vision_model.", "v.")
+ # projector and input embd
+ name = name.replace(".embeddings.patch_embedding.", ".patch_embd.")
+ name = name.replace(".embeddings.position_embedding.", ".position_embd.")
+ name = name.replace(
+ "multi_modal_projector.mm_input_projection_weight",
+ "mm.input_projection.weight"
+ )
+ name = name.replace(
+ "multi_modal_projector.mm_soft_emb_norm.weight",
+ "mm.soft_emb_norm.weight"
+ )
+ name = name.replace("post_layernorm.", "post_ln.")
+ # each block
+ name = name.replace(".self_attn.k_proj.", ".attn_k.")
+ name = name.replace(".self_attn.v_proj.", ".attn_v.")
+ name = name.replace(".self_attn.q_proj.", ".attn_q.")
+ name = name.replace(".self_attn.out_proj.", ".attn_out.")
+ name = name.replace(".layer_norm1.", ".ln1.")
+ name = name.replace(".layer_norm2.", ".ln2.")
+ name = name.replace(".mlp.fc1.", ".ffn_down.")
+ name = name.replace(".mlp.fc2.", ".ffn_up.")
+
+ if can_quantize:
+ if self.ftype == gguf.LlamaFileType.ALL_F32:
+ data_qtype = gguf.GGMLQuantizationType.F32
+ elif self.ftype == gguf.LlamaFileType.MOSTLY_F16:
+ data_qtype = gguf.GGMLQuantizationType.F16
+ elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
+ data_qtype = gguf.GGMLQuantizationType.BF16
+ elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0:
+ data_qtype = gguf.GGMLQuantizationType.Q8_0
+ else:
+ raise ValueError(f"Unsupported file type: {self.ftype}")
+
+ # corrent norm value ; only this "soft_emb_norm" need to be corrected as it's part of Gemma projector
+ # the other norm values are part of SigLIP model, and they are already correct
+ # ref code: Gemma3RMSNorm
+ if "soft_emb_norm.weight" in name:
+ logger.info(f"Correcting norm value for '{name}'")
+ data_torch = data_torch + 1
+
+ data = data_torch.numpy()
+
+ try:
+ data = gguf.quants.quantize(data, data_qtype)
+ except Exception as e:
+ logger.error(f"Error quantizing tensor '{name}': {e}, fallback to F16")
+ data_qtype = gguf.GGMLQuantizationType.F16
+ data = gguf.quants.quantize(data, data_qtype)
+
+ # reverse shape to make it similar to the internal ggml dimension order
+ shape_str = f"{{{', '.join(str(n) for n in reversed(data_torch.shape))}}}"
+ logger.info(f"{f'%-32s' % f'{name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")
+
+ self.gguf_writer.add_tensor(name, data, raw_dtype=data_qtype)
+
+ def write(self):
+ self.gguf_writer.write_header_to_file(path=self.fname_out)
+ self.gguf_writer.write_kv_data_to_file()
+ self.gguf_writer.write_tensors_to_file(progress=True)
+ self.gguf_writer.close()
+
+def parse_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(
+ description="Convert Gemma 3 vision tower safetensors to GGUF format",)
+ parser.add_argument(
+ "--outfile", type=Path, default="mmproj.gguf",
+ help="path to write to",
+ )
+ parser.add_argument(
+ "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0"], default="f16",
+ help="output format",
+ )
+ parser.add_argument(
+ "--bigendian", action="store_true",
+ help="model is executed on big endian machine",
+ )
+ parser.add_argument(
+ "model", type=Path,
+ help="directory containing model file",
+ nargs="?",
+ )
+ parser.add_argument(
+ "--verbose", action="store_true",
+ help="increase output verbosity",
+ )
+
+ args = parser.parse_args()
+ if args.model is None:
+ parser.error("the following arguments are required: model")
+ return args
+
+
+def main() -> None:
+ args = parse_args()
+
+ if args.verbose:
+ logging.basicConfig(level=logging.DEBUG)
+ else:
+ logging.basicConfig(level=logging.INFO)
+
+ dir_model = args.model
+
+ if not dir_model.is_dir():
+ logger.error(f'Error: {args.model} is not a directory')
+ sys.exit(1)
+
+ ftype_map: dict[str, gguf.LlamaFileType] = {
+ "f32": gguf.LlamaFileType.ALL_F32,
+ "f16": gguf.LlamaFileType.MOSTLY_F16,
+ "bf16": gguf.LlamaFileType.MOSTLY_BF16,
+ "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0,
+ }
+
+ logger.info(f"Loading model: {dir_model.name}")
+
+ with torch.inference_mode():
+ gemma3_vision_tower = Gemma3VisionTower(
+ dir_model=dir_model,
+ fname_out=args.outfile,
+ ftype=ftype_map[args.outtype],
+ is_big_endian=args.bigendian,
+ )
+ gemma3_vision_tower.write()
+
+
+if __name__ == '__main__':
+ main()
+
MINICPM3 = auto()
GEMMA = auto()
GEMMA2 = auto()
+ GEMMA3 = auto()
STARCODER2 = auto()
RWKV6 = auto()
RWKV6QWEN2 = auto()
MODEL_ARCH.MINICPM3: "minicpm3",
MODEL_ARCH.GEMMA: "gemma",
MODEL_ARCH.GEMMA2: "gemma2",
+ MODEL_ARCH.GEMMA3: "gemma3",
MODEL_ARCH.STARCODER2: "starcoder2",
MODEL_ARCH.RWKV6: "rwkv6",
MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2",
MODEL_TENSOR.FFN_PRE_NORM,
MODEL_TENSOR.FFN_POST_NORM,
],
+ MODEL_ARCH.GEMMA3: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_POST_NORM,
+ MODEL_TENSOR.FFN_PRE_NORM,
+ MODEL_TENSOR.FFN_POST_NORM,
+ ],
MODEL_ARCH.STARCODER2: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
{ LLM_ARCH_MINICPM3, "minicpm3" },
{ LLM_ARCH_GEMMA, "gemma" },
{ LLM_ARCH_GEMMA2, "gemma2" },
+ { LLM_ARCH_GEMMA3, "gemma3" },
{ LLM_ARCH_STARCODER2, "starcoder2" },
{ LLM_ARCH_MAMBA, "mamba" },
{ LLM_ARCH_XVERSE, "xverse" },
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
},
},
+ {
+ LLM_ARCH_GEMMA3,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
+ { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
+ { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
+ { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
+ { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
+ },
+ },
{
LLM_ARCH_STARCODER2,
{
LLM_ARCH_MINICPM3,
LLM_ARCH_GEMMA,
LLM_ARCH_GEMMA2,
+ LLM_ARCH_GEMMA3,
LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA,
LLM_ARCH_XVERSE,
#include <algorithm>
#include <cassert>
#include <cstring>
+#include <cmath>
#include <functional>
#include <map>
#include <sstream>
default: type = LLM_TYPE_UNKNOWN;
}
} break;
+ case LLM_ARCH_GEMMA3:
+ {
+ ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+
+ switch (hparams.n_layer) {
+ case 26: type = LLM_TYPE_1B; break;
+ case 34: type = LLM_TYPE_4B; break;
+ case 48: type = LLM_TYPE_12B; break;
+ case 62: type = LLM_TYPE_27B; break;
+ default: type = LLM_TYPE_UNKNOWN;
+ }
+
+ hparams.f_attention_scale = type == LLM_TYPE_27B
+ ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
+ : 1.0f / std::sqrt(float(hparams.n_embd_head_k));
+ } break;
case LLM_ARCH_STARCODER2:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
+ layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+ layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
+ layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
+ }
+ } break;
+ case LLM_ARCH_GEMMA3:
+ {
+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+ // output
+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
+
+ for (int i = 0; i < n_layer; ++i) {
+ auto & layer = layers[i];
+
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
+
+ layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
+ layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
+ layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
+
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv);
LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias);
LLAMA_LOG_INFO("%s: f_logit_scale = %.1e\n", __func__, hparams.f_logit_scale);
+ LLAMA_LOG_INFO("%s: f_attn_scale = %.1e\n", __func__, hparams.f_attention_scale);
LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer).c_str());
LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert);
LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used);
case LLM_ARCH_PHIMOE:
case LLM_ARCH_GEMMA:
case LLM_ARCH_GEMMA2:
+ case LLM_ARCH_GEMMA3:
case LLM_ARCH_STARCODER2:
case LLM_ARCH_OPENELM:
case LLM_ARCH_GPTNEOX:
return gf;
}
+ struct ggml_cgraph * build_gemma3() {
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
+
+ const int64_t n_embd_head_k = hparams.n_embd_head_k;
+
+ struct ggml_tensor * cur;
+ struct ggml_tensor * inpL;
+
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
+
+ // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
+ if (ubatch.token) {
+ inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
+ cb(inpL, "inp_scaled", -1);
+ }
+
+ // inp_pos - contains the positions
+ struct ggml_tensor * inp_pos = build_inp_pos();
+
+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+ // gemma3 requires different mask for layers using sliding window (SWA)
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask(true);
+ struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa(true);
+
+ // "5-to-1 interleaved attention"
+ // 5 layers of local attention followed by 1 layer of global attention
+ static const int sliding_window_pattern = 6;
+
+ for (int il = 0; il < n_layer; ++il) {
+ const bool is_sliding = (il + 1) % sliding_window_pattern;
+ const float freq_base_l = is_sliding ? 10000.0f : freq_base;
+ const float freq_scale_l = is_sliding ? 1.0f : freq_scale;
+ struct ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask;
+
+ // norm
+ cur = llm_build_norm(ctx0, inpL, hparams,
+ model.layers[il].attn_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "attn_norm", il);
+
+ // self-attention
+ {
+ // compute Q and K and RoPE them
+ struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
+ cb(Qcur, "Qcur", il);
+
+ struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
+ cb(Kcur, "Kcur", il);
+
+ struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
+ cb(Vcur, "Vcur", il);
+
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens);
+ Qcur = llm_build_norm(ctx0, Qcur, hparams,
+ model.layers[il].attn_q_norm,
+ NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(Qcur, "Qcur_normed", il);
+
+ Qcur = ggml_rope_ext(
+ ctx0, Qcur, inp_pos, nullptr,
+ n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+ ext_factor, attn_factor, beta_fast, beta_slow);
+ cb(Qcur, "Qcur", il);
+
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens);
+ Kcur = llm_build_norm(ctx0, Kcur, hparams,
+ model.layers[il].attn_k_norm,
+ NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(Kcur, "Kcur_normed", il);
+
+ Kcur = ggml_rope_ext(
+ ctx0, Kcur, inp_pos, nullptr,
+ n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+ ext_factor, attn_factor, beta_fast, beta_slow);
+ cb(Kcur, "Kcur", il);
+
+ cur = llm_build_kv(ctx0, lctx, kv_self, gf,
+ model.layers[il].wo, NULL,
+ Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, hparams.f_attention_scale, cb, il);
+ }
+
+ cur = llm_build_norm(ctx0, cur, hparams,
+ model.layers[il].attn_post_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "attn_post_norm", il);
+
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ }
+
+ struct ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
+ cb(sa_out, "sa_out", il);
+
+ cur = llm_build_norm(ctx0, sa_out, hparams,
+ model.layers[il].ffn_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "ffn_norm", il);
+
+ // feed-forward network
+ {
+ cur = llm_build_ffn(ctx0, lctx, cur,
+ model.layers[il].ffn_up, NULL, NULL,
+ model.layers[il].ffn_gate, NULL, NULL,
+ model.layers[il].ffn_down, NULL, NULL,
+ NULL,
+ LLM_FFN_GELU, LLM_FFN_PAR, cb, il);
+ cb(cur, "ffn_out", il);
+ }
+
+ cur = llm_build_norm(ctx0, cur, hparams,
+ model.layers[il].ffn_post_norm, NULL,
+ LLM_NORM_RMS, cb, -1);
+ cb(cur, "ffn_post_norm", -1);
+
+ cur = ggml_add(ctx0, cur, sa_out);
+ cur = lctx.cvec.apply_to(ctx0, cur, il);
+ cb(cur, "l_out", il);
+
+ // input for next layer
+ inpL = cur;
+ }
+
+ cur = inpL;
+
+ cur = llm_build_norm(ctx0, cur, hparams,
+ model.output_norm, NULL,
+ LLM_NORM_RMS, cb, -1);
+ cb(cur, "result_norm", -1);
+
+ // lm_head
+ cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+
+ cb(cur, "result_output", -1);
+
+ ggml_build_forward_expand(gf, cur);
+
+ return gf;
+ }
struct ggml_cgraph * build_starcoder2() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
{
result = llm.build_gemma2();
} break;
+ case LLM_ARCH_GEMMA3:
+ {
+ result = llm.build_gemma3();
+ } break;
case LLM_ARCH_STARCODER2:
{
result = llm.build_starcoder2();