if chkhsh == "4a2e2abae11ca2b86d570fc5b44be4d5eb5e72cc8f22dd136a94b37da83ab665":
# ref: https://huggingface.co/KORMo-Team/KORMo-tokenizer
res = "kormo"
+ if chkhsh == "9d70134b369a70e5735009b6de918f7581b5211f7c074d1f89f753aea8248af1":
+ # ref: https://huggingface.co/tencent/Youtu-LLM-2B
+ res = "youtu"
if chkhsh == "16389f0a1f51ee53e562ffd51c371dc508639ab0e4261502071836e50e223e91":
# ref: https://huggingface.co/upstage/Solar-Open-100B
res = "solar-open"
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
"KimiVLForConditionalGeneration",
+ "YoutuForCausalLM",
)
class DeepseekV2Model(TextModel):
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
super().set_gguf_parameters()
hparams = self.hparams
- self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
+ # first_k_dense_replace: number of leading layers using dense FFN instead of MoE
+ # For non-MoE models (like Youtu), set to n_layer to use dense FFN for all layers
+ # For MoE models (like DeepSeek-V2), this is the number of leading non-MoE layers
+ has_moe = hparams.get("n_routed_experts") is not None
+ first_k_dense_replace = hparams.get("first_k_dense_replace")
+ if first_k_dense_replace is None:
+ # Default: if no MoE, all layers are dense; if MoE, none are dense
+ first_k_dense_replace = hparams["num_hidden_layers"] if not has_moe else 0
+ self.gguf_writer.add_leading_dense_block_count(first_k_dense_replace)
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None:
self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"])
self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
self.gguf_writer.add_value_length_mla(hparams["v_head_dim"])
- self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
- self.gguf_writer.add_expert_count(hparams["n_routed_experts"])
- self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
- self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
- self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
+ # MoE parameters (required by C++ code for DEEPSEEK2 arch)
+ # For non-MoE models like Youtu, use intermediate_size as expert_feed_forward_length
+ moe_intermediate_size = self.find_hparam(["moe_intermediate_size", "intermediate_size"], optional=False)
+ self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
+
+ if (n_routed_experts := hparams.get("n_routed_experts")) is not None:
+ self.gguf_writer.add_expert_count(n_routed_experts)
+
+ # expert_shared_count is required by C++ code, default to 0 for non-MoE models
+ n_shared_experts = hparams.get("n_shared_experts", 0)
+ self.gguf_writer.add_expert_shared_count(n_shared_experts)
+
+ # When not set, C++ code will use scale_w = false to skip the no-op scaling
+ if (routed_scaling_factor := hparams.get("routed_scaling_factor")) is not None:
+ self.gguf_writer.add_expert_weights_scale(routed_scaling_factor)
+
+ if (norm_topk_prob := hparams.get("norm_topk_prob")) is not None and norm_topk_prob:
+ self.gguf_writer.add_expert_weights_norm(norm_topk_prob)
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
# skip vision tensors and remove "language_model." for Kimi-VL
if "vision_tower" in name or "multi_modal_projector" in name:
return []
-
+ if name.startswith("siglip2.") or name.startswith("merger."):
+ return []
if name.startswith("language_model."):
name = name.replace("language_model.", "")
+ # skip lm_head.weight if tie_word_embeddings is True
+ if self.hparams.get("tie_word_embeddings", False):
+ if name == "lm_head.weight" or name == "model.lm_head.weight":
+ logger.info("Skipping tied output layer 'lm_head.weight' (will use token_embd.weight)")
+ return []
+
# rename e_score_correction_bias tensors
if name.endswith("e_score_correction_bias"):
name = name.replace("e_score_correction_bias", "e_score_correction.bias")
return []
+@ModelBase.register("YOUTUVLForConditionalGeneration", "YOUTUVLForCausalLM")
+class YOUTUVLVisionModel(MmprojModel):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ assert self.hparams_vision is not None
+ self.hparams_vision["image_size"] = self.hparams_vision.get("image_size", 560)
+
+ def set_gguf_parameters(self):
+ super().set_gguf_parameters()
+
+ self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.YOUTUVL)
+ self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-6))
+
+ # Handle activation function
+ hidden_act = str(self.hparams.get("hidden_act", "gelu_pytorch_tanh")).lower()
+ if hidden_act in ("gelu", "gelu_pytorch_tanh", "gelu_fast", "gelu_new", "gelu_accurate"):
+ self.gguf_writer.add_vision_use_gelu(True)
+ elif hidden_act == "silu":
+ self.gguf_writer.add_vision_use_silu(True)
+ else:
+ raise ValueError(f"Unsupported activation function for YOUTUVL: {hidden_act}")
+
+ self.gguf_writer.add_vision_spatial_merge_size(self.hparams.get("spatial_merge_size", 2))
+
+ window_size = self.hparams.get("window_size")
+ if window_size is not None:
+ self.gguf_writer.add_vision_window_size(window_size)
+ # fullatt_block_indexes contains explicit layer indices that use full attention
+ # e.g., [2, 5, 8, 11] means layers 2, 5, 8, 11 use full attention
+ # All other layers use window attention
+ fullatt_block_indexes = self.hparams.get("fullatt_block_indexes")
+ assert fullatt_block_indexes is not None, "fullatt_block_indexes is required for youtuvl"
+ # Store the explicit layer indices for YoutuVL (irregular pattern approach)
+ self.gguf_writer.add_vision_wa_layer_indexes(layers=fullatt_block_indexes)
+
+ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+ del bid # unused
+
+ # Skip language model tensors
+ skip_prefixes = ('lm_head.', 'model.layers.', 'model.embed_tokens.', 'model.norm.')
+ if name.startswith(skip_prefixes):
+ return []
+
+ # Try to map the tensor using TensorNameMap (handles vision encoder and projector)
+ try:
+ new_name = self.map_tensor_name(name)
+ return [(new_name, data_torch)]
+ except ValueError:
+ # If mapping fails, log warning and skip
+ logger.warning(f"Cannot map tensor: {name}")
+ return []
+
+
@ModelBase.register("SolarOpenForCausalLM")
class SolarOpenModel(Glm4MoeModel):
model_arch = gguf.MODEL_ARCH.GLM4_MOE
{"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", },
{"name": "minimax-m2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/MiniMaxAI/MiniMax-M2", },
{"name": "kormo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/KORMo-Team/KORMo-tokenizer", },
+ {"name": "youtu", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Youtu-LLM-2B", },
{"name": "solar-open", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/upstage/Solar-Open-100B", },
]
USE_GELU = "clip.use_gelu"
USE_SILU = "clip.use_silu"
N_WA_PATTERN = "clip.vision.n_wa_pattern" # used by qwen2.5vl
+ WA_LAYER_INDEXES = "clip.vision.wa_layer_indexes" # used by youtuvl
IS_DEEPSTACK_LAYERS = "clip.vision.is_deepstack_layers"
+ WINDOW_SIZE = "clip.vision.window_size"
class Attention:
HEAD_COUNT = "clip.vision.attention.head_count"
LFM2A = "lfm2a" # audio
MUSIC_FLAMINGO = "musicflamingo" # audio
GLM4V = "glm4v"
+ YOUTUVL = "youtuvl"
# Items here are (block size, type size)
self.add_uint32(Keys.ClipVision.Projector.SCALE_FACTOR, value)
def add_vision_n_wa_pattern(self, value: int) -> None:
+ """Add window attention pattern interval for vision models.
+
+ This defines the pattern interval for window attention vs full attention layers.
+ For example, if n_wa_pattern=4, then layers 3, 7, 11, ... use full attention,
+ while other layers use window attention.
+
+ Used by models like Qwen2.5-VL where full attention layers follow a regular pattern.
+ """
self.add_uint32(Keys.ClipVision.N_WA_PATTERN, value)
+ def add_vision_wa_layer_indexes(self, layers: Sequence[int]) -> None:
+ """Add explicit layer indexes that use full attention in vision models.
+
+ This specifies the exact layer indices (0-based) that should use full attention
+ instead of window attention. All other layers will use window attention.
+
+ Args:
+ layers: List of layer indices that use full attention (e.g., [3, 7, 11, 15])
+
+ Used by models like YoutuVL where full attention layers are explicitly specified
+ rather than following a regular pattern.
+
+ Difference from add_vision_n_wa_pattern:
+ - n_wa_pattern: Defines a regular interval pattern (every Nth layer uses full attention)
+ - wa_layer_indexes: Explicitly lists which layers use full attention (irregular pattern)
+ """
+ self.add_array(Keys.ClipVision.WA_LAYER_INDEXES, layers)
+
def add_vision_is_deepstack_layers(self, layers: Sequence[bool]) -> None:
self.add_array(Keys.ClipVision.IS_DEEPSTACK_LAYERS, layers)
+ def add_vision_window_size(self, value: int) -> None:
+ self.add_uint32(Keys.ClipVision.WINDOW_SIZE, value)
+
# audio models
def add_audio_projection_dim(self, value: int) -> None:
MODEL_TENSOR.V_MMPROJ: (
"multi_modal_projector.linear_{bid}",
"visual.merger.mlp.{bid}", # qwen2vl
+ "merger.mlp.{bid}",
),
MODEL_TENSOR.V_MMPROJ_FC: (
"visual.patch_embed.proj", # qwen2vl
"vision_tower.patch_embed.proj", # kimi-vl
"model.vision.patch_embedding.proj", # cogvlm
+ "siglip2.vision_model.embeddings.patch_embedding",
),
MODEL_TENSOR.V_ENC_EMBD_NORM: (
"vision_encoder.transformer.layers.{bid}.attention.wq", # pixtral
"visual.blocks.{bid}.attn.q", # qwen2vl, generated
"vision_tower.encoder.blocks.{bid}.wq", # kimi-vl, generated
+ "siglip2.vision_model.encoder.layers.{bid}.self_attn.q_proj", # youtuvl
),
MODEL_TENSOR.V_ENC_ATTN_Q_NORM: (
"vision_encoder.transformer.layers.{bid}.attention.wk", # pixtral
"visual.blocks.{bid}.attn.k", # qwen2vl, generated
"vision_tower.encoder.blocks.{bid}.wk", # kimi-vl, generated
+ "siglip2.vision_model.encoder.layers.{bid}.self_attn.k_proj",
),
MODEL_TENSOR.V_ENC_ATTN_K_NORM: (
"vision_encoder.transformer.layers.{bid}.attention.wv", # pixtral
"visual.blocks.{bid}.attn.v", # qwen2vl, generated
"vision_tower.encoder.blocks.{bid}.wv", # kimi-vl, generated
+ "siglip2.vision_model.encoder.layers.{bid}.self_attn.v_proj",
),
MODEL_TENSOR.V_ENC_INPUT_NORM: (
"visual.blocks.{bid}.norm1", # qwen2vl
"vision_tower.encoder.blocks.{bid}.norm0", # kimi-vl (norm0/norm1)
"model.vision.transformer.layers.{bid}.input_layernorm", # cogvlm
+ "siglip2.vision_model.encoder.layers.{bid}.layer_norm1",
),
MODEL_TENSOR.V_ENC_ATTN_O: (
"visual.blocks.{bid}.attn.proj", # qwen2vl
"vision_tower.encoder.blocks.{bid}.wo", # kimi-vl
"model.vision.transformer.layers.{bid}.attention.dense", # cogvlm
+ "siglip2.vision_model.encoder.layers.{bid}.self_attn.out_proj", # youtuvl
),
MODEL_TENSOR.V_ENC_POST_ATTN_NORM: (
"visual.blocks.{bid}.norm2", # qwen2vl
"vision_tower.encoder.blocks.{bid}.norm1", # kimi-vl (norm0/norm1)
"model.vision.transformer.layers.{bid}.post_attention_layernorm", # cogvlm
+ "siglip2.vision_model.encoder.layers.{bid}.layer_norm2",
),
MODEL_TENSOR.V_ENC_FFN_UP: (
"visual.blocks.{bid}.mlp.linear_fc1", # qwen3vl
"vision_tower.encoder.blocks.{bid}.mlp.fc0", # kimi-vl (fc0/fc1)
"model.vision.transformer.layers.{bid}.mlp.fc1", # cogvlm
+ "siglip2.vision_model.encoder.layers.{bid}.mlp.fc1",
),
MODEL_TENSOR.V_ENC_FFN_GATE: (
"visual.blocks.{bid}.mlp.linear_fc2", # qwen3vl
"vision_tower.encoder.blocks.{bid}.mlp.fc1", # kimi-vl (fc0/fc1)
"model.vision.transformer.layers.{bid}.mlp.fc2", # cogvlm
+ "siglip2.vision_model.encoder.layers.{bid}.mlp.fc2",
),
MODEL_TENSOR.V_LAYER_SCALE_1: (
"visual.merger.ln_q", # qwen2vl
"vision_tower.encoder.final_layernorm", # kimi-vl
"visual.post_layernorm", # glm4v
+ "siglip2.vision_model.post_layernorm",
),
MODEL_TENSOR.V_MM_POST_NORM: (
"multi_modal_projector.pre_norm",
"pre_mm_projector_norm",
"model.vision.linear_proj.norm1", # cogvlm
+ "merger.ln_q",
),
MODEL_TENSOR.V_MM_SOFT_EMB_NORM: (
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla, false);
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
- ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale);
+ ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false);
ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false);
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false);
if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) {
// output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
- output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
+ // try to load output.weight, if not found, use token_embd (tied embeddings)
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+ if (!output) {
+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+ }
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
// output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
- output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
+ // try to load output.weight, if not found, use token_embd (tied embeddings)
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+ if (!output) {
+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+ }
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
};
break;
+ case LLAMA_VOCAB_PRE_TYPE_YOUTU:
+ regex_exprs = {
+ "[가-힣ㄱ-ㆎ]+|[!…“”‘’—:;,、-〿︰-﹏]+|[ㄅ-ㄯ]+|[一-龥-ゟ゠-ヿ]+",
+ "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
+ };
+ break;
case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER:
regex_exprs = {
"[\r\n]",
tokenizer_pre == "deepseek-v3") {
pre_type = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM;
clean_spaces = false;
+ } else if (
+ tokenizer_pre == "youtu") {
+ pre_type = LLAMA_VOCAB_PRE_TYPE_YOUTU;
+ clean_spaces = false;
+ ignore_merges = true;
} else if (
tokenizer_pre == "falcon") {
pre_type = LLAMA_VOCAB_PRE_TYPE_FALCON;
LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2 = 41,
LLAMA_VOCAB_PRE_TYPE_AFMOE = 42,
LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN = 43,
+ LLAMA_VOCAB_PRE_TYPE_YOUTU = 44,
};
struct LLM_KV;
model.layers[il].ffn_exp_probs_b,
n_expert, n_expert_used,
LLM_FFN_SILU, hparams.expert_weights_norm,
- true, hparams.expert_weights_scale,
+ hparams.expert_weights_scale, hparams.expert_weights_scale,
(llama_expert_gating_func_type) hparams.expert_gating_func,
il);
cb(moe_out, "ffn_moe_out", il);
{ "\\p{P}", unicode_cpt_flags::PUNCTUATION },
{ "\\p{M}", unicode_cpt_flags::ACCENT_MARK },
{ "\\p{S}", unicode_cpt_flags::SYMBOL },
+ { "\\p{Lu}", unicode_cpt_flags::LETTER }, // Uppercase letter
+ { "\\p{Ll}", unicode_cpt_flags::LETTER }, // Lowercase letter
+ { "\\p{Lt}", unicode_cpt_flags::LETTER }, // Titlecase letter
+ { "\\p{Lm}", unicode_cpt_flags::LETTER }, // Modifier letter
+ { "\\p{Lo}", unicode_cpt_flags::LETTER }, // Other letter
};
static const std::map<int, int> k_ucat_cpt = {
continue;
}
- if (regex_expr[i + 0] == '\\' && i + 4 < regex_expr.size() &&
+ // Match \p{...} Unicode properties of varying lengths
+ if (regex_expr[i + 0] == '\\' && i + 3 < regex_expr.size() &&
regex_expr[i + 1] == 'p' &&
- regex_expr[i + 2] == '{' &&
- regex_expr[i + 4] == '}') {
- const std::string pat = regex_expr.substr(i, 5);
- if (k_ucat_enum.find(pat) != k_ucat_enum.end()) {
- if (!inside) {
- regex_expr_collapsed += '[';
+ regex_expr[i + 2] == '{') {
+ // Find the closing brace
+ size_t closing_brace = regex_expr.find('}', i + 3);
+ if (closing_brace != std::string::npos && closing_brace <= i + 10) { // reasonable limit
+ const std::string pat = regex_expr.substr(i, closing_brace - i + 1);
+ if (k_ucat_enum.find(pat) != k_ucat_enum.end()) {
+ if (!inside) {
+ regex_expr_collapsed += '[';
+ }
+ regex_expr_collapsed += k_ucat_cpt.at(k_ucat_enum.at(pat));
+ regex_expr_collapsed += k_ucat_map.at(k_ucat_enum.at(pat));
+ if (!inside) {
+ regex_expr_collapsed += ']';
+ }
+ i = closing_brace;
+ continue;
}
- regex_expr_collapsed += k_ucat_cpt.at(k_ucat_enum.at(pat));
- regex_expr_collapsed += k_ucat_map.at(k_ucat_enum.at(pat));
- if (!inside) {
- regex_expr_collapsed += ']';
- }
- i += 4;
- continue;
}
}
models/qwen3vl.cpp
models/siglip.cpp
models/whisper-enc.cpp
+ models/youtuvl.cpp
)
set_target_properties(mtmd PROPERTIES
#define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size"
#define KEY_IS_DEEPSTACK_LAYERS "clip.vision.is_deepstack_layers"
-#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
-#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
-#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution"
-#define KEY_WIN_ATTN_PATTERN "clip.vision.n_wa_pattern"
-#define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size"
-#define KEY_MINICPMV_VERSION "clip.minicpmv_version"
-#define KEY_MINICPMV_QUERY_NUM "clip.minicpmv_query_num"
+#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
+#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
+#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution"
+#define KEY_WIN_ATTN_PATTERN "clip.vision.n_wa_pattern"
+#define KEY_WIN_ATTN_LAYER_INDEXES "clip.vision.wa_layer_indexes"
+#define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size"
+#define KEY_MINICPMV_VERSION "clip.minicpmv_version"
+#define KEY_MINICPMV_QUERY_NUM "clip.minicpmv_query_num"
// audio-specific
#define KEY_AUDIO_PROJ_TYPE "clip.audio.projector_type" // for models with mixed modalities
PROJECTOR_TYPE_JANUS_PRO,
PROJECTOR_TYPE_LFM2A,
PROJECTOR_TYPE_GLM4V,
+ PROJECTOR_TYPE_YOUTUVL,
PROJECTOR_TYPE_UNKNOWN,
};
{ PROJECTOR_TYPE_JANUS_PRO, "janus_pro"},
{ PROJECTOR_TYPE_LFM2A, "lfm2a"},
{ PROJECTOR_TYPE_GLM4V, "glm4v"},
+ { PROJECTOR_TYPE_YOUTUVL, "youtuvl"},
};
static projector_type clip_projector_type_from_string(const std::string & str) {
std::unordered_set<int32_t> vision_feature_layer;
int32_t attn_window_size = 0;
int32_t n_wa_pattern = 0;
+ std::unordered_set<int32_t> wa_layer_indexes; // explicit layer indexes that use full attention (for irregular patterns like YoutuVL)
// audio
int32_t n_mel_bins = 0; // whisper preprocessor
{
builder = std::make_unique<clip_graph_glm4v>(ctx, img);
} break;
+ case PROJECTOR_TYPE_YOUTUVL:
+ {
+ builder = std::make_unique<clip_graph_youtuvl>(ctx, img);
+ } break;
default:
GGML_ABORT("missing cgraph builder");
}
LOG_WRN("%s: more info: https://github.com/ggml-org/llama.cpp/issues/16842\n\n", __func__);
}
} break;
+ case PROJECTOR_TYPE_YOUTUVL:
+ {
+ hparams.n_merge = 2;
+ get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.n_merge, false);
+ get_u32(KEY_ATTN_WINDOW_SIZE, hparams.attn_window_size, true);
+ std::vector<int> wa_layer_indexes_vec;
+ get_arr_int(KEY_WIN_ATTN_LAYER_INDEXES, wa_layer_indexes_vec, true);
+ for (auto & layer : wa_layer_indexes_vec) {
+ hparams.wa_layer_indexes.insert(layer);
+ }
+ // support max_height * max_width = 8000 * 8000. 8000/16/2 = 250 image tokens
+ hparams.set_limit_image_tokens(1, 62500);
+ hparams.set_warmup_n_tokens(16*16); // avoid OOM on warmup
+ } break;
case PROJECTOR_TYPE_GLM4V:
{
hparams.rope_theta = 10000.0f;
LOG_INF("%s: has_llava_proj: %d\n", __func__, hparams.has_llava_projector);
LOG_INF("%s: minicpmv_version: %d\n", __func__, hparams.minicpmv_version);
LOG_INF("%s: n_merge: %d\n", __func__, hparams.n_merge);
- LOG_INF("%s: n_wa_pattern: %d\n", __func__, hparams.n_wa_pattern);
+ LOG_INF("%s: n_wa_pattern: %d\n", __func__, hparams.n_wa_pattern);
+ if (!hparams.wa_layer_indexes.empty()) {
+ LOG_INF("%s: wa_layer_indexes: ", __func__);
+ for (auto & layer : hparams.wa_layer_indexes) {
+ LOG_INF("%d ", layer);
+ }
+ LOG_INF("\n");
+ }
if (hparams.image_min_pixels > 0) {
LOG_INF("%s: image_min_pixels: %d%s\n", __func__, hparams.image_min_pixels, hparams.custom_image_min_tokens > 0 ? " (custom value)" : "");
}
model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
} break;
+ case PROJECTOR_TYPE_YOUTUVL:
+ {
+ model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM); // merger.ln_q (RMS norm)
+ model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight")); // merger.mlp.0
+ model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"));
+ model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); // merger.mlp.2
+ model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
+ } break;
case PROJECTOR_TYPE_GLM4V:
{
model.projection = get_tensor(TN_MM_PROJECTOR);
// res_imgs->data[0] = *res;
res_imgs->entries.push_back(std::move(img_f32));
} break;
+ case PROJECTOR_TYPE_YOUTUVL:
+ {
+ const int patch_size = params.patch_size; // typically 16
+ const int merge_size = params.n_merge; // typically 2
+ const int align_size = patch_size * merge_size; // 32
+
+ const int max_num_patches = params.image_max_pixels > 0 ?
+ params.image_max_pixels / (patch_size * patch_size) : 256;
+
+ // Linear search for optimal scale to fit within max_num_patches
+ float scale = 1.0f;
+ int target_height = original_size.height;
+ int target_width = original_size.width;
+
+ auto get_scaled_image_size = [align_size](float scale, int size) -> int {
+ float scaled_size = size * scale;
+ // Round up to nearest multiple of align_size
+ int aligned = static_cast<int>(std::ceil(scaled_size / align_size)) * align_size;
+ // Ensure at least one patch
+ return std::max(align_size, aligned);
+ };
+
+ // Linear search with 0.02 step size
+ while (scale > 0.0f) {
+ target_height = get_scaled_image_size(scale, original_size.height);
+ target_width = get_scaled_image_size(scale, original_size.width);
+
+ int num_patches_h = target_height / patch_size;
+ int num_patches_w = target_width / patch_size;
+ int num_patches = num_patches_h * num_patches_w;
+
+ if (num_patches > max_num_patches) {
+ scale -= 0.02f;
+ } else {
+ break;
+ }
+ }
+
+ clip_image_size new_size = {target_width, target_height};
+
+ // Resize the image
+ clip_image_u8 resized;
+ img_tool::resize(*img, resized, new_size, img_tool::RESIZE_ALGO_BILINEAR, false);
+
+ // Normalize to float32
+ clip_image_f32_ptr img_f32(clip_image_f32_init());
+ normalize_image_u8_to_f32(resized, *img_f32, params.image_mean, params.image_std);
+
+ // Add to results
+ res_imgs->entries.push_back(std::move(img_f32));
+ } break;
case PROJECTOR_TYPE_IDEFICS3:
{
case PROJECTOR_TYPE_QWEN25VL:
case PROJECTOR_TYPE_QWEN3VL:
case PROJECTOR_TYPE_GLM4V:
+ case PROJECTOR_TYPE_YOUTUVL:
return (img->nx / params.patch_size) / 2;
default:
break;
case PROJECTOR_TYPE_QWEN25VL:
case PROJECTOR_TYPE_QWEN3VL:
case PROJECTOR_TYPE_GLM4V:
+ case PROJECTOR_TYPE_YOUTUVL:
return (img->ny / params.patch_size) / 2;
default:
break;
case PROJECTOR_TYPE_QWEN25VL:
case PROJECTOR_TYPE_QWEN3VL:
case PROJECTOR_TYPE_GLM4V:
+ case PROJECTOR_TYPE_YOUTUVL:
{
// dynamic size (2 conv, so double patch size)
int x_patch = img->nx / (params.patch_size * 2);
const int pos_w = image_size_width / patch_size;
const int pos_h = image_size_height / patch_size;
- const bool use_window_attn = hparams.n_wa_pattern > 0; // for qwen2.5vl
auto get_inp_tensor = [&gf](const char * name) {
ggml_tensor * inp = ggml_graph_get_tensor(gf, name);
set_input_i32("positions", positions);
} break;
case PROJECTOR_TYPE_QWEN25VL:
+ case PROJECTOR_TYPE_YOUTUVL:
{
// pw * ph = number of tokens output by ViT after apply patch merger
// ipw * ipw = number of vision token been processed inside ViT
+ const bool use_window_attn = ctx->model.proj_type == PROJECTOR_TYPE_QWEN25VL ? hparams.n_wa_pattern > 0 : !hparams.wa_layer_indexes.empty();
const int merge_ratio = 2;
const int pw = image_size_width / patch_size / merge_ratio;
const int ph = image_size_height / patch_size / merge_ratio;
std::vector<int> inv_idx(ph * pw);
if (use_window_attn) {
- const int attn_window_size = 112;
+ const int attn_window_size = hparams.attn_window_size > 0 ? hparams.attn_window_size : 112;
const int grid_window = attn_window_size / patch_size / merge_ratio;
int dst = 0;
// [num_vision_tokens, num_vision_tokens] attention mask tensor
case PROJECTOR_TYPE_QWEN2VL:
case PROJECTOR_TYPE_QWEN25VL:
case PROJECTOR_TYPE_JANUS_PRO:
+ case PROJECTOR_TYPE_YOUTUVL:
return ctx->model.mm_1_b->ne[0];
case PROJECTOR_TYPE_QWEN3VL:
// main path + deepstack paths
ggml_cgraph * build() override;
};
+struct clip_graph_youtuvl : clip_graph {
+ clip_graph_youtuvl(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
+ ggml_cgraph * build() override;
+};
+
struct clip_graph_minicpmv : clip_graph {
clip_graph_minicpmv(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
ggml_cgraph * build() override;
--- /dev/null
+#include "models.h"
+
+ggml_cgraph * clip_graph_youtuvl::build() {
+ GGML_ASSERT(model.class_embedding == nullptr);
+ const int batch_size = 1;
+ const bool use_window_attn = !hparams.wa_layer_indexes.empty();
+ const int n_pos = n_patches;
+ const int num_position_ids = n_pos * 4;
+ const int m = 2;
+ const int Wp = n_patches_x;
+ const int Hp = n_patches_y;
+ const int Hm = Hp / m;
+ const int Wm = Wp / m;
+ norm_type norm_t = NORM_TYPE_NORMAL;
+
+ int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
+
+ ggml_tensor * inp = build_inp_raw();
+
+ // change conv3d to linear
+ // reshape and permute to get patches, permute from (patch_size, m, Wm, patch_size, m, Hm, C) to (C, patch_size, patch_size, m, m, Wm, Hm)
+ {
+ inp = ggml_reshape_4d(
+ ctx0, inp,
+ Wm * m * patch_size, m * patch_size, Hm, 3);
+ inp = ggml_permute(ctx0, inp, 1, 2, 3, 0);
+ inp = ggml_cont_4d(
+ ctx0, inp,
+ m * patch_size * 3, Wm, m * patch_size, Hm);
+
+ inp = ggml_permute(ctx0, inp, 0, 2, 1, 3);
+ inp = ggml_cont_4d(
+ ctx0, inp,
+ m * patch_size * 3, patch_size, m, Hm * Wm);
+
+ inp = ggml_permute(ctx0, inp, 1, 0, 2, 3);
+ inp = ggml_cont_4d(
+ ctx0, inp,
+ patch_size, 3, patch_size, Hm * Wm * m * m);
+
+ inp = ggml_permute(ctx0, inp, 2, 0, 1, 3);
+ inp = ggml_cont_3d(
+ ctx0, inp,
+ 3*patch_size* patch_size, Hm * Wm * m * m, 1);
+ }
+ inp = ggml_mul_mat(ctx0, model.patch_embeddings_0, inp);
+
+ if (model.patch_bias) {
+ inp = ggml_add(ctx0, inp, model.patch_bias);
+ }
+
+ inp = ggml_reshape_2d(ctx0, inp, n_embd, n_patches);
+
+ ggml_tensor * inpL = inp;
+ ggml_tensor * window_mask = nullptr;
+ ggml_tensor * window_idx = nullptr;
+ ggml_tensor * inv_window_idx = nullptr;
+
+ ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids);
+ ggml_set_name(positions, "positions");
+ ggml_set_input(positions);
+
+ // pre-layernorm
+ if (model.pre_ln_w) {
+ inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, norm_t, eps, -1);
+ }
+ if (use_window_attn) {
+ inv_window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos / 4);
+ ggml_set_name(inv_window_idx, "inv_window_idx");
+ ggml_set_input(inv_window_idx);
+ // mask for window attention
+ window_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_pos, n_pos);
+ ggml_set_name(window_mask, "window_mask");
+ ggml_set_input(window_mask);
+
+ // if flash attn is used, we need to pad the mask and cast to f16
+ if (flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) {
+ window_mask = ggml_cast(ctx0, window_mask, GGML_TYPE_F16);
+ }
+
+ // inpL shape: [n_embd, n_patches_x * n_patches_y, batch_size]
+ GGML_ASSERT(batch_size == 1);
+ inpL = ggml_reshape_2d(ctx0, inpL, n_embd * 4, n_patches_x * n_patches_y * batch_size / 4);
+ inpL = ggml_get_rows(ctx0, inpL, inv_window_idx);
+ inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_patches_x * n_patches_y, batch_size);
+ }
+
+ // loop over layers
+ for (int il = 0; il < n_layer; il++) {
+ const auto & layer = model.layers[il];
+ const bool full_attn = use_window_attn ? hparams.wa_layer_indexes.count(il) > 0 : true;
+
+ ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states
+
+ // layernorm1
+ cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, norm_t, eps, il);
+ // self-attention
+ {
+ ggml_tensor * Qcur = ggml_add(ctx0,
+ ggml_mul_mat(ctx0, layer.q_w, cur), layer.q_b);
+ ggml_tensor * Kcur = ggml_add(ctx0,
+ ggml_mul_mat(ctx0, layer.k_w, cur), layer.k_b);
+ ggml_tensor * Vcur = ggml_add(ctx0,
+ ggml_mul_mat(ctx0, layer.v_w, cur), layer.v_b);
+
+ Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_patches);
+ Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_patches);
+ Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_patches);
+
+ Qcur = ggml_rope_multi(
+ ctx0, Qcur, positions, nullptr,
+ d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
+ Kcur = ggml_rope_multi(
+ ctx0, Kcur, positions, nullptr,
+ d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
+
+ ggml_tensor * attn_mask = full_attn ? nullptr : window_mask;
+
+ cur = build_attn(layer.o_w, layer.o_b,
+ Qcur, Kcur, Vcur, attn_mask, kq_scale, il);
+ }
+ // re-add the layer input, e.g., residual
+ cur = ggml_add(ctx0, cur, inpL);
+
+ inpL = cur; // inpL = residual, cur = hidden_states
+
+ // layernorm2
+ cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, norm_t, eps, il);
+
+ // ffn
+ cur = build_ffn(cur,
+ layer.ff_up_w, layer.ff_up_b,
+ nullptr, nullptr,
+ layer.ff_down_w, layer.ff_down_b,
+ hparams.ffn_op, il);
+
+ // residual 2
+ cur = ggml_add(ctx0, inpL, cur);
+
+ inpL = cur;
+ }
+
+ ggml_tensor * embeddings = inpL;
+ if (use_window_attn) {
+ const int spatial_merge_unit = 4;
+ window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos / spatial_merge_unit);
+ ggml_set_name(window_idx, "window_idx");
+ ggml_set_input(window_idx);
+ GGML_ASSERT(batch_size == 1);
+ embeddings = ggml_reshape_2d(ctx0, embeddings, n_embd * spatial_merge_unit, n_patches / spatial_merge_unit);
+ embeddings = ggml_get_rows(ctx0, embeddings, window_idx);
+ embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd, n_patches, batch_size);
+ cb(embeddings, "window_order_restored", -1);
+ }
+
+ // post-layernorm (part of Siglip2VisionTransformer, applied after encoder)
+ if (model.post_ln_w) {
+ embeddings = build_norm(embeddings, model.post_ln_w, model.post_ln_b, norm_t, eps, n_layer);
+ }
+
+ // Now apply merger (VLPatchMerger):
+ // 1. Apply RMS norm (ln_q in VLPatchMerger)
+ embeddings = build_norm(embeddings, model.mm_input_norm_w, nullptr, NORM_TYPE_RMS, 1e-6, -1);
+ cb(embeddings, "merger_normed", -1);
+
+ // 2. First reshape for spatial merge (merge 2x2 patches)
+ embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, n_pos / 4, batch_size);
+ cb(embeddings, "merger_reshaped", -1);
+
+ embeddings = build_ffn(embeddings,
+ model.mm_0_w, model.mm_0_b,
+ nullptr, nullptr,
+ model.mm_1_w, model.mm_1_b,
+ FFN_GELU,
+ -1);
+ ggml_build_forward_expand(gf, embeddings);
+
+ return gf;
+}
// https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md
img_end = "[IMG_END]";
- } else if (proj == PROJECTOR_TYPE_QWEN2VL || proj == PROJECTOR_TYPE_QWEN25VL || proj == PROJECTOR_TYPE_QWEN3VL) {
+ } else if (proj == PROJECTOR_TYPE_QWEN2VL || proj == PROJECTOR_TYPE_QWEN25VL || proj == PROJECTOR_TYPE_QWEN3VL || proj == PROJECTOR_TYPE_YOUTUVL) {
// <|vision_start|> ... (image embeddings) ... <|vision_end|>
img_beg = "<|vision_start|>";
img_end = "<|vision_end|>";