]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
mtmd : support Qwen 2.5 Omni (input audio+vision, no audio output) (#13784)
authorXuan-Son Nguyen <redacted>
Tue, 27 May 2025 12:06:10 +0000 (14:06 +0200)
committerGitHub <redacted>
Tue, 27 May 2025 12:06:10 +0000 (14:06 +0200)
* mtmd : allow multiple modalities at the same time

* refactor mtmd tokenizer

* fix compile

* ok, missing SinusoidsPositionEmbedding

* first working version

* fix style

* more strict validate of n_embd

* refactor if..else to switch

* fix regression

* add test for 3B

* update docs

* fix tokenizing with add_special

* add more tests

* fix test case "huge"

* rm redundant code

* set_position_mrope_1d rm n_tokens

12 files changed:
convert_hf_to_gguf.py
docs/multimodal.md
gguf-py/gguf/constants.py
gguf-py/gguf/tensor_mapping.py
tools/mtmd/clip-impl.h
tools/mtmd/clip.cpp
tools/mtmd/clip.h
tools/mtmd/mtmd-cli.cpp
tools/mtmd/mtmd-helper.cpp
tools/mtmd/mtmd.cpp
tools/mtmd/test-2.mp3 [new file with mode: 0644]
tools/mtmd/tests.sh

index 91af508a2fb28df75aee78354e70d91a28be8dae..a015ecee08328bbe56acff98462a2efbb6ac9cce 100755 (executable)
@@ -432,6 +432,9 @@ class ModelBase:
                 if "llm_config" in config:
                     # rename for InternVL
                     config["text_config"] = config["llm_config"]
+                if "thinker_config" in config:
+                    # rename for Qwen2.5-Omni
+                    config["text_config"] = config["thinker_config"]["text_config"]
                 return config
 
     @classmethod
@@ -1121,18 +1124,21 @@ class MmprojModel(ModelBase):
     preprocessor_config: dict[str, Any]
     global_config: dict[str, Any]
 
+    n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"]
+
     has_vision_encoder: bool = True # by default
     has_audio_encoder: bool = False
 
+    # for models having multiple encoders, we need to separate their hparams
+    hparams_vision: dict[str, Any] | None = None
+    hparams_audio: dict[str, Any] | None = None
+
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
 
         if self.model_arch != gguf.MODEL_ARCH.MMPROJ:
             raise TypeError("MmprojModel must be subclassed with model_arch = gguf.MODEL_ARCH.MMPROJ")
 
-        if self.has_vision_encoder and self.has_audio_encoder:
-            raise NotImplementedError("both vision + audio not supported yet")
-
         # get n_embd of the text model
         if "text_config" not in self.hparams:
             self.hparams["text_config"] = {}
@@ -1143,22 +1149,32 @@ class MmprojModel(ModelBase):
         assert self.n_embd_text > 0, "n_embd not found in hparams"
 
         # move vision config to the top level, while preserving the original hparams in global_config
-        self.global_config = self.hparams
+        import copy
+        self.global_config = copy.deepcopy(self.hparams)
+        self.hparams_vision = self.get_vision_config()
+        self.hparams_audio = self.get_audio_config()
 
-        if "vision_config" in self.hparams:
-            self.hparams = self.hparams["vision_config"]
-        elif "audio_config" in self.hparams:
-            self.hparams = self.hparams["audio_config"]
-        else:
+        if self.hparams_vision is None and self.hparams_audio is None:
             raise ValueError("vision_config / audio_config not found in hparams")
 
-        self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"])
+        # for compat with vision-only models
+        self.hparams = self.hparams_vision or self.hparams_audio or self.hparams
+
+        # TODO @ngxson : this is a hack to support both vision and audio encoders
+        have_multiple_encoders = self.has_audio_encoder and self.has_vision_encoder
+        self.block_count = 128 if have_multiple_encoders else self.find_hparam(self.n_block_keys, True)
         self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, self.block_count)
 
         # load preprocessor config
         with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f:
             self.preprocessor_config = json.load(f)
 
+    def get_vision_config(self) -> dict[str, Any] | None:
+        return self.global_config.get("vision_config")
+
+    def get_audio_config(self) -> dict[str, Any] | None:
+        return self.global_config.get("audio_config")
+
     def set_type(self):
         self.gguf_writer.add_type(gguf.GGUFType.MMPROJ)
 
@@ -1170,26 +1186,26 @@ class MmprojModel(ModelBase):
             self.gguf_writer.add_vision_projection_dim(self.n_embd_text)
 
             # vision config
-            self.gguf_writer.add_vision_image_size(self.find_hparam(["image_size"]))
-            self.gguf_writer.add_vision_patch_size(self.find_hparam(["patch_size"]))
-            self.gguf_writer.add_vision_embedding_length(self.find_hparam(["hidden_size"]))
-            self.gguf_writer.add_vision_feed_forward_length(self.find_hparam(["intermediate_size"]))
-            self.gguf_writer.add_vision_block_count(self.block_count)
-            self.gguf_writer.add_vision_head_count(self.find_hparam(["num_attention_heads"]))
+            self.gguf_writer.add_vision_image_size(self.find_vparam(["image_size"]))
+            self.gguf_writer.add_vision_patch_size(self.find_vparam(["patch_size"]))
+            self.gguf_writer.add_vision_embedding_length(self.find_vparam(["hidden_size"]))
+            self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size"]))
+            self.gguf_writer.add_vision_block_count(self.find_vparam(self.n_block_keys))
+            self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads"]))
 
             # preprocessor config
             self.gguf_writer.add_vision_image_mean(self.preprocessor_config["image_mean"])
             self.gguf_writer.add_vision_image_std(self.preprocessor_config["image_std"])
 
-        elif self.has_audio_encoder:
+        if self.has_audio_encoder:
             self.gguf_writer.add_clip_has_audio_encoder(True)
             self.gguf_writer.add_audio_projection_dim(self.n_embd_text)
 
             # audio config
-            self.gguf_writer.add_audio_embedding_length(self.find_hparam(["hidden_size"]))
-            self.gguf_writer.add_audio_feed_forward_length(self.find_hparam(["intermediate_size"]))
-            self.gguf_writer.add_audio_block_count(self.block_count)
-            self.gguf_writer.add_audio_head_count(self.find_hparam(["num_attention_heads"]))
+            self.gguf_writer.add_audio_embedding_length(self.find_aparam(["hidden_size"]))
+            self.gguf_writer.add_audio_feed_forward_length(self.find_aparam(["intermediate_size"]))
+            self.gguf_writer.add_audio_block_count(self.find_aparam(self.n_block_keys))
+            self.gguf_writer.add_audio_head_count(self.find_aparam(["num_attention_heads"]))
 
         else:
             raise ValueError("MmprojModel must have either vision or audio encoder")
@@ -1197,6 +1213,22 @@ class MmprojModel(ModelBase):
     def write_vocab(self):
         raise ValueError("MmprojModel does not support vocab writing")
 
+    def find_vparam(self, keys: Iterable[str], optional: bool = False) -> Any:
+        assert self.hparams_vision is not None
+        return self._find_param(self.hparams_vision, keys, optional)
+
+    def find_aparam(self, keys: Iterable[str], optional: bool = False) -> Any:
+        assert self.hparams_audio is not None
+        return self._find_param(self.hparams_audio, keys, optional)
+
+    def _find_param(self, obj: dict[str, Any], keys: Iterable[str], optional: bool = False) -> Any:
+        key = next((k for k in keys if k in obj), None)
+        if key is not None:
+            return obj[key]
+        if optional:
+            return None
+        raise KeyError(f"could not find any of: {keys}")
+
 
 @ModelBase.register("GPTNeoXForCausalLM")
 class GPTNeoXModel(TextModel):
@@ -2674,7 +2706,12 @@ class Qwen2Model(TextModel):
         yield from super().modify_tensors(data_torch, name, bid)
 
 
-@ModelBase.register("Qwen2VLModel", "Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration")
+@ModelBase.register(
+    "Qwen2VLModel",
+    "Qwen2VLForConditionalGeneration",
+    "Qwen2_5_VLForConditionalGeneration",
+    "Qwen2_5OmniModel",
+)
 class Qwen2VLModel(TextModel):
     model_arch = gguf.MODEL_ARCH.QWEN2VL
 
@@ -2692,8 +2729,11 @@ class Qwen2VLModel(TextModel):
 
     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
         del bid  # unused
-        if name.startswith("visual."):
-            # skip visual tensors
+        if name.startswith("thinker."):
+            name = name.replace("thinker.", "")
+        if name.startswith("visual") or name.startswith("audio") or \
+                name.startswith("talker") or name.startswith("token2wav"):
+            # skip multimodal tensors
             return []
         return [(self.map_tensor_name(name), data_torch)]
 
@@ -2702,21 +2742,27 @@ class Qwen2VLModel(TextModel):
 class Qwen2VLVisionModel(MmprojModel):
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
-        self.hparams["image_size"] = self.hparams.get("image_size", 560)
+        assert self.hparams_vision is not None
+        self.hparams_vision["image_size"] = self.hparams_vision.get("image_size", 560)
         # rename config.json values
-        self.hparams["num_attention_heads"] = self.hparams.get("num_heads")
-        self.hparams["num_hidden_layers"] = self.hparams.get("depth")
-        if "embed_dim" in self.hparams: # qwen2vl
-            self.hparams["intermediate_size"] = self.hparams.get("hidden_size")
-            self.hparams["hidden_size"] = self.hparams.get("embed_dim")
+        self.hparams_vision["num_attention_heads"] = self.hparams_vision.get("num_heads")
+        self.hparams_vision["num_hidden_layers"] = self.hparams_vision.get("depth")
+        if "embed_dim" in self.hparams_vision: # qwen2vl
+            self.hparams_vision["intermediate_size"] = self.hparams_vision.get("hidden_size")
+            self.hparams_vision["hidden_size"] = self.hparams_vision.get("embed_dim")
 
     def set_gguf_parameters(self):
         super().set_gguf_parameters()
-        hparams = self.hparams
-        if self.global_config['model_type'] == 'qwen2_vl':
+        assert self.hparams_vision is not None
+        hparams = self.hparams_vision
+        model_type = self.global_config['model_type']
+        if model_type == 'qwen2_vl':
             self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN2VL)
-        elif self.global_config['model_type'] == 'qwen2_5_vl':
-            self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN25VL)
+        elif model_type == 'qwen2_5_vl' or model_type == 'qwen2_5_omni':
+            if model_type == 'qwen2_5_omni':
+                self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN25O)
+            else:
+                self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN25VL)
             self.gguf_writer.add_vision_use_silu(True)
             # find n_wa_pattern (window attention pattern)
             fullatt_block_indexes = hparams.get("fullatt_block_indexes")
@@ -2774,6 +2820,66 @@ class Qwen2VLVisionModel(MmprojModel):
         return [] # skip other tensors
 
 
+@ModelBase.register("Qwen2_5OmniModel")
+class Qwen25OmniModel(Qwen2VLVisionModel):
+    has_vision_encoder = True
+    has_audio_encoder = True
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        assert self.hparams_audio is not None
+        self.hparams_audio["hidden_size"] = self.hparams_audio["d_model"]
+        self.hparams_audio["intermediate_size"] = self.hparams_audio["encoder_ffn_dim"]
+        self.hparams_audio["num_attention_heads"] = self.hparams_audio["encoder_attention_heads"]
+
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+        assert self.hparams_audio is not None
+        self.gguf_writer.add_audio_num_mel_bins(self.hparams_audio["num_mel_bins"])
+        self.gguf_writer.add_audio_attention_layernorm_eps(self.hparams_audio.get("layer_norm_eps", 1e-5))
+
+    def get_vision_config(self) -> dict[str, Any] | None:
+        return self.global_config["thinker_config"].get("vision_config")
+
+    def get_audio_config(self) -> dict[str, Any] | None:
+        return self.global_config["thinker_config"].get("audio_config")
+
+    def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
+        # SinusoidsPositionEmbedding
+        assert self.hparams_audio is not None
+        max_timescale = 10000
+        length = 1500
+        channels = self.hparams_audio["hidden_size"]
+        log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
+        inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2).float())
+        scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
+        pos_embd = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1).to(dtype=torch.float32)
+        yield ("audio_tower.embed_positions.weight", pos_embd)
+
+    def tensor_force_quant(self, name, new_name, bid, n_dims):
+        del bid, new_name, n_dims  # unused
+        if ".conv" in name and ".weight" in name:
+            return gguf.GGMLQuantizationType.F16
+        return False
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        if name.startswith("thinker."):
+            name = name.replace("thinker.", "")
+
+        if name.startswith("audio_tower"):
+            # process audio tensors
+            if "conv1.bias" in name or "conv2.bias" in name:
+                # transpose conv1 and conv2 bias
+                data_torch = data_torch.unsqueeze(-1)
+            if "audio_bos_eos_token" in name:
+                # this tensor is left unused in transformers code
+                # https://github.com/huggingface/transformers/blob/6e3063422c4b1c014aa60c32b9254fd2902f0f28/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py#L1809
+                return []
+            return [(self.map_tensor_name(name), data_torch)]
+
+        return super().modify_tensors(data_torch, name, bid)
+
+
 @ModelBase.register("InternVisionModel")
 class InternVisionModel(MmprojModel):
     def set_gguf_parameters(self):
index 3a0994a279ae87137698bb6251f12956368bef46..e849c2a0b8ba15bb3fc5e553baef1e39e0b8dbb5 100644 (file)
@@ -98,3 +98,12 @@ NOTE: some models may require large context window, for example: `-c 8192`
 # note: no pre-quantized GGUF this model, as they have very poor result
 # ref: https://github.com/ggml-org/llama.cpp/pull/13760
 ```
+
+**Mixed modalities**:
+
+```sh
+# Qwen2.5 Omni
+# Capabilities: audio input, vision input
+(tool_name) -hf ggml-org/Qwen2.5-Omni-3B-GGUF
+(tool_name) -hf ggml-org/Qwen2.5-Omni-7B-GGUF
+```
index c6255d6867a1505a147722e240a1e76884d264ac..31163effad8f283cf92915d019c3f68783e398a8 100644 (file)
@@ -2260,6 +2260,7 @@ class VisionProjectorType:
     ULTRAVOX = "ultravox"
     INTERNVL = "internvl"
     QWEN2A = "qwen2a" # audio
+    QWEN25O = "qwen2.5o" # omni
 
 
 # Items here are (block size, type size)
index 4a0615b656812304645cb940135f15da6949dad1..000ffd00615b580d73419433e437fcade4863d67 100644 (file)
@@ -1125,6 +1125,7 @@ class TensorNameMap:
 
         MODEL_TENSOR.A_POST_NORM: (
             "audio_tower.layer_norm", # ultravox
+            "audio_tower.ln_post", # qwen2omni
         ),
 
         MODEL_TENSOR.A_ENC_ATTN_Q: (
@@ -1161,12 +1162,16 @@ class TensorNameMap:
             "audio_tower.layers.{bid}.fc2", # ultravox
         ),
 
+        # note: some tensors below has "audio." pseudo-prefix, to prevent conflicts with vision tensors
+        # this prefix is added in the conversion code in modify_tensors()
+
         MODEL_TENSOR.A_MMPROJ: (
             "audio.multi_modal_projector.linear_{bid}", # ultravox
         ),
 
         MODEL_TENSOR.A_MMPROJ_FC: (
             "audio.multi_modal_projector.linear", # qwen2audio
+            "audio_tower.proj", # qwen2omni
         ),
 
         MODEL_TENSOR.A_MM_NORM_PRE: (
index 27ce8c43f678ccdea992e4292712792f6c2ef3f3..62c936ed00f7752c35dddbcdf3c0771fb6eafd08 100644 (file)
@@ -130,6 +130,7 @@ enum projector_type {
     PROJECTOR_TYPE_INTERNVL,
     PROJECTOR_TYPE_LLAMA4,
     PROJECTOR_TYPE_QWEN2A,
+    PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx
     PROJECTOR_TYPE_UNKNOWN,
 };
 
@@ -148,6 +149,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
     { PROJECTOR_TYPE_INTERNVL,  "internvl"},
     { PROJECTOR_TYPE_LLAMA4,    "llama4"},
     { PROJECTOR_TYPE_QWEN2A,    "qwen2a"},
+    { PROJECTOR_TYPE_QWEN25O,   "qwen2.5o"},
 };
 
 static projector_type clip_projector_type_from_string(const std::string & str) {
index 6205dad5ae262b5435fc0df569b1574b9b4a6b77..6ae2c2ce46fd28ef7e55a8b43c04812439ee4e5d 100644 (file)
@@ -166,9 +166,6 @@ enum patch_merge_type {
 };
 
 struct clip_hparams {
-    bool has_vision = false;
-    bool has_audio = false;
-
     int32_t image_size;
     int32_t patch_size;
     int32_t n_embd;
@@ -178,9 +175,13 @@ struct clip_hparams {
     int32_t n_layer;
     int32_t proj_scale_factor = 0; // idefics3
 
+    float image_mean[3];
+    float image_std[3];
+
     // for models using dynamic image size, we need to have a smaller image size to warmup
     // otherwise, user will get OOM everytime they load the model
     int32_t warmup_image_size = 0;
+    int32_t warmup_audio_size = 3000;
 
     ffn_op_type ffn_op = FFN_GELU;
 
@@ -199,6 +200,10 @@ struct clip_hparams {
     // audio
     int32_t n_mel_bins = 0; // whisper preprocessor
     int32_t proj_stack_factor = 0; // ultravox
+
+    // legacy
+    bool has_llava_projector = false;
+    int minicpmv_version = 0;
 };
 
 struct clip_layer {
@@ -236,8 +241,10 @@ struct clip_layer {
     ggml_tensor * ls_2_w = nullptr;
 };
 
-struct clip_vision_model {
-    struct clip_hparams hparams;
+struct clip_model {
+    clip_modality modality = CLIP_MODALITY_VISION;
+    projector_type proj_type = PROJECTOR_TYPE_MLP;
+    clip_hparams hparams;
 
     // embeddings
     ggml_tensor * class_embedding = nullptr;
@@ -353,14 +360,7 @@ struct clip_vision_model {
 };
 
 struct clip_ctx {
-    bool has_llava_projector = false;
-    int minicpmv_version = 0;
-
-    struct clip_vision_model vision_model;
-    projector_type proj_type = PROJECTOR_TYPE_MLP;
-
-    float image_mean[3];
-    float image_std[3];
+    clip_model model;
 
     gguf_context_ptr ctx_gguf;
     ggml_context_ptr ctx_data;
@@ -414,11 +414,16 @@ struct clip_ctx {
             ggml_backend_free(backend_cpu);
         }
     }
+
+    // this function is added so that we don't change too much of the existing code
+    projector_type proj_type() const {
+        return model.proj_type;
+    }
 };
 
 struct clip_graph {
     clip_ctx * ctx;
-    const clip_vision_model & model;
+    const clip_model & model;
     const clip_hparams & hparams;
 
     // we only support single image per batch
@@ -441,7 +446,7 @@ struct clip_graph {
 
     clip_graph(clip_ctx * ctx, const clip_image_f32 & img) :
             ctx(ctx),
-            model(ctx->vision_model),
+            model(ctx->model),
             hparams(model.hparams),
             img(img),
             patch_size(hparams.patch_size),
@@ -473,7 +478,7 @@ struct clip_graph {
                                 model.position_embeddings,
                                 nullptr);
 
-        if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
+        if (ctx->proj_type() == PROJECTOR_TYPE_GEMMA3) {
             const int batch_size = 1;
             GGML_ASSERT(n_patches_x == n_patches_y);
             const int patches_per_image = n_patches_x;
@@ -496,7 +501,7 @@ struct clip_graph {
                 ggml_cont(ctx0, ggml_transpose(ctx0, model.mm_input_proj_w)),
                 cur);
 
-        } else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
+        } else if (ctx->proj_type() == PROJECTOR_TYPE_IDEFICS3) {
             // https://github.com/huggingface/transformers/blob/0a950e0bbe1ed58d5401a6b547af19f15f0c195e/src/transformers/models/idefics3/modeling_idefics3.py#L578
 
             const int scale_factor = model.hparams.proj_scale_factor;
@@ -630,7 +635,7 @@ struct clip_graph {
         const int n_pos            = n_patches;
         const int num_position_ids = n_pos * 4; // m-rope requires 4 dim per position
 
-        norm_type norm_t = ctx->proj_type == PROJECTOR_TYPE_QWEN25VL
+        norm_type norm_t = ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL
             ? NORM_TYPE_RMS // qwen 2.5 vl
             : NORM_TYPE_NORMAL; // qwen 2 vl
 
@@ -846,11 +851,11 @@ struct clip_graph {
             const int d_head = 128;
             int n_head = n_embd/d_head;
             int num_query = 96;
-            if (ctx->minicpmv_version == 2) {
+            if (ctx->model.hparams.minicpmv_version == 2) {
                 num_query = 96;
-            } else if (ctx->minicpmv_version == 3) {
+            } else if (ctx->model.hparams.minicpmv_version == 3) {
                 num_query = 64;
-            } else if (ctx->minicpmv_version == 4) {
+            } else if (ctx->model.hparams.minicpmv_version == 4) {
                 num_query = 64;
             }
 
@@ -1067,7 +1072,7 @@ struct clip_graph {
             int il_last = hparams.n_layer - 1;
             int deepest_feature_layer = -1;
 
-            if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
+            if (ctx->proj_type() == PROJECTOR_TYPE_MINICPMV || ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE) {
                 il_last += 1;
             }
 
@@ -1201,7 +1206,7 @@ struct clip_graph {
         }
 
         // llava projector (also used by granite)
-        if (ctx->has_llava_projector) {
+        if (ctx->model.hparams.has_llava_projector) {
             embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);
 
             ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
@@ -1215,7 +1220,7 @@ struct clip_graph {
             // print_tensor_info(embeddings, "embeddings");
 
             // llava projector
-            if (ctx->proj_type == PROJECTOR_TYPE_MLP) {
+            if (ctx->proj_type() == PROJECTOR_TYPE_MLP) {
                 embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
                 embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
 
@@ -1225,7 +1230,7 @@ struct clip_graph {
                     embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
                 }
             }
-            else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) {
+            else if (ctx->proj_type() == PROJECTOR_TYPE_MLP_NORM) {
                 embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
                 embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
                 // ggml_tensor_printf(embeddings, "mm_0_w",0,true,false);
@@ -1246,7 +1251,7 @@ struct clip_graph {
                 embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_4_w),
                                     model.mm_4_b);
             }
-            else if (ctx->proj_type == PROJECTOR_TYPE_LDP) {
+            else if (ctx->proj_type() == PROJECTOR_TYPE_LDP) {
                 // MobileVLM projector
                 int n_patch = 24;
                 ggml_tensor * mlp_1 = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w, embeddings);
@@ -1356,7 +1361,7 @@ struct clip_graph {
                 }
                 embeddings = block_1;
             }
-            else if (ctx->proj_type == PROJECTOR_TYPE_LDPV2)
+            else if (ctx->proj_type() == PROJECTOR_TYPE_LDPV2)
             {
                 int n_patch = 24;
                 ggml_tensor * mlp_0 = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings);
@@ -1386,7 +1391,7 @@ struct clip_graph {
         }
 
         // glm projector
-        else if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
+        else if (ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE) {
             size_t gridsz = (size_t)sqrt(embeddings->ne[1]);
             embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings,1,0,2,3));
             embeddings = ggml_reshape_3d(ctx0, embeddings, gridsz, gridsz, embeddings->ne[1]);
@@ -1473,7 +1478,7 @@ struct clip_graph {
 
         cb(cur, "after_transformer", -1);
 
-        if (ctx->proj_type == PROJECTOR_TYPE_ULTRAVOX) {
+        if (ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX) {
             // StackAudioFrames
             // https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_2-1b/blob/main/ultravox_model.py
             {
@@ -1518,7 +1523,7 @@ struct clip_graph {
                 cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
             }
 
-        } else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2A) {
+        } else if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2A) {
             // projector
             cur = ggml_mul_mat(ctx0, model.mm_fc_w, cur);
             cur = ggml_add(ctx0, cur, model.mm_fc_b);
@@ -1668,7 +1673,7 @@ private:
         }
 
         // TODO @ngxson : find a way to move this outside
-        if (ctx->proj_type == PROJECTOR_TYPE_QWEN2A) {
+        if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2A) {
             ggml_tensor * cur = inpL;
             cur = ggml_transpose(ctx0, cur);
             cur = ggml_cont(ctx0, cur);
@@ -1947,7 +1952,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
 
     ggml_cgraph * res;
 
-    switch (ctx->proj_type) {
+    switch (ctx->proj_type()) {
         case PROJECTOR_TYPE_GEMMA3:
         case PROJECTOR_TYPE_IDEFICS3:
             {
@@ -1991,13 +1996,15 @@ struct clip_model_loader {
     ggml_context_ptr ctx_meta;
     gguf_context_ptr ctx_gguf;
 
-    clip_ctx & ctx_clip;
     std::string fname;
 
     size_t model_size = 0; // in bytes
 
-    // TODO @ngxson : we should not pass clip_ctx here, it should be clip_vision_model
-    clip_model_loader(const char * fname, clip_ctx & ctx_clip) : ctx_clip(ctx_clip), fname(fname) {
+    bool has_vision = false;
+    bool has_audio  = false;
+
+    // TODO @ngxson : we should not pass clip_ctx here, it should be clip_model
+    clip_model_loader(const char * fname) : fname(fname) {
         struct ggml_context * meta = nullptr;
 
         struct gguf_init_params params = {
@@ -2029,6 +2036,19 @@ struct clip_model_loader {
             LOG_INF("\n");
         }
 
+        // modalities
+        {
+            get_bool(KEY_HAS_VISION_ENC, has_vision, false);
+            get_bool(KEY_HAS_AUDIO_ENC,  has_audio,  false);
+
+            if (has_vision) {
+                LOG_INF("%s: has vision encoder\n", __func__);
+            }
+            if (has_audio) {
+                LOG_INF("%s: has audio encoder\n", __func__);
+            }
+        }
+
         // tensors
         {
             for (int i = 0; i < n_tensors; ++i) {
@@ -2044,28 +2064,44 @@ struct clip_model_loader {
         }
     }
 
-    void load_hparams() {
-        auto & hparams = ctx_clip.vision_model.hparams;
+    void load_hparams(clip_model & model, clip_modality modality) {
+        auto & hparams = model.hparams;
         std::string log_ffn_op; // for logging
 
+        // sanity check
+        if (modality == CLIP_MODALITY_VISION) {
+            GGML_ASSERT(has_vision);
+        } else if (modality == CLIP_MODALITY_AUDIO) {
+            GGML_ASSERT(has_audio);
+        }
+        model.modality = modality;
+
+
         // projector type
         std::string proj_type;
         {
             get_string(KEY_PROJ_TYPE, proj_type, false);
             if (!proj_type.empty()) {
-                ctx_clip.proj_type = clip_projector_type_from_string(proj_type);
+                model.proj_type = clip_projector_type_from_string(proj_type);
             }
-            if (ctx_clip.proj_type == PROJECTOR_TYPE_UNKNOWN) {
+            if (model.proj_type == PROJECTOR_TYPE_UNKNOWN) {
                 throw std::runtime_error(string_format("%s: unknown projector type: %s\n", __func__, proj_type.c_str()));
             }
+
+            // correct arch for multimodal models
+            if (model.proj_type == PROJECTOR_TYPE_QWEN25O) {
+                model.proj_type = modality == CLIP_MODALITY_VISION
+                                    ? PROJECTOR_TYPE_QWEN25VL
+                                    : PROJECTOR_TYPE_QWEN2A;
+            }
         }
 
+        const bool is_vision = model.modality == CLIP_MODALITY_VISION;
+        const bool is_audio  = model.modality == CLIP_MODALITY_AUDIO;
+
         // other hparams
         {
-            get_bool(KEY_HAS_AUDIO_ENC,  hparams.has_audio, false);
-            get_bool(KEY_HAS_VISION_ENC, hparams.has_vision, false);
-
-            const char * prefix = hparams.has_vision ? "vision" : "audio";
+            const char * prefix = is_vision ? "vision" : "audio";
             get_u32(string_format(KEY_N_EMBD,         prefix), hparams.n_embd);
             get_u32(string_format(KEY_N_HEAD,         prefix), hparams.n_head);
             get_u32(string_format(KEY_N_FF,           prefix), hparams.n_ff);
@@ -2073,27 +2109,27 @@ struct clip_model_loader {
             get_u32(string_format(KEY_PROJ_DIM,       prefix), hparams.projection_dim);
             get_f32(string_format(KEY_LAYER_NORM_EPS, prefix), hparams.eps);
 
-            if (hparams.has_vision) {
+            if (is_vision) {
                 get_u32(KEY_IMAGE_SIZE, hparams.image_size);
                 get_u32(KEY_PATCH_SIZE, hparams.patch_size);
                 get_u32(KEY_IMAGE_CROP_RESOLUTION,    hparams.image_crop_resolution, false);
                 get_arr_int(KEY_IMAGE_GRID_PINPOINTS, hparams.image_grid_pinpoints, false);
-                get_i32(KEY_MINICPMV_VERSION, ctx_clip.minicpmv_version, false); // legacy
+                get_i32(KEY_MINICPMV_VERSION, hparams.minicpmv_version, false); // legacy
 
-            } else if (hparams.has_audio) {
+            } else if (is_audio) {
                 get_u32(KEY_A_NUM_MEL_BINS, hparams.n_mel_bins);
 
             } else {
-                throw std::runtime_error(string_format("%s: neither vision nor audio encoder is present\n", __func__));
+                GGML_ASSERT(false && "unknown modality");
             }
 
             // default warmup value
             hparams.warmup_image_size = hparams.image_size;
 
-            ctx_clip.has_llava_projector = ctx_clip.proj_type == PROJECTOR_TYPE_MLP
-                                        || ctx_clip.proj_type == PROJECTOR_TYPE_MLP_NORM
-                                        || ctx_clip.proj_type == PROJECTOR_TYPE_LDP
-                                        || ctx_clip.proj_type == PROJECTOR_TYPE_LDPV2;
+            hparams.has_llava_projector = model.proj_type == PROJECTOR_TYPE_MLP
+                                       || model.proj_type == PROJECTOR_TYPE_MLP_NORM
+                                       || model.proj_type == PROJECTOR_TYPE_LDP
+                                       || model.proj_type == PROJECTOR_TYPE_LDPV2;
 
             {
                 bool use_gelu = false;
@@ -2123,7 +2159,7 @@ struct clip_model_loader {
                 }
             }
 
-            if (hparams.has_vision) {
+            if (is_vision) {
                 int idx_mean = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_MEAN);
                 int idx_std  = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_STD);
                 GGML_ASSERT(idx_mean >= 0 && "image_mean not found");
@@ -2131,8 +2167,8 @@ struct clip_model_loader {
                 const float * mean_data = (const float *) gguf_get_arr_data(ctx_gguf.get(), idx_mean);
                 const float * std_data  = (const float *) gguf_get_arr_data(ctx_gguf.get(), idx_std);
                 for (int i = 0; i < 3; ++i) {
-                    ctx_clip.image_mean[i] = mean_data[i];
-                    ctx_clip.image_std[i]  = std_data[i];
+                    hparams.image_mean[i] = mean_data[i];
+                    hparams.image_std[i]  = std_data[i];
                 }
             }
 
@@ -2149,11 +2185,11 @@ struct clip_model_loader {
             }
 
             // model-specific params
-            switch (ctx_clip.proj_type) {
+            switch (model.proj_type) {
                 case PROJECTOR_TYPE_MINICPMV:
                     {
-                        if (ctx_clip.minicpmv_version == 0) {
-                            ctx_clip.minicpmv_version = 2; // default to 2 if not set
+                        if (hparams.minicpmv_version == 0) {
+                            hparams.minicpmv_version = 2; // default to 2 if not set
                         }
                     } break;
                 case PROJECTOR_TYPE_IDEFICS3:
@@ -2212,7 +2248,7 @@ struct clip_model_loader {
                 case PROJECTOR_TYPE_ULTRAVOX:
                 case PROJECTOR_TYPE_QWEN2A:
                     {
-                        bool require_stack = ctx_clip.proj_type == PROJECTOR_TYPE_ULTRAVOX;
+                        bool require_stack = model.proj_type == PROJECTOR_TYPE_ULTRAVOX;
                         get_u32(KEY_A_PROJ_STACK_FACTOR, hparams.proj_stack_factor, require_stack);
                         if (hparams.n_mel_bins != 128) {
                             throw std::runtime_error(string_format("%s: only 128 mel bins are supported for ultravox\n", __func__));
@@ -2225,23 +2261,22 @@ struct clip_model_loader {
             }
 
             LOG_INF("%s: projector:          %s\n", __func__, proj_type.c_str());
-            LOG_INF("%s: has_vision_encoder: %d\n", __func__, hparams.has_vision);
-            LOG_INF("%s: has_audio_encoder:  %d\n", __func__, hparams.has_audio);
             LOG_INF("%s: n_embd:             %d\n", __func__, hparams.n_embd);
             LOG_INF("%s: n_head:             %d\n", __func__, hparams.n_head);
             LOG_INF("%s: n_ff:               %d\n", __func__, hparams.n_ff);
             LOG_INF("%s: n_layer:            %d\n", __func__, hparams.n_layer);
             LOG_INF("%s: ffn_op:             %s\n", __func__, log_ffn_op.c_str());
             LOG_INF("%s: projection_dim:     %d\n", __func__, hparams.projection_dim);
-            LOG_INF("\n");
-            if (hparams.has_vision) {
+            if (is_vision) {
+                LOG_INF("\n--- vision hparams ---\n");
                 LOG_INF("%s: image_size:         %d\n", __func__, hparams.image_size);
                 LOG_INF("%s: patch_size:         %d\n", __func__, hparams.patch_size);
-                LOG_INF("%s: has_llava_proj:     %d\n", __func__, ctx_clip.has_llava_projector);
-                LOG_INF("%s: minicpmv_version:   %d\n", __func__, ctx_clip.minicpmv_version);
+                LOG_INF("%s: has_llava_proj:     %d\n", __func__, hparams.has_llava_projector);
+                LOG_INF("%s: minicpmv_version:   %d\n", __func__, hparams.minicpmv_version);
                 LOG_INF("%s: proj_scale_factor:  %d\n", __func__, hparams.proj_scale_factor);
                 LOG_INF("%s: n_wa_pattern:       %d\n", __func__, hparams.n_wa_pattern);
-            } else if (hparams.has_audio) {
+            } else if (is_audio) {
+                LOG_INF("\n--- audio hparams ---\n");
                 LOG_INF("%s: n_mel_bins:         %d\n", __func__, hparams.n_mel_bins);
                 LOG_INF("%s: proj_stack_factor:  %d\n", __func__, hparams.proj_stack_factor);
             }
@@ -2251,13 +2286,14 @@ struct clip_model_loader {
         }
     }
 
-    void load_tensors() {
-        auto & hparams = ctx_clip.vision_model.hparams;
+    void load_tensors(clip_ctx & ctx_clip) {
+        auto & model = ctx_clip.model;
+        auto & hparams = model.hparams;
         std::map<std::string, size_t> tensor_offset;
         std::vector<ggml_tensor *> tensors_to_load;
 
         // TODO @ngxson : support both audio and video in the future
-        const char * prefix = hparams.has_audio ? "a" : "v";
+        const char * prefix = model.modality == CLIP_MODALITY_AUDIO ? "a" : "v";
 
         // get offsets
         for (int64_t i = 0; i < gguf_get_n_tensors(ctx_gguf.get()); ++i) {
@@ -2292,26 +2328,24 @@ struct clip_model_loader {
             return cur;
         };
 
-        auto & vision_model = ctx_clip.vision_model; // TODO: rename this to just "model"
-
-        vision_model.class_embedding = get_tensor(TN_CLASS_EMBD, false);
+        model.class_embedding = get_tensor(TN_CLASS_EMBD, false);
 
-        vision_model.pre_ln_w = get_tensor(string_format(TN_LN_PRE, prefix, "weight"), false);
-        vision_model.pre_ln_b = get_tensor(string_format(TN_LN_PRE, prefix, "bias"),   false);
+        model.pre_ln_w = get_tensor(string_format(TN_LN_PRE, prefix, "weight"), false);
+        model.pre_ln_b = get_tensor(string_format(TN_LN_PRE, prefix, "bias"),   false);
 
-        vision_model.post_ln_w = get_tensor(string_format(TN_LN_POST, prefix, "weight"), false);
-        vision_model.post_ln_b = get_tensor(string_format(TN_LN_POST, prefix, "bias"),   false);
+        model.post_ln_w = get_tensor(string_format(TN_LN_POST, prefix, "weight"), false);
+        model.post_ln_b = get_tensor(string_format(TN_LN_POST, prefix, "bias"),   false);
 
-        vision_model.patch_bias = get_tensor(TN_PATCH_BIAS, false);
-        vision_model.patch_embeddings_0 = get_tensor(TN_PATCH_EMBD,   false);
-        vision_model.patch_embeddings_1 = get_tensor(TN_PATCH_EMBD_1, false);
+        model.patch_bias = get_tensor(TN_PATCH_BIAS, false);
+        model.patch_embeddings_0 = get_tensor(TN_PATCH_EMBD,   false);
+        model.patch_embeddings_1 = get_tensor(TN_PATCH_EMBD_1, false);
 
-        vision_model.position_embeddings = get_tensor(string_format(TN_POS_EMBD, prefix), false);
+        model.position_embeddings = get_tensor(string_format(TN_POS_EMBD, prefix), false);
 
         // layers
-        vision_model.layers.resize(hparams.n_layer);
+        model.layers.resize(hparams.n_layer);
         for (int il = 0; il < hparams.n_layer; ++il) {
-            auto & layer = vision_model.layers[il];
+            auto & layer = model.layers[il];
             layer.k_w    = get_tensor(string_format(TN_ATTN_K,      prefix, il, "weight"));
             layer.q_w    = get_tensor(string_format(TN_ATTN_Q,      prefix, il, "weight"));
             layer.v_w    = get_tensor(string_format(TN_ATTN_V,      prefix, il, "weight"));
@@ -2352,166 +2386,166 @@ struct clip_model_loader {
             }
         }
 
-        switch (ctx_clip.proj_type) {
+        switch (model.proj_type) {
             case PROJECTOR_TYPE_MLP:
             case PROJECTOR_TYPE_MLP_NORM:
                 {
                     // LLaVA projection
-                    vision_model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"), false);
-                    vision_model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"), false);
+                    model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"), false);
+                    model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"), false);
                     // Yi-type llava
-                    vision_model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"), false);
-                    vision_model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false);
+                    model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"), false);
+                    model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false);
                     // missing in Yi-type llava
-                    vision_model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"), false);
-                    vision_model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false);
+                    model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"), false);
+                    model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false);
                     // Yi-type llava
-                    vision_model.mm_3_w = get_tensor(string_format(TN_LLAVA_PROJ, 3, "weight"), false);
-                    vision_model.mm_3_b = get_tensor(string_format(TN_LLAVA_PROJ, 3, "bias"), false);
-                    vision_model.mm_4_w = get_tensor(string_format(TN_LLAVA_PROJ, 4, "weight"), false);
-                    vision_model.mm_4_b = get_tensor(string_format(TN_LLAVA_PROJ, 4, "bias"), false);
-                    if (vision_model.mm_3_w) {
+                    model.mm_3_w = get_tensor(string_format(TN_LLAVA_PROJ, 3, "weight"), false);
+                    model.mm_3_b = get_tensor(string_format(TN_LLAVA_PROJ, 3, "bias"), false);
+                    model.mm_4_w = get_tensor(string_format(TN_LLAVA_PROJ, 4, "weight"), false);
+                    model.mm_4_b = get_tensor(string_format(TN_LLAVA_PROJ, 4, "bias"), false);
+                    if (model.mm_3_w) {
                         // TODO: this is a hack to support Yi-type llava
-                        ctx_clip.proj_type = PROJECTOR_TYPE_MLP_NORM;
+                        model.proj_type = PROJECTOR_TYPE_MLP_NORM;
                     }
-                    vision_model.image_newline = get_tensor(TN_IMAGE_NEWLINE, false);
+                    model.image_newline = get_tensor(TN_IMAGE_NEWLINE, false);
                 } break;
             case PROJECTOR_TYPE_LDP:
                 {
                     // MobileVLM projection
-                    vision_model.mm_model_mlp_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight"));
-                    vision_model.mm_model_mlp_1_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "bias"));
-                    vision_model.mm_model_mlp_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight"));
-                    vision_model.mm_model_mlp_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias"));
-                    vision_model.mm_model_block_1_block_0_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 0, "0.weight"));
-                    vision_model.mm_model_block_1_block_0_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 0, "1.weight"));
-                    vision_model.mm_model_block_1_block_0_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 0, "1.bias"));
-                    vision_model.mm_model_block_1_block_1_fc1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc1.weight"));
-                    vision_model.mm_model_block_1_block_1_fc1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc1.bias"));
-                    vision_model.mm_model_block_1_block_1_fc2_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc2.weight"));
-                    vision_model.mm_model_block_1_block_1_fc2_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc2.bias"));
-                    vision_model.mm_model_block_1_block_2_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 2, "0.weight"));
-                    vision_model.mm_model_block_1_block_2_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 2, "1.weight"));
-                    vision_model.mm_model_block_1_block_2_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 2, "1.bias"));
-                    vision_model.mm_model_block_2_block_0_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 0, "0.weight"));
-                    vision_model.mm_model_block_2_block_0_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 0, "1.weight"));
-                    vision_model.mm_model_block_2_block_0_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 0, "1.bias"));
-                    vision_model.mm_model_block_2_block_1_fc1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc1.weight"));
-                    vision_model.mm_model_block_2_block_1_fc1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc1.bias"));
-                    vision_model.mm_model_block_2_block_1_fc2_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc2.weight"));
-                    vision_model.mm_model_block_2_block_1_fc2_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc2.bias"));
-                    vision_model.mm_model_block_2_block_2_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 2, "0.weight"));
-                    vision_model.mm_model_block_2_block_2_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.weight"));
-                    vision_model.mm_model_block_2_block_2_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.bias"));
+                    model.mm_model_mlp_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight"));
+                    model.mm_model_mlp_1_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "bias"));
+                    model.mm_model_mlp_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight"));
+                    model.mm_model_mlp_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias"));
+                    model.mm_model_block_1_block_0_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 0, "0.weight"));
+                    model.mm_model_block_1_block_0_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 0, "1.weight"));
+                    model.mm_model_block_1_block_0_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 0, "1.bias"));
+                    model.mm_model_block_1_block_1_fc1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc1.weight"));
+                    model.mm_model_block_1_block_1_fc1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc1.bias"));
+                    model.mm_model_block_1_block_1_fc2_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc2.weight"));
+                    model.mm_model_block_1_block_1_fc2_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc2.bias"));
+                    model.mm_model_block_1_block_2_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 2, "0.weight"));
+                    model.mm_model_block_1_block_2_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 2, "1.weight"));
+                    model.mm_model_block_1_block_2_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 2, "1.bias"));
+                    model.mm_model_block_2_block_0_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 0, "0.weight"));
+                    model.mm_model_block_2_block_0_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 0, "1.weight"));
+                    model.mm_model_block_2_block_0_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 0, "1.bias"));
+                    model.mm_model_block_2_block_1_fc1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc1.weight"));
+                    model.mm_model_block_2_block_1_fc1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc1.bias"));
+                    model.mm_model_block_2_block_1_fc2_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc2.weight"));
+                    model.mm_model_block_2_block_1_fc2_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc2.bias"));
+                    model.mm_model_block_2_block_2_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 2, "0.weight"));
+                    model.mm_model_block_2_block_2_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.weight"));
+                    model.mm_model_block_2_block_2_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.bias"));
                 } break;
             case PROJECTOR_TYPE_LDPV2:
                 {
                     // MobilVLM_V2 projection
-                    vision_model.mm_model_mlp_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight"));
-                    vision_model.mm_model_mlp_0_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "bias"));
-                    vision_model.mm_model_mlp_2_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "weight"));
-                    vision_model.mm_model_mlp_2_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "bias"));
-                    vision_model.mm_model_peg_0_w = get_tensor(string_format(TN_MVLM_PROJ_PEG, 0, "weight"));
-                    vision_model.mm_model_peg_0_b = get_tensor(string_format(TN_MVLM_PROJ_PEG, 0, "bias"));
+                    model.mm_model_mlp_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight"));
+                    model.mm_model_mlp_0_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "bias"));
+                    model.mm_model_mlp_2_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "weight"));
+                    model.mm_model_mlp_2_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "bias"));
+                    model.mm_model_peg_0_w = get_tensor(string_format(TN_MVLM_PROJ_PEG, 0, "weight"));
+                    model.mm_model_peg_0_b = get_tensor(string_format(TN_MVLM_PROJ_PEG, 0, "bias"));
                 } break;
             case PROJECTOR_TYPE_MINICPMV:
                 {
-                    // vision_model.mm_model_pos_embed = get_tensor(new_clip->ctx_data, TN_MINICPMV_POS_EMBD);
-                    vision_model.mm_model_pos_embed_k = get_tensor(TN_MINICPMV_POS_EMBD_K);
-                    vision_model.mm_model_query = get_tensor(TN_MINICPMV_QUERY);
-                    vision_model.mm_model_proj = get_tensor(TN_MINICPMV_PROJ);
-                    vision_model.mm_model_kv_proj = get_tensor(TN_MINICPMV_KV_PROJ);
-                    vision_model.mm_model_attn_q_w = get_tensor(string_format(TN_MINICPMV_ATTN, "q", "weight"));
-                    vision_model.mm_model_attn_k_w = get_tensor(string_format(TN_MINICPMV_ATTN, "k", "weight"));
-                    vision_model.mm_model_attn_v_w = get_tensor(string_format(TN_MINICPMV_ATTN, "v", "weight"));
-                    vision_model.mm_model_attn_q_b = get_tensor(string_format(TN_MINICPMV_ATTN, "q", "bias"));
-                    vision_model.mm_model_attn_k_b = get_tensor(string_format(TN_MINICPMV_ATTN, "k", "bias"));
-                    vision_model.mm_model_attn_v_b = get_tensor(string_format(TN_MINICPMV_ATTN, "v", "bias"));
-                    vision_model.mm_model_attn_o_w = get_tensor(string_format(TN_MINICPMV_ATTN, "out", "weight"));
-                    vision_model.mm_model_attn_o_b = get_tensor(string_format(TN_MINICPMV_ATTN, "out", "bias"));
-                    vision_model.mm_model_ln_q_w = get_tensor(string_format(TN_MINICPMV_LN, "q", "weight"));
-                    vision_model.mm_model_ln_q_b = get_tensor(string_format(TN_MINICPMV_LN, "q", "bias"));
-                    vision_model.mm_model_ln_kv_w = get_tensor(string_format(TN_MINICPMV_LN, "kv", "weight"));
-                    vision_model.mm_model_ln_kv_b = get_tensor(string_format(TN_MINICPMV_LN, "kv", "bias"));
-                    vision_model.mm_model_ln_post_w = get_tensor(string_format(TN_MINICPMV_LN, "post", "weight"));
-                    vision_model.mm_model_ln_post_b = get_tensor(string_format(TN_MINICPMV_LN, "post", "bias"));
+                    // model.mm_model_pos_embed = get_tensor(new_clip->ctx_data, TN_MINICPMV_POS_EMBD);
+                    model.mm_model_pos_embed_k = get_tensor(TN_MINICPMV_POS_EMBD_K);
+                    model.mm_model_query = get_tensor(TN_MINICPMV_QUERY);
+                    model.mm_model_proj = get_tensor(TN_MINICPMV_PROJ);
+                    model.mm_model_kv_proj = get_tensor(TN_MINICPMV_KV_PROJ);
+                    model.mm_model_attn_q_w = get_tensor(string_format(TN_MINICPMV_ATTN, "q", "weight"));
+                    model.mm_model_attn_k_w = get_tensor(string_format(TN_MINICPMV_ATTN, "k", "weight"));
+                    model.mm_model_attn_v_w = get_tensor(string_format(TN_MINICPMV_ATTN, "v", "weight"));
+                    model.mm_model_attn_q_b = get_tensor(string_format(TN_MINICPMV_ATTN, "q", "bias"));
+                    model.mm_model_attn_k_b = get_tensor(string_format(TN_MINICPMV_ATTN, "k", "bias"));
+                    model.mm_model_attn_v_b = get_tensor(string_format(TN_MINICPMV_ATTN, "v", "bias"));
+                    model.mm_model_attn_o_w = get_tensor(string_format(TN_MINICPMV_ATTN, "out", "weight"));
+                    model.mm_model_attn_o_b = get_tensor(string_format(TN_MINICPMV_ATTN, "out", "bias"));
+                    model.mm_model_ln_q_w = get_tensor(string_format(TN_MINICPMV_LN, "q", "weight"));
+                    model.mm_model_ln_q_b = get_tensor(string_format(TN_MINICPMV_LN, "q", "bias"));
+                    model.mm_model_ln_kv_w = get_tensor(string_format(TN_MINICPMV_LN, "kv", "weight"));
+                    model.mm_model_ln_kv_b = get_tensor(string_format(TN_MINICPMV_LN, "kv", "bias"));
+                    model.mm_model_ln_post_w = get_tensor(string_format(TN_MINICPMV_LN, "post", "weight"));
+                    model.mm_model_ln_post_b = get_tensor(string_format(TN_MINICPMV_LN, "post", "bias"));
                 } break;
             case PROJECTOR_TYPE_GLM_EDGE:
                 {
-                    vision_model.mm_model_adapter_conv_w = get_tensor(string_format(TN_GLM_ADAPER_CONV, "weight"));
-                    vision_model.mm_model_adapter_conv_b = get_tensor(string_format(TN_GLM_ADAPER_CONV, "bias"));
-                    vision_model.mm_model_mlp_0_w = get_tensor(string_format(TN_GLM_ADAPTER_LINEAR, "weight"));
-                    vision_model.mm_model_ln_q_w = get_tensor(string_format(TN_GLM_ADAPTER_NORM_1, "weight"));
-                    vision_model.mm_model_ln_q_b = get_tensor(string_format(TN_GLM_ADAPTER_NORM_1, "bias"));
-                    vision_model.mm_model_mlp_1_w = get_tensor(string_format(TN_GLM_ADAPTER_D_H_2_4H, "weight"));
-                    vision_model.mm_model_mlp_2_w = get_tensor(string_format(TN_GLM_ADAPTER_GATE, "weight"));
-                    vision_model.mm_model_mlp_3_w = get_tensor(string_format(TN_GLM_ADAPTER_D_4H_2_H, "weight"));
-                    vision_model.mm_glm_tok_boi = get_tensor(string_format(TN_TOK_GLM_BOI, "weight"));
-                    vision_model.mm_glm_tok_eoi = get_tensor(string_format(TN_TOK_GLM_EOI, "weight"));
+                    model.mm_model_adapter_conv_w = get_tensor(string_format(TN_GLM_ADAPER_CONV, "weight"));
+                    model.mm_model_adapter_conv_b = get_tensor(string_format(TN_GLM_ADAPER_CONV, "bias"));
+                    model.mm_model_mlp_0_w = get_tensor(string_format(TN_GLM_ADAPTER_LINEAR, "weight"));
+                    model.mm_model_ln_q_w = get_tensor(string_format(TN_GLM_ADAPTER_NORM_1, "weight"));
+                    model.mm_model_ln_q_b = get_tensor(string_format(TN_GLM_ADAPTER_NORM_1, "bias"));
+                    model.mm_model_mlp_1_w = get_tensor(string_format(TN_GLM_ADAPTER_D_H_2_4H, "weight"));
+                    model.mm_model_mlp_2_w = get_tensor(string_format(TN_GLM_ADAPTER_GATE, "weight"));
+                    model.mm_model_mlp_3_w = get_tensor(string_format(TN_GLM_ADAPTER_D_4H_2_H, "weight"));
+                    model.mm_glm_tok_boi = get_tensor(string_format(TN_TOK_GLM_BOI, "weight"));
+                    model.mm_glm_tok_eoi = get_tensor(string_format(TN_TOK_GLM_EOI, "weight"));
                 } break;
             case PROJECTOR_TYPE_QWEN2VL:
             case PROJECTOR_TYPE_QWEN25VL:
                 {
-                    vision_model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"));
-                    vision_model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"));
-                    vision_model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
-                    vision_model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
+                    model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"));
+                    model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"));
+                    model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
+                    model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
                 } break;
             case PROJECTOR_TYPE_GEMMA3:
                 {
-                    vision_model.mm_input_proj_w = get_tensor(TN_MM_INP_PROJ);
-                    vision_model.mm_soft_emb_norm_w = get_tensor(TN_MM_SOFT_EMB_N);
+                    model.mm_input_proj_w = get_tensor(TN_MM_INP_PROJ);
+                    model.mm_soft_emb_norm_w = get_tensor(TN_MM_SOFT_EMB_N);
                 } break;
             case PROJECTOR_TYPE_IDEFICS3:
                 {
-                    vision_model.projection = get_tensor(TN_MM_PROJECTOR);
+                    model.projection = get_tensor(TN_MM_PROJECTOR);
                 } break;
             case PROJECTOR_TYPE_PIXTRAL:
                 {
-                    vision_model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
-                    vision_model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false);
-                    vision_model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
-                    vision_model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false);
+                    model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
+                    model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false);
+                    model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
+                    model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false);
                     // [IMG_BREAK] token embedding
-                    vision_model.token_embd_img_break = get_tensor(TN_TOK_IMG_BREAK);
+                    model.token_embd_img_break = get_tensor(TN_TOK_IMG_BREAK);
                     // for mistral small 3.1
-                    vision_model.mm_input_norm_w   = get_tensor(TN_MM_INP_NORM,     false);
-                    vision_model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false);
+                    model.mm_input_norm_w   = get_tensor(TN_MM_INP_NORM,     false);
+                    model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false);
                 } break;
             case PROJECTOR_TYPE_ULTRAVOX:
                 {
-                    vision_model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
-                    vision_model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias"));
-                    vision_model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight"));
-                    vision_model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias"));
-                    vision_model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight"));
-                    vision_model.mm_2_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "weight"));
-                    vision_model.mm_norm_pre_w = get_tensor(string_format(TN_MM_NORM_PRE, "weight"));
-                    vision_model.mm_norm_mid_w = get_tensor(string_format(TN_MM_NORM_MID, "weight"));
+                    model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
+                    model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias"));
+                    model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight"));
+                    model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias"));
+                    model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight"));
+                    model.mm_2_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "weight"));
+                    model.mm_norm_pre_w = get_tensor(string_format(TN_MM_NORM_PRE, "weight"));
+                    model.mm_norm_mid_w = get_tensor(string_format(TN_MM_NORM_MID, "weight"));
                 } break;
             case PROJECTOR_TYPE_QWEN2A:
                 {
-                    vision_model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
-                    vision_model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias"));
-                    vision_model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight"));
-                    vision_model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias"));
-                    vision_model.mm_fc_w = get_tensor(string_format(TN_MM_AUDIO_FC, "weight"));
-                    vision_model.mm_fc_b = get_tensor(string_format(TN_MM_AUDIO_FC, "bias"));
+                    model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
+                    model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias"));
+                    model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight"));
+                    model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias"));
+                    model.mm_fc_w = get_tensor(string_format(TN_MM_AUDIO_FC, "weight"));
+                    model.mm_fc_b = get_tensor(string_format(TN_MM_AUDIO_FC, "bias"));
                 } break;
             case PROJECTOR_TYPE_INTERNVL:
                 {
-                    vision_model.mm_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight"));
-                    vision_model.mm_0_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "bias"));
-                    vision_model.mm_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight"));
-                    vision_model.mm_1_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "bias"));
-                    vision_model.mm_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight"));
-                    vision_model.mm_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias"));
+                    model.mm_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight"));
+                    model.mm_0_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "bias"));
+                    model.mm_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight"));
+                    model.mm_1_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "bias"));
+                    model.mm_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight"));
+                    model.mm_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias"));
                 } break;
             case PROJECTOR_TYPE_LLAMA4:
                 {
-                    vision_model.mm_model_proj    = get_tensor(TN_MM_PROJECTOR);
-                    vision_model.mm_model_mlp_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight"));
-                    vision_model.mm_model_mlp_2_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "weight"));
+                    model.mm_model_proj    = get_tensor(TN_MM_PROJECTOR);
+                    model.mm_model_mlp_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight"));
+                    model.mm_model_mlp_2_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "weight"));
                 } break;
             default:
                 GGML_ASSERT(false && "unknown projector type");
@@ -2554,21 +2588,20 @@ struct clip_model_loader {
         }
     }
 
-    void alloc_compute_meta() {
-        const auto & hparams = ctx_clip.vision_model.hparams;
+    void alloc_compute_meta(clip_ctx & ctx_clip) {
+        const auto & hparams = ctx_clip.model.hparams;
         ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead());
 
         // create a fake batch
         clip_image_f32_batch batch;
         clip_image_f32_ptr img(clip_image_f32_init());
-        if (hparams.has_vision) {
+        if (ctx_clip.model.modality == CLIP_MODALITY_VISION) {
             img->nx = hparams.warmup_image_size;
             img->ny = hparams.warmup_image_size;
         } else {
-            img->nx = 1024; // TODO @ngxson : use a better default
+            img->nx = hparams.warmup_audio_size;
             img->ny = hparams.n_mel_bins;
         }
-        img->buf.resize(img->nx * img->ny * 3);
         batch.entries.push_back(std::move(img));
 
         ggml_cgraph * gf = clip_image_build_graph(&ctx_clip, batch);
@@ -2646,23 +2679,40 @@ struct clip_model_loader {
     }
 };
 
-struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_params) {
+struct clip_init_result clip_init(const char * fname, struct clip_context_params ctx_params) {
     g_logger_state.verbosity_thold = ctx_params.verbosity;
-    clip_ctx * ctx_clip = nullptr;
+    clip_ctx * ctx_vision = nullptr;
+    clip_ctx * ctx_audio = nullptr;
 
     try {
-        ctx_clip = new clip_ctx(ctx_params);
-        clip_model_loader loader(fname, *ctx_clip);
-        loader.load_hparams();
-        loader.load_tensors();
-        loader.alloc_compute_meta();
+        clip_model_loader loader(fname);
+
+        if (loader.has_vision) {
+            ctx_vision = new clip_ctx(ctx_params);
+            loader.load_hparams(ctx_vision->model, CLIP_MODALITY_VISION);
+            loader.load_tensors(*ctx_vision);
+            loader.alloc_compute_meta(*ctx_vision);
+        }
+
+        if (loader.has_audio) {
+            ctx_audio = new clip_ctx(ctx_params);
+            loader.load_hparams(ctx_audio->model, CLIP_MODALITY_AUDIO);
+            loader.load_tensors(*ctx_audio);
+            loader.alloc_compute_meta(*ctx_audio);
+        }
+
     } catch (const std::exception & e) {
         LOG_ERR("%s: failed to load model '%s': %s\n", __func__, fname, e.what());
-        delete ctx_clip;
-        return nullptr;
+        if (ctx_vision) {
+            delete ctx_vision;
+        }
+        if (ctx_audio) {
+            delete ctx_audio;
+        }
+        return {nullptr, nullptr};
     }
 
-    return ctx_clip;
+    return {ctx_vision, ctx_audio};
 }
 
 struct clip_image_size * clip_image_size_init() {
@@ -3023,12 +3073,12 @@ struct llava_uhd {
         const float ratio = (float)original_width * original_height / (slice_size * slice_size);
         const int multiple = fmin(ceil(ratio), max_slice_nums);
         const bool has_slices = (multiple > 1);
-        const bool has_pinpoints = !ctx->vision_model.hparams.image_grid_pinpoints.empty();
+        const bool has_pinpoints = !ctx->model.hparams.image_grid_pinpoints.empty();
 
         if (has_pinpoints) {
             // has pinpoints, use them to calculate the grid size (e.g. llava-1.6)
             auto refine_size = llava_uhd::select_best_resolution(
-                ctx->vision_model.hparams.image_grid_pinpoints,
+                ctx->model.hparams.image_grid_pinpoints,
                 original_size);
             res.overview_size   = clip_image_size{slice_size, slice_size};
             res.refined_size    = refine_size;
@@ -3250,7 +3300,7 @@ private:
 bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, struct clip_image_f32_batch * res_imgs) {
     clip_image_size original_size{img->nx, img->ny};
     bool pad_to_square = true;
-    auto & params = ctx->vision_model.hparams;
+    auto & params = ctx->model.hparams;
     // The model config actually contains all we need to decide on how to preprocess, here we automatically switch to the new llava-1.6 preprocessing
     if (params.mm_patch_merge_type == PATCH_MERGE_SPATIAL_UNPAD) {
         pad_to_square = false;
@@ -3263,7 +3313,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
         for (size_t i = 0; i < imgs.size(); ++i) {
             // clip_image_save_to_bmp(*imgs[i], "slice_" + std::to_string(i) + ".bmp");
             clip_image_f32_ptr res(clip_image_f32_init());
-            normalize_image_u8_to_f32(*imgs[i], *res, ctx->image_mean, ctx->image_std);
+            normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std);
             res_imgs->entries.push_back(std::move(res));
         }
 
@@ -3271,7 +3321,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
         res_imgs->grid_y = inst.grid_size.height;
         return true;
 
-    } else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) {
+    } else if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL) {
         clip_image_u8 resized;
         auto patch_size = params.patch_size * 2;
         auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, patch_size, params.image_size);
@@ -3279,42 +3329,42 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
 
         clip_image_f32_ptr img_f32(clip_image_f32_init());
         // clip_image_f32_ptr res(clip_image_f32_init());
-        normalize_image_u8_to_f32(resized, *img_f32, ctx->image_mean, ctx->image_std);
+        normalize_image_u8_to_f32(resized, *img_f32, params.image_mean, params.image_std);
         // res_imgs->data[0] = *res;
         res_imgs->entries.push_back(std::move(img_f32));
         return true;
     }
-    else if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE
-            || ctx->proj_type == PROJECTOR_TYPE_GEMMA3
-            || ctx->proj_type == PROJECTOR_TYPE_IDEFICS3
-            || ctx->proj_type == PROJECTOR_TYPE_INTERNVL // TODO @ngxson : support dynamic resolution
+    else if (ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE
+            || ctx->proj_type() == PROJECTOR_TYPE_GEMMA3
+            || ctx->proj_type() == PROJECTOR_TYPE_IDEFICS3
+            || ctx->proj_type() == PROJECTOR_TYPE_INTERNVL // TODO @ngxson : support dynamic resolution
     ) {
         clip_image_u8 resized_image;
         int sz = params.image_size;
         image_manipulation::resize_and_pad_image(*img, resized_image, {sz, sz});
         clip_image_f32_ptr img_f32(clip_image_f32_init());
         //clip_image_save_to_bmp(resized_image, "resized.bmp");
-        normalize_image_u8_to_f32(resized_image, *img_f32, ctx->image_mean, ctx->image_std);
+        normalize_image_u8_to_f32(resized_image, *img_f32, params.image_mean, params.image_std);
         res_imgs->entries.push_back(std::move(img_f32));
         return true;
 
-    } else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
+    } else if (ctx->proj_type() == PROJECTOR_TYPE_PIXTRAL) {
         clip_image_u8 resized_image;
         auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, params.patch_size, params.image_size);
         image_manipulation::bilinear_resize(*img, resized_image, new_size.width, new_size.height);
         clip_image_f32_ptr img_f32(clip_image_f32_init());
-        normalize_image_u8_to_f32(resized_image, *img_f32, ctx->image_mean, ctx->image_std);
+        normalize_image_u8_to_f32(resized_image, *img_f32, params.image_mean, params.image_std);
         res_imgs->entries.push_back(std::move(img_f32));
         return true;
 
-    } else if (ctx->proj_type == PROJECTOR_TYPE_LLAMA4) {
+    } else if (ctx->proj_type() == PROJECTOR_TYPE_LLAMA4) {
         GGML_ASSERT(!params.image_grid_pinpoints.empty());
         auto const inst = llava_uhd::get_slice_instructions(ctx, original_size);
         std::vector<clip_image_u8_ptr> imgs = llava_uhd::slice_image(img, inst);
 
         for (size_t i = 0; i < imgs.size(); ++i) {
             clip_image_f32_ptr res(clip_image_f32_init());
-            normalize_image_u8_to_f32(*imgs[i], *res, ctx->image_mean, ctx->image_std);
+            normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std);
             res_imgs->entries.push_back(std::move(res));
         }
 
@@ -3344,7 +3394,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
         image_manipulation::resize_and_pad_image(*img, *temp, clip_image_size{params.image_size, params.image_size}, pad_color);
 
         clip_image_f32_ptr res(clip_image_f32_init());
-        normalize_image_u8_to_f32(*temp, *res, ctx->image_mean, ctx->image_std);
+        normalize_image_u8_to_f32(*temp, *res, params.image_mean, params.image_std);
         res_imgs->entries.push_back(std::move(res));
         return true;
 
@@ -3356,7 +3406,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
         for (size_t i = 0; i < imgs.size(); ++i) {
             // clip_image_save_to_bmp(*imgs[i], "slice_" + std::to_string(i) + ".bmp");
             clip_image_f32_ptr res(clip_image_f32_init());
-            normalize_image_u8_to_f32(*imgs[i], *res, ctx->image_mean, ctx->image_std);
+            normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std);
             res_imgs->entries.push_back(std::move(res));
         }
 
@@ -3368,7 +3418,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
 }
 
 ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx) {
-    return ctx->vision_model.image_newline;
+    return ctx->model.image_newline;
 }
 
 void clip_free(clip_ctx * ctx) {
@@ -3380,8 +3430,8 @@ void clip_free(clip_ctx * ctx) {
 
 // deprecated
 size_t clip_embd_nbytes(const struct clip_ctx * ctx) {
-    const int32_t nx = ctx->vision_model.hparams.image_size;
-    const int32_t ny = ctx->vision_model.hparams.image_size;
+    const int32_t nx = ctx->model.hparams.image_size;
+    const int32_t ny = ctx->model.hparams.image_size;
     return clip_embd_nbytes_by_img(ctx, nx, ny);
 }
 
@@ -3393,105 +3443,135 @@ size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_w, int img_h
 }
 
 int32_t clip_get_image_size(const struct clip_ctx * ctx) {
-    return ctx->vision_model.hparams.image_size;
+    return ctx->model.hparams.image_size;
 }
 
 int32_t clip_get_patch_size(const struct clip_ctx * ctx) {
-    return ctx->vision_model.hparams.patch_size;
+    return ctx->model.hparams.patch_size;
 }
 
 int32_t clip_get_hidden_size(const struct clip_ctx * ctx) {
-    return ctx->vision_model.hparams.n_embd;
+    return ctx->model.hparams.n_embd;
 }
 
 const char * clip_patch_merge_type(const struct clip_ctx * ctx) {
-    return ctx->vision_model.hparams.mm_patch_merge_type == PATCH_MERGE_SPATIAL_UNPAD ? "spatial_unpad" : "flat";
+    return ctx->model.hparams.mm_patch_merge_type == PATCH_MERGE_SPATIAL_UNPAD ? "spatial_unpad" : "flat";
 }
 
 const int32_t * clip_image_grid(const struct clip_ctx * ctx) {
-    if (ctx->vision_model.hparams.image_grid_pinpoints.size()) {
-        return &ctx->vision_model.hparams.image_grid_pinpoints.front();
+    if (ctx->model.hparams.image_grid_pinpoints.size()) {
+        return &ctx->model.hparams.image_grid_pinpoints.front();
     }
     return nullptr;
 }
 
 size_t get_clip_image_grid_size(const struct clip_ctx * ctx) {
-    return ctx->vision_model.hparams.image_grid_pinpoints.size();
+    return ctx->model.hparams.image_grid_pinpoints.size();
 }
 
 int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
-    const auto & params = ctx->vision_model.hparams;
+    const auto & params = ctx->model.hparams;
     const int n_total = clip_n_output_tokens(ctx, img);
-    if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) {
+    if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL) {
         return img->nx / (params.patch_size * 2) + (int)(img->nx % params.patch_size > 0);
     }
     return n_total;
 }
 
 int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
-    const auto & params = ctx->vision_model.hparams;
-    if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) {
+    const auto & params = ctx->model.hparams;
+    if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL) {
         return img->ny / (params.patch_size * 2) + (int)(img->ny % params.patch_size > 0);
     }
     return 1;
 }
 
 int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
-    const auto & params = ctx->vision_model.hparams;
+    const auto & params = ctx->model.hparams;
 
-    int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size);
-    int scale_factor = ctx->vision_model.hparams.proj_scale_factor;
+    // only for models using fixed size square images
+    int n_patches_sq = (params.image_size / params.patch_size) * (params.image_size / params.patch_size);
 
-    if (ctx->proj_type == PROJECTOR_TYPE_LDP
-            || ctx->proj_type == PROJECTOR_TYPE_LDPV2
-            || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
-        n_patches /= 4;
-        if (ctx->vision_model.mm_glm_tok_boi) {
-            n_patches += 2; // for BOI and EOI token embeddings
-        }
-    } else if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
-        if (ctx->minicpmv_version == 2) {
-            n_patches = 96;
-        }
-        else if (ctx->minicpmv_version == 3) {
-            n_patches = 64;
-        }
-        else if (ctx->minicpmv_version == 4) {
-            n_patches = 64;
-        }
-        else {
-            GGML_ABORT("Unknown minicpmv version");
-        }
-    } else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) {
-        int patch_size = params.patch_size * 2;
-        int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0);
-        int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0);
-        n_patches = x_patch * y_patch;
-    } else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
-        int n_per_side = params.image_size / params.patch_size;
-        int n_per_side_2d_pool = n_per_side / params.proj_scale_factor;
-        n_patches = n_per_side_2d_pool * n_per_side_2d_pool;
-    } else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3 || ctx->proj_type == PROJECTOR_TYPE_INTERNVL) {
-        // both W and H are divided by proj_scale_factor
-        n_patches /= (params.proj_scale_factor * params.proj_scale_factor);
-    } else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
-        int n_merge = params.spatial_merge_size;
-        int n_patches_x = img->nx / params.patch_size / (n_merge > 0 ? n_merge : 1);
-        int n_patches_y = img->ny / params.patch_size / (n_merge > 0 ? n_merge : 1);
-        n_patches = n_patches_y*n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
-    } else if (ctx->proj_type == PROJECTOR_TYPE_LLAMA4) {
-        n_patches /= (scale_factor * scale_factor);
-    } else if (ctx->proj_type == PROJECTOR_TYPE_ULTRAVOX) {
-        const int proj_stack_factor = ctx->vision_model.hparams.proj_stack_factor;
-        const int n_len = CLIP_ALIGN(img->nx, proj_stack_factor);
-        n_patches = n_len / proj_stack_factor / 2;
-    } else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2A) {
-        // divide by 2 because of whisper
-        // another divide by 2 because of nn.AvgPool1d(2, stride=2)
-        n_patches = img->nx / 4;
-    }
-
-    return n_patches;
+    projector_type proj = ctx->proj_type();
+
+    switch (proj) {
+        case PROJECTOR_TYPE_MLP:
+        case PROJECTOR_TYPE_MLP_NORM:
+            {
+                // do nothing
+            } break;
+        case PROJECTOR_TYPE_LDP:
+        case PROJECTOR_TYPE_LDPV2:
+        case PROJECTOR_TYPE_GLM_EDGE:
+            {
+                n_patches_sq /= 4;
+                if (ctx->model.mm_glm_tok_boi) {
+                    n_patches_sq += 2; // for BOI and EOI token embeddings
+                }
+            } break;
+        case PROJECTOR_TYPE_MINICPMV:
+            {
+                if (params.minicpmv_version == 2) {
+                    n_patches_sq = 96;
+                } else if (params.minicpmv_version == 3) {
+                    n_patches_sq = 64;
+                } else if (params.minicpmv_version == 4) {
+                    n_patches_sq = 64;
+                } else {
+                    GGML_ABORT("Unknown minicpmv version");
+                }
+            } break;
+        case PROJECTOR_TYPE_QWEN2VL:
+        case PROJECTOR_TYPE_QWEN25VL:
+            {
+                // dynamic size
+                int patch_size = params.patch_size * 2;
+                int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0);
+                int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0);
+                n_patches_sq = x_patch * y_patch;
+            } break;
+        case PROJECTOR_TYPE_GEMMA3:
+            {
+                int n_per_side = params.image_size / params.patch_size;
+                int n_per_side_2d_pool = n_per_side / params.proj_scale_factor;
+                n_patches_sq = n_per_side_2d_pool * n_per_side_2d_pool;
+            } break;
+        case PROJECTOR_TYPE_IDEFICS3:
+        case PROJECTOR_TYPE_INTERNVL:
+            {
+                // both W and H are divided by proj_scale_factor
+                n_patches_sq /= (params.proj_scale_factor * params.proj_scale_factor);
+            } break;
+        case PROJECTOR_TYPE_PIXTRAL:
+            {
+                // dynamic size
+                int n_merge = params.spatial_merge_size;
+                int n_patches_x = img->nx / params.patch_size / (n_merge > 0 ? n_merge : 1);
+                int n_patches_y = img->ny / params.patch_size / (n_merge > 0 ? n_merge : 1);
+                n_patches_sq = n_patches_y * n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
+            } break;
+        case PROJECTOR_TYPE_LLAMA4:
+            {
+                int scale_factor = ctx->model.hparams.proj_scale_factor;
+                n_patches_sq /= (scale_factor * scale_factor);
+            } break;
+        case PROJECTOR_TYPE_ULTRAVOX:
+            {
+                const int proj_stack_factor = ctx->model.hparams.proj_stack_factor;
+                const int n_len = CLIP_ALIGN(img->nx, proj_stack_factor);
+                n_patches_sq = n_len / proj_stack_factor / 2;
+            } break;
+        case PROJECTOR_TYPE_QWEN2A:
+            {
+                // divide by 2 because of whisper
+                // another divide by 2 because of nn.AvgPool1d(2, stride=2)
+                n_patches_sq = img->nx / 4;
+            } break;
+        default:
+            GGML_ABORT("unsupported projector type");
+    }
+
+    return n_patches_sq;
 }
 
 static std::vector<std::vector<std::vector<float>>> get_1d_sincos_pos_embed_from_grid_new(int embed_dim, const std::vector<std::vector<float>> & pos) {
@@ -3606,7 +3686,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
     ggml_backend_sched_alloc_graph(ctx->sched.get(), gf);
 
     // set inputs
-    const auto & model   = ctx->vision_model;
+    const auto & model   = ctx->model;
     const auto & hparams = model.hparams;
 
     const int image_size_width  = imgs.entries[0]->nx;
@@ -3696,7 +3776,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
     }
 
     // set input per projector
-    switch (ctx->proj_type) {
+    switch (ctx->model.proj_type) {
         case PROJECTOR_TYPE_MINICPMV:
             {
                 // inspired from siglip:
@@ -3961,80 +4041,83 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
 }
 
 int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
-    switch (ctx->proj_type) {
+    const auto & hparams = ctx->model.hparams;
+    switch (ctx->model.proj_type) {
         case PROJECTOR_TYPE_LDP:
-            return ctx->vision_model.mm_model_block_1_block_2_1_b->ne[0];
+            return ctx->model.mm_model_block_1_block_2_1_b->ne[0];
         case PROJECTOR_TYPE_LDPV2:
-            return ctx->vision_model.mm_model_peg_0_b->ne[0];
+            return ctx->model.mm_model_peg_0_b->ne[0];
         case PROJECTOR_TYPE_MLP:
         case PROJECTOR_TYPE_PIXTRAL:
-            return ctx->vision_model.mm_2_w->ne[1];
+            return ctx->model.mm_2_w->ne[1];
         case PROJECTOR_TYPE_MLP_NORM:
-            return ctx->vision_model.mm_3_b->ne[0];
+            return ctx->model.mm_3_b->ne[0];
         case PROJECTOR_TYPE_MINICPMV:
-            if (ctx->minicpmv_version == 2) {
+            if (hparams.minicpmv_version == 2) {
                 return 4096;
-            } else if (ctx->minicpmv_version == 3) {
+            } else if (hparams.minicpmv_version == 3) {
                 return 3584;
-            } else if (ctx->minicpmv_version == 4) {
+            } else if (hparams.minicpmv_version == 4) {
                 return 3584;
             }
             GGML_ABORT("Unknown minicpmv version");
         case PROJECTOR_TYPE_GLM_EDGE:
-            return ctx->vision_model.mm_model_mlp_3_w->ne[1];
+            return ctx->model.mm_model_mlp_3_w->ne[1];
         case PROJECTOR_TYPE_QWEN2VL:
         case PROJECTOR_TYPE_QWEN25VL:
-            return ctx->vision_model.mm_1_b->ne[0];
+            return ctx->model.mm_1_b->ne[0];
         case PROJECTOR_TYPE_GEMMA3:
-            return ctx->vision_model.mm_input_proj_w->ne[0];
+            return ctx->model.mm_input_proj_w->ne[0];
         case PROJECTOR_TYPE_IDEFICS3:
-            return ctx->vision_model.projection->ne[1];
+            return ctx->model.projection->ne[1];
         case PROJECTOR_TYPE_ULTRAVOX:
-            return ctx->vision_model.mm_2_w->ne[1];
+            return ctx->model.mm_2_w->ne[1];
         case PROJECTOR_TYPE_INTERNVL:
-            return ctx->vision_model.mm_3_w->ne[1];
+            return ctx->model.mm_3_w->ne[1];
         case PROJECTOR_TYPE_LLAMA4:
-            return ctx->vision_model.mm_model_proj->ne[1];
+            return ctx->model.mm_model_proj->ne[1];
         case PROJECTOR_TYPE_QWEN2A:
-            return ctx->vision_model.mm_fc_w->ne[1];
+            return ctx->model.mm_fc_w->ne[1];
         default:
             GGML_ABORT("Unknown projector type");
     }
 }
 
 int clip_is_minicpmv(const struct clip_ctx * ctx) {
-    if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
-        return ctx->minicpmv_version;
+    if (ctx->proj_type() == PROJECTOR_TYPE_MINICPMV) {
+        return ctx->model.hparams.minicpmv_version;
     }
     return 0;
 }
 
 bool clip_is_glm(const struct clip_ctx * ctx) {
-    return ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE;
+    return ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE;
 }
 
 bool clip_is_qwen2vl(const struct clip_ctx * ctx) {
-    return ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL;
+    return ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL
+        || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL;
 }
 
 bool clip_is_llava(const struct clip_ctx * ctx) {
-    return ctx->has_llava_projector;
+    return ctx->model.hparams.has_llava_projector;
 }
 
 bool clip_is_gemma3(const struct clip_ctx * ctx) {
-    return ctx->proj_type == PROJECTOR_TYPE_GEMMA3;
+    return ctx->proj_type() == PROJECTOR_TYPE_GEMMA3;
 }
 
 bool clip_has_vision_encoder(const struct clip_ctx * ctx) {
-    return ctx->vision_model.hparams.has_vision;
+    return ctx->model.modality == CLIP_MODALITY_VISION;
 }
 
 bool clip_has_audio_encoder(const struct clip_ctx * ctx) {
-    return ctx->vision_model.hparams.has_audio;
+    return ctx->model.modality == CLIP_MODALITY_AUDIO;
 }
 
 bool clip_has_whisper_encoder(const struct clip_ctx * ctx) {
-    return ctx->proj_type == PROJECTOR_TYPE_ULTRAVOX || ctx->proj_type == PROJECTOR_TYPE_QWEN2A;
+    return ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX
+        || ctx->proj_type() == PROJECTOR_TYPE_QWEN2A;
 }
 
 bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {
@@ -4055,7 +4138,7 @@ bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img,
 //
 
 projector_type clip_get_projector_type(const struct clip_ctx * ctx) {
-    return ctx->proj_type;
+    return ctx->proj_type();
 }
 
 void clip_image_f32_batch_add_mel(struct clip_image_f32_batch * batch, int n_mel, int n_frames, float * mel) {
index 5abfcd1a3c418dcefaa68279f0c72741b556d2b6..cb2eb261fe2e8b5171ca026ca498cacd0d565e26 100644 (file)
@@ -17,12 +17,22 @@ struct clip_image_f32;
 struct clip_image_u8_batch;
 struct clip_image_f32_batch;
 
+enum clip_modality {
+    CLIP_MODALITY_VISION,
+    CLIP_MODALITY_AUDIO,
+};
+
 struct clip_context_params {
     bool use_gpu;
     enum ggml_log_level verbosity;
 };
 
-struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_params);
+struct clip_init_result {
+    struct clip_ctx * ctx_v; // vision context
+    struct clip_ctx * ctx_a; // audio context
+};
+
+struct clip_init_result clip_init(const char * fname, struct clip_context_params ctx_params);
 
 void clip_free(struct clip_ctx * ctx);
 
index 0f8bb0cdc42dc55f1c7ee5eda54850b889ec634a..a70f11ca9d7184e7f3929ede1812b78462ccc12b 100644 (file)
@@ -284,7 +284,9 @@ int main(int argc, char ** argv) {
     if (is_single_turn) {
         g_is_generating = true;
         if (params.prompt.find(mtmd_default_marker()) == std::string::npos) {
-            params.prompt += mtmd_default_marker();
+            for (size_t i = 0; i < params.image.size(); i++) {
+                params.prompt += mtmd_default_marker();
+            }
         }
         common_chat_msg msg;
         msg.role = "user";
index b79094c0a48b61a95fea8f633261300df213e22d..e6c926080cde335b7858de3c6e605a3e2ecd3182 100644 (file)
@@ -66,7 +66,8 @@ struct decode_embd_batch {
         }
     }
 
-    void set_position_mrope(llama_pos pos_0, int nx, int ny, llama_seq_id seq_id) {
+    // M-RoPE for image
+    void set_position_mrope_2d(llama_pos pos_0, int nx, int ny, llama_seq_id seq_id) {
         GGML_ASSERT(n_pos_per_embd == 4);
         seq_id_0[0] = seq_id;
         for (int y = 0; y < ny; y++) {
@@ -85,6 +86,23 @@ struct decode_embd_batch {
         }
     }
 
+    // M-RoPE for audio
+    void set_position_mrope_1d(llama_pos pos_0, llama_seq_id seq_id) {
+        GGML_ASSERT(n_pos_per_embd == 4);
+        seq_id_0[0] = seq_id;
+        for (int i = 0; i < batch.n_tokens; i++) {
+            pos[i                     ] = pos_0 + i;
+            pos[i + batch.n_tokens    ] = pos_0 + i;
+            pos[i + batch.n_tokens * 2] = pos_0 + i;
+            pos[i + batch.n_tokens * 3] = 0; // last pos dim is unused
+        }
+        for (int i = 0; i < batch.n_tokens; i++) {
+            batch.n_seq_id[i] = 1;
+            batch.seq_id  [i] = seq_id_0.data();
+            batch.logits  [i] = false;
+        }
+    }
+
     llama_batch get_view(int offset, int n_tokens) {
         llama_pos * pos_ptr;
         pos_view.clear();
@@ -146,18 +164,20 @@ int32_t mtmd_helper_decode_image_chunk(
     decode_embd_batch batch_embd(encoded_embd, n_tokens, n_pos_per_embd, n_mmproj_embd);
 
     if (mtmd_decode_use_mrope(ctx)) {
-        const auto image_tokens = mtmd_input_chunk_get_tokens_image(chunk);
-        if (chunk_type != MTMD_INPUT_CHUNK_TYPE_IMAGE) {
-            LOG_ERR("failed to decode chunk: M-RoPE only accepts image chunk\n");
-            return -1;
-        }
-        if (!image_tokens) {
-            LOG_ERR("failed to decode chunk: image tokens are null\n");
-            return -1;
+        if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
+            const auto image_tokens = mtmd_input_chunk_get_tokens_image(chunk);
+            if (!image_tokens) {
+                LOG_ERR("failed to decode chunk: image tokens are null\n");
+                return -1;
+            }
+            const int nx = mtmd_image_tokens_get_nx(image_tokens);
+            const int ny = mtmd_image_tokens_get_ny(image_tokens);
+            batch_embd.set_position_mrope_2d(n_past, nx, ny, seq_id);
+        } else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
+            batch_embd.set_position_mrope_1d(n_past, seq_id);
+        } else {
+            GGML_ABORT("invalid chunk type for M-RoPE");
         }
-        const int nx = mtmd_image_tokens_get_nx(image_tokens);
-        const int ny = mtmd_image_tokens_get_ny(image_tokens);
-        batch_embd.set_position_mrope(n_past, nx, ny, seq_id);
     } else {
         batch_embd.set_position_normal(n_past, seq_id);
     }
index c3be91265f331619304fac53f066e697f518c080..52bf71e2c9dc090345ce7bf2424e4a3a38b62685 100644 (file)
@@ -95,15 +95,21 @@ mtmd_context_params mtmd_context_params_default() {
 }
 
 struct mtmd_context {
-    struct clip_ctx * ctx_clip;
+    struct clip_ctx * ctx_v; // vision
+    struct clip_ctx * ctx_a; // audio
     const struct llama_model * text_model;
     std::vector<float> image_embd_v; // image embedding vector
 
     bool print_timings;
     int n_threads;
     std::string media_marker;
-    bool has_vision;
-    bool has_audio;
+    const int n_embd_text;
+
+    // these are not token, but strings used to mark the beginning and end of image/audio embeddings
+    std::string img_beg;
+    std::string img_end;
+    std::string aud_beg;
+    std::string aud_end;
 
     // for llava-uhd style models, we need special tokens in-between slices
     // minicpmv calls them "slices", llama 4 calls them "tiles"
@@ -132,33 +138,61 @@ struct mtmd_context {
         text_model   (text_model),
         print_timings(ctx_params.print_timings),
         n_threads    (ctx_params.n_threads),
-        media_marker (ctx_params.media_marker)
+        media_marker (ctx_params.media_marker),
+        n_embd_text  (llama_model_n_embd(text_model))
     {
         if (std::string(ctx_params.image_marker) != MTMD_DEFAULT_IMAGE_MARKER) {
             throw std::runtime_error("custom image_marker is not supported anymore, use media_marker instead");
         }
 
+        if (media_marker.empty()) {
+            throw std::runtime_error("media_marker must not be empty");
+        }
+
         clip_context_params ctx_clip_params;
         ctx_clip_params.use_gpu   = ctx_params.use_gpu;
         ctx_clip_params.verbosity = ctx_params.verbosity;
-        ctx_clip = clip_init(mmproj_fname, ctx_clip_params);
-        if (!ctx_clip) {
+        auto res = clip_init(mmproj_fname, ctx_clip_params);
+        ctx_v = res.ctx_v;
+        ctx_a = res.ctx_a;
+        if (!ctx_v && !ctx_a) {
             throw std::runtime_error(string_format("Failed to load CLIP model from %s\n", mmproj_fname));
         }
 
-        if (llama_model_n_embd(text_model) != clip_n_mmproj_embd(ctx_clip)) {
+        // if both vision and audio mmproj are present, we need to validate their n_embd
+        if (ctx_v && ctx_a) {
+            int n_embd_v = clip_n_mmproj_embd(ctx_v);
+            int n_embd_a = clip_n_mmproj_embd(ctx_a);
+            if (n_embd_v != n_embd_a) {
+                throw std::runtime_error(string_format(
+                    "mismatch between vision and audio mmproj (n_embd_v = %d, n_embd_a = %d)\n",
+                    n_embd_v, n_embd_a));
+            }
+        }
+
+        // since we already validate n_embd of vision and audio mmproj,
+        // we can safely assume that they are the same
+        int n_embd_clip = clip_n_mmproj_embd(ctx_v ? ctx_v : ctx_a);
+        if (n_embd_text != n_embd_clip) {
             throw std::runtime_error(string_format(
                 "mismatch between text model (n_embd = %d) and mmproj (n_embd = %d)\n"
                 "hint: you may be using wrong mmproj\n",
-                llama_model_n_embd(text_model), clip_n_mmproj_embd(ctx_clip)));
+                n_embd_text, n_embd_clip));
+        }
+        if (ctx_v) {
+            init_vision();
         }
+        if (ctx_a) {
+            init_audio();
+        }
+    }
 
-        has_vision = clip_has_vision_encoder(ctx_clip);
-        has_audio  = clip_has_audio_encoder(ctx_clip);
-        use_mrope  = clip_is_qwen2vl(ctx_clip);
+    void init_vision() {
+        GGML_ASSERT(ctx_v != nullptr);
+        use_mrope = clip_is_qwen2vl(ctx_v);
 
-        projector_type proj = clip_get_projector_type(ctx_clip);
-        int minicpmv_version = clip_is_minicpmv(ctx_clip);
+        projector_type proj = clip_get_projector_type(ctx_v);
+        int minicpmv_version = clip_is_minicpmv(ctx_v);
         if (minicpmv_version == 2) {
             // minicpmv 2.5 format:
             // <image> (overview) </image><slice><image> (slice) </image><image> (slice) </image>\n ... </slice>
@@ -203,24 +237,82 @@ struct mtmd_context {
             ov_img_first      = false; // overview image is last
         }
 
-        if (clip_has_whisper_encoder(ctx_clip)) {
+        // set boi/eoi
+        if (proj == PROJECTOR_TYPE_GEMMA3) {
+            // <start_of_image> ... (image embeddings) ... <end_of_image>
+            img_beg = "<start_of_image>";
+            img_end = "<end_of_image>";
+
+        } else if (proj == PROJECTOR_TYPE_IDEFICS3) {
+            // https://github.com/huggingface/transformers/blob/a42ba80fa520c784c8f11a973ca9034e5f859b79/src/transformers/models/idefics3/processing_idefics3.py#L192-L215
+            img_beg = "<fake_token_around_image><global-img>";
+            img_end = "<fake_token_around_image>";
+
+        } else if (proj == PROJECTOR_TYPE_PIXTRAL) {
+            // https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md
+            img_end = "[IMG_END]";
+
+        } else if (proj == PROJECTOR_TYPE_QWEN2VL || proj == PROJECTOR_TYPE_QWEN25VL) {
+            // <|vision_start|> ... (image embeddings) ... <|vision_end|>
+            img_beg = "<|vision_start|>";
+            img_end = "<|vision_end|>";
+
+        } else if (proj == PROJECTOR_TYPE_LLAMA4) {
+            // (more details in mtmd_context constructor)
+            img_beg = "<|image_start|>";
+            img_end = "<|image_end|>";
+            LOG_WRN("%s: llama 4 vision is known to have degraded quality:\n"
+                    "    https://github.com/ggml-org/llama.cpp/pull/13282\n", __func__);
+
+        } else if (proj == PROJECTOR_TYPE_INTERNVL) {
+            // <img> ... (image embeddings) ... </img>
+            img_beg = "<img>";
+            img_end = "</img>";
+
+        }
+    }
+
+    void init_audio() {
+        GGML_ASSERT(ctx_a != nullptr);
+        projector_type proj = clip_get_projector_type(ctx_a);
+
+        if (clip_has_whisper_encoder(ctx_a)) {
             // TODO @ngxson : check if model n_mel is 128 or 80
             w_filters = whisper_precalc_filters::get_128_bins();
         }
 
-        // warning messages
-        if (proj == PROJECTOR_TYPE_LLAMA4) {
-            LOG_WRN("%s: llama 4 vision is known to have degraded quality:\n"
-                    "    https://github.com/ggml-org/llama.cpp/pull/13282\n", __func__);
+        LOG_WRN("%s: audio input is in experimental stage and may have reduced quality:\n"
+                "    https://github.com/ggml-org/llama.cpp/discussions/13759\n", __func__);
+
+        if (proj == PROJECTOR_TYPE_QWEN2A) {
+            // <|audio_bos|> ... (embeddings) ... <|audio_eos|>
+            aud_beg = "<|audio_bos|>";
+            aud_end = "<|audio_eos|>";
+
         }
-        if (has_audio) {
-            LOG_WRN("%s: audio input is in experimental stage and may have reduced quality:\n"
-                    "    https://github.com/ggml-org/llama.cpp/discussions/13759\n", __func__);
+    }
+
+    // get clip ctx based on chunk type
+    clip_ctx * get_clip_ctx(const mtmd_input_chunk * chunk) const {
+        if (chunk->type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
+            return ctx_v;
+        } else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
+            return ctx_a;
         }
+        GGML_ABORT("unknown chunk type");
+    }
+
+    projector_type proj_type_v() const {
+        return ctx_v ? clip_get_projector_type(ctx_v) : PROJECTOR_TYPE_UNKNOWN;
+    }
+
+    projector_type proj_type_a() const {
+        return ctx_a ? clip_get_projector_type(ctx_a) : PROJECTOR_TYPE_UNKNOWN;
     }
 
     ~mtmd_context() {
-        clip_free(ctx_clip);
+        clip_free(ctx_a);
+        clip_free(ctx_v);
     }
 
 private:
@@ -267,167 +359,137 @@ void mtmd_free(mtmd_context * ctx) {
     }
 }
 
-// copied from common_tokenize
-static std::vector<llama_token> mtmd_tokenize_text_internal(
-    const struct llama_vocab * vocab,
-           const std::string & text,
-                        bool   add_special,
-                        bool   parse_special) {
-    // upper limit for the number of tokens
-    int n_tokens = text.length() + 2 * add_special;
-    std::vector<llama_token> result(n_tokens);
-    n_tokens = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
-    if (n_tokens < 0) {
-        result.resize(-n_tokens);
-        int check = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
-        GGML_ASSERT(check == -n_tokens);
-    } else {
-        result.resize(n_tokens);
-    }
-    return result;
-}
+struct mtmd_tokenizer {
+    mtmd_context * ctx;
+    std::vector<const mtmd_bitmap *> bitmaps;
 
-int32_t mtmd_tokenize(mtmd_context * ctx,
-            mtmd_input_chunks * output,
+    std::string input_text;
+    bool add_special;
+    bool parse_special;
+    const llama_vocab * vocab;
+
+    mtmd_input_chunks cur;
+
+    mtmd_tokenizer(mtmd_context * ctx,
             const mtmd_input_text * text,
             const mtmd_bitmap ** bitmaps,
-            size_t n_bitmaps) {
-    auto vocab = llama_model_get_vocab(ctx->text_model);
-
-    std::string prompt_modified(text->text);
-    std::string marker_modified(ctx->media_marker);
-    projector_type proj_type = clip_get_projector_type(ctx->ctx_clip);
-
-    // for compatibility, we convert image marker to media marker
-    string_replace_all(prompt_modified, MTMD_DEFAULT_IMAGE_MARKER, ctx->media_marker);
-
-    // a bit hacky here, but works for now
-    // for some models, we need to add prefix and suffix to the image embeddings
-    if (clip_is_gemma3(ctx->ctx_clip)) {
-        // gemma 3
-        // <start_of_image> ... (image embeddings) ... <end_of_image>
-        marker_modified = "<start_of_image>" + ctx->media_marker + "<end_of_image>";
-        string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
-
-    } else if (proj_type == PROJECTOR_TYPE_IDEFICS3) {
-        // https://github.com/huggingface/transformers/blob/a42ba80fa520c784c8f11a973ca9034e5f859b79/src/transformers/models/idefics3/processing_idefics3.py#L192-L215
-        marker_modified = "<fake_token_around_image><global-img>" + ctx->media_marker + "<fake_token_around_image>";
-        string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
-
-    } else if (proj_type == PROJECTOR_TYPE_PIXTRAL) {
-        // https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md
-        marker_modified = ctx->media_marker + "[IMG_END]";
-        string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
-
-    } else if (proj_type == PROJECTOR_TYPE_QWEN2VL || proj_type == PROJECTOR_TYPE_QWEN25VL) {
-        // <|vision_start|> ... (image embeddings) ... <|vision_end|>
-        marker_modified = "<|vision_start|>" + ctx->media_marker + "<|vision_end|>";
-        string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
-
-    } else if (proj_type == PROJECTOR_TYPE_LLAMA4) {
-        // (more details in mtmd_context constructor)
-        marker_modified = "<|image_start|>" + ctx->media_marker + "<|image_end|>";
-        string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
-
-    } else if (proj_type == PROJECTOR_TYPE_INTERNVL) {
-        // <img> ... (image embeddings) ... </img>
-        marker_modified = "<img>" + ctx->media_marker + "</img>";
-        string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
-
-    } else if (proj_type == PROJECTOR_TYPE_QWEN2A) {
-        // <|audio_bos|> ... (embeddings) ... <|audio_eos|>
-        marker_modified = "<|audio_bos|>" + ctx->media_marker + "<|audio_eos|>";
-        string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
-
-    }
-
-    // llava-1.5, llava-1.6, Yi-VL, Yi-34B, granite: don't need to add prefix and suffix
-    // for glm-edge, BOI and EOI token's embeddings are not present in the text model
-
-    std::vector<std::string> parts = string_split_str(prompt_modified, ctx->media_marker);
-    output->entries.clear();
-    output->entries.reserve(parts.size());
-
-    size_t i_bm = 0;
-
-    // utility for adding raw tokens
-    auto add_text_chunk = [&output](std::vector<llama_token> && tokens) {
-        mtmd_input_chunk chunk{
-            MTMD_INPUT_CHUNK_TYPE_TEXT,
-            std::move(tokens),
-            nullptr, // image tokens
-            nullptr, // audio tokens
-        };
-        output->entries.emplace_back(std::move(chunk));
-    };
+            size_t n_bitmaps) : ctx(ctx), bitmaps(bitmaps, bitmaps + n_bitmaps) {
+        add_special   = text->add_special;
+        parse_special = text->parse_special;
+        input_text    = text->text;
+        vocab         = llama_model_get_vocab(ctx->text_model);
+
+        // for compatibility, we convert image marker to media marker
+        string_replace_all(input_text, MTMD_DEFAULT_IMAGE_MARKER, ctx->media_marker);
+    }
 
-    // utility for splitting batch of multiple images into chunks of batch having single images
-    auto split_batch_to_chunk = [&ctx](clip_image_f32_batch && batch_f32, const std::string & id) {
-        std::vector<mtmd_input_chunk> chunks;
+    int32_t tokenize(mtmd_input_chunks * output) {
+        cur.entries.clear();
+        std::vector<std::string> parts = split_text(input_text, ctx->media_marker);
+        size_t i_bm = 0; // index of the current bitmap
+        for (auto & part : parts) {
+            if (part == ctx->media_marker) {
+                // this is a marker, we should add the next bitmap
+                if (i_bm >= bitmaps.size()) {
+                    LOG_ERR("%s: error: number of bitmaps (%zu) does not match number of markers (%zu)\n",
+                            __func__, bitmaps.size(), parts.size() - 1);
+                    return 1;
+                }
+                const mtmd_bitmap * bitmap = bitmaps[i_bm++];
+                int32_t res = add_media(bitmap);
+                if (res != 0) {
+                    return res;
+                }
+            } else {
+                // this is a text part, we should add it as text
+                add_text(part, parse_special);
+            }
+        }
 
-        for (auto & entry : batch_f32.entries) {
-            mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens);
-            image_tokens->nx = clip_n_output_tokens(ctx->ctx_clip, entry.get());
-            image_tokens->ny = 1;
-            image_tokens->batch_f32.entries.push_back(std::move(entry));
-            image_tokens->id = id;
+        if (add_special && llama_vocab_get_add_bos(vocab)) {
+            // if first chunk is text, we add BOS token to first text chunk
+            // otherwise, create a new text chunk with BOS token
+            if (!cur.entries.empty() && cur.entries[0].type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
+                // add BOS token to the beginning of first text chunk
+                cur.entries[0].tokens_text.insert(cur.entries[0].tokens_text.begin(), llama_vocab_bos(vocab));
+            } else {
+                // create a new text chunk with BOS token at the beginning
+                mtmd_input_chunk bos_chunk{
+                    MTMD_INPUT_CHUNK_TYPE_TEXT,
+                    {llama_vocab_bos(vocab)},
+                    nullptr, // image tokens
+                    nullptr, // audio tokens
+                };
+                cur.entries.insert(cur.entries.begin(), std::move(bos_chunk));
+            }
+        }
 
-            mtmd_input_chunk chunk{
-                MTMD_INPUT_CHUNK_TYPE_IMAGE,
-                {}, // text tokens
-                std::move(image_tokens),
-                nullptr, // audio tokens
-            };
-            chunks.emplace_back(std::move(chunk));
+        if (add_special && llama_vocab_get_add_eos(vocab)) {
+            // if last chunk is text, we add EOS token to it
+            add_text({llama_vocab_eos(vocab)});
         }
 
-        return chunks;
-    };
+        if (i_bm != bitmaps.size()) {
+            LOG_ERR("%s: error: number of bitmaps (%zu) does not match number of markers (%zu)\n",
+                    __func__, bitmaps.size(), parts.size() - 1);
+            return 1;
+        }
+
+        *output = std::move(cur);
+
+        return 0;
+    }
+
+    void add_text(const std::string & txt, bool parse_special) {
+        LOG_DBG("%s: %s\n", __func__, txt.c_str());
+        auto tokens = mtmd_tokenize_text_internal(vocab, txt, /* add_special */ false, parse_special);
+        add_text(tokens);
+    }
 
-    for (const auto & part : parts) {
-        // printf("tokenizing part: %s\n", part.c_str());
-        bool add_bos = &parts.front() == &part;
-        auto tokens = mtmd_tokenize_text_internal(vocab, part, text->add_special && add_bos, text->parse_special);
+    void add_text(const std::vector<llama_token> & tokens) {
         if (tokens.empty()) {
-            continue;
+            return;
         }
-        mtmd_input_chunk chunk{
-            MTMD_INPUT_CHUNK_TYPE_TEXT,
-            std::move(tokens),
-            nullptr, // image tokens
-            nullptr, // audio tokens
-        };
-        output->entries.emplace_back(std::move(chunk));
-
-        // only add image/audio tokens to middle of 2 parts
-        // therefore, we skip handling image/audio if this is the last part
-        if (&parts.back() == &part) {
-            continue;
+        // if last entry is also a text chunk, add tokens to it instead of creating new chunk
+        if (!cur.entries.empty() && cur.entries.back().type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
+            cur.entries.back().tokens_text.insert(
+                                            cur.entries.back().tokens_text.end(),
+                                            tokens.begin(),
+                                            tokens.end());
+        } else {
+            mtmd_input_chunk chunk{
+                MTMD_INPUT_CHUNK_TYPE_TEXT,
+                tokens,
+                nullptr, // image tokens
+                nullptr, // audio tokens
+            };
+            cur.entries.emplace_back(std::move(chunk));
         }
+    }
 
-        if (!bitmaps[i_bm]->is_audio) {
+    int32_t add_media(const mtmd_bitmap * bitmap) {
+        if (!bitmap->is_audio) {
             // handle image
 
-            if (i_bm >= n_bitmaps) {
-                LOG_ERR("%s: error: not enough images for %d parts\n", __func__, (int)parts.size());
-                return 1;
-            }
-
-            if (!ctx->has_vision) {
+            if (!ctx->ctx_v) {
                 LOG_ERR("%s: error: model does not support vision input\n", __func__);
                 return 2;
             }
 
+            if (!ctx->img_beg.empty()) {
+                add_text(ctx->img_beg, true); // add image begin token
+            }
+
             // convert mtmd_bitmap to clip_image_u8
             clip_image_u8_ptr img_u8(clip_image_u8_init());
-            img_u8->nx = bitmaps[i_bm]->nx;
-            img_u8->ny = bitmaps[i_bm]->ny;
-            img_u8->buf.resize(bitmaps[i_bm]->data.size());
-            std::memcpy(img_u8->buf.data(), bitmaps[i_bm]->data.data(), img_u8->nx * img_u8->ny * 3);
+            img_u8->nx = bitmap->nx;
+            img_u8->ny = bitmap->ny;
+            img_u8->buf.resize(bitmap->data.size());
+            std::memcpy(img_u8->buf.data(), bitmap->data.data(), img_u8->nx * img_u8->ny * 3);
 
             // preprocess image
             clip_image_f32_batch batch_f32;
-            bool ok = clip_image_preprocess(ctx->ctx_clip, img_u8.get(), &batch_f32);
+            bool ok = clip_image_preprocess(ctx->ctx_v, img_u8.get(), &batch_f32);
             if (!ok) {
                 LOG_ERR("Unable to preprocess image\n");
                 return 2;
@@ -440,7 +502,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
                 || ctx->slice_tmpl == MTMD_SLICE_TMPL_LLAMA4
             ) {
                 // split batch into chunks of single images
-                auto chunks = split_batch_to_chunk(std::move(batch_f32), bitmaps[i_bm]->id);
+                auto chunks = split_batch_to_chunk(std::move(batch_f32), bitmap->id);
                 GGML_ASSERT(chunks.size() > 0);
 
                 auto ov_chunk = std::move(chunks.front());
@@ -449,11 +511,11 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
                 // add overview image (first)
                 if (ctx->ov_img_first) {
                     if (ctx->tok_ov_img_start != LLAMA_TOKEN_NULL) {
-                        add_text_chunk({ctx->tok_ov_img_start});
+                        add_text({ctx->tok_ov_img_start});
                     }
-                    output->entries.emplace_back(std::move(ov_chunk));
+                    cur.entries.emplace_back(std::move(ov_chunk));
                     if (ctx->tok_ov_img_end != LLAMA_TOKEN_NULL) {
-                        add_text_chunk({ctx->tok_ov_img_end});
+                        add_text({ctx->tok_ov_img_end});
                     }
                 }
 
@@ -462,53 +524,53 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
                     const int n_col = batch_f32.grid_x;
                     const int n_row = batch_f32.grid_y;
                     if (ctx->tok_slices_start != LLAMA_TOKEN_NULL) {
-                        add_text_chunk({ctx->tok_slices_start});
+                        add_text({ctx->tok_slices_start});
                     }
                     for (int y = 0; y < n_row; y++) {
                         for (int x = 0; x < n_col; x++) {
                             const bool is_last_in_row = (x == n_col - 1);
                             if (ctx->tok_sli_img_start != LLAMA_TOKEN_NULL) {
-                                add_text_chunk({ctx->tok_sli_img_start});
+                                add_text({ctx->tok_sli_img_start});
                             }
-                            output->entries.emplace_back(std::move(chunks[y * n_col + x]));
+                            cur.entries.emplace_back(std::move(chunks[y * n_col + x]));
                             if (ctx->tok_sli_img_end != LLAMA_TOKEN_NULL) {
-                                add_text_chunk({ctx->tok_sli_img_end});
+                                add_text({ctx->tok_sli_img_end});
                             }
                             if (!is_last_in_row && ctx->tok_sli_img_mid != LLAMA_TOKEN_NULL) {
-                                add_text_chunk({ctx->tok_sli_img_mid});
+                                add_text({ctx->tok_sli_img_mid});
                             }
                         }
                         if ((y != n_row - 1 || ctx->tok_row_end_trail) && ctx->tok_row_end != LLAMA_TOKEN_NULL) {
-                            add_text_chunk({ctx->tok_row_end});
+                            add_text({ctx->tok_row_end});
                         }
                     }
                     if (ctx->tok_slices_end != LLAMA_TOKEN_NULL) {
-                        add_text_chunk({ctx->tok_slices_end});
+                        add_text({ctx->tok_slices_end});
                     }
                 }
 
                 // add overview image (last)
                 if (!ctx->ov_img_first) {
                     if (ctx->tok_ov_img_start != LLAMA_TOKEN_NULL) {
-                        add_text_chunk({ctx->tok_ov_img_start});
+                        add_text({ctx->tok_ov_img_start});
                     }
-                    output->entries.emplace_back(std::move(ov_chunk));
+                    cur.entries.emplace_back(std::move(ov_chunk));
                     if (ctx->tok_ov_img_end != LLAMA_TOKEN_NULL) {
-                        add_text_chunk({ctx->tok_ov_img_end});
+                        add_text({ctx->tok_ov_img_end});
                     }
                 }
 
             } else {
                 size_t n_tokens = 0;
                 for (const auto & entry : batch_f32.entries) {
-                    n_tokens += clip_n_output_tokens(ctx->ctx_clip, entry.get());
+                    n_tokens += clip_n_output_tokens(ctx->ctx_v, entry.get());
                 }
 
                 mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens);
                 if (ctx->use_mrope) {
                     // for Qwen2VL, we need this information for M-RoPE decoding positions
-                    image_tokens->nx = clip_n_output_tokens_x(ctx->ctx_clip, batch_f32.entries[0].get());
-                    image_tokens->ny = clip_n_output_tokens_y(ctx->ctx_clip, batch_f32.entries[0].get());
+                    image_tokens->nx = clip_n_output_tokens_x(ctx->ctx_v, batch_f32.entries[0].get());
+                    image_tokens->ny = clip_n_output_tokens_y(ctx->ctx_v, batch_f32.entries[0].get());
                     image_tokens->use_mrope_pos = true;
                 } else {
                     // other models, we only need the total number of tokens
@@ -516,7 +578,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
                     image_tokens->ny = 1;
                 }
                 image_tokens->batch_f32 = std::move(batch_f32);
-                image_tokens->id = bitmaps[i_bm]->id; // optional
+                image_tokens->id = bitmap->id; // optional
 
                 LOG_DBG("image_tokens->nx = %d\n", image_tokens->nx);
                 LOG_DBG("image_tokens->ny = %d\n", image_tokens->ny);
@@ -528,35 +590,35 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
                     std::move(image_tokens),
                     nullptr, // audio tokens
                 };
-                output->entries.emplace_back(std::move(chunk));
+                cur.entries.emplace_back(std::move(chunk));
             }
 
-            i_bm++; // move to next image
-            continue;
+            if (!ctx->img_end.empty()) {
+                add_text(ctx->img_end, true); // add image end token
+            }
 
         } else {
             // handle audio
 
-            if (i_bm >= n_bitmaps) {
-                LOG_ERR("%s: error: not enough images for %d parts\n", __func__, (int)parts.size());
-                return 1;
-            }
-
-            if (!ctx->has_audio) {
+            if (!ctx->ctx_a) {
                 LOG_ERR("%s: error: model does not support audio input\n", __func__);
                 return 2;
             }
 
-            if (bitmaps[i_bm]->data.size() == 0) {
+            if (bitmap->data.size() == 0) {
                 LOG_ERR("%s: error: empty audio data\n", __func__);
                 return 2;
             }
 
+            if (!ctx->aud_beg.empty()) {
+                add_text(ctx->aud_beg, true); // add audio begin token
+            }
+
             // preprocess audio
             GGML_ASSERT(ctx->w_filters.n_mel); // make sure we have filter preloaded
             std::vector<whisper_preprocessor::whisper_mel> mel_spec_chunks;
-            const float * samples = (const float *)bitmaps[i_bm]->data.data();
-            size_t n_samples = bitmaps[i_bm]->data.size() / sizeof(float);
+            const float * samples = (const float *)bitmap->data.data();
+            size_t n_samples = bitmap->data.size() / sizeof(float);
             bool ok = whisper_preprocessor::preprocess_audio(samples, n_samples, ctx->w_filters, mel_spec_chunks);
             if (!ok) {
                 LOG_ERR("Unable to preprocess audio\n");
@@ -570,7 +632,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
                 mel_f32->nx  = mel_spec.n_len;
                 mel_f32->ny  = mel_spec.n_mel;
                 mel_f32->buf = std::move(mel_spec.data);
-                size_t n_tokens = clip_n_output_tokens(ctx->ctx_clip, mel_f32.get());
+                size_t n_tokens = clip_n_output_tokens(ctx->ctx_a, mel_f32.get());
 
                 clip_image_f32_batch batch_f32;
                 batch_f32.is_audio = true;
@@ -579,7 +641,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
                 mtmd_audio_tokens_ptr audio_tokens(new mtmd_audio_tokens);
                 audio_tokens->n_tokens = n_tokens;
                 audio_tokens->batch_f32 = std::move(batch_f32);
-                audio_tokens->id = bitmaps[i_bm]->id; // optional
+                audio_tokens->id = bitmap->id; // optional
 
                 LOG_DBG("audio_tokens->n_tokens = %d\n", audio_tokens->n_tokens);
 
@@ -589,15 +651,88 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
                     nullptr, // image tokens
                     std::move(audio_tokens),
                 };
-                output->entries.emplace_back(std::move(chunk));
+                cur.entries.emplace_back(std::move(chunk));
             }
 
-            i_bm++;
-            continue;
+            if (!ctx->aud_end.empty()) {
+                add_text(ctx->aud_end, true); // add audio end token
+            }
         }
+
+        return 0;
     }
 
-    return 0;
+    std::vector<mtmd_input_chunk> split_batch_to_chunk(clip_image_f32_batch && batch_f32, const std::string & id) {
+        std::vector<mtmd_input_chunk> chunks;
+
+        for (auto & entry : batch_f32.entries) {
+            mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens);
+            image_tokens->nx = clip_n_output_tokens(ctx->ctx_v, entry.get());
+            image_tokens->ny = 1;
+            image_tokens->batch_f32.entries.push_back(std::move(entry));
+            image_tokens->id = id;
+
+            mtmd_input_chunk chunk{
+                MTMD_INPUT_CHUNK_TYPE_IMAGE,
+                {}, // text tokens
+                std::move(image_tokens),
+                nullptr, // audio tokens
+            };
+            chunks.emplace_back(std::move(chunk));
+        }
+
+        return chunks;
+    }
+
+    // for example: "a <__media__> b <__media__> c" --> "a", "<__media__>", "b", "<__media__>", "c"
+    static std::vector<std::string> split_text(const std::string & input, const std::string & delimiter) {
+        std::vector<std::string> result;
+        if (input.empty()) {
+            return result;
+        }
+        size_t start = 0;
+        size_t pos = 0;
+        while ((pos = input.find(delimiter, start)) != std::string::npos) {
+            if (pos > start) {
+                result.push_back(input.substr(start, pos - start));
+            }
+            result.push_back(delimiter);
+            start = pos + delimiter.length();
+        }
+        if (start < input.length()) {
+            result.push_back(input.substr(start));
+        }
+        return result;
+    }
+
+    // copied from common_tokenize
+    static std::vector<llama_token> mtmd_tokenize_text_internal(
+        const struct llama_vocab * vocab,
+               const std::string & text,
+                            bool   add_special,
+                            bool   parse_special) {
+        // upper limit for the number of tokens
+        int n_tokens = text.length() + 2 * add_special;
+        std::vector<llama_token> result(n_tokens);
+        n_tokens = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
+        if (n_tokens < 0) {
+            result.resize(-n_tokens);
+            int check = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
+            GGML_ASSERT(check == -n_tokens);
+        } else {
+            result.resize(n_tokens);
+        }
+        return result;
+    }
+};
+
+int32_t mtmd_tokenize(mtmd_context * ctx,
+            mtmd_input_chunks * output,
+            const mtmd_input_text * text,
+            const mtmd_bitmap ** bitmaps,
+            size_t n_bitmaps) {
+    mtmd_tokenizer tokenizer(ctx, text, bitmaps, n_bitmaps);
+    return tokenizer.tokenize(output);
 }
 
 int32_t mtmd_encode_chunk(mtmd_context * ctx, const mtmd_input_chunk * chunk) {
@@ -605,41 +740,54 @@ int32_t mtmd_encode_chunk(mtmd_context * ctx, const mtmd_input_chunk * chunk) {
         LOG_WRN("mtmd_encode_chunk has no effect for text chunks\n");
         return 0;
     } else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
+        if (!ctx->ctx_v) {
+            LOG_ERR("%s: model does not support vision input\n", __func__);
+            return 1;
+        }
         return mtmd_encode(ctx, chunk->tokens_image.get());
     } else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
-        int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
+        if (!ctx->ctx_a) {
+            LOG_ERR("%s: model does not support audio input\n", __func__);
+            return 1;
+        }
+        int n_mmproj_embd = ctx->n_embd_text;
         ctx->image_embd_v.resize(chunk->tokens_audio->n_tokens * n_mmproj_embd);
         bool ok = clip_image_batch_encode(
-            ctx->ctx_clip,
+            ctx->ctx_a,
             ctx->n_threads,
             &chunk->tokens_audio->batch_f32,
             ctx->image_embd_v.data());
         return ok ? 0 : 1;
     }
 
-    LOG_ERR("mtmd_encode_chunk: unknown chunk type %d\n", (int)chunk->type);
+    LOG_ERR("%s: unknown chunk type %d\n", __func__, (int)chunk->type);
     return 1;
 }
 
 int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) {
-    int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
+    clip_ctx * ctx_clip = ctx->ctx_v;
+    if (!ctx_clip) {
+        LOG_ERR("%s: this API does not support non-vision input, please use mtmd_encode_chunk instead\n", __func__);
+        return 1;
+    }
+    int n_mmproj_embd = clip_n_mmproj_embd(ctx_clip);
     ctx->image_embd_v.resize(image_tokens->n_tokens() * n_mmproj_embd);
     bool ok = false;
 
-    if (clip_is_llava(ctx->ctx_clip) || clip_is_minicpmv(ctx->ctx_clip) || clip_is_glm(ctx->ctx_clip)) {
+    if (clip_is_llava(ctx_clip) || clip_is_minicpmv(ctx_clip) || clip_is_glm(ctx_clip)) {
         // TODO @ngxson : llava does not support batched encoding ; this should be fixed inside clip_image_batch_encode()
         const auto & entries = image_tokens->batch_f32.entries;
         for (size_t i = 0; i < entries.size(); i++) {
-            int n_tokens_per_image = clip_n_output_tokens(ctx->ctx_clip, entries[i].get());
+            int n_tokens_per_image = clip_n_output_tokens(ctx_clip, entries[i].get());
             ok = clip_image_encode(
-                ctx->ctx_clip,
+                ctx_clip,
                 ctx->n_threads,
                 entries[i].get(),
                 ctx->image_embd_v.data() + i*n_mmproj_embd*n_tokens_per_image);
         }
     } else {
         ok = clip_image_batch_encode(
-            ctx->ctx_clip,
+            ctx_clip,
             ctx->n_threads,
             &image_tokens->batch_f32,
             ctx->image_embd_v.data());
@@ -653,8 +801,7 @@ float * mtmd_get_output_embd(mtmd_context * ctx) {
 }
 
 bool mtmd_decode_use_non_causal(mtmd_context * ctx) {
-    projector_type proj_type = clip_get_projector_type(ctx->ctx_clip);
-    if (proj_type == PROJECTOR_TYPE_GEMMA3) {
+    if (ctx->ctx_v && clip_get_projector_type(ctx->ctx_v) == PROJECTOR_TYPE_GEMMA3) {
         return true;
     }
     return false;
@@ -665,11 +812,11 @@ bool mtmd_decode_use_mrope(mtmd_context * ctx) {
 }
 
 bool mtmd_support_vision(mtmd_context * ctx) {
-    return ctx->has_vision;
+    return ctx->ctx_v != nullptr;
 }
 
 bool mtmd_support_audio(mtmd_context * ctx) {
-    return ctx->has_audio;
+    return ctx->ctx_a != nullptr;
 }
 
 // these 2 helpers below use internal clip_image_u8_ptr,
diff --git a/tools/mtmd/test-2.mp3 b/tools/mtmd/test-2.mp3
new file mode 100644 (file)
index 0000000..aa9d7ec
Binary files /dev/null and b/tools/mtmd/test-2.mp3 differ
index 15a37b0d22bb46ae0bed8a06630e1b4e8df87d30..aa0019893283ecbda1cd387a5e19a5144994f114 100755 (executable)
@@ -25,80 +25,99 @@ RUN_HUGE_TESTS=false
 if [ "${1:-}" = "huge" ]; then
     RUN_HUGE_TESTS=true
     RUN_BIG_TESTS=true
-    echo "Include BIG models..."
+    echo "Include BIG and HUGE models..."
 fi
 
 ###############
 
-arr_bin=()
+arr_prefix=()
 arr_hf=()
 arr_tmpl=() # chat template
+arr_file=()
 
-add_test() {
-    local bin=$1
-    local hf=$2
-    local tmpl=${3:-""} # default to empty string if not provided
-    arr_bin+=("$bin")
+add_test_vision() {
+    local hf=$1
+    local tmpl=${2:-""} # default to empty string if not provided
+    arr_prefix+=("[vision]")
     arr_hf+=("$hf")
     arr_tmpl+=("$tmpl")
+    arr_file+=("test-1.jpeg")
+}
+
+add_test_audio() {
+    local hf=$1
+    arr_prefix+=("[audio] ")
+    arr_hf+=("$hf")
+    arr_tmpl+=("") # no need for chat tmpl
+    arr_file+=("test-2.mp3")
 }
 
-add_test "llama-mtmd-cli"  "ggml-org/SmolVLM-500M-Instruct-GGUF:Q8_0"
-add_test "llama-mtmd-cli"  "ggml-org/SmolVLM2-2.2B-Instruct-GGUF:Q4_K_M"
-add_test "llama-mtmd-cli"  "ggml-org/SmolVLM2-500M-Video-Instruct-GGUF:Q8_0"
-add_test "llama-mtmd-cli"  "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M"
-add_test "llama-mtmd-cli"  "THUDM/glm-edge-v-5b-gguf:Q4_K_M"
-add_test "llama-mtmd-cli"  "second-state/Llava-v1.5-7B-GGUF:Q2_K"            "vicuna"
-add_test "llama-mtmd-cli"  "cjpais/llava-1.6-mistral-7b-gguf:Q3_K_M"         "vicuna"
-add_test "llama-mtmd-cli"  "ibm-research/granite-vision-3.2-2b-GGUF:Q4_K_M"
-add_test "llama-mtmd-cli"  "second-state/MiniCPM-Llama3-V-2_5-GGUF:Q2_K"  # model from openbmb is corrupted
-add_test "llama-mtmd-cli"  "openbmb/MiniCPM-V-2_6-gguf:Q2_K"
-add_test "llama-mtmd-cli"  "openbmb/MiniCPM-o-2_6-gguf:Q4_0"
-add_test "llama-mtmd-cli"  "bartowski/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M"
-add_test "llama-mtmd-cli"  "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M"
-add_test "llama-mtmd-cli"  "ggml-org/InternVL2_5-1B-GGUF:Q8_0"
-add_test "llama-mtmd-cli"  "ggml-org/InternVL3-1B-Instruct-GGUF:Q8_0"
+add_test_vision "ggml-org/SmolVLM-500M-Instruct-GGUF:Q8_0"
+add_test_vision "ggml-org/SmolVLM2-2.2B-Instruct-GGUF:Q4_K_M"
+add_test_vision "ggml-org/SmolVLM2-500M-Video-Instruct-GGUF:Q8_0"
+add_test_vision "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M"
+add_test_vision "THUDM/glm-edge-v-5b-gguf:Q4_K_M"
+add_test_vision "second-state/Llava-v1.5-7B-GGUF:Q2_K"            "vicuna"
+add_test_vision "cjpais/llava-1.6-mistral-7b-gguf:Q3_K_M"         "vicuna"
+add_test_vision "ibm-research/granite-vision-3.2-2b-GGUF:Q4_K_M"
+add_test_vision "second-state/MiniCPM-Llama3-V-2_5-GGUF:Q2_K"  # model from openbmb is corrupted
+add_test_vision "openbmb/MiniCPM-V-2_6-gguf:Q2_K"
+add_test_vision "openbmb/MiniCPM-o-2_6-gguf:Q4_0"
+add_test_vision "bartowski/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M"
+add_test_vision "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M"
+add_test_vision "ggml-org/InternVL2_5-1B-GGUF:Q8_0"
+add_test_vision "ggml-org/InternVL3-1B-Instruct-GGUF:Q8_0"
+add_test_vision "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M"
+
+add_test_audio  "ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF:Q8_0"
+add_test_audio  "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M"
 
 # to test the big models, run: ./tests.sh big
 if [ "$RUN_BIG_TESTS" = true ]; then
-    add_test "llama-mtmd-cli" "ggml-org/pixtral-12b-GGUF:Q4_K_M"
-    add_test "llama-mtmd-cli" "ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF" "mistral-v7"
-    add_test "llama-mtmd-cli" "ggml-org/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M"
-    add_test "llama-mtmd-cli" "ggml-org/Qwen2-VL-7B-Instruct-GGUF:Q4_K_M"
-    add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M"
-    add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-7B-Instruct-GGUF:Q4_K_M"
-    add_test "llama-mtmd-cli" "ggml-org/InternVL3-8B-Instruct-GGUF:Q4_K_M"
-    add_test "llama-mtmd-cli" "ggml-org/InternVL3-14B-Instruct-GGUF:Q4_K_M"
-    # add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-32B-Instruct-GGUF:Q4_K_M" # does not work on my mac M3 Ultra
+    add_test_vision "ggml-org/pixtral-12b-GGUF:Q4_K_M"
+    add_test_vision "ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF" "mistral-v7"
+    add_test_vision "ggml-org/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M"
+    add_test_vision "ggml-org/Qwen2-VL-7B-Instruct-GGUF:Q4_K_M"
+    add_test_vision "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M"
+    add_test_vision "ggml-org/Qwen2.5-VL-7B-Instruct-GGUF:Q4_K_M"
+    add_test_vision "ggml-org/InternVL3-8B-Instruct-GGUF:Q4_K_M"
+    add_test_vision "ggml-org/InternVL3-14B-Instruct-GGUF:Q4_K_M"
+    add_test_vision "ggml-org/Qwen2.5-Omni-7B-GGUF:Q4_K_M"
+    # add_test_vision "ggml-org/Qwen2.5-VL-32B-Instruct-GGUF:Q4_K_M" # does not work on my mac M3 Ultra
+
+    add_test_audio  "ggml-org/ultravox-v0_5-llama-3_1-8b-GGUF:Q4_K_M"
+    add_test_audio  "ggml-org/Qwen2.5-Omni-7B-GGUF:Q4_K_M"
 fi
 
 # to test the huge models, run: ./tests.sh huge
 # this will run both the big and huge models
 # huge models are > 32B parameters
 if [ "$RUN_HUGE_TESTS" = true ]; then
-    add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-72B-Instruct-GGUF:Q4_K_M"
-    add_test "llama-mtmd-cli" "ggml-org/Llama-4-Scout-17B-16E-Instruct-GGUF:IQ1_S"
+    add_test_vision "ggml-org/Qwen2.5-VL-72B-Instruct-GGUF:Q4_K_M"
+    add_test_vision "ggml-org/Llama-4-Scout-17B-16E-Instruct-GGUF:IQ1_S"
 fi
 
 # these models always give the wrong answer, not sure why
-# add_test "llama-mtmd-cli"  "ggml-org/SmolVLM-Instruct-GGUF:Q4_K_M"
-# add_test "llama-mtmd-cli"  "ggml-org/SmolVLM-256M-Instruct-GGUF:Q8_0"
-# add_test "llama-mtmd-cli"  "ggml-org/SmolVLM2-256M-Video-Instruct-GGUF:Q8_0"
+# add_test_vision "ggml-org/SmolVLM-Instruct-GGUF:Q4_K_M"
+# add_test_vision "ggml-org/SmolVLM-256M-Instruct-GGUF:Q8_0"
+# add_test_vision "ggml-org/SmolVLM2-256M-Video-Instruct-GGUF:Q8_0"
 
 # this model has broken chat template, not usable
-# add_test "llama-mtmd-cli"  "cmp-nct/Yi-VL-6B-GGUF:Q5_K"
-# add_test "llama-mtmd-cli"  "guinmoon/MobileVLM-3B-GGUF:Q4_K_M" "deepseek"
+# add_test_vision "cmp-nct/Yi-VL-6B-GGUF:Q5_K"
+# add_test_vision "guinmoon/MobileVLM-3B-GGUF:Q4_K_M" "deepseek"
 
 ###############
 
-cmake --build build -j --target "${arr_bin[@]}"
+cmake --build build -j --target llama-mtmd-cli
 
 arr_res=()
 
-for i in "${!arr_bin[@]}"; do
-    bin="${arr_bin[$i]}"
+for i in "${!arr_hf[@]}"; do
+    bin="llama-mtmd-cli"
+    prefix="${arr_prefix[$i]}"
     hf="${arr_hf[$i]}"
     tmpl="${arr_tmpl[$i]}"
+    inp_file="${arr_file[$i]}"
 
     echo "Running test with binary: $bin and HF model: $hf"
     echo ""
@@ -107,7 +126,7 @@ for i in "${!arr_bin[@]}"; do
     output=$(\
         "$PROJ_ROOT/build/bin/$bin" \
         -hf "$hf" \
-        --image $SCRIPT_DIR/test-1.jpeg \
+        --image $SCRIPT_DIR/$inp_file \
         -p "what is the publisher name of the newspaper?" \
         --temp 0 -n 128 \
         ${tmpl:+--chat-template "$tmpl"} \
@@ -116,9 +135,9 @@ for i in "${!arr_bin[@]}"; do
     echo "$output" > $SCRIPT_DIR/output/$bin-$(echo "$hf" | tr '/' '-').log
 
     if echo "$output" | grep -iq "new york"; then
-        result="\033[32mOK\033[0m:   $bin $hf"
+        result="$prefix \033[32mOK\033[0m:   $bin $hf"
     else
-        result="\033[31mFAIL\033[0m: $bin $hf"
+        result="$prefix \033[31mFAIL\033[0m: $bin $hf"
     fi
     echo -e "$result"
     arr_res+=("$result")