]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
model: support youtu-vl model (#18479)
authortt <redacted>
Thu, 1 Jan 2026 18:25:54 +0000 (02:25 +0800)
committerGitHub <redacted>
Thu, 1 Jan 2026 18:25:54 +0000 (19:25 +0100)
* Support Youtu-VL Model

* merge code

* fix bug

* revert qwen2 code & support rsplit in minja.hpp

* update warm info

* fix annotation

* u

* revert minja.hpp

* fix

* Do not write routed_scaling_factor to gguf when routed_scaling_factor is None

* fix expert_weights_scale

* LGTM after whitespace fixes

* fix

* fix

* fix

* layers to layer_index

* enum fix

---------

Co-authored-by: Xuan-Son Nguyen <redacted>
Co-authored-by: Sigbjørn Skjæret <redacted>
17 files changed:
convert_hf_to_gguf.py
convert_hf_to_gguf_update.py
gguf-py/gguf/constants.py
gguf-py/gguf/gguf_writer.py
gguf-py/gguf/tensor_mapping.py
src/llama-model.cpp
src/llama-vocab.cpp
src/llama-vocab.h
src/models/deepseek2.cpp
src/unicode.cpp
tools/mtmd/CMakeLists.txt
tools/mtmd/clip-impl.h
tools/mtmd/clip-model.h
tools/mtmd/clip.cpp
tools/mtmd/models/models.h
tools/mtmd/models/youtuvl.cpp [new file with mode: 0644]
tools/mtmd/mtmd.cpp

index cfa60c4ede144bd8da3d3aca41f146598499e54d..7ad20c0869bdab1500337247de1599cf0b687a50 100755 (executable)
@@ -1233,6 +1233,9 @@ class TextModel(ModelBase):
         if chkhsh == "4a2e2abae11ca2b86d570fc5b44be4d5eb5e72cc8f22dd136a94b37da83ab665":
             # ref: https://huggingface.co/KORMo-Team/KORMo-tokenizer
             res = "kormo"
+        if chkhsh == "9d70134b369a70e5735009b6de918f7581b5211f7c074d1f89f753aea8248af1":
+            # ref: https://huggingface.co/tencent/Youtu-LLM-2B
+            res = "youtu"
         if chkhsh == "16389f0a1f51ee53e562ffd51c371dc508639ab0e4261502071836e50e223e91":
             # ref: https://huggingface.co/upstage/Solar-Open-100B
             res = "solar-open"
@@ -7189,6 +7192,7 @@ class DeepseekModel(TextModel):
     "DeepseekV2ForCausalLM",
     "DeepseekV3ForCausalLM",
     "KimiVLForConditionalGeneration",
+    "YoutuForCausalLM",
 )
 class DeepseekV2Model(TextModel):
     model_arch = gguf.MODEL_ARCH.DEEPSEEK2
@@ -7255,7 +7259,15 @@ class DeepseekV2Model(TextModel):
         super().set_gguf_parameters()
         hparams = self.hparams
 
-        self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
+        # first_k_dense_replace: number of leading layers using dense FFN instead of MoE
+        # For non-MoE models (like Youtu), set to n_layer to use dense FFN for all layers
+        # For MoE models (like DeepSeek-V2), this is the number of leading non-MoE layers
+        has_moe = hparams.get("n_routed_experts") is not None
+        first_k_dense_replace = hparams.get("first_k_dense_replace")
+        if first_k_dense_replace is None:
+            # Default: if no MoE, all layers are dense; if MoE, none are dense
+            first_k_dense_replace = hparams["num_hidden_layers"] if not has_moe else 0
+        self.gguf_writer.add_leading_dense_block_count(first_k_dense_replace)
         self.gguf_writer.add_vocab_size(hparams["vocab_size"])
         if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None:
             self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"])
@@ -7267,11 +7279,24 @@ class DeepseekV2Model(TextModel):
         self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
         self.gguf_writer.add_value_length_mla(hparams["v_head_dim"])
 
-        self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
-        self.gguf_writer.add_expert_count(hparams["n_routed_experts"])
-        self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
-        self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
-        self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
+        # MoE parameters (required by C++ code for DEEPSEEK2 arch)
+        # For non-MoE models like Youtu, use intermediate_size as expert_feed_forward_length
+        moe_intermediate_size = self.find_hparam(["moe_intermediate_size", "intermediate_size"], optional=False)
+        self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
+
+        if (n_routed_experts := hparams.get("n_routed_experts")) is not None:
+            self.gguf_writer.add_expert_count(n_routed_experts)
+
+        # expert_shared_count is required by C++ code, default to 0 for non-MoE models
+        n_shared_experts = hparams.get("n_shared_experts", 0)
+        self.gguf_writer.add_expert_shared_count(n_shared_experts)
+
+        # When not set, C++ code will use scale_w = false to skip the no-op scaling
+        if (routed_scaling_factor := hparams.get("routed_scaling_factor")) is not None:
+            self.gguf_writer.add_expert_weights_scale(routed_scaling_factor)
+
+        if (norm_topk_prob := hparams.get("norm_topk_prob")) is not None and norm_topk_prob:
+            self.gguf_writer.add_expert_weights_norm(norm_topk_prob)
 
         self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
 
@@ -7287,10 +7312,17 @@ class DeepseekV2Model(TextModel):
         # skip vision tensors and remove "language_model." for Kimi-VL
         if "vision_tower" in name or "multi_modal_projector" in name:
             return []
-
+        if name.startswith("siglip2.") or name.startswith("merger."):
+            return []
         if name.startswith("language_model."):
             name = name.replace("language_model.", "")
 
+        # skip lm_head.weight if tie_word_embeddings is True
+        if self.hparams.get("tie_word_embeddings", False):
+            if name == "lm_head.weight" or name == "model.lm_head.weight":
+                logger.info("Skipping tied output layer 'lm_head.weight' (will use token_embd.weight)")
+                return []
+
         # rename e_score_correction_bias tensors
         if name.endswith("e_score_correction_bias"):
             name = name.replace("e_score_correction_bias", "e_score_correction.bias")
@@ -10625,6 +10657,59 @@ class JanusProVisionModel(MmprojModel):
         return []
 
 
+@ModelBase.register("YOUTUVLForConditionalGeneration", "YOUTUVLForCausalLM")
+class YOUTUVLVisionModel(MmprojModel):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        assert self.hparams_vision is not None
+        self.hparams_vision["image_size"] = self.hparams_vision.get("image_size", 560)
+
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+
+        self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.YOUTUVL)
+        self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-6))
+
+        # Handle activation function
+        hidden_act = str(self.hparams.get("hidden_act", "gelu_pytorch_tanh")).lower()
+        if hidden_act in ("gelu", "gelu_pytorch_tanh", "gelu_fast", "gelu_new", "gelu_accurate"):
+            self.gguf_writer.add_vision_use_gelu(True)
+        elif hidden_act == "silu":
+            self.gguf_writer.add_vision_use_silu(True)
+        else:
+            raise ValueError(f"Unsupported activation function for YOUTUVL: {hidden_act}")
+
+        self.gguf_writer.add_vision_spatial_merge_size(self.hparams.get("spatial_merge_size", 2))
+
+        window_size = self.hparams.get("window_size")
+        if window_size is not None:
+            self.gguf_writer.add_vision_window_size(window_size)
+        # fullatt_block_indexes contains explicit layer indices that use full attention
+        # e.g., [2, 5, 8, 11] means layers 2, 5, 8, 11 use full attention
+        # All other layers use window attention
+        fullatt_block_indexes = self.hparams.get("fullatt_block_indexes")
+        assert fullatt_block_indexes is not None, "fullatt_block_indexes is required for youtuvl"
+        # Store the explicit layer indices for YoutuVL (irregular pattern approach)
+        self.gguf_writer.add_vision_wa_layer_indexes(layers=fullatt_block_indexes)
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        del bid  # unused
+
+        # Skip language model tensors
+        skip_prefixes = ('lm_head.', 'model.layers.', 'model.embed_tokens.', 'model.norm.')
+        if name.startswith(skip_prefixes):
+            return []
+
+        # Try to map the tensor using TensorNameMap (handles vision encoder and projector)
+        try:
+            new_name = self.map_tensor_name(name)
+            return [(new_name, data_torch)]
+        except ValueError:
+            # If mapping fails, log warning and skip
+            logger.warning(f"Cannot map tensor: {name}")
+            return []
+
+
 @ModelBase.register("SolarOpenForCausalLM")
 class SolarOpenModel(Glm4MoeModel):
     model_arch = gguf.MODEL_ARCH.GLM4_MOE
index 243cf8a29b8fa092c47a96c0df3d8928b5c91caf..74c67e6a9c0aee64d085d69b22c224e6e1fcd69f 100755 (executable)
@@ -145,6 +145,7 @@ models = [
     {"name": "granite-docling",  "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", },
     {"name": "minimax-m2",       "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/MiniMaxAI/MiniMax-M2", },
     {"name": "kormo",            "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/KORMo-Team/KORMo-tokenizer", },
+    {"name": "youtu",            "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Youtu-LLM-2B", },
     {"name": "solar-open",       "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/upstage/Solar-Open-100B", },
 ]
 
index c2a0f41c1bce613682f079e67f0ff972e95ac5b5..0ac512ff367bb03cb8b47556df5a97b3a0b976ec 100644 (file)
@@ -294,7 +294,9 @@ class Keys:
         USE_GELU            = "clip.use_gelu"
         USE_SILU            = "clip.use_silu"
         N_WA_PATTERN        = "clip.vision.n_wa_pattern" # used by qwen2.5vl
+        WA_LAYER_INDEXES    = "clip.vision.wa_layer_indexes" # used by youtuvl
         IS_DEEPSTACK_LAYERS = "clip.vision.is_deepstack_layers"
+        WINDOW_SIZE         = "clip.vision.window_size"
 
         class Attention:
             HEAD_COUNT      = "clip.vision.attention.head_count"
@@ -3494,6 +3496,7 @@ class VisionProjectorType:
     LFM2A = "lfm2a" # audio
     MUSIC_FLAMINGO = "musicflamingo" # audio
     GLM4V = "glm4v"
+    YOUTUVL = "youtuvl"
 
 
 # Items here are (block size, type size)
index 6a4a504f8dcb8deaa26e31b26b52076840d269d9..612a978e4c30c82dc667f75191d435299b51885d 100644 (file)
@@ -1129,11 +1129,40 @@ class GGUFWriter:
         self.add_uint32(Keys.ClipVision.Projector.SCALE_FACTOR, value)
 
     def add_vision_n_wa_pattern(self, value: int) -> None:
+        """Add window attention pattern interval for vision models.
+
+        This defines the pattern interval for window attention vs full attention layers.
+        For example, if n_wa_pattern=4, then layers 3, 7, 11, ... use full attention,
+        while other layers use window attention.
+
+        Used by models like Qwen2.5-VL where full attention layers follow a regular pattern.
+        """
         self.add_uint32(Keys.ClipVision.N_WA_PATTERN, value)
 
+    def add_vision_wa_layer_indexes(self, layers: Sequence[int]) -> None:
+        """Add explicit layer indexes that use full attention in vision models.
+
+        This specifies the exact layer indices (0-based) that should use full attention
+        instead of window attention. All other layers will use window attention.
+
+        Args:
+            layers: List of layer indices that use full attention (e.g., [3, 7, 11, 15])
+
+        Used by models like YoutuVL where full attention layers are explicitly specified
+        rather than following a regular pattern.
+
+        Difference from add_vision_n_wa_pattern:
+        - n_wa_pattern: Defines a regular interval pattern (every Nth layer uses full attention)
+        - wa_layer_indexes: Explicitly lists which layers use full attention (irregular pattern)
+        """
+        self.add_array(Keys.ClipVision.WA_LAYER_INDEXES, layers)
+
     def add_vision_is_deepstack_layers(self, layers: Sequence[bool]) -> None:
         self.add_array(Keys.ClipVision.IS_DEEPSTACK_LAYERS, layers)
 
+    def add_vision_window_size(self, value: int) -> None:
+        self.add_uint32(Keys.ClipVision.WINDOW_SIZE, value)
+
     # audio models
 
     def add_audio_projection_dim(self, value: int) -> None:
index 115df6c7c37db3149b581b3b7147f316abe1d5c4..64dd4ddca50ca0385f5fcd5e6952608a44a08ed6 100644 (file)
@@ -1221,6 +1221,7 @@ class TensorNameMap:
         MODEL_TENSOR.V_MMPROJ: (
             "multi_modal_projector.linear_{bid}",
             "visual.merger.mlp.{bid}", # qwen2vl
+            "merger.mlp.{bid}",
         ),
 
         MODEL_TENSOR.V_MMPROJ_FC: (
@@ -1258,6 +1259,7 @@ class TensorNameMap:
             "visual.patch_embed.proj", # qwen2vl
             "vision_tower.patch_embed.proj", # kimi-vl
             "model.vision.patch_embedding.proj", # cogvlm
+            "siglip2.vision_model.embeddings.patch_embedding",
         ),
 
         MODEL_TENSOR.V_ENC_EMBD_NORM: (
@@ -1291,6 +1293,7 @@ class TensorNameMap:
             "vision_encoder.transformer.layers.{bid}.attention.wq", # pixtral
             "visual.blocks.{bid}.attn.q", # qwen2vl, generated
             "vision_tower.encoder.blocks.{bid}.wq", # kimi-vl, generated
+            "siglip2.vision_model.encoder.layers.{bid}.self_attn.q_proj", # youtuvl
         ),
 
         MODEL_TENSOR.V_ENC_ATTN_Q_NORM: (
@@ -1308,6 +1311,7 @@ class TensorNameMap:
             "vision_encoder.transformer.layers.{bid}.attention.wk", # pixtral
             "visual.blocks.{bid}.attn.k", # qwen2vl, generated
             "vision_tower.encoder.blocks.{bid}.wk", # kimi-vl, generated
+            "siglip2.vision_model.encoder.layers.{bid}.self_attn.k_proj",
         ),
 
         MODEL_TENSOR.V_ENC_ATTN_K_NORM: (
@@ -1325,6 +1329,7 @@ class TensorNameMap:
             "vision_encoder.transformer.layers.{bid}.attention.wv", # pixtral
             "visual.blocks.{bid}.attn.v", # qwen2vl, generated
             "vision_tower.encoder.blocks.{bid}.wv", # kimi-vl, generated
+            "siglip2.vision_model.encoder.layers.{bid}.self_attn.v_proj",
         ),
 
         MODEL_TENSOR.V_ENC_INPUT_NORM: (
@@ -1339,6 +1344,7 @@ class TensorNameMap:
             "visual.blocks.{bid}.norm1", # qwen2vl
             "vision_tower.encoder.blocks.{bid}.norm0", # kimi-vl (norm0/norm1)
             "model.vision.transformer.layers.{bid}.input_layernorm", # cogvlm
+            "siglip2.vision_model.encoder.layers.{bid}.layer_norm1",
         ),
 
         MODEL_TENSOR.V_ENC_ATTN_O: (
@@ -1354,6 +1360,7 @@ class TensorNameMap:
             "visual.blocks.{bid}.attn.proj", # qwen2vl
             "vision_tower.encoder.blocks.{bid}.wo", # kimi-vl
             "model.vision.transformer.layers.{bid}.attention.dense", # cogvlm
+            "siglip2.vision_model.encoder.layers.{bid}.self_attn.out_proj", # youtuvl
         ),
 
         MODEL_TENSOR.V_ENC_POST_ATTN_NORM: (
@@ -1368,6 +1375,7 @@ class TensorNameMap:
             "visual.blocks.{bid}.norm2", # qwen2vl
             "vision_tower.encoder.blocks.{bid}.norm1", # kimi-vl (norm0/norm1)
             "model.vision.transformer.layers.{bid}.post_attention_layernorm", # cogvlm
+            "siglip2.vision_model.encoder.layers.{bid}.layer_norm2",
         ),
 
         MODEL_TENSOR.V_ENC_FFN_UP: (
@@ -1383,6 +1391,7 @@ class TensorNameMap:
             "visual.blocks.{bid}.mlp.linear_fc1", # qwen3vl
             "vision_tower.encoder.blocks.{bid}.mlp.fc0", # kimi-vl (fc0/fc1)
             "model.vision.transformer.layers.{bid}.mlp.fc1", # cogvlm
+            "siglip2.vision_model.encoder.layers.{bid}.mlp.fc1",
         ),
 
         MODEL_TENSOR.V_ENC_FFN_GATE: (
@@ -1404,6 +1413,7 @@ class TensorNameMap:
             "visual.blocks.{bid}.mlp.linear_fc2", # qwen3vl
             "vision_tower.encoder.blocks.{bid}.mlp.fc1", # kimi-vl (fc0/fc1)
             "model.vision.transformer.layers.{bid}.mlp.fc2", # cogvlm
+            "siglip2.vision_model.encoder.layers.{bid}.mlp.fc2",
         ),
 
         MODEL_TENSOR.V_LAYER_SCALE_1: (
@@ -1430,6 +1440,7 @@ class TensorNameMap:
             "visual.merger.ln_q", # qwen2vl
             "vision_tower.encoder.final_layernorm", # kimi-vl
             "visual.post_layernorm", # glm4v
+            "siglip2.vision_model.post_layernorm",
         ),
 
         MODEL_TENSOR.V_MM_POST_NORM: (
@@ -1446,6 +1457,7 @@ class TensorNameMap:
             "multi_modal_projector.pre_norm",
             "pre_mm_projector_norm",
             "model.vision.linear_proj.norm1", # cogvlm
+            "merger.ln_q",
         ),
 
         MODEL_TENSOR.V_MM_SOFT_EMB_NORM: (
index d8b1221df5ac94b0219c9162b8e40bdbf07afef4..c2cd44de448a7b7a78133902bae9d6d121b3e5aa 100644 (file)
@@ -1683,7 +1683,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                 ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla, false);
                 ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
                 ml.get_key(LLM_KV_EXPERT_SHARED_COUNT,        hparams.n_expert_shared);
-                ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE,       hparams.expert_weights_scale);
+                ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE,       hparams.expert_weights_scale, false);
                 ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM,        hparams.expert_weights_norm, false);
                 ml.get_key(LLM_KV_EXPERT_GATING_FUNC,         hparams.expert_gating_func, false);
                 if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) {
@@ -4785,7 +4785,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
 
                     // output
                     output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+                    // try to load output.weight, if not found, use token_embd (tied embeddings)
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+                    if (!output) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
 
                     for (int i = 0; i < n_layer; ++i) {
                         auto & layer = layers[i];
@@ -4848,7 +4852,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
 
                     // output
                     output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+                    // try to load output.weight, if not found, use token_embd (tied embeddings)
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+                    if (!output) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
 
                     for (int i = 0; i < n_layer; ++i) {
                         auto & layer = layers[i];
index c57055082b0c556607715750c0b3ab0197b3518a..bd311bea45b382cc8d78fbc92cbfd80decb159f9 100644 (file)
@@ -314,6 +314,12 @@ struct llm_tokenizer_bpe : llm_tokenizer {
                     "[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
                 };
                 break;
+            case LLAMA_VOCAB_PRE_TYPE_YOUTU:
+                regex_exprs = {
+                    "[가-힣ㄱ-ㆎ]+|[!…“”‘’—:;,、-〿︰-﹏]+|[ㄅ-ㄯ]+|[一-龥぀-ゟ゠-ヿ]+",
+                    "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
+                };
+                break;
             case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER:
                 regex_exprs = {
                     "[\r\n]",
@@ -1861,6 +1867,11 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                     tokenizer_pre == "deepseek-v3") {
                 pre_type = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM;
                 clean_spaces = false;
+            } else if (
+                    tokenizer_pre == "youtu") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_YOUTU;
+                clean_spaces = false;
+                ignore_merges = true;
             } else if (
                     tokenizer_pre == "falcon") {
                 pre_type = LLAMA_VOCAB_PRE_TYPE_FALCON;
index f5bdd2231183980a83ae1c67d4b10ae761aad78f..2b240a5491bed19bebf2aa83fc5477eafe97eab5 100644 (file)
@@ -52,6 +52,7 @@ enum llama_vocab_pre_type {
     LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2      = 41,
     LLAMA_VOCAB_PRE_TYPE_AFMOE           = 42,
     LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN      = 43,
+    LLAMA_VOCAB_PRE_TYPE_YOUTU           = 44,
 };
 
 struct LLM_KV;
index 49382874baae10ec831ed07c24fd6f7933e1a57f..ca63a62ad1b1e3e6f69e0ba21a05cf8b71288112 100644 (file)
@@ -215,7 +215,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
                 model.layers[il].ffn_exp_probs_b,
                 n_expert, n_expert_used,
                 LLM_FFN_SILU, hparams.expert_weights_norm,
-                true, hparams.expert_weights_scale,
+                hparams.expert_weights_scale, hparams.expert_weights_scale,
                 (llama_expert_gating_func_type) hparams.expert_gating_func,
                 il);
             cb(moe_out, "ffn_moe_out", il);
index bb44edfaddffdbf159c26629a815fa8c7a4d8cd8..b47dcbe6198a82cd42ab171925aa1200eb1603e8 100644 (file)
@@ -964,6 +964,11 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
         { "\\p{P}", unicode_cpt_flags::PUNCTUATION },
         { "\\p{M}", unicode_cpt_flags::ACCENT_MARK },
         { "\\p{S}", unicode_cpt_flags::SYMBOL },
+        { "\\p{Lu}", unicode_cpt_flags::LETTER }, // Uppercase letter
+        { "\\p{Ll}", unicode_cpt_flags::LETTER }, // Lowercase letter
+        { "\\p{Lt}", unicode_cpt_flags::LETTER }, // Titlecase letter
+        { "\\p{Lm}", unicode_cpt_flags::LETTER }, // Modifier letter
+        { "\\p{Lo}", unicode_cpt_flags::LETTER }, // Other letter
     };
 
     static const std::map<int, int> k_ucat_cpt = {
@@ -1074,22 +1079,26 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
                         continue;
                     }
 
-                    if (regex_expr[i + 0] == '\\' && i + 4 < regex_expr.size() &&
+                    // Match \p{...} Unicode properties of varying lengths
+                    if (regex_expr[i + 0] == '\\' && i + 3 < regex_expr.size() &&
                         regex_expr[i + 1] == 'p' &&
-                        regex_expr[i + 2] == '{' &&
-                        regex_expr[i + 4] == '}') {
-                        const std::string pat = regex_expr.substr(i, 5);
-                        if (k_ucat_enum.find(pat) != k_ucat_enum.end()) {
-                            if (!inside) {
-                                regex_expr_collapsed += '[';
+                        regex_expr[i + 2] == '{') {
+                        // Find the closing brace
+                        size_t closing_brace = regex_expr.find('}', i + 3);
+                        if (closing_brace != std::string::npos && closing_brace <= i + 10) { // reasonable limit
+                            const std::string pat = regex_expr.substr(i, closing_brace - i + 1);
+                            if (k_ucat_enum.find(pat) != k_ucat_enum.end()) {
+                                if (!inside) {
+                                    regex_expr_collapsed += '[';
+                                }
+                                regex_expr_collapsed += k_ucat_cpt.at(k_ucat_enum.at(pat));
+                                regex_expr_collapsed += k_ucat_map.at(k_ucat_enum.at(pat));
+                                if (!inside) {
+                                    regex_expr_collapsed += ']';
+                                }
+                                i = closing_brace;
+                                continue;
                             }
-                            regex_expr_collapsed += k_ucat_cpt.at(k_ucat_enum.at(pat));
-                            regex_expr_collapsed += k_ucat_map.at(k_ucat_enum.at(pat));
-                            if (!inside) {
-                                regex_expr_collapsed += ']';
-                            }
-                            i += 4;
-                            continue;
                         }
                     }
 
index 317d5f19fd949f6e4c014e486432e081d4a52c26..4b9022cb5827c759b468ad22b9304c6a45c373e1 100644 (file)
@@ -27,6 +27,7 @@ add_library(mtmd
             models/qwen3vl.cpp
             models/siglip.cpp
             models/whisper-enc.cpp
+            models/youtuvl.cpp
             )
 
 set_target_properties(mtmd PROPERTIES
index 1ed07418831061e74e911e2db7ebbb350f1099f3..df7e479765b6ca9b2813cedddde282c8c05c16ba 100644 (file)
 #define KEY_SPATIAL_MERGE_SIZE  "clip.vision.spatial_merge_size"
 #define KEY_IS_DEEPSTACK_LAYERS "clip.vision.is_deepstack_layers"
 
-#define KEY_MM_PATCH_MERGE_TYPE   "clip.vision.mm_patch_merge_type"
-#define KEY_IMAGE_GRID_PINPOINTS  "clip.vision.image_grid_pinpoints"
-#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution"
-#define KEY_WIN_ATTN_PATTERN      "clip.vision.n_wa_pattern"
-#define KEY_ATTN_WINDOW_SIZE      "clip.vision.window_size"
-#define KEY_MINICPMV_VERSION      "clip.minicpmv_version"
-#define KEY_MINICPMV_QUERY_NUM    "clip.minicpmv_query_num"
+#define KEY_MM_PATCH_MERGE_TYPE    "clip.vision.mm_patch_merge_type"
+#define KEY_IMAGE_GRID_PINPOINTS   "clip.vision.image_grid_pinpoints"
+#define KEY_IMAGE_CROP_RESOLUTION  "clip.vision.image_crop_resolution"
+#define KEY_WIN_ATTN_PATTERN       "clip.vision.n_wa_pattern"
+#define KEY_WIN_ATTN_LAYER_INDEXES "clip.vision.wa_layer_indexes"
+#define KEY_ATTN_WINDOW_SIZE       "clip.vision.window_size"
+#define KEY_MINICPMV_VERSION       "clip.minicpmv_version"
+#define KEY_MINICPMV_QUERY_NUM     "clip.minicpmv_query_num"
 
 // audio-specific
 #define KEY_AUDIO_PROJ_TYPE     "clip.audio.projector_type" // for models with mixed modalities
@@ -188,6 +189,7 @@ enum projector_type {
     PROJECTOR_TYPE_JANUS_PRO,
     PROJECTOR_TYPE_LFM2A,
     PROJECTOR_TYPE_GLM4V,
+    PROJECTOR_TYPE_YOUTUVL,
     PROJECTOR_TYPE_UNKNOWN,
 };
 
@@ -218,6 +220,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
     { PROJECTOR_TYPE_JANUS_PRO, "janus_pro"},
     { PROJECTOR_TYPE_LFM2A,     "lfm2a"},
     { PROJECTOR_TYPE_GLM4V,     "glm4v"},
+    { PROJECTOR_TYPE_YOUTUVL,   "youtuvl"},
 };
 
 static projector_type clip_projector_type_from_string(const std::string & str) {
index 1e5aa87b986b2d9cf4ae764dc16ac2a2555f280d..702e10151a3729d67098f30b0fc3bb8df8ef95ca 100644 (file)
@@ -61,6 +61,7 @@ struct clip_hparams {
     std::unordered_set<int32_t> vision_feature_layer;
     int32_t attn_window_size = 0;
     int32_t n_wa_pattern = 0;
+    std::unordered_set<int32_t> wa_layer_indexes; // explicit layer indexes that use full attention (for irregular patterns like YoutuVL)
 
     // audio
     int32_t n_mel_bins = 0; // whisper preprocessor
index fb08dd258c1d75965235dec224c451054df341ba..9f551e8f3cd654a0a79d2646ac4879b3b649d4f4 100644 (file)
@@ -846,6 +846,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
             {
                 builder = std::make_unique<clip_graph_glm4v>(ctx, img);
             } break;
+        case PROJECTOR_TYPE_YOUTUVL:
+            {
+                builder = std::make_unique<clip_graph_youtuvl>(ctx, img);
+            } break;
         default:
             GGML_ABORT("missing cgraph builder");
     }
@@ -1159,6 +1163,20 @@ struct clip_model_loader {
                             LOG_WRN("%s: more info: https://github.com/ggml-org/llama.cpp/issues/16842\n\n", __func__);
                         }
                     } break;
+                case PROJECTOR_TYPE_YOUTUVL:
+                    {
+                        hparams.n_merge = 2;
+                        get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.n_merge, false);
+                        get_u32(KEY_ATTN_WINDOW_SIZE, hparams.attn_window_size, true);
+                        std::vector<int> wa_layer_indexes_vec;
+                        get_arr_int(KEY_WIN_ATTN_LAYER_INDEXES, wa_layer_indexes_vec, true);
+                        for (auto & layer : wa_layer_indexes_vec) {
+                            hparams.wa_layer_indexes.insert(layer);
+                        }
+                        // support max_height * max_width = 8000 * 8000. 8000/16/2 = 250 image tokens
+                        hparams.set_limit_image_tokens(1, 62500);
+                        hparams.set_warmup_n_tokens(16*16); // avoid OOM on warmup
+                    } break;
                 case PROJECTOR_TYPE_GLM4V:
                     {
                         hparams.rope_theta = 10000.0f;
@@ -1227,7 +1245,14 @@ struct clip_model_loader {
                 LOG_INF("%s: has_llava_proj:     %d\n", __func__, hparams.has_llava_projector);
                 LOG_INF("%s: minicpmv_version:   %d\n", __func__, hparams.minicpmv_version);
                 LOG_INF("%s: n_merge:            %d\n", __func__, hparams.n_merge);
-                LOG_INF("%s: n_wa_pattern:       %d\n", __func__, hparams.n_wa_pattern);
+                LOG_INF("%s: n_wa_pattern: %d\n", __func__, hparams.n_wa_pattern);
+                if (!hparams.wa_layer_indexes.empty()) {
+                    LOG_INF("%s: wa_layer_indexes:  ", __func__);
+                    for (auto & layer : hparams.wa_layer_indexes) {
+                        LOG_INF("%d ", layer);
+                    }
+                    LOG_INF("\n");
+                }
                 if (hparams.image_min_pixels > 0) {
                     LOG_INF("%s: image_min_pixels:   %d%s\n", __func__, hparams.image_min_pixels, hparams.custom_image_min_tokens > 0 ? " (custom value)" : "");
                 }
@@ -1495,6 +1520,14 @@ struct clip_model_loader {
                     model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
                     model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
                 } break;
+            case PROJECTOR_TYPE_YOUTUVL:
+                {
+                    model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM);        // merger.ln_q (RMS norm)
+                    model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"));  // merger.mlp.0
+                    model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"));
+                    model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));  // merger.mlp.2
+                    model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
+                } break;
             case PROJECTOR_TYPE_GLM4V:
                 {
                     model.projection     = get_tensor(TN_MM_PROJECTOR);
@@ -2697,6 +2730,57 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
                 // res_imgs->data[0] = *res;
                 res_imgs->entries.push_back(std::move(img_f32));
             } break;
+        case PROJECTOR_TYPE_YOUTUVL:
+            {
+                const int patch_size = params.patch_size;  // typically 16
+                const int merge_size = params.n_merge;      // typically 2
+                const int align_size = patch_size * merge_size;  // 32
+
+                const int max_num_patches = params.image_max_pixels > 0 ?
+                    params.image_max_pixels / (patch_size * patch_size) : 256;
+
+                // Linear search for optimal scale to fit within max_num_patches
+                float scale = 1.0f;
+                int target_height = original_size.height;
+                int target_width = original_size.width;
+
+                auto get_scaled_image_size = [align_size](float scale, int size) -> int {
+                    float scaled_size = size * scale;
+                    // Round up to nearest multiple of align_size
+                    int aligned = static_cast<int>(std::ceil(scaled_size / align_size)) * align_size;
+                    // Ensure at least one patch
+                    return std::max(align_size, aligned);
+                };
+
+                // Linear search with 0.02 step size
+                while (scale > 0.0f) {
+                    target_height = get_scaled_image_size(scale, original_size.height);
+                    target_width = get_scaled_image_size(scale, original_size.width);
+
+                    int num_patches_h = target_height / patch_size;
+                    int num_patches_w = target_width / patch_size;
+                    int num_patches = num_patches_h * num_patches_w;
+
+                    if (num_patches > max_num_patches) {
+                        scale -= 0.02f;
+                    } else {
+                        break;
+                    }
+                }
+
+                clip_image_size new_size = {target_width, target_height};
+
+                // Resize the image
+                clip_image_u8 resized;
+                img_tool::resize(*img, resized, new_size, img_tool::RESIZE_ALGO_BILINEAR, false);
+
+                // Normalize to float32
+                clip_image_f32_ptr img_f32(clip_image_f32_init());
+                normalize_image_u8_to_f32(resized, *img_f32, params.image_mean, params.image_std);
+
+                // Add to results
+                res_imgs->entries.push_back(std::move(img_f32));
+            } break;
 
         case PROJECTOR_TYPE_IDEFICS3:
             {
@@ -2929,6 +3013,7 @@ int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 *
         case PROJECTOR_TYPE_QWEN25VL:
         case PROJECTOR_TYPE_QWEN3VL:
         case PROJECTOR_TYPE_GLM4V:
+        case PROJECTOR_TYPE_YOUTUVL:
             return (img->nx / params.patch_size) / 2;
         default:
             break;
@@ -2944,6 +3029,7 @@ int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 *
         case PROJECTOR_TYPE_QWEN25VL:
         case PROJECTOR_TYPE_QWEN3VL:
         case PROJECTOR_TYPE_GLM4V:
+        case PROJECTOR_TYPE_YOUTUVL:
             return (img->ny / params.patch_size) / 2;
         default:
             break;
@@ -3004,6 +3090,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
         case PROJECTOR_TYPE_QWEN25VL:
         case PROJECTOR_TYPE_QWEN3VL:
         case PROJECTOR_TYPE_GLM4V:
+        case PROJECTOR_TYPE_YOUTUVL:
             {
                 // dynamic size (2 conv, so double patch size)
                 int x_patch = img->nx / (params.patch_size * 2);
@@ -3131,7 +3218,6 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
     const int pos_w = image_size_width  / patch_size;
     const int pos_h = image_size_height / patch_size;
 
-    const bool use_window_attn = hparams.n_wa_pattern > 0; // for qwen2.5vl
 
     auto get_inp_tensor = [&gf](const char * name) {
         ggml_tensor * inp = ggml_graph_get_tensor(gf, name);
@@ -3280,9 +3366,11 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
                 set_input_i32("positions", positions);
             } break;
         case PROJECTOR_TYPE_QWEN25VL:
+        case PROJECTOR_TYPE_YOUTUVL:
             {
                 // pw * ph = number of tokens output by ViT after apply patch merger
                 // ipw * ipw = number of vision token been processed inside ViT
+                const bool use_window_attn = ctx->model.proj_type == PROJECTOR_TYPE_QWEN25VL ? hparams.n_wa_pattern > 0 : !hparams.wa_layer_indexes.empty();
                 const int merge_ratio = 2;
                 const int pw  = image_size_width  / patch_size / merge_ratio;
                 const int ph  = image_size_height / patch_size / merge_ratio;
@@ -3293,7 +3381,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
                 std::vector<int> inv_idx(ph * pw);
 
                 if (use_window_attn) {
-                    const int attn_window_size = 112;
+                    const int attn_window_size = hparams.attn_window_size > 0 ? hparams.attn_window_size : 112;
                     const int grid_window = attn_window_size / patch_size / merge_ratio;
                     int dst = 0;
                     // [num_vision_tokens, num_vision_tokens] attention mask tensor
@@ -3531,6 +3619,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
         case PROJECTOR_TYPE_QWEN2VL:
         case PROJECTOR_TYPE_QWEN25VL:
         case PROJECTOR_TYPE_JANUS_PRO:
+        case PROJECTOR_TYPE_YOUTUVL:
             return ctx->model.mm_1_b->ne[0];
         case PROJECTOR_TYPE_QWEN3VL:
             // main path + deepstack paths
index e08c33f353a5944a07f91a87ce19900dfb74e4f5..74e94f60ec0c49fc29906883e4ed47bc5814e402 100644 (file)
@@ -27,6 +27,11 @@ struct clip_graph_qwen3vl : clip_graph {
     ggml_cgraph * build() override;
 };
 
+struct clip_graph_youtuvl : clip_graph {
+    clip_graph_youtuvl(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
+    ggml_cgraph * build() override;
+};
+
 struct clip_graph_minicpmv : clip_graph {
     clip_graph_minicpmv(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
     ggml_cgraph * build() override;
diff --git a/tools/mtmd/models/youtuvl.cpp b/tools/mtmd/models/youtuvl.cpp
new file mode 100644 (file)
index 0000000..ffbf2be
--- /dev/null
@@ -0,0 +1,179 @@
+#include "models.h"
+
+ggml_cgraph * clip_graph_youtuvl::build() {
+    GGML_ASSERT(model.class_embedding == nullptr);
+    const int batch_size       = 1;
+    const bool use_window_attn = !hparams.wa_layer_indexes.empty();
+    const int n_pos            = n_patches;
+    const int num_position_ids = n_pos * 4;
+    const int m = 2;
+    const int Wp = n_patches_x;
+    const int Hp = n_patches_y;
+    const int Hm = Hp / m;
+    const int Wm = Wp / m;
+    norm_type norm_t = NORM_TYPE_NORMAL;
+
+    int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
+
+    ggml_tensor * inp = build_inp_raw();
+
+    // change conv3d to linear
+    // reshape and permute to get patches, permute from (patch_size, m, Wm, patch_size, m, Hm, C) to (C, patch_size, patch_size, m, m, Wm, Hm)
+    {
+        inp = ggml_reshape_4d(
+            ctx0, inp,
+            Wm * m * patch_size, m * patch_size, Hm, 3);
+        inp = ggml_permute(ctx0, inp, 1, 2, 3, 0);
+        inp = ggml_cont_4d(
+            ctx0, inp,
+            m * patch_size * 3, Wm, m * patch_size, Hm);
+
+        inp = ggml_permute(ctx0, inp, 0, 2, 1, 3);
+        inp = ggml_cont_4d(
+            ctx0, inp,
+            m * patch_size * 3, patch_size, m, Hm * Wm);
+
+        inp = ggml_permute(ctx0, inp, 1, 0, 2, 3);
+        inp = ggml_cont_4d(
+            ctx0, inp,
+            patch_size, 3, patch_size, Hm * Wm * m * m);
+
+        inp = ggml_permute(ctx0, inp, 2, 0, 1, 3);
+        inp = ggml_cont_3d(
+            ctx0, inp,
+            3*patch_size* patch_size,  Hm * Wm * m * m, 1);
+    }
+    inp = ggml_mul_mat(ctx0, model.patch_embeddings_0, inp);
+
+    if (model.patch_bias) {
+        inp = ggml_add(ctx0, inp, model.patch_bias);
+    }
+
+    inp = ggml_reshape_2d(ctx0, inp, n_embd, n_patches);
+
+    ggml_tensor * inpL           = inp;
+    ggml_tensor * window_mask    = nullptr;
+    ggml_tensor * window_idx     = nullptr;
+    ggml_tensor * inv_window_idx = nullptr;
+
+    ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids);
+    ggml_set_name(positions, "positions");
+    ggml_set_input(positions);
+
+    // pre-layernorm
+    if (model.pre_ln_w) {
+        inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, norm_t, eps, -1);
+    }
+    if (use_window_attn) {
+        inv_window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos / 4);
+        ggml_set_name(inv_window_idx, "inv_window_idx");
+        ggml_set_input(inv_window_idx);
+        // mask for window attention
+        window_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_pos, n_pos);
+        ggml_set_name(window_mask, "window_mask");
+        ggml_set_input(window_mask);
+
+        // if flash attn is used, we need to pad the mask and cast to f16
+        if (flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) {
+            window_mask = ggml_cast(ctx0, window_mask, GGML_TYPE_F16);
+        }
+
+        // inpL shape: [n_embd, n_patches_x * n_patches_y, batch_size]
+        GGML_ASSERT(batch_size == 1);
+        inpL = ggml_reshape_2d(ctx0, inpL, n_embd * 4, n_patches_x * n_patches_y * batch_size / 4);
+        inpL = ggml_get_rows(ctx0, inpL, inv_window_idx);
+        inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_patches_x * n_patches_y, batch_size);
+    }
+
+    // loop over layers
+    for (int il = 0; il < n_layer; il++) {
+        const auto & layer = model.layers[il];
+        const bool full_attn = use_window_attn ? hparams.wa_layer_indexes.count(il) > 0 : true;
+
+        ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states
+
+        // layernorm1
+        cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, norm_t, eps, il);
+        // self-attention
+        {
+            ggml_tensor * Qcur = ggml_add(ctx0,
+                ggml_mul_mat(ctx0, layer.q_w, cur), layer.q_b);
+            ggml_tensor * Kcur = ggml_add(ctx0,
+                ggml_mul_mat(ctx0, layer.k_w, cur), layer.k_b);
+            ggml_tensor * Vcur = ggml_add(ctx0,
+                ggml_mul_mat(ctx0, layer.v_w, cur), layer.v_b);
+
+            Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_patches);
+            Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_patches);
+            Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_patches);
+
+            Qcur = ggml_rope_multi(
+                ctx0, Qcur, positions, nullptr,
+                d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
+            Kcur = ggml_rope_multi(
+                ctx0, Kcur, positions, nullptr,
+                d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
+
+            ggml_tensor * attn_mask = full_attn ? nullptr : window_mask;
+
+            cur = build_attn(layer.o_w, layer.o_b,
+                Qcur, Kcur, Vcur, attn_mask, kq_scale, il);
+        }
+        // re-add the layer input, e.g., residual
+        cur = ggml_add(ctx0, cur, inpL);
+
+        inpL = cur; // inpL = residual, cur = hidden_states
+
+        // layernorm2
+        cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, norm_t, eps, il);
+
+        // ffn
+        cur = build_ffn(cur,
+            layer.ff_up_w, layer.ff_up_b,
+            nullptr, nullptr,
+            layer.ff_down_w, layer.ff_down_b,
+            hparams.ffn_op, il);
+
+        // residual 2
+        cur = ggml_add(ctx0, inpL, cur);
+
+        inpL = cur;
+    }
+
+    ggml_tensor * embeddings = inpL;
+    if (use_window_attn) {
+        const int spatial_merge_unit = 4;
+        window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos / spatial_merge_unit);
+        ggml_set_name(window_idx, "window_idx");
+        ggml_set_input(window_idx);
+        GGML_ASSERT(batch_size == 1);
+        embeddings = ggml_reshape_2d(ctx0, embeddings, n_embd * spatial_merge_unit, n_patches / spatial_merge_unit);
+        embeddings = ggml_get_rows(ctx0, embeddings, window_idx);
+        embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd, n_patches, batch_size);
+        cb(embeddings, "window_order_restored", -1);
+    }
+
+    // post-layernorm (part of Siglip2VisionTransformer, applied after encoder)
+    if (model.post_ln_w) {
+        embeddings = build_norm(embeddings, model.post_ln_w, model.post_ln_b, norm_t, eps, n_layer);
+    }
+
+    // Now apply merger (VLPatchMerger):
+    // 1. Apply RMS norm (ln_q in VLPatchMerger)
+    embeddings = build_norm(embeddings, model.mm_input_norm_w, nullptr, NORM_TYPE_RMS, 1e-6, -1);
+    cb(embeddings, "merger_normed", -1);
+
+    // 2. First reshape for spatial merge (merge 2x2 patches)
+    embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, n_pos / 4, batch_size);
+    cb(embeddings, "merger_reshaped", -1);
+
+    embeddings = build_ffn(embeddings,
+                    model.mm_0_w, model.mm_0_b,
+                    nullptr, nullptr,
+                    model.mm_1_w, model.mm_1_b,
+                    FFN_GELU,
+                    -1);
+    ggml_build_forward_expand(gf, embeddings);
+
+    return gf;
+}
index b0b5ab42abb59e388a1d2f7f7b08a0bdddd7b057..fca55b76f8cf6c760e6c73df1b3a75e2caf231ac 100644 (file)
@@ -283,7 +283,7 @@ struct mtmd_context {
             // https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md
             img_end = "[IMG_END]";
 
-        } else if (proj == PROJECTOR_TYPE_QWEN2VL || proj == PROJECTOR_TYPE_QWEN25VL || proj == PROJECTOR_TYPE_QWEN3VL) {
+        } else if (proj == PROJECTOR_TYPE_QWEN2VL || proj == PROJECTOR_TYPE_QWEN25VL || proj == PROJECTOR_TYPE_QWEN3VL || proj == PROJECTOR_TYPE_YOUTUVL) {
             // <|vision_start|> ... (image embeddings) ... <|vision_end|>
             img_beg = "<|vision_start|>";
             img_end = "<|vision_end|>";