raise ValueError(f"Unprocessed experts: {experts}")
-@ModelBase.register("LlavaForConditionalGeneration")
+@ModelBase.register(
+ "LlavaForConditionalGeneration", # pixtral
+ "Mistral3ForConditionalGeneration", # mistral small 3.1
+)
class LlavaVisionModel(VisionModel):
img_break_tok_id = -1
if self.hparams["model_type"] == "pixtral":
# layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py
self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5)
- self.img_break_tok_id = 12 # see tokenizer_config.json
+ self.img_break_tok_id = self.get_token_id("[IMG_BREAK]")
+ logger.info(f"Image break token id: {self.img_break_tok_id}")
else:
raise ValueError(f"Unsupported model type: {self.hparams['model_type']}")
+ def get_token_id(self, token: str) -> int:
+ tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
+ with open(tokenizer_config_file, "r", encoding="utf-8") as f:
+ added_tokens_decoder = json.load(f)['added_tokens_decoder']
+ for id_, token_data in added_tokens_decoder.items():
+ if token_data["content"] == token:
+ return int(id_)
+ raise ValueError(f"Token '{token}' not found in tokenizer config.")
+
def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
if hparams["model_type"] == "pixtral":
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.PIXTRAL)
self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
- self.gguf_writer.add_vision_use_silu(True)
+
+ # hidden_act
+ if hparams["hidden_act"] == "silu":
+ self.gguf_writer.add_vision_use_silu(True)
+ elif hparams["hidden_act"] == "gelu":
+ self.gguf_writer.add_vision_use_gelu(True)
+ else:
+ raise ValueError(f"Unsupported hidden_act: {hparams['hidden_act']}")
+
+ # spatial_merge_size
+ if "spatial_merge_size" in self.global_config:
+ self.gguf_writer.add_vision_spatial_merge_size(self.global_config["spatial_merge_size"])
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
std::unordered_set<int32_t> vision_feature_layer;
int32_t attn_window_size = 0;
int32_t n_wa_pattern = 0;
+ int32_t spatial_merge_size = 0;
};
struct clip_layer {
struct ggml_tensor * projection;
// LLaVA projection
+ struct ggml_tensor * mm_input_norm_w = nullptr;
struct ggml_tensor * mm_0_w = nullptr;
struct ggml_tensor * mm_0_b = nullptr;
struct ggml_tensor * mm_2_w = nullptr;
// pixtral
struct ggml_tensor * token_embd_img_break = nullptr;
+ struct ggml_tensor * mm_patch_merger_w = nullptr;
};
struct clip_ctx {
const int d_head = hidden_size / n_head;
const int n_layer = hparams.n_layer;
const float eps = hparams.eps;
+ const int n_merge = hparams.spatial_merge_size;
struct ggml_init_params params = {
/*.mem_size =*/ ctx->buf_compute_meta.size(),
{
ggml_tensor * gate_proj = ggml_mul_mat(ctx0, model.layers[il].ff_gate_w, cur);
ggml_tensor * up_proj = ggml_mul_mat(ctx0, model.layers[il].ff_up_w, cur);
- gate_proj = ggml_silu(ctx0, gate_proj); // pixtral uses silu
+ if (ctx->use_silu) {
+ gate_proj = ggml_silu(ctx0, gate_proj);
+ } else if (ctx->use_gelu) {
+ gate_proj = ggml_gelu(ctx0, gate_proj);
+ } else {
+ GGML_ABORT("Pixtral: Unsupported activation");
+ }
cur = ggml_mul(ctx0, up_proj, gate_proj);
cur = ggml_mul_mat(ctx0, model.layers[il].ff_down_w, cur);
}
embeddings = cur;
}
- // LlavaMultiModalProjector (with GELU activation)
+ // mistral small 3.1 patch merger
+ // ref: https://github.com/huggingface/transformers/blob/7a3e208892c06a5e278144eaf38c8599a42f53e7/src/transformers/models/mistral3/modeling_mistral3.py#L67
+ if (model.mm_patch_merger_w) {
+ GGML_ASSERT(hparams.spatial_merge_size > 0);
+
+ ggml_tensor * cur = embeddings;
+ cur = ggml_mul(ctx0, ggml_rms_norm(ctx0, cur, eps), model.mm_input_norm_w);
+
+ // reshape image tokens to 2D grid
+ cur = ggml_reshape_3d(ctx0, cur, hidden_size, n_patches_x, n_patches_y);
+ cur = ggml_permute(ctx0, cur, 2, 0, 1, 3); // [x, y, hidden_size]
+ cur = ggml_cont(ctx0, cur);
+
+ // torch.nn.functional.unfold is just an im2col under the hood
+ // we just need a dummy kernel to make it work
+ ggml_tensor * kernel = ggml_view_3d(ctx0, cur, n_merge, n_merge, cur->ne[2], 0, 0, 0);
+ cur = ggml_im2col(ctx0, kernel, cur, n_merge, n_merge, 0, 0, 1, 1, true, inp->type);
+
+ // project to hidden_size
+ cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]);
+ cur = ggml_mul_mat(ctx0, model.mm_patch_merger_w, cur);
+ embeddings = cur;
+ }
+
+ // LlavaMultiModalProjector (always using GELU activation)
{
embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings);
- embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
+ if (model.mm_1_b) {
+ embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
+ }
embeddings = ggml_gelu(ctx0, embeddings);
embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
- embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
+ if (model.mm_2_b) {
+ embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
+ }
}
// arrangement of the [IMG_BREAK] token
// and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension
// after the concatenation, we have a tensor with shape [hidden_size, n_patches_per_row + 1, n_rows]
+ const int p_y = n_merge > 0 ? n_patches_y / n_merge : n_patches_y;
+ const int p_x = n_merge > 0 ? n_patches_x / n_merge : n_patches_x;
+ const int p_total = p_x * p_y;
const int n_embd_text = embeddings->ne[0];
- const int n_tokens_output = num_patches + n_patches_y - 1; // one [IMG_BREAK] per row, except the last row
+ const int n_tokens_output = p_total + p_y - 1; // one [IMG_BREAK] per row, except the last row
- ggml_tensor * cur = ggml_reshape_3d(ctx0, embeddings, n_embd_text, n_patches_x, n_patches_y);
- ggml_tensor * tok = ggml_new_tensor_3d(ctx0, embeddings->type, n_embd_text, 1, n_patches_y);
+ ggml_tensor * cur = ggml_reshape_3d(ctx0, embeddings, n_embd_text, p_x, p_y);
+ ggml_tensor * tok = ggml_new_tensor_3d(ctx0, embeddings->type, n_embd_text, 1, p_y);
tok = ggml_scale(ctx0, tok, 0.0); // clear the tensor
tok = ggml_add(ctx0, tok, model.token_embd_img_break);
cur = ggml_concat(ctx0, cur, tok, 1);
case PROJECTOR_TYPE_PIXTRAL:
{
hparams.rope_theta = 10000.0f;
+ get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size, false);
} break;
case PROJECTOR_TYPE_QWEN25VL:
{
case PROJECTOR_TYPE_PIXTRAL:
{
vision_model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
- vision_model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"));
+ vision_model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false);
vision_model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
- vision_model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
+ vision_model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false);
// [IMG_BREAK] token embedding
vision_model.token_embd_img_break = get_tensor(TN_TOK_IMG_BREAK);
+ // for mistral small 3.1
+ vision_model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false);
+ vision_model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false);
} break;
default:
GGML_ASSERT(false && "unknown projector type");
} else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
n_patches /= ctx->vision_model.hparams.proj_scale_factor;
} else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
- int n_patches_x = img->nx / params.patch_size;
- int n_patches_y = img->ny / params.patch_size;
+ int n_merge = ctx->vision_model.hparams.spatial_merge_size;
+ int n_patches_x = img->nx / params.patch_size / (n_merge > 0 ? n_merge : 1);
+ int n_patches_y = img->ny / params.patch_size / (n_merge > 0 ? n_merge : 1);
n_patches = n_patches_y*n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
}
return ctx->vision_model.mm_model_peg_0_b->ne[0];
case PROJECTOR_TYPE_MLP:
case PROJECTOR_TYPE_PIXTRAL:
- return ctx->vision_model.mm_2_b->ne[0];
+ return ctx->vision_model.mm_2_w->ne[1];
case PROJECTOR_TYPE_MLP_NORM:
return ctx->vision_model.mm_3_b->ne[0];
case PROJECTOR_TYPE_MINICPMV:
BLOCK_COUNT = "clip.vision.block_count"
IMAGE_MEAN = "clip.vision.image_mean"
IMAGE_STD = "clip.vision.image_std"
+ SPATIAL_MERGE_SIZE = "clip.vision.spatial_merge_size"
USE_GELU = "clip.use_gelu"
USE_SILU = "clip.use_silu"
V_ENC_FFN_DOWN = auto()
V_PRE_NORM = auto()
V_POST_NORM = auto()
+ V_MM_INP_NORM = auto()
V_MM_INP_PROJ = auto() # gemma3
V_MM_SOFT_EMB_NORM = auto() # gemma3
V_RESMPL_POS_EMBD_K = auto() # minicpmv
V_RESMPL_PROJ = auto() # minicpmv
V_RESMPL_QUERY = auto() # minicpmv
V_TOK_EMBD_IMG_BREAK = auto() # pixtral
+ V_MM_PATCH_MERGER = auto() # mistral small 3.1
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_TENSOR.V_PRE_NORM: "v.pre_ln",
MODEL_TENSOR.V_POST_NORM: "v.post_ln",
MODEL_TENSOR.V_MM_INP_PROJ: "mm.input_projection",
+ MODEL_TENSOR.V_MM_INP_NORM: "mm.input_norm",
MODEL_TENSOR.V_MM_SOFT_EMB_NORM: "mm.soft_emb_norm",
MODEL_TENSOR.V_RESMPL_POS_EMBD_K: "resampler.pos_embd_k",
MODEL_TENSOR.V_RESMPL_ATTN_Q: "resampler.attn.q",
MODEL_TENSOR.V_RESMPL_PROJ: "resampler.proj",
MODEL_TENSOR.V_RESMPL_QUERY: "resampler.query",
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: "v.token_embd.img_break", # pixtral
+ MODEL_TENSOR.V_MM_PATCH_MERGER: "mm.patch_merger", # mistral small 3.1
}
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.V_PRE_NORM,
MODEL_TENSOR.V_POST_NORM,
MODEL_TENSOR.V_MM_INP_PROJ,
+ MODEL_TENSOR.V_MM_INP_NORM,
MODEL_TENSOR.V_MM_SOFT_EMB_NORM,
MODEL_TENSOR.V_RESMPL_POS_EMBD_K,
MODEL_TENSOR.V_RESMPL_ATTN_Q,
MODEL_TENSOR.V_RESMPL_PROJ,
MODEL_TENSOR.V_RESMPL_QUERY,
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK,
+ MODEL_TENSOR.V_MM_PATCH_MERGER,
],
MODEL_ARCH.LLAMA: [
MODEL_TENSOR.TOKEN_EMBD,