]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : Add Gemma 3 support (+ experimental vision capability) (#12343)
authorXuan-Son Nguyen <redacted>
Wed, 12 Mar 2025 08:30:24 +0000 (09:30 +0100)
committerGitHub <redacted>
Wed, 12 Mar 2025 08:30:24 +0000 (09:30 +0100)
* llama : Add Gemma 3 text-only support

* fix python coding style

* fix compile on ubuntu

* python: fix style

* fix ubuntu compile

* fix build on ubuntu (again)

* fix ubuntu build, finally

* clip : Experimental support for Gemma 3 vision (#12344)

* clip : Experimental support for Gemma 3 vision

* fix build

* PRId64

convert_hf_to_gguf.py
examples/llava/CMakeLists.txt
examples/llava/README-gemma3.md [new file with mode: 0644]
examples/llava/clip.cpp
examples/llava/gemma3-cli.cpp [new file with mode: 0644]
examples/llava/gemma3_convert_encoder_to_gguf.py [new file with mode: 0644]
gguf-py/gguf/constants.py
src/llama-arch.cpp
src/llama-arch.h
src/llama-model.cpp
src/llama.cpp

index 6358a94e9b55f8d21f564c13427a80ef07385aef..b5d95bd5639f3a294a0184de69a05d02f48d69ef 100755 (executable)
@@ -861,6 +861,9 @@ class Model:
                 for token_id, token_data in added_tokens_decoder.items():
                     token_id = int(token_id)
                     token: str = token_data["content"]
+                    if token_id >= vocab_size:
+                        logger.warning(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}')
+                        continue
                     if toktypes[token_id] != SentencePieceTokenTypes.UNUSED:
                         if tokens[token_id] != token.encode("utf-8"):
                             logger.warning(f'replacing token {token_id}: {tokens[token_id].decode("utf-8")!r} -> {token!r}')
@@ -3322,6 +3325,83 @@ class Gemma2Model(Model):
         return [(self.map_tensor_name(name), data_torch)]
 
 
+@Model.register("Gemma3ForCausalLM", "Gemma3ForConditionalGeneration")
+class Gemma3Model(Model):
+    model_arch = gguf.MODEL_ARCH.GEMMA3
+    has_vision: bool = False
+
+    # we need to merge the text_config into the root level of hparams
+    def __init__(self, *args, **kwargs):
+        hparams = Model.load_hparams(kwargs["dir_model"])
+        if "text_config" in hparams:
+            hparams = {**hparams, **hparams["text_config"]}
+            kwargs["hparams"] = hparams
+        super().__init__(*args, **kwargs)
+        if "vision_config" in hparams:
+            logger.info("Has vision encoder, but it will be ignored")
+            self.has_vision = True
+
+    def write(self):
+        super().write()
+        if self.has_vision:
+            logger.info("NOTE: this script only convert the language model to GGUF")
+            logger.info("      for the vision model, please use gemma3_convert_encoder_to_gguf.py")
+
+    def set_vocab(self):
+        self._set_vocab_sentencepiece()
+
+        self.gguf_writer.add_add_space_prefix(False)
+
+    def set_gguf_parameters(self):
+        hparams = self.hparams
+        block_count = hparams["num_hidden_layers"]
+
+        # some default values are not specified in the hparams
+        self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 131072))
+        self.gguf_writer.add_embedding_length(hparams["hidden_size"])
+        self.gguf_writer.add_block_count(block_count)
+        self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
+        self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 8))
+        self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-6))
+        self.gguf_writer.add_key_length(hparams.get("head_dim", 256))
+        self.gguf_writer.add_value_length(hparams.get("head_dim", 256))
+        self.gguf_writer.add_file_type(self.ftype)
+        self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1_000_000.0)) # for global layers
+        # both attn_logit_softcapping and final_logit_softcapping are removed in Gemma3
+        assert hparams.get("attn_logit_softcapping") is None
+        assert hparams.get("final_logit_softcapping") is None
+        self.gguf_writer.add_sliding_window(hparams["sliding_window"])
+        self.gguf_writer.add_head_count_kv(hparams.get("num_key_value_heads", 4))
+        if hparams.get("rope_scaling") is not None:
+            assert hparams["rope_scaling"]["rope_type"] == "linear"
+            # important: this rope_scaling is only applied for global layers, and not used by 1B model
+            self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
+            self.gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"])
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        del bid  # unused
+
+        if name.startswith("language_model."):
+            name = name.replace("language_model.", "")
+        elif name.startswith("multi_modal_projector.") or name.startswith("vision_tower.") \
+                or name.startswith("multimodal_projector.") or name.startswith("vision_model."): # this is for old HF model, should be removed later
+            # ignore vision tensors
+            return []
+
+        # remove OOV (out-of-vocabulary) rows in token_embd
+        if "embed_tokens.weight" in name:
+            vocab = self._create_vocab_sentencepiece()
+            tokens = vocab[0]
+            data_torch = data_torch[:len(tokens)]
+
+        # ref code in Gemma3RMSNorm
+        # output = output * (1.0 + self.weight.float())
+        if name.endswith("norm.weight"):
+            data_torch = data_torch + 1
+
+        return [(self.map_tensor_name(name), data_torch)]
+
+
 @Model.register("Starcoder2ForCausalLM")
 class StarCoder2Model(Model):
     model_arch = gguf.MODEL_ARCH.STARCODER2
index 319effd199aa48ba1a23e93ece247222d8c11a0f..f275ce1ccd0037c84d7eeb729d3dbd48d2e62f95 100644 (file)
@@ -51,6 +51,13 @@ install(TARGETS ${TARGET} RUNTIME)
 target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
 target_compile_features(${TARGET} PRIVATE cxx_std_17)
 
+set(TARGET llama-gemma3-cli)
+add_executable(${TARGET} gemma3-cli.cpp)
+set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-gemma3-cli)
+install(TARGETS ${TARGET} RUNTIME)
+target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
+target_compile_features(${TARGET} PRIVATE cxx_std_17)
+
 set(TARGET llama-llava-clip-quantize-cli)
 add_executable(${TARGET} clip-quantize-cli.cpp)
 set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-llava-clip-quantize-cli)
diff --git a/examples/llava/README-gemma3.md b/examples/llava/README-gemma3.md
new file mode 100644 (file)
index 0000000..20bf73f
--- /dev/null
@@ -0,0 +1,30 @@
+# Gemma 3 vision
+
+> [!IMPORTANT]
+>
+> This is very experimental, only used for demo purpose.
+
+## How to get mmproj.gguf?
+
+```bash
+cd gemma-3-4b-it
+python ../llama.cpp/examples/llava/gemma3_convert_encoder_to_gguf.py .
+
+# output file is mmproj.gguf
+```
+
+## How to run it?
+
+What you need:
+- The text model GGUF, can be converted using `convert_hf_to_gguf.py`
+- The mmproj file from step above
+- An image file
+
+```bash
+# build
+cmake -B build
+cmake --build build --target llama-gemma3-cli
+
+# run it
+./build/bin/llama-gemma3-cli -m {text_model}.gguf --mmproj mmproj.gguf --image your_image.jpg
+```
index 7f892beb6edb1cd1b7274760d8cbcc61943a9d44..a1f050e39a0944e8c18ae7172635e2a2f7630300 100644 (file)
@@ -136,6 +136,8 @@ static std::string format(const char * fmt, ...) {
 #define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s"
 #define TN_MVLM_PROJ_PEG   "mm.model.peg.%d.%s"
 #define TN_IMAGE_NEWLINE   "model.image_newline"
+#define TN_MM_INP_PROJ     "mm.input_projection.weight" // gemma3
+#define TN_MM_SOFT_EMB_N   "mm.soft_emb_norm.weight"    // gemma3
 
 #define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k"
 #define TN_MINICPMV_QUERY "resampler.query"
@@ -162,6 +164,7 @@ enum projector_type {
     PROJECTOR_TYPE_RESAMPLER,
     PROJECTOR_TYPE_GLM_EDGE,
     PROJECTOR_TYPE_MERGER,
+    PROJECTOR_TYPE_GEMMA3,
     PROJECTOR_TYPE_UNKNOWN,
 };
 
@@ -172,6 +175,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
     { PROJECTOR_TYPE_RESAMPLER, "resampler"},
     { PROJECTOR_TYPE_GLM_EDGE, "adapter"},
     { PROJECTOR_TYPE_MERGER, "qwen2vl_merger"},
+    { PROJECTOR_TYPE_GEMMA3, "gemma3"},
 };
 
 
@@ -298,7 +302,7 @@ static projector_type clip_projector_type_from_string(const std::string & name)
             return kv.first;
         }
     }
-    return PROJECTOR_TYPE_UNKNOWN;
+    throw std::runtime_error(format("Unknown projector type: %s", name.c_str()));
 }
 
 #ifdef CLIP_DEBUG_FUNCTIONS
@@ -555,6 +559,10 @@ struct clip_vision_model {
     struct ggml_tensor * mm_model_ln_kv_b;
     struct ggml_tensor * mm_model_ln_post_w;
     struct ggml_tensor * mm_model_ln_post_b;
+
+    // gemma3
+    struct ggml_tensor * mm_input_proj_w;
+    struct ggml_tensor * mm_soft_emb_norm_w;
 };
 
 struct clip_ctx {
@@ -569,7 +577,7 @@ struct clip_ctx {
     struct clip_vision_model vision_model;
     projector_type proj_type = PROJECTOR_TYPE_MLP;
 
-    int32_t max_feature_layer;
+    int32_t max_feature_layer; // unused in newer models like gemma3
     float image_mean[3];
     float image_std[3];
     bool use_gelu = false;
@@ -595,7 +603,7 @@ struct clip_ctx {
 
     ggml_backend_sched_ptr sched;
 
-    struct clip_image_size * load_image_size;
+    struct clip_image_size * load_image_size = nullptr;
 
     clip_ctx(clip_context_params & ctx_params) {
         backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
@@ -631,7 +639,159 @@ struct clip_ctx {
     }
 };
 
-static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false) {
+static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_image_f32_batch * imgs) {
+    const auto & model = ctx->vision_model;
+    const auto & hparams = model.hparams;
+
+    const int image_size = hparams.image_size;
+    int image_size_width  = image_size;
+    int image_size_height = image_size;
+
+    const int patch_size           = hparams.patch_size;
+    const int num_patches          = ((image_size_width / patch_size) * (image_size_height / patch_size));
+    const int hidden_size          = hparams.hidden_size;
+    const int n_head               = hparams.n_head;
+    const int d_head               = hidden_size / n_head;
+    const int n_layer              = hparams.n_layer;
+    const float eps                = hparams.eps;
+
+    GGML_ASSERT(imgs->size == 1); // batch_size == 1
+
+    struct ggml_init_params params = {
+        /*.mem_size   =*/ ctx->buf_compute_meta.size(),
+        /*.mem_buffer =*/ ctx->buf_compute_meta.data(),
+        /*.no_alloc   =*/ true,
+    };
+
+    struct ggml_context * ctx0 = ggml_init(params);
+    struct ggml_cgraph * gf = ggml_new_graph(ctx0);
+
+    // input raw
+    struct ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3);
+    ggml_set_name(inp_raw, "inp_raw");
+    ggml_set_input(inp_raw);
+
+    struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
+    inp = ggml_reshape_2d(ctx0, inp, num_patches, hidden_size);
+    inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
+    inp = ggml_add(ctx0, inp, model.patch_bias);
+
+    // position embeddings
+    struct ggml_tensor * embeddings = ggml_add(ctx0, inp, model.position_embeddings);
+
+    // loop over layers
+    for (int il = 0; il < n_layer; il++) {
+        struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
+
+        // layernorm1
+        {
+            cur = ggml_norm(ctx0, cur, eps);
+            cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_1_w), model.layers[il].ln_1_b);
+        }
+
+        // self-attention
+        {
+
+            struct ggml_tensor * Q =
+                ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b);
+
+            Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_patches);
+            Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
+
+            struct ggml_tensor * K =
+                ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b);
+
+            K = ggml_reshape_3d(ctx0, K, d_head, n_head, num_patches);
+            K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
+
+            struct ggml_tensor * V =
+                ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].v_w, cur), model.layers[il].v_b);
+
+            V = ggml_reshape_3d(ctx0, V, d_head, n_head, num_patches);
+            V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
+
+            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+            KQ = ggml_scale_inplace(ctx0, KQ, 1.0f / sqrtf((float)d_head));
+            KQ = ggml_soft_max_inplace(ctx0, KQ);
+
+            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
+            KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_patches, n_head);
+            KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+
+            cur = ggml_cont_2d(ctx0, KQV, hidden_size, num_patches);
+        }
+
+        // attention output
+        cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].o_w, cur), model.layers[il].o_b);
+
+        // re-add the layer input, e.g., residual
+        cur = ggml_add(ctx0, cur, embeddings);
+
+        embeddings = cur; // embeddings = residual, cur = hidden_states
+
+        // layernorm2
+        {
+            cur = ggml_norm(ctx0, cur, eps);
+            cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b);
+        }
+
+        cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
+        cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
+
+        // siglip uses gelu
+        cur = ggml_gelu(ctx0, cur);
+
+        cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
+        cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b);
+
+        // residual 2
+        cur = ggml_add(ctx0, embeddings, cur);
+
+        embeddings = cur;
+    }
+
+    // post-layernorm
+    if (ctx->has_post_norm) {
+        embeddings = ggml_norm(ctx0, embeddings, eps);
+        ggml_set_name(embeddings, "post_ln");
+
+        embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b);
+    }
+
+    if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
+        const int batch_size = 1;
+        const int mm_tokens_per_image = 256; // default value for gemma3
+        const int tokens_per_side = sqrt(mm_tokens_per_image);
+        const int patches_per_image = sqrt(num_patches);
+        const int kernel_size = patches_per_image / tokens_per_side;
+
+        embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings));
+        embeddings = ggml_reshape_4d(ctx0, embeddings, patches_per_image, patches_per_image, hidden_size, batch_size);
+
+        // doing a pool2d to reduce the number of output tokens to 256
+        embeddings = ggml_pool_2d(ctx0, embeddings, GGML_OP_POOL_AVG, kernel_size, kernel_size, kernel_size, kernel_size, 0, 0);
+        embeddings = ggml_reshape_3d(ctx0, embeddings, embeddings->ne[0] * embeddings->ne[0], hidden_size, batch_size);
+        embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings));
+
+        // apply norm before projection
+        embeddings = ggml_rms_norm(ctx0, embeddings, eps);
+        embeddings = ggml_mul(ctx0, embeddings, model.mm_soft_emb_norm_w);
+
+        // apply projection
+        embeddings = ggml_mul_mat(ctx0,
+            ggml_cont(ctx0, ggml_transpose(ctx0, model.mm_input_proj_w)),
+            embeddings);
+    }
+
+    // build the graph
+    ggml_build_forward_expand(gf, embeddings);
+
+    ggml_free(ctx0);
+
+    return gf;
+}
+
+static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false) {
     if (!ctx->has_vision_encoder) {
         LOG_ERR("This gguf file seems to have no vision encoder\n");
         return nullptr;
@@ -1177,7 +1337,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
         } else {
             GGML_ABORT("fatel error");
         }
-    } else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
+    }
+    else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
         embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size);
 
         embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
@@ -1199,6 +1360,15 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
     return gf;
 }
 
+static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false) {
+    if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
+        return clip_image_build_graph_siglip(ctx, imgs);
+    } else {
+        // TODO: we should have one build_* function per model
+        return clip_image_build_graph_legacy(ctx, imgs, load_image_size, is_inf);
+    }
+}
+
 // read and create ggml_context containing the tensors and their data
 struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
     return clip_init(fname, clip_context_params{
@@ -1358,8 +1528,12 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p
         GGML_ASSERT(new_clip->has_vision_encoder);
         GGML_ASSERT(!new_clip->has_text_encoder);
 
-        idx = get_key_idx(ctx, KEY_USE_GELU);
-        new_clip->use_gelu = gguf_get_val_bool(ctx, idx);
+        try {
+            idx = get_key_idx(ctx, KEY_USE_GELU);
+            new_clip->use_gelu = gguf_get_val_bool(ctx, idx);
+        } catch (std::runtime_error & /*e*/) {
+            new_clip->use_gelu = false;
+        }
 
         try {
             idx = get_key_idx(ctx, KEY_USE_SILU);
@@ -1567,11 +1741,17 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p
         }
 
         try {
-            vision_model.patch_embeddings_0    = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD);
+            vision_model.patch_embeddings_0 = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD);
+        } catch(const std::exception& /*e*/) {
+            vision_model.patch_embeddings_0 = nullptr;
+        }
+
+        try {
             vision_model.position_embeddings = get_tensor(new_clip->ctx_data, format(TN_POS_EMBD, "v"));
         } catch(const std::exception& /*e*/) {
-            LOG_ERR("%s: failed to load vision model tensors\n", __func__);
+            vision_model.position_embeddings = nullptr;
         }
+
         try {
             vision_model.patch_embeddings_1    = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD_1);
         } catch(const std::exception& /*e*/) {
@@ -1682,6 +1862,10 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p
             vision_model.mm_1_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "weight"));
             vision_model.mm_1_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "bias"));
         }
+        else if (new_clip->proj_type == PROJECTOR_TYPE_GEMMA3) {
+            vision_model.mm_input_proj_w    = get_tensor(new_clip->ctx_data, TN_MM_INP_PROJ);
+            vision_model.mm_soft_emb_norm_w = get_tensor(new_clip->ctx_data, TN_MM_SOFT_EMB_N);
+        }
         else {
             std::string proj_type = PROJECTOR_TYPE_NAMES[new_clip->proj_type];
             throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));
@@ -2223,7 +2407,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
         return true;
     }
 
-    if (ctx->has_glm_projector) {
+    if (ctx->has_glm_projector || ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
         res_imgs->size = 1;
         res_imgs->data = new clip_image_f32[res_imgs->size];
         clip_image_u8 resized_image;
@@ -2748,6 +2932,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
             ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
             free(positions_data);
         }
+        else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
+            // do nothing
+        }
         else {
             struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
 
@@ -2960,6 +3147,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
     if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
         return ctx->vision_model.mm_1_b->ne[0];
     }
+    if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
+        return ctx->vision_model.mm_input_proj_w->ne[0];
+    }
 
     std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type];
     throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));
diff --git a/examples/llava/gemma3-cli.cpp b/examples/llava/gemma3-cli.cpp
new file mode 100644 (file)
index 0000000..a07864d
--- /dev/null
@@ -0,0 +1,341 @@
+#include "arg.h"
+#include "log.h"
+#include "common.h"
+#include "sampling.h"
+#include "clip.h"
+#include "stb_image.h"
+#include "llama.h"
+#include "ggml.h"
+#include "console.h"
+
+#include <vector>
+#include <limits.h>
+#include <inttypes.h>
+
+#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
+#include <signal.h>
+#include <unistd.h>
+#elif defined (_WIN32)
+#define WIN32_LEAN_AND_MEAN
+#ifndef NOMINMAX
+#define NOMINMAX
+#endif
+#include <windows.h>
+#include <signal.h>
+#endif
+
+static bool g_is_generating = false;
+
+/**
+ * Please note that this is NOT a production-ready stuff.
+ * It is a playground for trying Gemma 3 vision capabilities.
+ * For contributors: please keep this code simple and easy to understand.
+ */
+
+static void show_additional_info(int /*argc*/, char ** argv) {
+    LOG(
+        "Experimental CLI for using Gemma 3 vision model\n\n"
+        "Usage: %s [options] -m <model> --mmproj <mmproj> --image <image> -p <prompt>\n\n"
+        "  -m and --mmproj are required\n"
+        "  --image and -p are optional, if NOT provided, the CLI will run in chat mode\n",
+        argv[0]
+    );
+}
+
+#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
+static void sigint_handler(int signo) {
+    if (signo == SIGINT) {
+        if (g_is_generating) {
+            g_is_generating = false;
+        } else {
+            console::cleanup();
+            LOG("\nInterrupted by user\n");
+            _exit(130);
+        }
+    }
+}
+#endif
+
+struct gemma3_context {
+    struct clip_ctx    * ctx_clip = NULL;
+    common_init_result   llama_init;
+
+    llama_model       * model;
+    llama_context     * lctx;
+    const llama_vocab * vocab;
+    llama_batch         batch;
+
+    int n_threads    = 1;
+    llama_pos n_past = 0;
+
+    gemma3_context(common_params & params) : llama_init(common_init_from_params(params)) {
+        model = llama_init.model.get();
+        lctx = llama_init.context.get();
+        vocab = llama_model_get_vocab(model);
+        n_threads = params.cpuparams.n_threads;
+        batch = llama_batch_init(params.n_batch, 0, 1);
+        init_clip_model(params);
+    }
+
+    void init_clip_model(common_params & params) {
+        const char * clip_path = params.mmproj.c_str();
+        ctx_clip = clip_model_load(clip_path, params.verbosity > 1);
+    }
+
+    ~gemma3_context() {
+        clip_free(ctx_clip);
+    }
+};
+
+struct decode_embd_batch {
+    std::vector<llama_pos>      pos;
+    std::vector<int32_t>        n_seq_id;
+    std::vector<llama_seq_id>   seq_id_0;
+    std::vector<llama_seq_id *> seq_ids;
+    std::vector<int8_t>         logits;
+    llama_batch batch;
+    decode_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
+        pos     .resize(n_tokens);
+        n_seq_id.resize(n_tokens);
+        seq_ids .resize(n_tokens + 1);
+        logits  .resize(n_tokens);
+        seq_id_0.resize(1);
+        seq_id_0[0] = seq_id;
+        seq_ids [n_tokens] = nullptr;
+        batch = {
+            /*n_tokens       =*/ n_tokens,
+            /*tokens         =*/ nullptr,
+            /*embd           =*/ embd,
+            /*pos            =*/ pos.data(),
+            /*n_seq_id       =*/ n_seq_id.data(),
+            /*seq_id         =*/ seq_ids.data(),
+            /*logits         =*/ logits.data(),
+        };
+        for (int i = 0; i < n_tokens; i++) {
+            batch.pos     [i] = pos_0 + i;
+            batch.n_seq_id[i] = 1;
+            batch.seq_id  [i] = seq_id_0.data();
+            batch.logits  [i] = false;
+        }
+    }
+};
+
+static int eval_text(gemma3_context & ctx, std::string input, bool logits_last = false) {
+    llama_tokens tokens = common_tokenize(ctx.lctx, input, false, true);
+    common_batch_clear(ctx.batch);
+    for (llama_token & t : tokens) {
+        common_batch_add(ctx.batch, t, ctx.n_past++, {0}, false);
+    }
+    if (logits_last) {
+        ctx.batch.logits[ctx.batch.n_tokens - 1] = true;
+    }
+    // LOG("eval_text (n_tokens = %d): %s\n", (int)tokens.size(), input.c_str());
+    if (llama_decode(ctx.lctx, ctx.batch)) {
+        LOG_ERR("Failed to decode text\n");
+        return 1;
+    }
+    return 0;
+}
+
+static int eval_image(gemma3_context & ctx, std::string & fname) {
+    std::vector<float> image_embd_v;
+    int n_embd = llama_model_n_embd(ctx.model);
+    int n_tokens = 256;
+    image_embd_v.resize(n_tokens * n_embd);
+
+    bool ok;
+    struct clip_image_u8 * img_u8 = clip_image_u8_init();
+    ok = clip_image_load_from_file(fname.c_str(), img_u8);
+    if (!ok) {
+        LOG_ERR("Unable to load image %s\n", fname.c_str());
+        clip_image_u8_free(img_u8);
+        return 2; // non-fatal error
+    }
+
+    clip_image_f32_batch batch_f32;
+    ok = clip_image_preprocess(ctx.ctx_clip, img_u8, &batch_f32);
+    if (!ok) {
+        LOG_ERR("Unable to preprocess image\n");
+        clip_image_f32_batch_free(&batch_f32);
+        clip_image_u8_free(img_u8);
+        return 1;
+    }
+
+    int64_t t0 = ggml_time_ms();
+    LOG("Encoding image %s\n", fname.c_str());
+    ok = clip_image_batch_encode(ctx.ctx_clip, ctx.n_threads, &batch_f32, image_embd_v.data());
+    if (!ok) {
+        LOG_ERR("Unable to encode image\n");
+        clip_image_f32_batch_free(&batch_f32);
+        clip_image_u8_free(img_u8);
+        return 1;
+    }
+    LOG("Image encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
+
+    clip_image_f32_batch_free(&batch_f32);
+    clip_image_u8_free(img_u8);
+
+    // decode image embeddings
+    int64_t t1 = ggml_time_ms();
+    eval_text(ctx, "<start_of_image>");
+    llama_set_causal_attn(ctx.lctx, false);
+    decode_embd_batch batch_img(image_embd_v.data(), n_tokens, ctx.n_past, 0);
+    if (llama_decode(ctx.lctx, batch_img.batch)) {
+        LOG_ERR("failed to decode image\n");
+        return 1;
+    }
+    ctx.n_past += n_tokens;
+    llama_set_causal_attn(ctx.lctx, true);
+    eval_text(ctx, "<end_of_image>");
+    LOG("Image decoded in %" PRId64 " ms\n", ggml_time_ms() - t1);
+    return 0;
+}
+
+static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_predict) {
+    for (int i = 0; i < n_predict; i++) {
+        if (i > n_predict || !g_is_generating) {
+            printf("\n");
+            break;
+        }
+
+        llama_token token_id = common_sampler_sample(smpl, ctx.lctx, -1);
+        common_sampler_accept(smpl, token_id, true);
+
+        if (llama_vocab_is_eog(ctx.vocab, token_id)) {
+            printf("\n");
+            break; // end of generation
+        }
+
+        printf("%s", common_token_to_piece(ctx.lctx, token_id).c_str());
+        fflush(stdout);
+
+        // eval the token
+        common_batch_clear(ctx.batch);
+        common_batch_add(ctx.batch, token_id, ctx.n_past++, {0}, true);
+        if (llama_decode(ctx.lctx, ctx.batch)) {
+            LOG_ERR("failed to decode token\n");
+            return 1;
+        }
+    }
+    return 0;
+}
+
+int main(int argc, char ** argv) {
+    ggml_time_init();
+
+    common_params params;
+    params.sampling.temp = 0.2; // lower temp by default for better quality
+
+    if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LLAVA, show_additional_info)) {
+        return 1;
+    }
+
+    common_init();
+
+    if (params.mmproj.empty()) {
+        show_additional_info(argc, argv);
+        return 1;
+    }
+
+    gemma3_context ctx(params);
+    printf("%s: %s\n", __func__, params.model.c_str());
+
+    bool is_single_turn = !params.prompt.empty() && !params.image.empty();
+
+    struct common_sampler * smpl = common_sampler_init(ctx.model, params.sampling);
+    int n_predict = params.n_predict < 0 ? INT_MAX : params.n_predict;
+
+    // ctrl+C handling
+    {
+#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
+        struct sigaction sigint_action;
+        sigint_action.sa_handler = sigint_handler;
+        sigemptyset (&sigint_action.sa_mask);
+        sigint_action.sa_flags = 0;
+        sigaction(SIGINT, &sigint_action, NULL);
+#elif defined (_WIN32)
+        auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
+            return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false;
+        };
+        SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
+#endif
+    }
+
+    if (eval_text(ctx, "<bos>")) {
+        return 1;
+    }
+
+    if (is_single_turn) {
+        g_is_generating = true;
+        if (eval_text(ctx, "<start_of_turn>user\n")) {
+            return 1;
+        }
+        for (auto & fname : params.image) {
+            if (eval_image(ctx, fname)) {
+                return 1;
+            }
+        }
+        if (eval_text(ctx, params.prompt + "<end_of_turn><start_of_turn>model\n", true)) {
+            return 1;
+        }
+        if (generate_response(ctx, smpl, n_predict)) {
+            return 1;
+        }
+
+    } else {
+        LOG("\n Running in chat mode, available commands:");
+        LOG("\n   /image <path>    load an image");
+        LOG("\n   /clear           clear the chat history");
+        LOG("\n   /quit or /exit   exit the program");
+        LOG("\n");
+
+        if (eval_text(ctx, "<start_of_turn>user\n")) {
+            return 1;
+        }
+
+        while (true) {
+            g_is_generating = false;
+            LOG("\n> ");
+            console::set_display(console::user_input);
+            std::string line;
+            console::readline(line, false);
+            console::set_display(console::reset);
+            line = string_strip(line);
+            if (line.empty()) {
+                continue;
+            }
+            if (line == "/quit" || line == "/exit") {
+                break;
+            }
+            if (line == "/clear") {
+                ctx.n_past = 0;
+                llama_kv_cache_seq_rm(ctx.lctx, 0, 1, -1); // keep BOS
+                LOG("Chat history cleared\n\n");
+                continue;
+            }
+            g_is_generating = true;
+            if (line.find("/image") == 0) {
+                std::string image = line.substr(7);
+                int res = eval_image(ctx, image);
+                if (res == 2) {
+                    continue; // image not found
+                }
+                if (res) {
+                    return 1;
+                }
+                continue;
+            }
+            if (eval_text(ctx, line + "<end_of_turn><start_of_turn>model\n", true)) {
+                return 1;
+            }
+            if (generate_response(ctx, smpl, n_predict)) {
+                return 1;
+            }
+            if (eval_text(ctx, "<end_of_turn><start_of_turn>user\n")) {
+                return 1;
+            }
+        }
+    }
+
+    return 0;
+}
diff --git a/examples/llava/gemma3_convert_encoder_to_gguf.py b/examples/llava/gemma3_convert_encoder_to_gguf.py
new file mode 100644 (file)
index 0000000..241b526
--- /dev/null
@@ -0,0 +1,307 @@
+import gguf
+import argparse
+import logging
+import sys
+import torch
+import json
+import os
+import numpy as np
+from typing import cast, ContextManager, Any, Iterator
+from pathlib import Path
+from torch import Tensor
+
+logger = logging.getLogger("gemma3-mmproj")
+
+
+# (copied from convert_hf_to_gguf.py)
+# tree of lazy tensors
+class LazyTorchTensor(gguf.LazyBase):
+    _tensor_type = torch.Tensor
+    # to keep the type-checker happy
+    dtype: torch.dtype
+    shape: torch.Size
+
+    # only used when converting a torch.Tensor to a np.ndarray
+    _dtype_map: dict[torch.dtype, type] = {
+        torch.float16: np.float16,
+        torch.float32: np.float32,
+    }
+
+    # used for safetensors slices
+    # ref: https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L1046
+    # TODO: uncomment U64, U32, and U16, ref: https://github.com/pytorch/pytorch/issues/58734
+    _dtype_str_map: dict[str, torch.dtype] = {
+        "F64": torch.float64,
+        "F32": torch.float32,
+        "BF16": torch.bfloat16,
+        "F16": torch.float16,
+        # "U64": torch.uint64,
+        "I64": torch.int64,
+        # "U32": torch.uint32,
+        "I32": torch.int32,
+        # "U16": torch.uint16,
+        "I16": torch.int16,
+        "U8": torch.uint8,
+        "I8": torch.int8,
+        "BOOL": torch.bool,
+        "F8_E4M3": torch.float8_e4m3fn,
+        "F8_E5M2": torch.float8_e5m2,
+    }
+
+    def numpy(self) -> gguf.LazyNumpyTensor:
+        dtype = self._dtype_map[self.dtype]
+        return gguf.LazyNumpyTensor(
+            meta=gguf.LazyNumpyTensor.meta_with_dtype_and_shape(dtype, self.shape),
+            args=(self,),
+            func=(lambda s: s.numpy())
+        )
+
+    @classmethod
+    def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: tuple[int, ...]) -> Tensor:
+        return torch.empty(size=shape, dtype=dtype, device="meta")
+
+    @classmethod
+    def from_safetensors_slice(cls, st_slice: Any) -> Tensor:
+        dtype = cls._dtype_str_map[st_slice.get_dtype()]
+        shape: tuple[int, ...] = tuple(st_slice.get_shape())
+        lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:])
+        return cast(torch.Tensor, lazy)
+
+    @classmethod
+    def __torch_function__(cls, func, types, args=(), kwargs=None):
+        del types  # unused
+
+        if kwargs is None:
+            kwargs = {}
+
+        if func is torch.Tensor.numpy:
+            return args[0].numpy()
+
+        return cls._wrap_fn(func)(*args, **kwargs)
+
+
+class Gemma3VisionTower:
+    hparams: dict
+    gguf_writer: gguf.GGUFWriter
+    fname_out: Path
+    ftype: gguf.LlamaFileType
+
+    @staticmethod
+    def load_hparams(dir_model: Path):
+        with open(dir_model / "config.json", "r", encoding="utf-8") as f:
+            return json.load(f)
+
+    @staticmethod
+    def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]:
+        part_names: list[str] = []
+        for filename in os.listdir(dir_model):
+            if filename.startswith(prefix) and filename.endswith(suffix):
+                part_names.append(filename)
+        part_names.sort()
+        return part_names
+
+    def __init__(self,
+                 dir_model: Path,
+                 fname_out: Path,
+                 ftype: gguf.LlamaFileType,
+                 is_big_endian: bool,):
+        hparams = Gemma3VisionTower.load_hparams(dir_model)
+        self.hparams = hparams
+        self.fname_out = fname_out
+        self.ftype = ftype
+        endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
+        self.gguf_writer = gguf.GGUFWriter(path=None, arch="clip", endianess=endianess)
+
+        text_config = hparams["text_config"]
+        vision_config = hparams["vision_config"]
+
+        assert hparams["architectures"][0] == "Gemma3ForConditionalGeneration"
+        assert text_config is not None
+        assert vision_config is not None
+
+        self.gguf_writer.add_string ("clip.projector_type",              "gemma3")
+        self.gguf_writer.add_bool   ("clip.has_text_encoder",            False)
+        self.gguf_writer.add_bool   ("clip.has_vision_encoder",          True)
+        self.gguf_writer.add_bool   ("clip.has_llava_projector",         False) # legacy
+        self.gguf_writer.add_uint32 ("clip.vision.image_size",           vision_config["image_size"])
+        self.gguf_writer.add_uint32 ("clip.vision.patch_size",           vision_config["patch_size"])
+        self.gguf_writer.add_uint32 ("clip.vision.embedding_length",     vision_config["hidden_size"])
+        self.gguf_writer.add_uint32 ("clip.vision.feed_forward_length",  vision_config["intermediate_size"])
+        self.gguf_writer.add_uint32 ("clip.vision.projection_dim",       text_config["hidden_size"])
+        self.gguf_writer.add_uint32 ("clip.vision.block_count",          vision_config["num_hidden_layers"])
+        self.gguf_writer.add_uint32 ("clip.vision.attention.head_count", vision_config["num_attention_heads"])
+        self.gguf_writer.add_float32("clip.vision.attention.layer_norm_epsilon", vision_config.get("layer_norm_eps", 1e-6))
+        # default values taken from HF tranformers code
+        self.gguf_writer.add_array  ("clip.vision.image_mean", [0.5, 0.5, 0.5])
+        self.gguf_writer.add_array  ("clip.vision.image_std",  [0.5, 0.5, 0.5])
+        self.gguf_writer.add_bool   ("clip.use_gelu", True)
+
+        # load tensors
+        for name, data_torch in self.get_tensors(dir_model):
+            # convert any unsupported data types to float32
+            if data_torch.dtype not in (torch.float16, torch.float32):
+                data_torch = data_torch.to(torch.float32)
+            self.add_tensor(name, data_torch)
+
+    def get_tensors(self, dir_model: Path) -> Iterator[tuple[str, Tensor]]:
+        part_names = Gemma3VisionTower.get_model_part_names(dir_model, "model", ".safetensors")
+        tensor_names_from_parts: set[str] = set()
+        for part_name in part_names:
+            logger.info(f"gguf: loading model part '{part_name}'")
+            from safetensors import safe_open
+            ctx = cast(ContextManager[Any], safe_open(dir_model / part_name, framework="pt", device="cpu"))
+            with ctx as model_part:
+                tensor_names_from_parts.update(model_part.keys())
+
+                for name in model_part.keys():
+                    data = model_part.get_slice(name)
+                    data = LazyTorchTensor.from_safetensors_slice(data)
+                    yield name, data
+
+    def add_tensor(self, name: str, data_torch: Tensor):
+        is_1d = len(data_torch.shape) == 1
+        is_embd = ".embeddings." in name
+        old_dtype = data_torch.dtype
+        can_quantize = not is_1d and not is_embd
+        data_qtype = gguf.GGMLQuantizationType.F32
+
+        # this is to support old checkpoint
+        # TODO: remove this when we have the final model
+        name = name.replace("vision_model.vision_model.", "vision_tower.vision_model.")
+        name = name.replace("multimodal_projector.", "multi_modal_projector.")
+
+        # filter only vision tensors
+        if not name.startswith("vision_tower.vision_model.") and not name.startswith("multi_modal_projector."):
+            return
+        # prefix
+        name = name.replace("vision_tower.vision_model.encoder.layers.", "v.blk.")
+        name = name.replace("vision_tower.vision_model.", "v.")
+        # projector and input embd
+        name = name.replace(".embeddings.patch_embedding.", ".patch_embd.")
+        name = name.replace(".embeddings.position_embedding.", ".position_embd.")
+        name = name.replace(
+            "multi_modal_projector.mm_input_projection_weight",
+            "mm.input_projection.weight"
+        )
+        name = name.replace(
+            "multi_modal_projector.mm_soft_emb_norm.weight",
+            "mm.soft_emb_norm.weight"
+        )
+        name = name.replace("post_layernorm.", "post_ln.")
+        # each block
+        name = name.replace(".self_attn.k_proj.", ".attn_k.")
+        name = name.replace(".self_attn.v_proj.", ".attn_v.")
+        name = name.replace(".self_attn.q_proj.", ".attn_q.")
+        name = name.replace(".self_attn.out_proj.", ".attn_out.")
+        name = name.replace(".layer_norm1.", ".ln1.")
+        name = name.replace(".layer_norm2.", ".ln2.")
+        name = name.replace(".mlp.fc1.", ".ffn_down.")
+        name = name.replace(".mlp.fc2.", ".ffn_up.")
+
+        if can_quantize:
+            if self.ftype == gguf.LlamaFileType.ALL_F32:
+                data_qtype = gguf.GGMLQuantizationType.F32
+            elif self.ftype == gguf.LlamaFileType.MOSTLY_F16:
+                data_qtype = gguf.GGMLQuantizationType.F16
+            elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
+                data_qtype = gguf.GGMLQuantizationType.BF16
+            elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0:
+                data_qtype = gguf.GGMLQuantizationType.Q8_0
+            else:
+                raise ValueError(f"Unsupported file type: {self.ftype}")
+
+        # corrent norm value ; only this "soft_emb_norm" need to be corrected as it's part of Gemma projector
+        # the other norm values are part of SigLIP model, and they are already correct
+        # ref code: Gemma3RMSNorm
+        if "soft_emb_norm.weight" in name:
+            logger.info(f"Correcting norm value for '{name}'")
+            data_torch = data_torch + 1
+
+        data = data_torch.numpy()
+
+        try:
+            data = gguf.quants.quantize(data, data_qtype)
+        except Exception as e:
+            logger.error(f"Error quantizing tensor '{name}': {e}, fallback to F16")
+            data_qtype = gguf.GGMLQuantizationType.F16
+            data = gguf.quants.quantize(data, data_qtype)
+
+        # reverse shape to make it similar to the internal ggml dimension order
+        shape_str = f"{{{', '.join(str(n) for n in reversed(data_torch.shape))}}}"
+        logger.info(f"{f'%-32s' % f'{name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")
+
+        self.gguf_writer.add_tensor(name, data, raw_dtype=data_qtype)
+
+    def write(self):
+        self.gguf_writer.write_header_to_file(path=self.fname_out)
+        self.gguf_writer.write_kv_data_to_file()
+        self.gguf_writer.write_tensors_to_file(progress=True)
+        self.gguf_writer.close()
+
+def parse_args() -> argparse.Namespace:
+    parser = argparse.ArgumentParser(
+        description="Convert Gemma 3 vision tower safetensors to GGUF format",)
+    parser.add_argument(
+        "--outfile", type=Path, default="mmproj.gguf",
+        help="path to write to",
+    )
+    parser.add_argument(
+        "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0"], default="f16",
+        help="output format",
+    )
+    parser.add_argument(
+        "--bigendian", action="store_true",
+        help="model is executed on big endian machine",
+    )
+    parser.add_argument(
+        "model", type=Path,
+        help="directory containing model file",
+        nargs="?",
+    )
+    parser.add_argument(
+        "--verbose", action="store_true",
+        help="increase output verbosity",
+    )
+
+    args = parser.parse_args()
+    if args.model is None:
+        parser.error("the following arguments are required: model")
+    return args
+
+
+def main() -> None:
+    args = parse_args()
+
+    if args.verbose:
+        logging.basicConfig(level=logging.DEBUG)
+    else:
+        logging.basicConfig(level=logging.INFO)
+
+    dir_model = args.model
+
+    if not dir_model.is_dir():
+        logger.error(f'Error: {args.model} is not a directory')
+        sys.exit(1)
+
+    ftype_map: dict[str, gguf.LlamaFileType] = {
+        "f32": gguf.LlamaFileType.ALL_F32,
+        "f16": gguf.LlamaFileType.MOSTLY_F16,
+        "bf16": gguf.LlamaFileType.MOSTLY_BF16,
+        "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0,
+    }
+
+    logger.info(f"Loading model: {dir_model.name}")
+
+    with torch.inference_mode():
+        gemma3_vision_tower = Gemma3VisionTower(
+            dir_model=dir_model,
+            fname_out=args.outfile,
+            ftype=ftype_map[args.outtype],
+            is_big_endian=args.bigendian,
+        )
+        gemma3_vision_tower.write()
+
+
+if __name__ == '__main__':
+    main()
+
index ecac5b4bb7f59cfc91f0d6b1b2956521198d3e3f..19624eae04ece3db0c2d8d3438025d1416116ae6 100644 (file)
@@ -253,6 +253,7 @@ class MODEL_ARCH(IntEnum):
     MINICPM3         = auto()
     GEMMA            = auto()
     GEMMA2           = auto()
+    GEMMA3           = auto()
     STARCODER2       = auto()
     RWKV6            = auto()
     RWKV6QWEN2       = auto()
@@ -440,6 +441,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.MINICPM3:         "minicpm3",
     MODEL_ARCH.GEMMA:            "gemma",
     MODEL_ARCH.GEMMA2:           "gemma2",
+    MODEL_ARCH.GEMMA3:           "gemma3",
     MODEL_ARCH.STARCODER2:       "starcoder2",
     MODEL_ARCH.RWKV6:            "rwkv6",
     MODEL_ARCH.RWKV6QWEN2:       "rwkv6qwen2",
@@ -1077,6 +1079,23 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_PRE_NORM,
         MODEL_TENSOR.FFN_POST_NORM,
     ],
+    MODEL_ARCH.GEMMA3: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.ATTN_Q,
+        MODEL_TENSOR.ATTN_Q_NORM,
+        MODEL_TENSOR.ATTN_K,
+        MODEL_TENSOR.ATTN_K_NORM,
+        MODEL_TENSOR.ATTN_V,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.FFN_GATE,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_POST_NORM,
+        MODEL_TENSOR.FFN_PRE_NORM,
+        MODEL_TENSOR.FFN_POST_NORM,
+    ],
     MODEL_ARCH.STARCODER2: [
         MODEL_TENSOR.TOKEN_EMBD,
         MODEL_TENSOR.OUTPUT_NORM,
index 97a1e7e5e01ef67c75bd63d7ddad39ade99e7ec5..28f2bbc8f72bffc39562b260a5d5b9458715b31f 100644 (file)
@@ -36,6 +36,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_MINICPM3,         "minicpm3"         },
     { LLM_ARCH_GEMMA,            "gemma"            },
     { LLM_ARCH_GEMMA2,           "gemma2"           },
+    { LLM_ARCH_GEMMA3,           "gemma3"           },
     { LLM_ARCH_STARCODER2,       "starcoder2"       },
     { LLM_ARCH_MAMBA,            "mamba"            },
     { LLM_ARCH_XVERSE,           "xverse"           },
@@ -766,6 +767,26 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_FFN_POST_NORM,   "blk.%d.post_ffw_norm" },
         },
     },
+    {
+        LLM_ARCH_GEMMA3,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_Q_NORM,     "blk.%d.attn_q_norm" },
+            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_K_NORM,     "blk.%d.attn_k_norm" },
+            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_ATTN_POST_NORM,  "blk.%d.post_attention_norm" },
+            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+            { LLM_TENSOR_FFN_POST_NORM,   "blk.%d.post_ffw_norm" },
+        },
+    },
     {
         LLM_ARCH_STARCODER2,
         {
index 122fdcebe0af6a5ce25612cb56888b53bb007c38..2ec2e2362eba1a3b0892e905235dfefb63bdbab0 100644 (file)
@@ -40,6 +40,7 @@ enum llm_arch {
     LLM_ARCH_MINICPM3,
     LLM_ARCH_GEMMA,
     LLM_ARCH_GEMMA2,
+    LLM_ARCH_GEMMA3,
     LLM_ARCH_STARCODER2,
     LLM_ARCH_MAMBA,
     LLM_ARCH_XVERSE,
index 1da4eae7e63e2aa435cfd52a0b3a6e9772218ccf..9f75589d805a9f3bb2793d9e1c2fdd4ed4680eea 100644 (file)
@@ -9,6 +9,7 @@
 #include <algorithm>
 #include <cassert>
 #include <cstring>
+#include <cmath>
 #include <functional>
 #include <map>
 #include <sstream>
@@ -864,6 +865,23 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     default: type = LLM_TYPE_UNKNOWN;
                }
             } break;
+        case LLM_ARCH_GEMMA3:
+            {
+                ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW,    hparams.n_swa);
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+
+                switch (hparams.n_layer) {
+                    case 26: type = LLM_TYPE_1B; break;
+                    case 34: type = LLM_TYPE_4B; break;
+                    case 48: type = LLM_TYPE_12B; break;
+                    case 62: type = LLM_TYPE_27B; break;
+                    default: type = LLM_TYPE_UNKNOWN;
+                }
+
+                hparams.f_attention_scale = type == LLM_TYPE_27B
+                    ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
+                    : 1.0f / std::sqrt(float(hparams.n_embd_head_k));
+            } break;
         case LLM_ARCH_STARCODER2:
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -2454,6 +2472,35 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
                         layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
 
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
+                    }
+                } break;
+            case LLM_ARCH_GEMMA3:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
+
+                        layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_k_norm    = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM,    "weight", i), {n_embd_head_k}, 0);
+                        layer.attn_q_norm    = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM,    "weight", i), {n_embd_head_k}, 0);
+
                         layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
                         layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
                         layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
@@ -3650,6 +3697,7 @@ void llama_model::print_info() const {
         LLAMA_LOG_INFO("%s: f_clamp_kqv      = %.1e\n",   __func__, hparams.f_clamp_kqv);
         LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n",   __func__, hparams.f_max_alibi_bias);
         LLAMA_LOG_INFO("%s: f_logit_scale    = %.1e\n",   __func__, hparams.f_logit_scale);
+        LLAMA_LOG_INFO("%s: f_attn_scale     = %.1e\n",   __func__, hparams.f_attention_scale);
         LLAMA_LOG_INFO("%s: n_ff             = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer).c_str());
         LLAMA_LOG_INFO("%s: n_expert         = %u\n",     __func__, hparams.n_expert);
         LLAMA_LOG_INFO("%s: n_expert_used    = %u\n",     __func__, hparams.n_expert_used);
@@ -3923,6 +3971,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
         case LLM_ARCH_PHIMOE:
         case LLM_ARCH_GEMMA:
         case LLM_ARCH_GEMMA2:
+        case LLM_ARCH_GEMMA3:
         case LLM_ARCH_STARCODER2:
         case LLM_ARCH_OPENELM:
         case LLM_ARCH_GPTNEOX:
index 607f27861596931ea50ecde86251f83040a539d2..4a4e91490107c1f9384ca601d81a092443c60610 100644 (file)
@@ -4978,6 +4978,149 @@ struct llm_build_context {
         return gf;
     }
 
+    struct ggml_cgraph * build_gemma3() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
+
+        const int64_t n_embd_head_k = hparams.n_embd_head_k;
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
+
+        // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
+        if (ubatch.token) {
+            inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
+            cb(inpL, "inp_scaled", -1);
+        }
+
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = build_inp_pos();
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        // gemma3 requires different mask for layers using sliding window (SWA)
+        struct ggml_tensor * KQ_mask     = build_inp_KQ_mask(true);
+        struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa(true);
+
+        // "5-to-1 interleaved attention"
+        // 5 layers of local attention followed by 1 layer of global attention
+        static const int sliding_window_pattern = 6;
+
+        for (int il = 0; il < n_layer; ++il) {
+            const bool is_sliding          = (il + 1) % sliding_window_pattern;
+            const float freq_base_l        = is_sliding ? 10000.0f    : freq_base;
+            const float freq_scale_l       = is_sliding ? 1.0f        : freq_scale;
+            struct ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask;
+
+            // norm
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head,    n_tokens);
+                Qcur = llm_build_norm(ctx0, Qcur, hparams,
+                    model.layers[il].attn_q_norm,
+                    NULL,
+                    LLM_NORM_RMS, cb, il);
+                cb(Qcur, "Qcur_normed", il);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, Qcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+                        ext_factor, attn_factor, beta_fast, beta_slow);
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens);
+                Kcur = llm_build_norm(ctx0, Kcur, hparams,
+                    model.layers[il].attn_k_norm,
+                    NULL,
+                    LLM_NORM_RMS, cb, il);
+                cb(Kcur, "Kcur_normed", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, Kcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+                        ext_factor, attn_factor, beta_fast, beta_slow);
+                cb(Kcur, "Kcur", il);
+
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
+                        model.layers[il].wo, NULL,
+                        Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, hparams.f_attention_scale, cb, il);
+            }
+
+            cur = llm_build_norm(ctx0, cur, hparams,
+                    model.layers[il].attn_post_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "attn_post_norm", il);
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+            }
+
+            struct ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
+            cb(sa_out, "sa_out", il);
+
+            cur = llm_build_norm(ctx0, sa_out, hparams,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "ffn_norm", il);
+
+            // feed-forward network
+            {
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_GELU, LLM_FFN_PAR, cb, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = llm_build_norm(ctx0, cur, hparams,
+                model.layers[il].ffn_post_norm, NULL,
+                LLM_NORM_RMS, cb, -1);
+            cb(cur, "ffn_post_norm", -1);
+
+            cur = ggml_add(ctx0, cur, sa_out);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        // lm_head
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
 
     struct ggml_cgraph * build_starcoder2() {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
@@ -8298,6 +8441,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_gemma2();
             } break;
+        case LLM_ARCH_GEMMA3:
+            {
+                result = llm.build_gemma3();
+            } break;
         case LLM_ARCH_STARCODER2:
             {
                 result = llm.build_starcoder2();