self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling.get("original_max_position_embeddings", 4096))
-@ModelBase.register("Lfm2ForCausalLM")
-@ModelBase.register("LFM2ForCausalLM")
+@ModelBase.register("Lfm2ForCausalLM", "LFM2ForCausalLM")
class LFM2Model(TextModel):
model_arch = gguf.MODEL_ARCH.LFM2
self._add_feed_forward_length()
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+ is_vision_tensor = "vision_tower" in name or "multi_modal_projector" in name
+ if is_vision_tensor:
+ # skip vision tensors
+ return []
+
+ name = name.replace("language_model.", "")
+
# conv op requires 2d tensor
if 'conv.conv' in name:
data_torch = data_torch.squeeze(1)
return [(self.map_tensor_name(name), data_torch)]
+@ModelBase.register("Lfm2VlForConditionalGeneration")
+class LFM2VLModel(MmprojModel):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ assert self.hparams_vision is not None
+ # TODO(tarek): for dynamic resolution image_size is not specified, setting here for compatibility
+ self.hparams_vision["image_size"] = 256
+
+ def set_gguf_parameters(self):
+ super().set_gguf_parameters()
+ self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.LFM2)
+ self.gguf_writer.add_vision_attention_layernorm_eps(self.find_vparam(["layer_norm_eps"]))
+ self.gguf_writer.add_vision_projector_scale_factor(self.global_config.get("downsample_factor", 2))
+ self.gguf_writer.add_vision_use_gelu(True)
+ # python notation, e.g. for vision_feature_layer == -1, we pick last layer -> vision_feature_layers_to_drop = 0
+ vision_feature_layers_to_drop = -(self.global_config.get("vision_feature_layer", -1) + 1)
+ self.gguf_writer.add_vision_block_count(self.find_vparam(self.n_block_keys) - vision_feature_layers_to_drop)
+
+ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+ del bid # unused
+ is_vision_tensor = "vision_tower" in name or "multi_modal_projector" in name
+
+ if is_vision_tensor:
+ # remove "model." prefix
+ name = name.replace("model.vision_tower.", "vision_tower.")
+ name = name.replace("model.multi_modal_projector.", "multi_modal_projector.")
+
+ if "patch_embedding.weight" in name:
+ data_torch = data_torch.view(data_torch.shape[0], 16, 16, 3).permute(0, 3, 1, 2)
+
+ return [(self.map_tensor_name(name), data_torch)]
+
+ return [] # skip other tensors
+
+
@ModelBase.register("SmallThinkerForCausalLM")
class SmallThinkerModel(TextModel):
model_arch = gguf.MODEL_ARCH.SMALLTHINKER
// LLaVA projection
ggml_tensor * mm_input_norm_w = nullptr;
+ ggml_tensor * mm_input_norm_b = nullptr;
ggml_tensor * mm_0_w = nullptr;
ggml_tensor * mm_0_b = nullptr;
ggml_tensor * mm_2_w = nullptr;
ggml_cgraph * build_siglip() {
ggml_tensor * inp = build_inp();
+
+ ggml_tensor * learned_pos_embd = model.position_embeddings;
+ if (ctx->proj_type() == PROJECTOR_TYPE_LFM2) {
+ learned_pos_embd = resize_position_embeddings();
+ }
+
ggml_tensor * cur = build_vit(
inp, n_patches,
NORM_TYPE_NORMAL,
hparams.ffn_op,
- model.position_embeddings,
+ learned_pos_embd,
nullptr);
if (ctx->proj_type() == PROJECTOR_TYPE_GEMMA3) {
bsz);
cur = ggml_mul_mat(ctx0, model.projection, cur);
+ } else if (ctx->proj_type() == PROJECTOR_TYPE_LFM2) {
+ // pixel unshuffle block
+ const int scale_factor = model.hparams.proj_scale_factor;
+ GGML_ASSERT(scale_factor > 1);
+
+ const int n_embd = cur->ne[0];
+ int width = img.nx / patch_size;
+ int height = img.ny / patch_size;
+
+ // pad width and height to factor
+ const int64_t pad_width = CLIP_ALIGN(width, scale_factor) - width;
+ const int64_t pad_height = CLIP_ALIGN(height, scale_factor) - height;
+ cur = ggml_reshape_3d(ctx0, cur, n_embd, width, height);
+ if (pad_width || pad_height) {
+ cur = ggml_pad(ctx0, cur, 0, pad_width, pad_height, 0);
+ width += pad_width;
+ height += pad_height;
+ }
+
+ // unshuffle h
+ cur = ggml_reshape_3d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height);
+ cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 0, 2, 1, 3));
+
+ // unshuffle w
+ cur = ggml_reshape_3d(ctx0, cur, n_embd * scale_factor * scale_factor, height / scale_factor, width / scale_factor);
+ cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 0, 2, 1, 3));
+
+ cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]);
+
+ // projection
+ cur = ggml_norm(ctx0, cur, 1e-5); // default nn.LayerNorm
+ cur = ggml_mul(ctx0, cur, model.mm_input_norm_w);
+ cur = ggml_add(ctx0, cur, model.mm_input_norm_b);
+
+ cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
+ cur = ggml_add(ctx0, cur, model.mm_1_b);
+ cur = ggml_gelu(ctx0, cur);
+ cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
+ cur = ggml_add(ctx0, cur, model.mm_2_b);
} else {
GGML_ABORT("SigLIP: Unsupported projector type");
}
}
}
+ // siglip2 naflex
+ ggml_tensor * resize_position_embeddings() {
+ ggml_tensor * pos_embd = model.position_embeddings;
+ const int height = img.ny / patch_size;
+ const int width = img.nx / patch_size;
+
+ if (!pos_embd || height * width == pos_embd->ne[1]) {
+ return pos_embd;
+ }
+
+ const int n_pos_embd = std::sqrt(pos_embd->ne[1]);
+ pos_embd = ggml_reshape_3d(ctx0, pos_embd, n_embd, n_pos_embd, n_pos_embd); // -> (n_embd, n_pos_embd, n_pos_embd)
+ pos_embd = ggml_permute(ctx0, pos_embd, 2, 0, 1, 3); // -> (n_pos_embd, n_pos_embd, n_embd)
+ pos_embd = ggml_interpolate(ctx0, pos_embd, width, height, n_embd, 1, 1); // -> (width, height, n_embd)
+ pos_embd = ggml_reshape_2d(ctx0, pos_embd, height * width, n_embd); // -> (height * width, n_embd)
+ pos_embd = ggml_transpose(ctx0, pos_embd); // -> (n_embd, height * width)
+ pos_embd = ggml_cont(ctx0, pos_embd);
+
+ return pos_embd;
+ }
+
// build vision transformer (ViT) cgraph
// this function should cover most of the models
// if your model has specific features, you should probably duplicate this function
switch (ctx->proj_type()) {
case PROJECTOR_TYPE_GEMMA3:
case PROJECTOR_TYPE_IDEFICS3:
+ case PROJECTOR_TYPE_LFM2:
{
res = graph.build_siglip();
} break;
}
} break;
case PROJECTOR_TYPE_IDEFICS3:
+ case PROJECTOR_TYPE_LFM2:
case PROJECTOR_TYPE_INTERNVL:
{
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
{
model.projection = get_tensor(TN_MM_PROJECTOR);
} break;
+ case PROJECTOR_TYPE_LFM2:
+ {
+ model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM);
+ model.mm_input_norm_b = get_tensor(TN_MM_INP_NORM_B);
+ model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
+ model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"));
+ model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
+ model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
+ } break;
case PROJECTOR_TYPE_PIXTRAL:
{
model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
res_imgs->grid_y = inst.grid_size.height;
return true;
+ } else if (ctx->proj_type() == PROJECTOR_TYPE_LFM2) {
+ GGML_ASSERT(params.proj_scale_factor);
+
+ // smart resize
+ const int width = img->nx;
+ const int height = img->ny;
+ const int total_factor = params.patch_size * params.proj_scale_factor;
+ constexpr int min_image_tokens = 64;
+ constexpr int max_image_tokens = 256;
+ const float min_pixels = min_image_tokens * total_factor * total_factor;
+ const float max_pixels = max_image_tokens * total_factor * total_factor;
+
+ auto round_by_factor = [f = total_factor](float x) { return static_cast<int>(std::nearbyintf(x / static_cast<float>(f))) * f; };
+ auto ceil_by_factor = [f = total_factor](float x) { return static_cast<int>(std::ceil(x / static_cast<float>(f))) * f; };
+ auto floor_by_factor = [f = total_factor](float x) { return static_cast<int>(std::floor(x / static_cast<float>(f))) * f; };
+
+ int h_bar = std::max(total_factor, round_by_factor(height));
+ int w_bar = std::max(total_factor, round_by_factor(width));
+
+ if (h_bar * w_bar > max_pixels) {
+ const auto beta = std::sqrt((height * width) / max_pixels);
+ h_bar = std::max(total_factor, floor_by_factor(height / beta));
+ w_bar = std::max(total_factor, floor_by_factor(width / beta));
+ } else if (h_bar * w_bar < min_pixels) {
+ const auto beta = std::sqrt(min_pixels / (height * width));
+ h_bar = ceil_by_factor(height * beta);
+ w_bar = ceil_by_factor(width * beta);
+ }
+
+ const std::array<uint8_t, 3> pad_color = {122, 116, 104};
+
+ clip_image_u8 resized_img;
+ image_manipulation::resize_and_pad_image(*img, resized_img, clip_image_size{w_bar, h_bar}, pad_color);
+ clip_image_f32_ptr res(clip_image_f32_init());
+ normalize_image_u8_to_f32(resized_img, *res, params.image_mean, params.image_std);
+ res_imgs->entries.push_back(std::move(res));
+ return true;
}
// the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104)
n_patches_sq /= 2;
}
} break;
+ case PROJECTOR_TYPE_LFM2:
+ {
+ n_patches_sq = (img->nx / (params.patch_size * params.proj_scale_factor)) * (img->ny / (params.patch_size * params.proj_scale_factor));
+ } break;
default:
GGML_ABORT("unsupported projector type");
}
case PROJECTOR_TYPE_INTERNVL:
case PROJECTOR_TYPE_QWEN2A:
case PROJECTOR_TYPE_ULTRAVOX:
+ case PROJECTOR_TYPE_LFM2:
case PROJECTOR_TYPE_VOXTRAL:
{
// do nothing
return ctx->model.mm_model_proj->ne[1];
case PROJECTOR_TYPE_QWEN2A:
return ctx->model.mm_fc_w->ne[1];
+ case PROJECTOR_TYPE_LFM2:
+ return ctx->model.mm_2_w->ne[1];
default:
GGML_ABORT("Unknown projector type");
}