]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
model: Add PaddleOCR-VL model support (#18825)
authormegemini <redacted>
Thu, 19 Feb 2026 16:05:25 +0000 (00:05 +0800)
committerGitHub <redacted>
Thu, 19 Feb 2026 16:05:25 +0000 (17:05 +0100)
* support PaddleOCR-VL

* clip: update PaddleOCR model loader parameters to prevent OOM during warmup

* [update] add paddleocr vl text model instead of ernie4.5

* [update] restore change of minicpmv

* [update] format

* [update] format

* [update] positions and patch merge permute

* [update] mtmd_decode_use_mrope for paddleocr

* [update] image min/max pixels

* [update] remove set_limit_image_tokens

* upate: preprocess without padding

* clean up

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <redacted>
---------

Co-authored-by: Xuan Son Nguyen <redacted>
Co-authored-by: Sigbjørn Skjæret <redacted>
16 files changed:
convert_hf_to_gguf.py
gguf-py/gguf/constants.py
gguf-py/gguf/tensor_mapping.py
src/CMakeLists.txt
src/llama-arch.cpp
src/llama-arch.h
src/llama-model.cpp
src/llama-vocab.cpp
src/models/models.h
src/models/paddleocr.cpp [new file with mode: 0644]
tools/mtmd/CMakeLists.txt
tools/mtmd/clip-impl.h
tools/mtmd/clip.cpp
tools/mtmd/models/models.h
tools/mtmd/models/paddleocr.cpp [new file with mode: 0644]
tools/mtmd/mtmd.cpp

index 7eeb3aa9035f9e1f89230253a58f9aa13ff1a148..31acd5bb48546b5aeb4a43d32ad97cd3eeda9f9a 100755 (executable)
@@ -3733,6 +3733,13 @@ class Ernie4_5Model(TextModel):
     def set_vocab(self):
         self._set_vocab_sentencepiece()
 
+        tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
+        if tokenizer_config_file.is_file():
+            with open(tokenizer_config_file, "r", encoding="utf-8") as f:
+                tokenizer_config_json = json.load(f)
+                if "add_prefix_space" in tokenizer_config_json:
+                    self.gguf_writer.add_add_space_prefix(tokenizer_config_json["add_prefix_space"])
+
     def set_gguf_parameters(self):
         super().set_gguf_parameters()
 
@@ -3742,6 +3749,10 @@ class Ernie4_5Model(TextModel):
         if (head_dim := self.hparams.get("head_dim")) is None:
             head_dim = self.hparams["hidden_size"] // num_heads
 
+        if "mlp_AR" in name or "vision_model" in name:
+            # skip vision model and projector tensors
+            return
+
         if "ernie." in name:
             name = name.replace("ernie.", "model.")
         # split the qkv weights
@@ -3851,6 +3862,48 @@ class Ernie4_5MoeModel(Ernie4_5Model):
                 raise ValueError(f"Unprocessed experts: {experts}")
 
 
+@ModelBase.register("PaddleOCRVLForConditionalGeneration")
+class PaddleOCRModel(Ernie4_5Model):
+    model_arch = gguf.MODEL_ARCH.PADDLEOCR
+
+
+@ModelBase.register("PaddleOCRVisionModel")
+class PaddleOCRVisionModel(MmprojModel):
+    # PaddleOCR-VL uses a modified version of Siglip
+    min_pixels: int = 0
+    max_pixels: int = 0
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        assert self.hparams_vision is not None
+        self.min_pixels = self.preprocessor_config["min_pixels"]
+        self.max_pixels = self.preprocessor_config["max_pixels"]
+        self.hparams_vision["image_size"] = int(math.sqrt(self.max_pixels))
+
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+        assert self.hparams_vision is not None
+        hparams = self.hparams_vision
+        self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.PADDLEOCR)
+        self.gguf_writer.add_vision_max_pixels(self.max_pixels)
+        self.gguf_writer.add_vision_min_pixels(self.min_pixels)
+        self.gguf_writer.add_vision_use_gelu(True)
+        self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("rms_norm_eps", 1e-6))
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        name = name.replace("visual.", "model.")
+
+        if "vision_model" in name or "mlp_AR" in name:
+            if "packing_position_embedding" in name:
+                return # unused
+            elif "vision_model.head" in name:
+                # we don't yet support image embeddings for this model
+                return
+            else:
+                yield from super().modify_tensors(data_torch, name, bid)
+        return # skip other tensors
+
+
 @ModelBase.register(
     "Qwen2VLModel",
     "Qwen2VLForConditionalGeneration",
index e90826dd1be74ba76d70d5e175acae1ed5627469..689acdc65de30a3baf15166e67de11ddf04d5b8c 100644 (file)
@@ -473,6 +473,7 @@ class MODEL_ARCH(IntEnum):
     RND1             = auto()
     PANGU_EMBED      = auto()
     MISTRAL3         = auto()
+    PADDLEOCR        = auto()
     MIMO2            = auto()
     STEP35           = auto()
     LLAMA_EMBED      = auto()
@@ -914,6 +915,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.RND1:             "rnd1",
     MODEL_ARCH.PANGU_EMBED:      "pangu-embedded",
     MODEL_ARCH.MISTRAL3:         "mistral3",
+    MODEL_ARCH.PADDLEOCR:        "paddleocr",
     MODEL_ARCH.MIMO2:            "mimo2",
     MODEL_ARCH.STEP35:           "step35",
     MODEL_ARCH.LLAMA_EMBED:      "llama-embed",
@@ -3186,6 +3188,20 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_DOWN,
         MODEL_TENSOR.FFN_UP,
     ],
+    MODEL_ARCH.PADDLEOCR: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_Q,
+        MODEL_TENSOR.ATTN_K,
+        MODEL_TENSOR.ATTN_V,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_GATE,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+    ],
     MODEL_ARCH.FALCON_H1: [
         # Token embedding
         MODEL_TENSOR.TOKEN_EMBD,
@@ -3847,6 +3863,7 @@ class VisionProjectorType:
     VOXTRAL = "voxtral"
     LFM2 = "lfm2"
     KIMIVL = "kimivl"
+    PADDLEOCR = "paddleocr"
     KIMIK25 = "kimik25"
     LIGHTONOCR = "lightonocr"
     COGVLM = "cogvlm"
index 5fc75c52eb8d37763f07ab92ba426fdfb6503210..fc468d07745a6edd98898f01716f432c094e0e70 100644 (file)
@@ -1325,6 +1325,7 @@ class TensorNameMap:
             "multi_modal_projector.linear_{bid}",
             "mm_projector.proj.linear_{bid}", # Kimi-K2.5
             "visual.merger.mlp.{bid}", # qwen2vl
+            "mlp_AR.linear_{bid}", # PaddleOCR-VL
             "merger.mlp.{bid}",
         ),
 
@@ -1574,6 +1575,7 @@ class TensorNameMap:
             "mm_projector.pre_norm", # Kimi-K2.5
             "pre_mm_projector_norm",
             "model.vision.linear_proj.norm1", # cogvlm
+            "mlp_AR.pre_norm", # PaddleOCR-VL
             "merger.ln_q",
         ),
 
@@ -1599,6 +1601,7 @@ class TensorNameMap:
 
         MODEL_TENSOR.V_RESMPL_ATTN_OUT: (
             "resampler.attn.out_proj",
+            "model.vision_model.head.attention.out_proj",
         ),
 
         MODEL_TENSOR.V_RESMPL_KV: (
index c10d5c70fbf24a725dd97dce4fb4b8e5d5f2b7dd..2a661a1fe884179d9dc161eb62c8fe3acbec028b 100644 (file)
@@ -110,6 +110,7 @@ add_library(llama
             models/openai-moe-iswa.cpp
             models/openelm.cpp
             models/orion.cpp
+            models/paddleocr.cpp
             models/pangu-embedded.cpp
             models/phi2.cpp
             models/phi3.cpp
index 3cb45b6922df3d8e4f79e4f22f887065d4850ce7..39ebb9db027064bb99e851d28d58ed02a16e378b 100644 (file)
@@ -121,6 +121,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_RND1,             "rnd1"             },
     { LLM_ARCH_PANGU_EMBED,      "pangu-embedded"   },
     { LLM_ARCH_MISTRAL3,         "mistral3"         },
+    { LLM_ARCH_PADDLEOCR,        "paddleocr"        },
     { LLM_ARCH_MIMO2,            "mimo2"            },
     { LLM_ARCH_STEP35,           "step35"           },
     { LLM_ARCH_LLAMA_EMBED,      "llama-embed"      },
@@ -739,6 +740,7 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
         case LLM_ARCH_INTERNLM2:
         case LLM_ARCH_GRANITE:
         case LLM_ARCH_ERNIE4_5:
+        case LLM_ARCH_PADDLEOCR:
         case LLM_ARCH_SMOLLM3:
         case LLM_ARCH_DREAM:
         case LLM_ARCH_LLADA:
index 43ca9a6a486771a0ad2ff7209440232e24954862..11daa141334cd6431bcc0440d31cb778cbb5dd26 100644 (file)
@@ -125,6 +125,7 @@ enum llm_arch {
     LLM_ARCH_RND1,
     LLM_ARCH_PANGU_EMBED,
     LLM_ARCH_MISTRAL3,
+    LLM_ARCH_PADDLEOCR,
     LLM_ARCH_MIMO2,
     LLM_ARCH_STEP35,
     LLM_ARCH_LLAMA_EMBED,
index 2ff80d6735ce76c5eac495c67d02f1ce99c42f9f..764839b9bc86445f3816fdb54af039342e3adac3 100644 (file)
@@ -2244,7 +2244,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
             } break;
         case LLM_ARCH_ERNIE4_5:
         case LLM_ARCH_ERNIE4_5_MOE:
+        case LLM_ARCH_PADDLEOCR:
             {
+                // paddleocr need mrope_section
+                ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false);
+
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
                 if (arch == LLM_ARCH_ERNIE4_5_MOE) {
                     ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH,        hparams.n_ff_exp);
@@ -6631,6 +6635,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                 } break;
             case LLM_ARCH_ERNIE4_5:
             case LLM_ARCH_ERNIE4_5_MOE:
+            case LLM_ARCH_PADDLEOCR:
                 {
                     tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
@@ -8709,6 +8714,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
             {
                 llm = std::make_unique<llm_build_ernie4_5_moe>(*this, params);
             } break;
+        case LLM_ARCH_PADDLEOCR:
+            {
+                llm = std::make_unique<llm_build_paddleocr>(*this, params);
+            } break;
         case LLM_ARCH_HUNYUAN_MOE:
             {
                 llm = std::make_unique<llm_build_hunyuan_moe>(*this, params);
@@ -9045,6 +9054,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
             return LLAMA_ROPE_TYPE_NEOX;
 
         case LLM_ARCH_QWEN2VL:
+        case LLM_ARCH_PADDLEOCR:
             return LLAMA_ROPE_TYPE_MROPE;
         case LLM_ARCH_QWEN3VL:
         case LLM_ARCH_QWEN3VLMOE:
index 657df711efd90da25669fc37c2f631d99b50da67..69b25a1bf9f59f4db9b65e67d6577f4b7a9347a7 100644 (file)
@@ -2470,6 +2470,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                     || t.first == "<|calls|>"  // solar-open
                     || t.first == "<end_of_turn>"
                     || t.first == "<|endoftext|>"
+                    || t.first == "</s>"      // paddleocr
                     || t.first == "<|eom_id|>"
                     || t.first == "<EOT>"
                     || t.first == "_<EOT>"
index f8ef68cffd773db80966e81dd735c6a169124b13..10f8b58921ecbf28cea2d72c9b697f26c97bc52e 100644 (file)
@@ -190,6 +190,10 @@ struct llm_build_ernie4_5_moe : public llm_graph_context {
     llm_build_ernie4_5_moe(const llama_model & model, const llm_graph_params & params);
 };
 
+struct llm_build_paddleocr : public llm_graph_context {
+    llm_build_paddleocr(const llama_model & model, const llm_graph_params & params);
+};
+
 template <bool iswa>
 struct llm_build_exaone4 : public llm_graph_context {
     llm_build_exaone4(const llama_model & model, const llm_graph_params & params);
diff --git a/src/models/paddleocr.cpp b/src/models/paddleocr.cpp
new file mode 100644 (file)
index 0000000..39a368d
--- /dev/null
@@ -0,0 +1,122 @@
+#include "models.h"
+
+llm_build_paddleocr::llm_build_paddleocr(const llama_model & model, const llm_graph_params & params) :
+    llm_graph_context(params) {
+
+    // NOTE: same with qwen2vl.cpp, but bias tensors are optional
+
+    const int64_t n_embd_head = hparams.n_embd_head_v;
+
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+    GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+    ggml_tensor * cur;
+    ggml_tensor * inpL;
+
+    inpL = build_inp_embd(model.tok_embd);
+
+    int sections[4];
+    std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
+
+    // inp_pos - contains the positions
+    ggml_tensor * inp_pos = build_inp_pos();
+
+    auto * inp_attn = build_attn_inp_kv();
+
+    ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+    for (int il = 0; il < n_layer; ++il) {
+        ggml_tensor * inpSA = inpL;
+
+        // norm
+        {
+            cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+        }
+        // self-attention
+        {
+            ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+            cb(Qcur, "Qcur", il);
+            if (model.layers[il].bq) {
+                Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                cb(Qcur, "Qcur", il);
+            }
+            ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+            cb(Kcur, "Kcur", il);
+            if (model.layers[il].bk) {
+                Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                cb(Kcur, "Kcur", il);
+            }
+            ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+            cb(Vcur, "Vcur", il);
+            if (model.layers[il].bv) {
+                Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                cb(Vcur, "Vcur", il);
+            }
+            Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+            Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+            Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+            Qcur = ggml_rope_multi(
+                    ctx0, Qcur, inp_pos, nullptr,
+                    n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                    );
+
+            Kcur = ggml_rope_multi(
+                    ctx0, Kcur, inp_pos, nullptr,
+                    n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                    );
+
+            cb(Qcur, "Qcur", il);
+            cb(Kcur, "Kcur", il);
+            cb(Vcur, "Vcur", il);
+
+            cur = build_attn(inp_attn,
+                    model.layers[il].wo, model.layers[il].bo,
+                    Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+        }
+        if (il == n_layer - 1) {
+            // skip computing output for unused tokens
+            cur   = ggml_get_rows(ctx0, cur, inp_out_ids);
+            inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+        }
+        ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+        cb(ffn_inp, "ffn_inp", il);
+
+        // feed-forward network
+        {
+            cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up, NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
+                    NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+        }
+        cur = ggml_add(ctx0, cur, ffn_inp);
+
+        cur = build_cvec(cur, il);
+        cb(cur, "l_out", il);
+
+        // input for next layer
+        inpL = cur;
+    }
+    cur = inpL;
+
+    cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
+
+    cb(cur, "result_norm", -1);
+    res->t_embd = cur;
+
+    // lm_head
+    cur = build_lora_mm(model.output, cur);
+
+    cb(cur, "result_output", -1);
+    res->t_logits = cur;
+
+    ggml_build_forward_expand(gf, cur);
+}
index 755a3d4b0062484e1111c16f473ad97404cec9e4..3be3c27e87bb721da26414bb3e98755133bcf748 100644 (file)
@@ -24,6 +24,7 @@ add_library(mtmd
             models/llama4.cpp
             models/llava.cpp
             models/minicpmv.cpp
+            models/paddleocr.cpp
             models/pixtral.cpp
             models/qwen2vl.cpp
             models/qwen3vl.cpp
index 03bedf9d3fd6985230ce8e1d16ddb98b10749b3d..a30c32ed42bf7962c5bdee1dd37b3e5cf954ade5 100644 (file)
@@ -229,6 +229,7 @@ enum projector_type {
     PROJECTOR_TYPE_MUSIC_FLAMINGO,
     PROJECTOR_TYPE_LFM2,
     PROJECTOR_TYPE_KIMIVL,
+    PROJECTOR_TYPE_PADDLEOCR,
     PROJECTOR_TYPE_LIGHTONOCR,
     PROJECTOR_TYPE_COGVLM,
     PROJECTOR_TYPE_JANUS_PRO,
@@ -264,6 +265,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
     { PROJECTOR_TYPE_MUSIC_FLAMINGO, "musicflamingo"},
     { PROJECTOR_TYPE_LFM2,      "lfm2"},
     { PROJECTOR_TYPE_KIMIVL,    "kimivl"},
+    { PROJECTOR_TYPE_PADDLEOCR, "paddleocr"},
     { PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"},
     { PROJECTOR_TYPE_COGVLM,    "cogvlm"},
     { PROJECTOR_TYPE_JANUS_PRO, "janus_pro"},
index 57f6dd00a38fda3f10ce1420e9b0d055f91e91c2..607d4b837318675243869b7c6ab0d87e152b1c17 100644 (file)
@@ -841,6 +841,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
             {
                 builder = std::make_unique<clip_graph_kimivl>(ctx, img);
             } break;
+        case PROJECTOR_TYPE_PADDLEOCR:
+            {
+                builder = std::make_unique<clip_graph_paddleocr>(ctx, img);
+            } break;
         case PROJECTOR_TYPE_KIMIK25:
             {
                 builder = std::make_unique<clip_graph_kimik25>(ctx, img);
@@ -1256,6 +1260,14 @@ struct clip_model_loader {
                         hparams.audio_window_len   = 400;
                         hparams.audio_hop_len      = 160;
                     } break;
+                case PROJECTOR_TYPE_PADDLEOCR:
+                    {
+                        hparams.n_merge = 2;
+                        get_u32(KEY_IMAGE_MIN_PIXELS, hparams.image_min_pixels);
+                        get_u32(KEY_IMAGE_MAX_PIXELS, hparams.image_max_pixels);
+
+                        hparams.set_warmup_n_tokens(28*28); // avoid OOM on warmup
+                    } break;
                 case PROJECTOR_TYPE_LFM2A:
                     {
                         // audio preprocessing params
@@ -1704,6 +1716,7 @@ struct clip_model_loader {
                     model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
                 } break;
             case PROJECTOR_TYPE_KIMIVL:
+            case PROJECTOR_TYPE_PADDLEOCR:
             case PROJECTOR_TYPE_KIMIK25:
                 {
                     model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM);
@@ -2990,6 +3003,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
         case PROJECTOR_TYPE_QWEN25VL:
         case PROJECTOR_TYPE_QWEN3VL:
         case PROJECTOR_TYPE_GLM4V:
+        case PROJECTOR_TYPE_PADDLEOCR:
             {
                 GGML_ASSERT(params.image_min_pixels > 0 && params.image_max_pixels > 0);
                 clip_image_u8 resized;
@@ -3330,6 +3344,7 @@ int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 *
         case PROJECTOR_TYPE_QWEN25VL:
         case PROJECTOR_TYPE_QWEN3VL:
         case PROJECTOR_TYPE_GLM4V:
+        case PROJECTOR_TYPE_PADDLEOCR:
         case PROJECTOR_TYPE_YOUTUVL:
             return (img->nx / params.patch_size) / 2;
         default:
@@ -3346,6 +3361,7 @@ int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 *
         case PROJECTOR_TYPE_QWEN25VL:
         case PROJECTOR_TYPE_QWEN3VL:
         case PROJECTOR_TYPE_GLM4V:
+        case PROJECTOR_TYPE_PADDLEOCR:
         case PROJECTOR_TYPE_YOUTUVL:
             return (img->ny / params.patch_size) / 2;
         default:
@@ -3443,6 +3459,13 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
                 int y_patch = CLIP_ALIGN(img->ny, out_patch_size) / out_patch_size;
                 n_patches = x_patch * y_patch;
             } break;
+        case PROJECTOR_TYPE_PADDLEOCR:
+            {
+                // dynamic size
+                int n_merge = ctx->model.hparams.n_merge;
+                int stride = n_merge * n_merge;
+                n_patches = CLIP_ALIGN(n_patches, stride) / stride;
+            } break;
         case PROJECTOR_TYPE_PIXTRAL:
         case PROJECTOR_TYPE_LIGHTONOCR:
             {
@@ -3690,6 +3713,30 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
                     }
                 }
 
+                set_input_i32("positions", positions);
+            } break;
+        case PROJECTOR_TYPE_PADDLEOCR:
+            {
+                const int merge_ratio = hparams.n_merge;
+                const int pw = image_size_width  / patch_size;
+                const int ph = image_size_height / patch_size;
+                std::vector<int> positions(n_pos * 4);
+                int ptr = 0;
+                // NOTE: same as Qwen-VL, but x and y are swapped
+                for (int y = 0; y < ph; y += merge_ratio) {
+                    for (int dy = 0; dy < 2; dy++) {
+                        for (int x = 0; x < pw; x += merge_ratio) {
+                            for (int dx = 0; dx < 2; dx++) {
+                                positions[                  ptr] = y + dy;
+                                positions[    num_patches + ptr] = x + dx;
+                                positions[2 * num_patches + ptr] = y + dy;
+                                positions[3 * num_patches + ptr] = x + dx;
+                                ptr++;
+                            }
+                        }
+                    }
+                }
+
                 set_input_i32("positions", positions);
             } break;
         case PROJECTOR_TYPE_QWEN25VL:
@@ -4003,6 +4050,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
             return ctx->model.mm_2_w->ne[1];
         case PROJECTOR_TYPE_LFM2:
         case PROJECTOR_TYPE_KIMIVL:
+        case PROJECTOR_TYPE_PADDLEOCR:
         case PROJECTOR_TYPE_KIMIK25:
             return ctx->model.mm_2_w->ne[1];
         case PROJECTOR_TYPE_COGVLM:
index 0beff16c5ef31b11517c49d9de717170e0ac7158..aff222c71d3afeba23cf0e68f43669e732faf1f4 100644 (file)
@@ -57,6 +57,11 @@ struct clip_graph_kimivl : clip_graph {
     ggml_cgraph * build() override;
 };
 
+struct clip_graph_paddleocr : clip_graph {
+    clip_graph_paddleocr(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
+    ggml_cgraph * build() override;
+};
+
 struct clip_graph_cogvlm : clip_graph {
     clip_graph_cogvlm(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
     ggml_cgraph * build() override;
diff --git a/tools/mtmd/models/paddleocr.cpp b/tools/mtmd/models/paddleocr.cpp
new file mode 100644 (file)
index 0000000..5d3a13f
--- /dev/null
@@ -0,0 +1,52 @@
+#include "models.h"
+
+ggml_cgraph * clip_graph_paddleocr::build() {
+    const int n_pos            = n_patches;
+    const int num_position_ids = n_pos * 4; // m-rope requires 4 dim per position
+
+    int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
+
+    ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids);
+    ggml_set_name(positions, "positions");
+    ggml_set_input(positions);
+
+    auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
+        return ggml_rope_multi(
+                    ctx0, cur, positions, nullptr,
+                    d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION,
+                    32768, 10000, 1, 0, 1, 32, 1);
+    };
+
+    ggml_tensor * learned_pos_embd = resize_position_embeddings();
+    ggml_tensor * inp = build_inp();
+    ggml_tensor * cur = build_vit(
+                            inp, n_patches,
+                            NORM_TYPE_NORMAL,
+                            hparams.ffn_op,
+                            learned_pos_embd,
+                            add_pos);
+
+    cb(cur, "vit_out", -1);
+
+    {
+        // mlp_AR paddleocr projector
+        float proj_norm_eps = 1e-5;
+        cur = build_norm(cur,
+                    model.mm_input_norm_w, model.mm_input_norm_b,
+                    NORM_TYPE_NORMAL, proj_norm_eps, -1);
+
+        const int scale_factor = model.hparams.n_merge;
+        cur = build_patch_merge_permute(cur, scale_factor);
+        cur = build_ffn(cur,
+                    model.mm_1_w, model.mm_1_b,
+                    nullptr, nullptr,
+                    model.mm_2_w, model.mm_2_b,
+                    hparams.ffn_op, -1);
+        cb(cur, "mlp_out", -1);
+    }
+
+    // build the graph
+    ggml_build_forward_expand(gf, cur);
+
+    return gf;
+}
index af733d97d563c0e82dfef8e21de4fec5ff02f3e1..8ca979c86cf792794c11b8eba3c9fbfcc71426c8 100644 (file)
@@ -325,6 +325,10 @@ struct mtmd_context {
             img_beg = "<|begin_of_image|>";
             img_end = "<|end_of_image|>";
 
+        } else if (proj == PROJECTOR_TYPE_PADDLEOCR) {
+            // <|IMAGE_START|> ... (image embeddings) ... <|IMAGE_END|>
+            img_beg = "<|IMAGE_START|>";
+            img_end = "<|IMAGE_END|>";
         }
     }
 
@@ -890,6 +894,7 @@ bool mtmd_decode_use_mrope(mtmd_context * ctx) {
         case PROJECTOR_TYPE_QWEN25VL:
         case PROJECTOR_TYPE_QWEN3VL:
         case PROJECTOR_TYPE_GLM4V:
+        case PROJECTOR_TYPE_PADDLEOCR:
             return true;
         default:
             return false;