]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
model : support vision LiquidAI LFM2-VL family (#15347)
authorTarek Dakhran <redacted>
Sat, 16 Aug 2025 21:33:54 +0000 (23:33 +0200)
committerGitHub <redacted>
Sat, 16 Aug 2025 21:33:54 +0000 (23:33 +0200)
* wip lfm2 vision model

* Fix conv weight

* Implement dynamic resolution

* Fix cuda

* support LFM2-VL-450M

* happy CI

* Remove extra `ggml_conv` and put others into the right place

Co-authored-by: Sigbjørn Skjæret <redacted>
---------

Co-authored-by: Xuan Son Nguyen <redacted>
Co-authored-by: Sigbjørn Skjæret <redacted>
convert_hf_to_gguf.py
gguf-py/gguf/constants.py
gguf-py/gguf/tensor_mapping.py
tools/mtmd/clip-impl.h
tools/mtmd/clip.cpp

index 444e2cbdfbb6a09a13f4cc74963ef748acb3caa2..bd21e55f4a90ca61b9ba4271a77e99cf37bb7c5a 100755 (executable)
@@ -8251,8 +8251,7 @@ class GptOssModel(TextModel):
         self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling.get("original_max_position_embeddings", 4096))
 
 
-@ModelBase.register("Lfm2ForCausalLM")
-@ModelBase.register("LFM2ForCausalLM")
+@ModelBase.register("Lfm2ForCausalLM", "LFM2ForCausalLM")
 class LFM2Model(TextModel):
     model_arch = gguf.MODEL_ARCH.LFM2
 
@@ -8287,6 +8286,13 @@ class LFM2Model(TextModel):
         self._add_feed_forward_length()
 
     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        is_vision_tensor = "vision_tower" in name or "multi_modal_projector" in name
+        if is_vision_tensor:
+            # skip vision tensors
+            return []
+
+        name = name.replace("language_model.", "")
+
         # conv op requires 2d tensor
         if 'conv.conv' in name:
             data_torch = data_torch.squeeze(1)
@@ -8294,6 +8300,41 @@ class LFM2Model(TextModel):
         return [(self.map_tensor_name(name), data_torch)]
 
 
+@ModelBase.register("Lfm2VlForConditionalGeneration")
+class LFM2VLModel(MmprojModel):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        assert self.hparams_vision is not None
+        # TODO(tarek): for dynamic resolution image_size is not specified, setting here for compatibility
+        self.hparams_vision["image_size"] = 256
+
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+        self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.LFM2)
+        self.gguf_writer.add_vision_attention_layernorm_eps(self.find_vparam(["layer_norm_eps"]))
+        self.gguf_writer.add_vision_projector_scale_factor(self.global_config.get("downsample_factor", 2))
+        self.gguf_writer.add_vision_use_gelu(True)
+        # python notation, e.g. for vision_feature_layer == -1, we pick last layer -> vision_feature_layers_to_drop = 0
+        vision_feature_layers_to_drop = -(self.global_config.get("vision_feature_layer", -1) + 1)
+        self.gguf_writer.add_vision_block_count(self.find_vparam(self.n_block_keys) - vision_feature_layers_to_drop)
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        del bid  # unused
+        is_vision_tensor = "vision_tower" in name or "multi_modal_projector" in name
+
+        if is_vision_tensor:
+            # remove "model." prefix
+            name = name.replace("model.vision_tower.", "vision_tower.")
+            name = name.replace("model.multi_modal_projector.", "multi_modal_projector.")
+
+            if "patch_embedding.weight" in name:
+                data_torch = data_torch.view(data_torch.shape[0], 16, 16, 3).permute(0, 3, 1, 2)
+
+            return [(self.map_tensor_name(name), data_torch)]
+
+        return [] # skip other tensors
+
+
 @ModelBase.register("SmallThinkerForCausalLM")
 class SmallThinkerModel(TextModel):
     model_arch = gguf.MODEL_ARCH.SMALLTHINKER
index 911eea504a19e0735c2b6c5114ad69a0544f1518..41804f3a2bb1a3fc9b86c27d18dc6de69b0b39d0 100644 (file)
@@ -2832,6 +2832,7 @@ class VisionProjectorType:
     QWEN2A = "qwen2a" # audio
     QWEN25O = "qwen2.5o" # omni
     VOXTRAL = "voxtral"
+    LFM2 = "lfm2"
 
 
 # Items here are (block size, type size)
index c5c27980905de9ed1303290ff979513323a52184..87edaa3232ccc3339158cd433b9f89647ee1510b 100644 (file)
@@ -1272,6 +1272,7 @@ class TensorNameMap:
 
         MODEL_TENSOR.V_MM_INP_NORM: (
             "multi_modal_projector.norm",
+            "multi_modal_projector.layer_norm",
             "pre_mm_projector_norm",
         ),
 
index f1eb6333693760e5f9981cda48b8a447d4a8af1e..706ed2e3b5e21b72625ae7c7e27ba9b289bfdaa0 100644 (file)
@@ -82,6 +82,7 @@
 #define TN_MVLM_PROJ_PEG   "mm.model.peg.%d.%s"
 #define TN_IMAGE_NEWLINE   "model.image_newline"
 #define TN_MM_INP_NORM     "mm.input_norm.weight"
+#define TN_MM_INP_NORM_B   "mm.input_norm.bias"
 #define TN_MM_INP_PROJ     "mm.input_projection.weight" // gemma3
 #define TN_MM_SOFT_EMB_N   "mm.soft_emb_norm.weight"    // gemma3
 #define TN_MM_PROJECTOR    "mm.model.fc.weight"         // idefics3
@@ -133,6 +134,7 @@ enum projector_type {
     PROJECTOR_TYPE_QWEN2A,
     PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx
     PROJECTOR_TYPE_VOXTRAL,
+    PROJECTOR_TYPE_LFM2,
     PROJECTOR_TYPE_UNKNOWN,
 };
 
@@ -153,6 +155,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
     { PROJECTOR_TYPE_QWEN2A,    "qwen2a"},
     { PROJECTOR_TYPE_QWEN25O,   "qwen2.5o"},
     { PROJECTOR_TYPE_VOXTRAL,   "voxtral"},
+    { PROJECTOR_TYPE_LFM2,      "lfm2"},
 };
 
 static projector_type clip_projector_type_from_string(const std::string & str) {
index fdaf9738e88cbbc105c08770154cf45d89d9b5c9..c27f8ebbd9912e21344f34e32cb06cbd06dd8916 100644 (file)
@@ -265,6 +265,7 @@ struct clip_model {
 
     // LLaVA projection
     ggml_tensor * mm_input_norm_w = nullptr;
+    ggml_tensor * mm_input_norm_b = nullptr;
     ggml_tensor * mm_0_w = nullptr;
     ggml_tensor * mm_0_b = nullptr;
     ggml_tensor * mm_2_w = nullptr;
@@ -488,11 +489,17 @@ struct clip_graph {
 
     ggml_cgraph * build_siglip() {
         ggml_tensor * inp = build_inp();
+
+        ggml_tensor * learned_pos_embd = model.position_embeddings;
+        if (ctx->proj_type() == PROJECTOR_TYPE_LFM2) {
+            learned_pos_embd = resize_position_embeddings();
+        }
+
         ggml_tensor * cur = build_vit(
                                 inp, n_patches,
                                 NORM_TYPE_NORMAL,
                                 hparams.ffn_op,
-                                model.position_embeddings,
+                                learned_pos_embd,
                                 nullptr);
 
         if (ctx->proj_type() == PROJECTOR_TYPE_GEMMA3) {
@@ -542,6 +549,45 @@ struct clip_graph {
                 bsz);
 
             cur = ggml_mul_mat(ctx0, model.projection, cur);
+        } else if (ctx->proj_type() == PROJECTOR_TYPE_LFM2) {
+            // pixel unshuffle block
+            const int scale_factor = model.hparams.proj_scale_factor;
+            GGML_ASSERT(scale_factor > 1);
+
+            const int n_embd = cur->ne[0];
+            int width  = img.nx / patch_size;
+            int height = img.ny / patch_size;
+
+            // pad width and height to factor
+            const int64_t pad_width = CLIP_ALIGN(width, scale_factor) - width;
+            const int64_t pad_height = CLIP_ALIGN(height, scale_factor) - height;
+            cur = ggml_reshape_3d(ctx0, cur, n_embd, width, height);
+            if (pad_width || pad_height) {
+                cur     = ggml_pad(ctx0, cur, 0, pad_width, pad_height, 0);
+                width  += pad_width;
+                height += pad_height;
+            }
+
+            // unshuffle h
+            cur = ggml_reshape_3d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height);
+            cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 0, 2, 1, 3));
+
+            // unshuffle w
+            cur = ggml_reshape_3d(ctx0, cur, n_embd * scale_factor * scale_factor, height / scale_factor, width / scale_factor);
+            cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 0, 2, 1, 3));
+
+            cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]);
+
+            // projection
+            cur = ggml_norm(ctx0, cur, 1e-5); // default nn.LayerNorm
+            cur = ggml_mul(ctx0, cur, model.mm_input_norm_w);
+            cur = ggml_add(ctx0, cur, model.mm_input_norm_b);
+
+            cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
+            cur = ggml_add(ctx0, cur, model.mm_1_b);
+            cur = ggml_gelu(ctx0, cur);
+            cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
+            cur = ggml_add(ctx0, cur, model.mm_2_b);
         } else {
             GGML_ABORT("SigLIP: Unsupported projector type");
         }
@@ -1560,6 +1606,27 @@ private:
         }
     }
 
+    // siglip2 naflex
+    ggml_tensor * resize_position_embeddings() {
+        ggml_tensor * pos_embd = model.position_embeddings;
+        const int height       = img.ny / patch_size;
+        const int width        = img.nx / patch_size;
+
+        if (!pos_embd || height * width == pos_embd->ne[1]) {
+            return pos_embd;
+        }
+
+        const int n_pos_embd = std::sqrt(pos_embd->ne[1]);
+        pos_embd = ggml_reshape_3d(ctx0, pos_embd, n_embd, n_pos_embd, n_pos_embd);  // -> (n_embd, n_pos_embd, n_pos_embd)
+        pos_embd = ggml_permute(ctx0, pos_embd, 2, 0, 1, 3);                         // -> (n_pos_embd, n_pos_embd, n_embd)
+        pos_embd = ggml_interpolate(ctx0, pos_embd, width, height, n_embd, 1, 1);    // -> (width, height, n_embd)
+        pos_embd = ggml_reshape_2d(ctx0, pos_embd, height * width, n_embd);          // -> (height * width, n_embd)
+        pos_embd = ggml_transpose(ctx0, pos_embd);                                   // -> (n_embd, height * width)
+        pos_embd = ggml_cont(ctx0, pos_embd);
+
+        return pos_embd;
+    }
+
     // build vision transformer (ViT) cgraph
     // this function should cover most of the models
     // if your model has specific features, you should probably duplicate this function
@@ -1966,6 +2033,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
     switch (ctx->proj_type()) {
         case PROJECTOR_TYPE_GEMMA3:
         case PROJECTOR_TYPE_IDEFICS3:
+        case PROJECTOR_TYPE_LFM2:
             {
                 res = graph.build_siglip();
             } break;
@@ -2230,6 +2298,7 @@ struct clip_model_loader {
                         }
                     } break;
                 case PROJECTOR_TYPE_IDEFICS3:
+                case PROJECTOR_TYPE_LFM2:
                 case PROJECTOR_TYPE_INTERNVL:
                     {
                         get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
@@ -2533,6 +2602,15 @@ struct clip_model_loader {
                 {
                     model.projection = get_tensor(TN_MM_PROJECTOR);
                 } break;
+            case PROJECTOR_TYPE_LFM2:
+                {
+                    model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM);
+                    model.mm_input_norm_b = get_tensor(TN_MM_INP_NORM_B);
+                    model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
+                    model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"));
+                    model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
+                    model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
+                } break;
             case PROJECTOR_TYPE_PIXTRAL:
                 {
                     model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
@@ -3428,6 +3506,43 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
         res_imgs->grid_y = inst.grid_size.height;
         return true;
 
+    } else if (ctx->proj_type() == PROJECTOR_TYPE_LFM2) {
+        GGML_ASSERT(params.proj_scale_factor);
+
+        // smart resize
+        const int width = img->nx;
+        const int height = img->ny;
+        const int total_factor = params.patch_size * params.proj_scale_factor;
+        constexpr int min_image_tokens = 64;
+        constexpr int max_image_tokens = 256;
+        const float min_pixels = min_image_tokens * total_factor * total_factor;
+        const float max_pixels = max_image_tokens * total_factor * total_factor;
+
+        auto round_by_factor = [f = total_factor](float x) { return static_cast<int>(std::nearbyintf(x / static_cast<float>(f))) * f; };
+        auto ceil_by_factor  = [f = total_factor](float x) { return static_cast<int>(std::ceil(x / static_cast<float>(f))) * f; };
+        auto floor_by_factor = [f = total_factor](float x) { return static_cast<int>(std::floor(x / static_cast<float>(f))) * f; };
+
+        int h_bar = std::max(total_factor, round_by_factor(height));
+        int w_bar = std::max(total_factor, round_by_factor(width));
+
+        if (h_bar * w_bar > max_pixels) {
+            const auto beta = std::sqrt((height * width) / max_pixels);
+            h_bar = std::max(total_factor, floor_by_factor(height / beta));
+            w_bar = std::max(total_factor, floor_by_factor(width / beta));
+        } else if (h_bar * w_bar < min_pixels) {
+            const auto beta = std::sqrt(min_pixels / (height * width));
+            h_bar = ceil_by_factor(height * beta);
+            w_bar = ceil_by_factor(width * beta);
+        }
+
+        const std::array<uint8_t, 3> pad_color = {122, 116, 104};
+
+        clip_image_u8 resized_img;
+        image_manipulation::resize_and_pad_image(*img, resized_img, clip_image_size{w_bar, h_bar}, pad_color);
+        clip_image_f32_ptr res(clip_image_f32_init());
+        normalize_image_u8_to_f32(resized_img, *res, params.image_mean, params.image_std);
+        res_imgs->entries.push_back(std::move(res));
+        return true;
     }
 
     // the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104)
@@ -3630,6 +3745,10 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
                     n_patches_sq /= 2;
                 }
             } break;
+        case PROJECTOR_TYPE_LFM2:
+            {
+                n_patches_sq = (img->nx / (params.patch_size * params.proj_scale_factor)) * (img->ny / (params.patch_size * params.proj_scale_factor));
+            } break;
         default:
             GGML_ABORT("unsupported projector type");
     }
@@ -4034,6 +4153,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
         case PROJECTOR_TYPE_INTERNVL:
         case PROJECTOR_TYPE_QWEN2A:
         case PROJECTOR_TYPE_ULTRAVOX:
+        case PROJECTOR_TYPE_LFM2:
         case PROJECTOR_TYPE_VOXTRAL:
             {
                 // do nothing
@@ -4135,6 +4255,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
             return ctx->model.mm_model_proj->ne[1];
         case PROJECTOR_TYPE_QWEN2A:
             return ctx->model.mm_fc_w->ne[1];
+        case PROJECTOR_TYPE_LFM2:
+            return ctx->model.mm_2_w->ne[1];
         default:
             GGML_ABORT("Unknown projector type");
     }