]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
model : add LightOnOCR-1B model (#16764)
authorXuan-Son Nguyen <redacted>
Mon, 27 Oct 2025 15:02:58 +0000 (16:02 +0100)
committerGitHub <redacted>
Mon, 27 Oct 2025 15:02:58 +0000 (16:02 +0100)
* model : add LightOnOCR-1B model

* add test

convert_hf_to_gguf.py
gguf-py/gguf/constants.py
tools/mtmd/clip-impl.h
tools/mtmd/clip.cpp
tools/mtmd/mtmd.cpp
tools/mtmd/tests.sh

index 093f2ab467f4daa2ccec10bc189fbbc74e02fc43..b75936668439643e008142e3b3044ea54b6bed74 100755 (executable)
@@ -2460,18 +2460,21 @@ class ArceeModel(LlamaModel):
 )
 class LlavaVisionModel(MmprojModel):
     img_break_tok_id = -1
+    use_break_tok = True
 
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
         if self.hparams.get("model_type") == "pixtral":
             # layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py
             self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5)
-            self.img_break_tok_id = self.get_token_id("[IMG_BREAK]")
+            if self.use_break_tok:
+                self.img_break_tok_id = self.get_token_id("[IMG_BREAK]")
         elif self.is_mistral_format:
             # hparams is already vision config here so norm_eps is only defined in global_config.
             self.hparams["norm_eps"] = self.global_config.get("norm_eps", None)
             assert self.hparams["norm_eps"] is not None, "norm_eps not found in params.json"
-            self.img_break_tok_id = self.find_vparam(["image_break_token_id"])
+            if self.use_break_tok:
+                self.img_break_tok_id = self.find_vparam(["image_break_token_id"])
         else:
             raise ValueError(f"Unsupported model type: {self.hparams['model_type']}")
         logger.info(f"Image break token id: {self.img_break_tok_id}")
@@ -3962,6 +3965,10 @@ class Qwen3Model(Qwen2Model):
         return torch.stack([true_row, false_row], dim=0)
 
     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        if "model.vision_" in name:
+            # skip multimodal tensors
+            return []
+
         if self.is_rerank:
             is_tied_head = self.is_tied_embeddings and "embed_tokens" in name
             is_real_head = not self.is_tied_embeddings and "lm_head" in name
@@ -9435,6 +9442,21 @@ class PixtralModel(LlavaVisionModel):
         return super().map_tensor_name(name, try_suffixes)
 
 
+@ModelBase.register("LightOnOCRForConditionalGeneration")
+class LightOnOCRVisionModel(LlavaVisionModel):
+    is_mistral_format = False
+    use_break_tok = False
+
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+        self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.LIGHTONOCR)
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
+        name = name.replace("model.vision_encoder.", "vision_tower.")
+        name = name.replace("model.vision_projection.", "multi_modal_projector.")
+        return super().modify_tensors(data_torch, name, bid)
+
+
 @ModelBase.register("KimiVLForConditionalGeneration")
 class KimiVLModel(MmprojModel):
     def __init__(self, *args, **kwargs):
index 1b71fb3749aaa6483f9080159fd6077ecbbb1bfb..94fcfaf69cf099a7dbce247c7465b28718be288c 100644 (file)
@@ -3062,6 +3062,7 @@ class VisionProjectorType:
     VOXTRAL = "voxtral"
     LFM2 = "lfm2"
     KIMIVL = "kimivl"
+    LIGHTONOCR = "lightonocr"
 
 
 # Items here are (block size, type size)
index 1669fad99b36b35300b4a284d7fe6bee2df22e31..ad2108d1798ae2e31c8efad146c905b512b6ed50 100644 (file)
@@ -139,6 +139,7 @@ enum projector_type {
     PROJECTOR_TYPE_VOXTRAL,
     PROJECTOR_TYPE_LFM2,
     PROJECTOR_TYPE_KIMIVL,
+    PROJECTOR_TYPE_LIGHTONOCR,
     PROJECTOR_TYPE_UNKNOWN,
 };
 
@@ -161,6 +162,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
     { PROJECTOR_TYPE_VOXTRAL,   "voxtral"},
     { PROJECTOR_TYPE_LFM2,      "lfm2"},
     { PROJECTOR_TYPE_KIMIVL,    "kimivl"},
+    { PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"},
 };
 
 static projector_type clip_projector_type_from_string(const std::string & str) {
index f2abf88523843243d362c7f1e5859539d1767b43..08167625302c63ccc339e67eee9bb82937d06f93 100644 (file)
@@ -621,7 +621,7 @@ struct clip_graph {
         }
 
         // arrangement of the [IMG_BREAK] token
-        {
+        if (model.token_embd_img_break) {
             // not efficient, but works
             // the trick is to view the embeddings as a 3D tensor with shape [n_embd, n_patches_per_row, n_rows]
             // and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension
@@ -2095,6 +2095,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
                 res = graph.build_siglip();
             } break;
         case PROJECTOR_TYPE_PIXTRAL:
+        case PROJECTOR_TYPE_LIGHTONOCR:
             {
                 res = graph.build_pixtral();
             } break;
@@ -2380,6 +2381,7 @@ struct clip_model_loader {
                         get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
                     } break;
                 case PROJECTOR_TYPE_PIXTRAL:
+                case PROJECTOR_TYPE_LIGHTONOCR:
                     {
                         hparams.rope_theta = 10000.0f;
                         hparams.warmup_image_size = hparams.patch_size * 8;
@@ -2722,6 +2724,15 @@ struct clip_model_loader {
                     model.mm_input_norm_w   = get_tensor(TN_MM_INP_NORM,     false);
                     model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false);
                 } break;
+            case PROJECTOR_TYPE_LIGHTONOCR:
+                {
+                    model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
+                    model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false);
+                    model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
+                    model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false);
+                    model.mm_input_norm_w   = get_tensor(TN_MM_INP_NORM,     false);
+                    model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false);
+                } break;
             case PROJECTOR_TYPE_ULTRAVOX:
                 {
                     model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
@@ -3622,7 +3633,9 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
         res_imgs->entries.push_back(std::move(img_f32));
         return true;
 
-    } else if (ctx->proj_type() == PROJECTOR_TYPE_PIXTRAL) {
+    } else if (ctx->proj_type() == PROJECTOR_TYPE_PIXTRAL
+            || ctx->proj_type() == PROJECTOR_TYPE_LIGHTONOCR
+    ) {
         clip_image_u8 resized_image;
         auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, params.patch_size, params.image_size);
         image_manipulation::bilinear_resize(*img, resized_image, new_size.width, new_size.height);
@@ -3865,12 +3878,17 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
                 n_patches = x_patch * y_patch;
             } break;
         case PROJECTOR_TYPE_PIXTRAL:
+        case PROJECTOR_TYPE_LIGHTONOCR:
             {
                 // dynamic size
                 int n_merge = params.spatial_merge_size;
                 int n_patches_x = img->nx / patch_size / (n_merge > 0 ? n_merge : 1);
                 int n_patches_y = img->ny / patch_size / (n_merge > 0 ? n_merge : 1);
-                n_patches = n_patches_y * n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
+                if (ctx->model.token_embd_img_break) {
+                    n_patches = n_patches_y * n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
+                } else {
+                    n_patches = n_patches_y * n_patches_x;
+                }
             } break;
         case PROJECTOR_TYPE_VOXTRAL:
         case PROJECTOR_TYPE_ULTRAVOX:
@@ -4247,6 +4265,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
             } break;
         case PROJECTOR_TYPE_PIXTRAL:
         case PROJECTOR_TYPE_KIMIVL:
+        case PROJECTOR_TYPE_LIGHTONOCR:
             {
                 // set the 2D positions
                 int n_patches_per_col = image_size_width / patch_size;
@@ -4377,6 +4396,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
             return ctx->model.mm_model_peg_0_b->ne[0];
         case PROJECTOR_TYPE_MLP:
         case PROJECTOR_TYPE_PIXTRAL:
+        case PROJECTOR_TYPE_LIGHTONOCR:
             return ctx->model.mm_2_w->ne[1];
         case PROJECTOR_TYPE_MLP_NORM:
             return ctx->model.mm_3_b->ne[0];
index 4d487581ae0a0202b976b47ea7cafd50e75fd4cc..3b901bfac82159e3b52787ae27f11507d1815440 100644 (file)
@@ -275,6 +275,11 @@ struct mtmd_context {
             img_beg = "<img>";
             img_end = "</img>";
 
+        } else if (proj == PROJECTOR_TYPE_LIGHTONOCR) {
+            // <|im_start|> ... (image embeddings) ... <|im_end|>
+            img_beg = "<|im_start|>";
+            img_end = "<|im_end|>";
+
         }
     }
 
index dbdf7656a66d9537dda59ae7b0d1abb2d508b0b3..5e33d127649a0bbd1993bf919a9ed9924b7c74b4 100755 (executable)
@@ -70,6 +70,7 @@ add_test_vision "ggml-org/InternVL3-1B-Instruct-GGUF:Q8_0"
 add_test_vision "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M"
 add_test_vision "ggml-org/LFM2-VL-450M-GGUF:Q8_0"
 add_test_vision "ggml-org/granite-docling-258M-GGUF:Q8_0"
+add_test_vision "ggml-org/LightOnOCR-1B-1025-GGUF:Q8_0"
 
 add_test_audio  "ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF:Q8_0"
 add_test_audio  "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M"