]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
model: Add support for CogVLM model (#15002)
authorTianyue-Zhao <redacted>
Thu, 30 Oct 2025 11:18:50 +0000 (07:18 -0400)
committerGitHub <redacted>
Thu, 30 Oct 2025 11:18:50 +0000 (12:18 +0100)
* Added GGUF mappings for CogVLM model

* Add tensor mapping for CogVLM visual encoder

* Add CogVLM to conversion script, no vision part yet

* Added CogVLM vision model to conversion script

* Add graph for CogVLM CLIP model

* Add graph for CogVLM

* Fixes for CogVLM. Now compiles.

* Model now runs

* Fixes for cogvlm graph

* Account for graph context change after rebase

* Changes for whitespace

* Changes in convert script according to comments

* Switch CogVLM LLM graph to merged QKV tensor

* Use rope_type variable instead of direct definition

* Change CogVLM CLIP encoder to use SWIGLU

* Switch CogVLM CLIP to use merged QKV

* Apply rebase edits and remove ggml_cont call that is now unnecessary

* clean up

---------

Co-authored-by: Xuan Son Nguyen <redacted>
convert_hf_to_gguf.py
gguf-py/gguf/constants.py
gguf-py/gguf/tensor_mapping.py
src/llama-arch.cpp
src/llama-arch.h
src/llama-model.cpp
src/llama-model.h
tools/mtmd/clip-impl.h
tools/mtmd/clip.cpp

index b75936668439643e008142e3b3044ea54b6bed74..0fd8d5681d799e1f0d5d91f278d0ca79cfc2041b 100755 (executable)
@@ -1528,7 +1528,7 @@ class MmprojModel(ModelBase):
             self.gguf_writer.add_vision_embedding_length(self.find_vparam(["hidden_size"]))
             self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size"]))
             self.gguf_writer.add_vision_block_count(self.find_vparam(self.n_block_keys))
-            self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads"]))
+            self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads", "num_heads"]))
 
             # preprocessor config
             image_mean = _MISTRAL_COMMON_DATASET_MEAN if self.is_mistral_format else self.preprocessor_config["image_mean"]
@@ -9493,6 +9493,37 @@ class KimiVLModel(MmprojModel):
 
         return [] # skip other tensors
 
+
+@ModelBase.register("CogVLMForCausalLM")
+class CogVLMVisionModel(MmprojModel):
+
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+        self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-6))
+        self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.COGVLM)
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        del bid  # unused
+
+        if not name.startswith("model.vision."):
+            return []
+
+        return [(self.map_tensor_name(name), data_torch)]
+
+
+@ModelBase.register("CogVLMForCausalLM")
+class CogVLMModel(LlamaModel):
+    model_arch = gguf.MODEL_ARCH.COGVLM
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        del bid  # unused
+
+        # block vision tensors
+        if name.startswith("model.vision."):
+            return []
+
+        return [(self.map_tensor_name(name), data_torch)]
+
 ###### CONVERSION LOGIC ######
 
 
index 94fcfaf69cf099a7dbce247c7465b28718be288c..a0c08f69172e48fd4f379b2351dbd796380044b9 100644 (file)
@@ -420,6 +420,7 @@ class MODEL_ARCH(IntEnum):
     SEED_OSS         = auto()
     GROVEMOE         = auto()
     APERTUS          = auto()
+    COGVLM           = auto()
 
 
 class VISION_PROJECTOR_TYPE(IntEnum):
@@ -430,6 +431,7 @@ class VISION_PROJECTOR_TYPE(IntEnum):
     GLM_EDGE  = auto()
     MERGER    = auto()
     GEMMA3    = auto()
+    COGVLM    = auto()
 
 
 class MODEL_TENSOR(IntEnum):
@@ -600,6 +602,11 @@ class MODEL_TENSOR(IntEnum):
     SHORTCONV_CONV       = auto()
     SHORTCONV_INPROJ     = auto()
     SHORTCONV_OUTPROJ    = auto()
+    VISEXP_ATTN_QKV      = auto()
+    VISEXP_ATTN_OUT      = auto()
+    VISEXP_GATE          = auto()
+    VISEXP_DOWN          = auto()
+    VISEXP_UP            = auto()
     # vision
     V_MMPROJ             = auto()
     V_MMPROJ_FC          = auto()
@@ -609,6 +616,7 @@ class MODEL_TENSOR(IntEnum):
     V_ENC_EMBD_PATCH     = auto()
     V_ENC_EMBD_POS       = auto()
     V_ENC_INPUT_NORM     = auto()
+    V_ENC_ATTN_QKV       = auto()
     V_ENC_ATTN_Q         = auto()
     V_ENC_ATTN_Q_NORM    = auto()
     V_ENC_ATTN_K         = auto()
@@ -640,6 +648,12 @@ class MODEL_TENSOR(IntEnum):
     V_RESMPL_QUERY       = auto() # minicpmv
     V_TOK_EMBD_IMG_BREAK = auto() # pixtral
     V_MM_PATCH_MERGER    = auto() # mistral small 3.1
+    V_MM_POST_FC_NORM    = auto() # cogvlm
+    V_MM_UP              = auto() # cogvlm
+    V_MM_DOWN            = auto() # cogvlm
+    V_MM_GATE            = auto() # cogvlm
+    V_TOK_BOI            = auto() # cogvlm
+    V_TOK_EOI            = auto() # cogvlm
     # audio (mtmd)
     A_ENC_EMBD_POS       = auto()
     A_ENC_CONV1D         = auto()
@@ -766,6 +780,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.SEED_OSS:         "seed_oss",
     MODEL_ARCH.GROVEMOE:         "grovemoe",
     MODEL_ARCH.APERTUS:          "apertus",
+    MODEL_ARCH.COGVLM:           "cogvlm",
 }
 
 VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@@ -946,6 +961,11 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
     MODEL_TENSOR.SHORTCONV_CONV:            "blk.{bid}.shortconv.conv",
     MODEL_TENSOR.SHORTCONV_INPROJ:          "blk.{bid}.shortconv.in_proj",
     MODEL_TENSOR.SHORTCONV_OUTPROJ:         "blk.{bid}.shortconv.out_proj",
+    MODEL_TENSOR.VISEXP_ATTN_QKV:           "blk.{bid}.vis_attn_qkv",
+    MODEL_TENSOR.VISEXP_ATTN_OUT:           "blk.{bid}.vis_attn_output",
+    MODEL_TENSOR.VISEXP_GATE:               "blk.{bid}.vis_gate",
+    MODEL_TENSOR.VISEXP_DOWN:               "blk.{bid}.vis_down",
+    MODEL_TENSOR.VISEXP_UP:                 "blk.{bid}.vis_up",
     # vision
     MODEL_TENSOR.V_MMPROJ:                  "mm.{bid}",
     MODEL_TENSOR.V_MMPROJ_FC:               "mm.model.fc",
@@ -954,6 +974,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
     MODEL_TENSOR.V_ENC_EMBD_CLS:            "v.class_embd",
     MODEL_TENSOR.V_ENC_EMBD_PATCH:          "v.patch_embd",
     MODEL_TENSOR.V_ENC_EMBD_POS:            "v.position_embd",
+    MODEL_TENSOR.V_ENC_ATTN_QKV:            "v.blk.{bid}.attn_qkv",
     MODEL_TENSOR.V_ENC_ATTN_Q:              "v.blk.{bid}.attn_q",
     MODEL_TENSOR.V_ENC_ATTN_Q_NORM:         "v.blk.{bid}.attn_q_norm",
     MODEL_TENSOR.V_ENC_ATTN_K:              "v.blk.{bid}.attn_k",
@@ -986,6 +1007,12 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
     MODEL_TENSOR.V_RESMPL_QUERY:            "resampler.query",
     MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK:      "v.token_embd.img_break", # pixtral
     MODEL_TENSOR.V_MM_PATCH_MERGER:         "mm.patch_merger", # mistral small 3.1
+    MODEL_TENSOR.V_MM_POST_FC_NORM:         "mm.post_fc_norm", # cogvlm
+    MODEL_TENSOR.V_MM_UP:                   "mm.up",
+    MODEL_TENSOR.V_MM_DOWN:                 "mm.down",
+    MODEL_TENSOR.V_MM_GATE:                 "mm.gate",
+    MODEL_TENSOR.V_TOK_BOI:                 "v.boi",
+    MODEL_TENSOR.V_TOK_EOI:                 "v.eoi",
     # audio (mtmd)
     MODEL_TENSOR.A_ENC_EMBD_POS:            "a.position_embd",
     MODEL_TENSOR.A_ENC_CONV1D:              "a.conv1d.{bid}",
@@ -1023,6 +1050,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.V_ENC_EMBD_PATCH,
         MODEL_TENSOR.V_ENC_EMBD_POS,
         MODEL_TENSOR.V_ENC_INPUT_NORM,
+        MODEL_TENSOR.V_ENC_ATTN_QKV,
         MODEL_TENSOR.V_ENC_ATTN_Q,
         MODEL_TENSOR.V_ENC_ATTN_Q_NORM,
         MODEL_TENSOR.V_ENC_ATTN_K,
@@ -1054,6 +1082,12 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.V_RESMPL_QUERY,
         MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK,
         MODEL_TENSOR.V_MM_PATCH_MERGER,
+        MODEL_TENSOR.V_MM_POST_FC_NORM,
+        MODEL_TENSOR.V_MM_UP,
+        MODEL_TENSOR.V_MM_DOWN,
+        MODEL_TENSOR.V_MM_GATE,
+        MODEL_TENSOR.V_TOK_BOI,
+        MODEL_TENSOR.V_TOK_EOI,
         # audio
         MODEL_TENSOR.A_ENC_EMBD_POS,
         MODEL_TENSOR.A_ENC_CONV1D,
@@ -2837,6 +2871,23 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_DOWN_CHEXP,
         MODEL_TENSOR.FFN_UP_CHEXP,
     ],
+    MODEL_ARCH.COGVLM: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_QKV,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_GATE,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+        MODEL_TENSOR.VISEXP_ATTN_QKV,
+        MODEL_TENSOR.VISEXP_ATTN_OUT,
+        MODEL_TENSOR.VISEXP_GATE,
+        MODEL_TENSOR.VISEXP_UP,
+        MODEL_TENSOR.VISEXP_DOWN,
+    ],
     # TODO
 }
 
@@ -3063,6 +3114,7 @@ class VisionProjectorType:
     LFM2 = "lfm2"
     KIMIVL = "kimivl"
     LIGHTONOCR = "lightonocr"
+    COGVLM = "cogvlm"
 
 
 # Items here are (block size, type size)
index d7dcd8efb84260871f5035dd384ce1d37a15b3ec..37e6fc85d516a51b7f90943961ba5cda91c7a039 100644 (file)
@@ -104,6 +104,7 @@ class TensorNameMap:
             "backbone.final_layer_norm",               # wavtokenizer
             "model.norm",                              # llama4
             "model.transformer.ln_f",                  # llada
+            "model.norm",                              # cogvlm
         ),
 
         # Rope frequencies
@@ -162,6 +163,7 @@ class TensorNameMap:
             "encoder.layer.{bid}.layer_norm_1",             # jina-v2-code
             "rwkv.blocks.{bid}.ln2",                        # rwkv6
             "model.layers.{bid}.ln2",                       # rwkv7
+            "model.layers.{bid}.post_attention_layernorm",  # cogvlm
         ),
 
         # Attention query-key-value
@@ -184,6 +186,7 @@ class TensorNameMap:
             "encoder.layers.{bid}.self_attention.query_key_value",                 # chatglm
             "transformer.layers.{bid}.attn.qkv_proj",                              # openelm
             "transformer_encoder.{bid}.qkv",                                       # neobert
+            "model.layers.{bid}.self_attn.language_expert_query_key_value",        # cogvlm
         ),
 
         # Attention query
@@ -279,6 +282,7 @@ class TensorNameMap:
             "model.transformer.blocks.{bid}.attn_out",                      # llada
             "layers.{bid}.self_attn.o_proj",                                # qwen3-embedding
             "backbone.layers.{bid}.mixer.o_proj",                           # nemotron-h
+            "model.layers.{bid}.self_attn.language_expert_dense",           # cogvlm
         ),
 
         # Attention output norm
@@ -418,6 +422,7 @@ class TensorNameMap:
             "model.transformer.blocks.{bid}.up_proj",                 # llada
             "layers.{bid}.mlp.up_proj",                               # qwen3-embedding
             "backbone.layers.{bid}.mixer.up_proj",                    # nemotron-h
+            "model.layers.{bid}.mlp.language_mlp.up_proj",            # cogvlm
         ),
 
         MODEL_TENSOR.FFN_UP_EXP: (
@@ -450,21 +455,22 @@ class TensorNameMap:
 
         # Feed-forward gate
         MODEL_TENSOR.FFN_GATE: (
-            "model.layers.{bid}.mlp.gate_proj",           # llama-hf refact olmo2
-            "layers.{bid}.mlp.gate_proj",                 # embeddinggemma
-            "layers.{bid}.feed_forward.w1",               # llama-pth
-            "transformer.h.{bid}.mlp.w2",                 # qwen
-            "transformer.h.{bid}.mlp.c_fc2",              # jais
-            "model.layers.layers.{bid}.mlp.gate_proj",    # plamo
-            "model.layers.{bid}.feed_forward.w1",         # internlm2
-            "encoder.layers.{bid}.mlp.fc12",              # nomic-bert
-            "encoder.layer.{bid}.mlp.gated_layers_w",     # jina-bert-v2 (split up/gate, no longer used)
-            "transformer.h.{bid}.mlp.linear_1",           # refact
-            "model.layers.{bid}.residual_mlp.w1",         # arctic
-            "transformer.h.{bid}.mlp.c_fc_0",             # exaone
-            "model.layers.{bid}.feed_forward.gate_proj",  # llama4 jamba granite-hybrid
-            "model.transformer.blocks.{bid}.ff_proj",     # llada
-            "layers.{bid}.mlp.gate_proj",                 # qwen3-embedding
+            "model.layers.{bid}.mlp.gate_proj",               # llama-hf refact olmo2
+            "layers.{bid}.mlp.gate_proj",                     # embeddinggemma
+            "layers.{bid}.feed_forward.w1",                   # llama-pth
+            "transformer.h.{bid}.mlp.w2",                     # qwen
+            "transformer.h.{bid}.mlp.c_fc2",                  # jais
+            "model.layers.layers.{bid}.mlp.gate_proj",        # plamo
+            "model.layers.{bid}.feed_forward.w1",             # internlm2
+            "encoder.layers.{bid}.mlp.fc12",                  # nomic-bert
+            "encoder.layer.{bid}.mlp.gated_layers_w",         # jina-bert-v2 (split up/gate, no longer used)
+            "transformer.h.{bid}.mlp.linear_1",               # refact
+            "model.layers.{bid}.residual_mlp.w1",             # arctic
+            "transformer.h.{bid}.mlp.c_fc_0",                 # exaone
+            "model.layers.{bid}.feed_forward.gate_proj",      # llama4 jamba granite-hybrid
+            "model.transformer.blocks.{bid}.ff_proj",         # llada
+            "layers.{bid}.mlp.gate_proj",                     # qwen3-embedding
+            "model.layers.{bid}.mlp.language_mlp.gate_proj",  # cogvlm
         ),
 
         MODEL_TENSOR.FFN_GATE_EXP: (
@@ -522,6 +528,7 @@ class TensorNameMap:
             "model.transformer.blocks.{bid}.ff_out",                  # llada
             "layers.{bid}.mlp.down_proj",                             # qwen3-embedding
             "backbone.layers.{bid}.mixer.down_proj",                  # nemotron-h
+            "model.layers.{bid}.mlp.language_mlp.down_proj",          # cogvlm
         ),
 
         MODEL_TENSOR.FFN_DOWN_EXP: (
@@ -1047,6 +1054,26 @@ class TensorNameMap:
             "encoder.block.{bid}.layer.1.DenseReluDense.wo", # t5
         ),
 
+        MODEL_TENSOR.VISEXP_UP: (
+            "model.layers.{bid}.mlp.vision_mlp.up_proj",  # cogvlm
+        ),
+
+        MODEL_TENSOR.VISEXP_GATE: (
+            "model.layers.{bid}.mlp.vision_mlp.gate_proj",  # cogvlm
+        ),
+
+        MODEL_TENSOR.VISEXP_DOWN: (
+            "model.layers.{bid}.mlp.vision_mlp.down_proj",  # cogvlm
+        ),
+
+        MODEL_TENSOR.VISEXP_ATTN_OUT: (
+            "model.layers.{bid}.self_attn.vision_expert_dense",  # cogvlm
+        ),
+
+        MODEL_TENSOR.VISEXP_ATTN_QKV: (
+            "model.layers.{bid}.self_attn.vision_expert_query_key_value",  # cogvlm
+        ),
+
         ############################################################################
         # TODO: these do not belong to block_mappings_cfg - move them to mappings_cfg
         MODEL_TENSOR.ENC_OUTPUT_NORM: (
@@ -1148,6 +1175,7 @@ class TensorNameMap:
 
         MODEL_TENSOR.V_MMPROJ_FC: (
             "model.connector.modality_projection.proj", # SmolVLM
+            "model.vision.linear_proj.linear_proj", # cogvlm
         ),
 
         MODEL_TENSOR.V_MMPROJ_MLP: (
@@ -1164,6 +1192,7 @@ class TensorNameMap:
             "vision_tower.vision_model.embeddings.class_embedding",
             "model.vision_tower.embeddings.cls_token", # Intern-S1
             "vision_model.class_embedding", # llama 4
+            "model.vision.patch_embedding.cls_embedding", # cogvlm
         ),
 
         MODEL_TENSOR.V_ENC_EMBD_PATCH: (
@@ -1176,6 +1205,7 @@ class TensorNameMap:
             "vision_model.patch_embedding.linear", # llama 4
             "visual.patch_embed.proj", # qwen2vl
             "vision_tower.patch_embed.proj", # kimi-vl
+            "model.vision.patch_embedding.proj", # cogvlm
         ),
 
         MODEL_TENSOR.V_ENC_EMBD_POS: (
@@ -1185,6 +1215,11 @@ class TensorNameMap:
             "model.vision_model.embeddings.position_embedding", # SmolVLM
             "vision_model.positional_embedding_vlm", # llama 4
             "vision_tower.patch_embed.pos_emb", # kimi-vl
+            "model.vision.patch_embedding.position_embedding", # cogvlm
+        ),
+
+        MODEL_TENSOR.V_ENC_ATTN_QKV: (
+            "model.vision.transformer.layers.{bid}.attention.query_key_value", # cogvlm
         ),
 
         MODEL_TENSOR.V_ENC_ATTN_Q: (
@@ -1244,6 +1279,7 @@ class TensorNameMap:
             "vision_model.model.layers.{bid}.input_layernorm", # llama4
             "visual.blocks.{bid}.norm1", # qwen2vl
             "vision_tower.encoder.blocks.{bid}.norm0", # kimi-vl (norm0/norm1)
+            "model.vision.transformer.layers.{bid}.input_layernorm", # cogvlm
         ),
 
         MODEL_TENSOR.V_ENC_ATTN_O: (
@@ -1257,6 +1293,7 @@ class TensorNameMap:
             "vision_encoder.transformer.layers.{bid}.attention.wo", # pixtral
             "visual.blocks.{bid}.attn.proj", # qwen2vl
             "vision_tower.encoder.blocks.{bid}.wo", # kimi-vl
+            "model.vision.transformer.layers.{bid}.attention.dense", # cogvlm
         ),
 
         MODEL_TENSOR.V_ENC_POST_ATTN_NORM: (
@@ -1270,6 +1307,7 @@ class TensorNameMap:
             "vision_encoder.transformer.layers.{bid}.ffn_norm", # pixtral
             "visual.blocks.{bid}.norm2", # qwen2vl
             "vision_tower.encoder.blocks.{bid}.norm1", # kimi-vl (norm0/norm1)
+            "model.vision.transformer.layers.{bid}.post_attention_layernorm", # cogvlm
         ),
 
         MODEL_TENSOR.V_ENC_FFN_UP: (
@@ -1283,6 +1321,7 @@ class TensorNameMap:
             "visual.blocks.{bid}.mlp.fc1", # qwen2vl
             "visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl
             "vision_tower.encoder.blocks.{bid}.mlp.fc0", # kimi-vl (fc0/fc1)
+            "model.vision.transformer.layers.{bid}.mlp.fc1", # cogvlm
         ),
 
         MODEL_TENSOR.V_ENC_FFN_GATE: (
@@ -1302,6 +1341,7 @@ class TensorNameMap:
             "visual.blocks.{bid}.mlp.fc2", # qwen2vl
             "visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl
             "vision_tower.encoder.blocks.{bid}.mlp.fc1", # kimi-vl (fc0/fc1)
+            "model.vision.transformer.layers.{bid}.mlp.fc2", # cogvlm
         ),
 
         MODEL_TENSOR.V_LAYER_SCALE_1: (
@@ -1338,6 +1378,7 @@ class TensorNameMap:
             "multi_modal_projector.layer_norm",
             "multi_modal_projector.pre_norm",
             "pre_mm_projector_norm",
+            "model.vision.linear_proj.norm1", # cogvlm
         ),
 
         MODEL_TENSOR.V_MM_SOFT_EMB_NORM: (
@@ -1397,6 +1438,30 @@ class TensorNameMap:
             "patch_merger.merging_layer", # mistral
         ),
 
+        MODEL_TENSOR.V_MM_POST_FC_NORM: (
+            "model.vision.linear_proj.norm1", # cogvlm
+        ),
+
+        MODEL_TENSOR.V_MM_UP: (
+            "model.vision.linear_proj.dense_h_to_4h", # cogvlm
+        ),
+
+        MODEL_TENSOR.V_MM_DOWN: (
+            "model.vision.linear_proj.dense_4h_to_h", # cogvlm
+        ),
+
+        MODEL_TENSOR.V_MM_GATE: (
+            "model.vision.linear_proj.gate_proj", # cogvlm
+        ),
+
+        MODEL_TENSOR.V_TOK_BOI: (
+            "model.vision.boi", # cogvlm
+        ),
+
+        MODEL_TENSOR.V_TOK_EOI: (
+            "model.vision.eoi", # cogvlm
+        ),
+
         # audio (mtmd)
 
         MODEL_TENSOR.A_ENC_EMBD_POS: (
index 8ca769c5fd2ef7e0d3be88e34a33fc11d0e70c83..ba45e88c08e3ad359ea11cb9784247fc5029a32f 100644 (file)
@@ -103,6 +103,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_SEED_OSS,         "seed_oss"         },
     { LLM_ARCH_GROVEMOE,         "grovemoe"         },
     { LLM_ARCH_APERTUS,          "apertus"          },
+    { LLM_ARCH_COGVLM,           "cogvlm"           },
     { LLM_ARCH_UNKNOWN,          "(unknown)"        },
 };
 
@@ -2312,6 +2313,26 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_FFN_UP_CHEXPS,      "blk.%d.ffn_up_chexps" },
         },
     },
+    {
+        LLM_ARCH_COGVLM,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+            { LLM_TENSOR_VISEXP_ATTN_QKV, "blk.%d.vis_attn_qkv" },
+            { LLM_TENSOR_VISEXP_ATTN_OUT, "blk.%d.vis_attn_output" },
+            { LLM_TENSOR_VISEXP_FFN_GATE, "blk.%d.vis_gate" },
+            { LLM_TENSOR_VISEXP_FFN_DOWN, "blk.%d.vis_down" },
+            { LLM_TENSOR_VISEXP_FFN_UP,   "blk.%d.vis_up" },
+        },
+    },
     {
         LLM_ARCH_UNKNOWN,
         {
@@ -2488,6 +2509,11 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
     {LLM_TENSOR_SHORTCONV_CONV,             {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}},
     {LLM_TENSOR_SHORTCONV_INPROJ,           {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
     {LLM_TENSOR_SHORTCONV_OUTPROJ,          {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_VISEXP_ATTN_QKV,            {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_VISEXP_ATTN_OUT,            {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_VISEXP_FFN_GATE,            {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_VISEXP_FFN_DOWN,            {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_VISEXP_FFN_UP,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
     // NextN/MTP tensors are currently ignored (reserved for future MTP support)
     // These tensors only exist in the last layer(s) and are treated as output tensors
     {LLM_TENSOR_NEXTN_EH_PROJ,              {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
index dea725c1a753a92736fd61bcc0228fa9e3f19bac..3350e8b43153ca51f47278390bd4ae1b86c0b248 100644 (file)
@@ -107,6 +107,7 @@ enum llm_arch {
     LLM_ARCH_SEED_OSS,
     LLM_ARCH_GROVEMOE,
     LLM_ARCH_APERTUS,
+    LLM_ARCH_COGVLM,
     LLM_ARCH_UNKNOWN,
 };
 
@@ -455,6 +456,11 @@ enum llm_tensor {
     LLM_TENSOR_SHORTCONV_CONV,
     LLM_TENSOR_SHORTCONV_INPROJ,
     LLM_TENSOR_SHORTCONV_OUTPROJ,
+    LLM_TENSOR_VISEXP_ATTN_QKV,
+    LLM_TENSOR_VISEXP_ATTN_OUT,
+    LLM_TENSOR_VISEXP_FFN_GATE,
+    LLM_TENSOR_VISEXP_FFN_DOWN,
+    LLM_TENSOR_VISEXP_FFN_UP,
     LLM_TENSOR_NEXTN_EH_PROJ,
     LLM_TENSOR_NEXTN_EMBED_TOKENS,
     LLM_TENSOR_NEXTN_ENORM,
index ea6f59ed482bb24f5a9ccb0d2435164f3b382441..35759a00aecadc98907f6a843f4d6e105a145780 100644 (file)
@@ -2124,6 +2124,14 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     default: type = LLM_TYPE_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_COGVLM:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+                switch (hparams.n_layer) {
+                    case 32: type = LLM_TYPE_13B; break;
+                    default: type = LLM_TYPE_UNKNOWN;
+                }
+            } break;
         default: throw std::runtime_error("unsupported model architecture");
     }
 
@@ -6136,6 +6144,41 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias",   i), { n_embd_head_k }, TENSOR_NOT_REQUIRED);
                     }
                 } break;
+            case LLM_ARCH_COGVLM:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+
+                    // if output is NULL, init from the input tok embed
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd_head_k * n_head * 3}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
+
+                        layer.visexp_attn_wqkv = create_tensor(tn(LLM_TENSOR_VISEXP_ATTN_QKV, "weight", i), {n_embd, n_embd_head_k * n_head * 3}, 0);
+                        layer.visexp_attn_wo = create_tensor(tn(LLM_TENSOR_VISEXP_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
+
+                        layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+
+                        layer.visexp_ffn_gate = create_tensor(tn(LLM_TENSOR_VISEXP_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.visexp_ffn_down = create_tensor(tn(LLM_TENSOR_VISEXP_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.visexp_ffn_up   = create_tensor(tn(LLM_TENSOR_VISEXP_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                    }
+                } break;
             default:
                 throw std::runtime_error("unknown architecture");
         }
@@ -19641,6 +19684,104 @@ struct llm_build_apertus : public llm_graph_context {
     }
 };
 
+struct llm_build_cogvlm : public llm_graph_context {
+    llm_build_cogvlm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        float kq_scale = 1.0f / sqrtf(float(n_embd_head));
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * inpL, * cur;
+        inpL = build_inp_embd(model.tok_embd);
+
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv();
+
+        // check ubatch to see if we have input tokens (text)
+        // or an input embedding vector (image)
+        bool is_text;
+        if (ubatch.token) {
+            is_text = true;
+        } else {
+            is_text = false;
+        }
+
+        for (int il = 0; il < n_layer; ++il) {
+            // get either the text or image weight tensors
+            ggml_tensor * wqkv, * wo;
+            ggml_tensor * ffn_gate, * ffn_down, * ffn_up;
+
+            if (is_text) {
+                wqkv = model.layers[il].wqkv;
+                wo = model.layers[il].wo;
+                ffn_gate = model.layers[il].ffn_gate;
+                ffn_down = model.layers[il].ffn_down;
+                ffn_up = model.layers[il].ffn_up;
+            } else {
+                wqkv = model.layers[il].visexp_attn_wqkv;
+                wo = model.layers[il].visexp_attn_wo;
+                ffn_gate = model.layers[il].visexp_ffn_gate;
+                ffn_down = model.layers[il].visexp_ffn_down;
+                ffn_up = model.layers[il].visexp_ffn_up;
+            }
+
+            ggml_tensor * inpSA = inpL;
+            cur = build_norm(inpSA, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
+
+            // build self attention
+            {
+                ggml_tensor * qkv = build_lora_mm(wqkv, cur);
+
+                // split qkv into Q, K, V along the first dimension
+                ggml_tensor * Qcur = ggml_view_3d(ctx0, qkv, n_embd_head, n_head, n_tokens, n_embd_head * sizeof(float),
+                    qkv->nb[1], 0);
+                ggml_tensor * Kcur = ggml_view_3d(ctx0, qkv, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float),
+                    qkv->nb[1], n_embd * ggml_element_size(qkv));
+                ggml_tensor * Vcur = ggml_view_3d(ctx0, qkv, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float),
+                    qkv->nb[1], 2 * n_embd * ggml_element_size(qkv));
+
+                Qcur = ggml_rope(ctx0, Qcur, inp_pos, n_embd_head, rope_type);
+                Kcur = ggml_rope(ctx0, Kcur, inp_pos, n_embd_head, rope_type);
+
+                cur = build_attn(inp_attn, wo, nullptr, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
+                cb(cur, "attn_out", il);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = build_ffn(cur,
+                    ffn_up,   NULL, NULL,
+                    ffn_gate, NULL, NULL,
+                    ffn_down, NULL, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        cur = build_lora_mm(model.output, cur);
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+        ggml_build_forward_expand(gf, cur);
+
+    }
+};
+
 llama_memory_i * llama_model::create_memory(const llama_memory_params & params, const llama_cparams & cparams) const {
     llama_memory_i * res;
 
@@ -20165,6 +20306,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
             {
                 llm = std::make_unique<llm_build_apertus>(*this, params);
             } break;
+        case LLM_ARCH_COGVLM:
+            {
+                llm = std::make_unique<llm_build_cogvlm>(*this, params);
+            } break;
         default:
             GGML_ABORT("fatal error");
     }
@@ -20382,6 +20527,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_SEED_OSS:
         case LLM_ARCH_GROVEMOE:
         case LLM_ARCH_APERTUS:
+        case LLM_ARCH_COGVLM:
             return LLAMA_ROPE_TYPE_NEOX;
 
         case LLM_ARCH_QWEN2VL:
index 1ab1cf7f8e94d69e926bf6cf8e88d0379db5eb4c..a5affda1c999d8e3aaad4e09df1be58785d4bdd0 100644 (file)
@@ -384,6 +384,13 @@ struct llama_layer {
     // openai-moe
     struct ggml_tensor * attn_sinks = nullptr;
 
+    // cogvlm
+    struct ggml_tensor * visexp_attn_wqkv = nullptr;
+    struct ggml_tensor * visexp_attn_wo   = nullptr;
+    struct ggml_tensor * visexp_ffn_gate  = nullptr;
+    struct ggml_tensor * visexp_ffn_down  = nullptr;
+    struct ggml_tensor * visexp_ffn_up    = nullptr;
+
     // xIELU activation parameters for Apertus
     struct ggml_tensor * ffn_act_alpha_n = nullptr;
     struct ggml_tensor * ffn_act_alpha_p = nullptr;
index ad2108d1798ae2e31c8efad146c905b512b6ed50..d4b88cb6980da6a12207f557457da752fa6cc192 100644 (file)
@@ -63,6 +63,7 @@
 #define TN_PATCH_EMBD      "v.patch_embd.weight"  // not rename tensor with ".0" postfix for backwrad compat
 #define TN_PATCH_EMBD_1    "v.patch_embd.weight.1"
 #define TN_PATCH_BIAS      "v.patch_embd.bias"
+#define TN_ATTN_QKV        "%s.blk.%d.attn_qkv.%s"
 #define TN_ATTN_K          "%s.blk.%d.attn_k.%s"
 #define TN_ATTN_Q          "%s.blk.%d.attn_q.%s"
 #define TN_ATTN_V          "%s.blk.%d.attn_v.%s"
 #define TN_MM_NORM_PRE  "mm.a.norm_pre.%s"
 #define TN_MM_NORM_MID  "mm.a.norm_mid.%s"
 
+// cogvlm
+#define TN_MM_POST_FC_NORM "mm.post_fc_norm.%s"
+#define TN_MM_H_TO_4H      "mm.up.%s"
+#define TN_MM_GATE         "mm.gate.%s"
+#define TN_MM_4H_TO_H      "mm.down.%s"
+#define TN_TOK_BOI         "v.boi"
+#define TN_TOK_EOI         "v.eoi"
+
 // align x to upper multiple of n
 #define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n))
 
@@ -141,6 +150,7 @@ enum projector_type {
     PROJECTOR_TYPE_KIMIVL,
     PROJECTOR_TYPE_LIGHTONOCR,
     PROJECTOR_TYPE_UNKNOWN,
+    PROJECTOR_TYPE_COGVLM,
 };
 
 static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
@@ -163,6 +173,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
     { PROJECTOR_TYPE_LFM2,      "lfm2"},
     { PROJECTOR_TYPE_KIMIVL,    "kimivl"},
     { PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"},
+    { PROJECTOR_TYPE_COGVLM,    "cogvlm"},
 };
 
 static projector_type clip_projector_type_from_string(const std::string & str) {
index b44f0a3a28ad2842b437e809a48c167909af7eb4..a135ba0a2b35312de2308d818e8fb5b856ae63db 100644 (file)
@@ -214,6 +214,8 @@ struct clip_layer {
     ggml_tensor * q_b = nullptr;
     ggml_tensor * v_w = nullptr;
     ggml_tensor * v_b = nullptr;
+    ggml_tensor * qkv_w = nullptr;
+    ggml_tensor * qkv_b = nullptr;
 
     ggml_tensor * o_w = nullptr;
     ggml_tensor * o_b = nullptr;
@@ -286,8 +288,6 @@ struct clip_model {
     // GLMV-Edge projection
     ggml_tensor * mm_model_adapter_conv_w = nullptr;
     ggml_tensor * mm_model_adapter_conv_b = nullptr;
-    ggml_tensor * mm_glm_tok_boi = nullptr;
-    ggml_tensor * mm_glm_tok_eoi = nullptr;
 
     // MobileVLM projection
     ggml_tensor * mm_model_mlp_1_w = nullptr;
@@ -359,6 +359,15 @@ struct clip_model {
     ggml_tensor * mm_norm_pre_w = nullptr;
     ggml_tensor * mm_norm_mid_w = nullptr;
 
+    // cogvlm
+    ggml_tensor * mm_post_fc_norm_w = nullptr;
+    ggml_tensor * mm_post_fc_norm_b = nullptr;
+    ggml_tensor * mm_h_to_4h_w = nullptr;
+    ggml_tensor * mm_gate_w = nullptr;
+    ggml_tensor * mm_4h_to_h_w = nullptr;
+    ggml_tensor * mm_boi = nullptr;
+    ggml_tensor * mm_eoi = nullptr;
+
     bool audio_has_avgpool() const {
         return proj_type == PROJECTOR_TYPE_QWEN2A
             || proj_type == PROJECTOR_TYPE_VOXTRAL;
@@ -1494,8 +1503,8 @@ struct clip_graph {
             // note: these embeddings are not present in text model, hence we cannot process them as text tokens
             // see: https://huggingface.co/THUDM/glm-edge-v-2b/blob/main/siglip.py#L53
             {
-                embeddings = ggml_concat(ctx0, model.mm_glm_tok_boi, embeddings, 1); // BOI
-                embeddings = ggml_concat(ctx0, embeddings, model.mm_glm_tok_eoi, 1); // EOI
+                embeddings = ggml_concat(ctx0, model.mm_boi, embeddings, 1); // BOI
+                embeddings = ggml_concat(ctx0, embeddings, model.mm_eoi, 1); // EOI
             }
         }
 
@@ -1613,6 +1622,104 @@ struct clip_graph {
         return gf;
     }
 
+    // cogvlm vision encoder
+    ggml_cgraph * build_cogvlm() {
+        GGML_ASSERT(model.class_embedding != nullptr);
+        GGML_ASSERT(model.position_embeddings != nullptr);
+
+        const int n_pos = n_patches + 1; // +1 for [CLS]
+
+        // build input and concatenate class embedding
+        ggml_tensor * inp = build_inp();
+        inp = ggml_concat(ctx0, inp, model.class_embedding, 1);
+
+        inp = ggml_add(ctx0, inp, model.position_embeddings);
+        cb(inp, "inp_pos", -1);
+
+        ggml_tensor * inpL = inp;
+
+        for (int il = 0; il < n_layer; il++) {
+            auto & layer = model.layers[il];
+            ggml_tensor * cur = inpL;
+
+            cur = ggml_mul_mat(ctx0, layer.qkv_w, cur);
+
+            cur = ggml_add(ctx0, cur, layer.qkv_b);
+
+            ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float),
+                cur->nb[1], 0);
+            ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float),
+                cur->nb[1], n_embd * sizeof(float));
+            ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float),
+                cur->nb[1], 2 * n_embd * sizeof(float));
+
+            cb(Qcur, "Qcur", il);
+            cb(Kcur, "Kcur", il);
+            cb(Vcur, "Vcur", il);
+
+            cur = build_attn(layer.o_w, layer.o_b,
+                Qcur, Kcur, Vcur, nullptr, kq_scale, il);
+            cb(cur, "attn_out", il);
+
+            cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_NORMAL, eps, il);
+            cb(cur, "attn_post_norm", il);
+
+            cur = ggml_add(ctx0, cur, inpL);
+            inpL = cur;
+
+            cur = build_ffn(cur,
+                layer.ff_up_w, layer.ff_up_b,
+                layer.ff_gate_w, layer.ff_gate_b,
+                layer.ff_down_w, layer.ff_down_b,
+                hparams.ffn_op, il);
+
+            cb(cur, "ffn_out", il);
+
+            cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, NORM_TYPE_NORMAL, eps, il);
+            cb(cur, "ffn_post_norm", il);
+
+            cur = ggml_add(ctx0, cur, inpL);
+            cb(cur, "layer_out", il);
+            inpL = cur;
+
+        }
+
+        // remove CLS token (like build_llama4 does)
+        ggml_tensor * cur = ggml_view_2d(ctx0, inpL,
+            n_embd, n_patches,
+            ggml_row_size(inpL->type, n_embd), 0);
+
+        // Multiply with mm_model_proj
+        cur = ggml_mul_mat(ctx0, model.mm_model_proj, cur);
+
+        // Apply layernorm, weight, bias
+        cur = build_norm(cur, model.mm_post_fc_norm_w, model.mm_post_fc_norm_b, NORM_TYPE_NORMAL, 1e-5, -1);
+
+        // Apply GELU
+        cur = ggml_gelu_inplace(ctx0, cur);
+
+        // Branch 1: multiply with mm_h_to_4h_w
+        ggml_tensor * h_to_4h = ggml_mul_mat(ctx0, model.mm_h_to_4h_w, cur);
+
+        // Branch 2: multiply with mm_gate_w
+        ggml_tensor * gate = ggml_mul_mat(ctx0, model.mm_gate_w, cur);
+
+        // Apply silu
+        gate = ggml_swiglu_split(ctx0, gate, h_to_4h);
+
+        // Apply mm_4h_to_h_w
+        cur = ggml_mul_mat(ctx0, model.mm_4h_to_h_w, gate);
+
+        // Concatenate with boi and eoi
+        cur = ggml_concat(ctx0, model.mm_boi, cur, 1);
+        cur = ggml_concat(ctx0, cur, model.mm_eoi, 1);
+
+        // build the graph
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
 private:
     //
     // utility functions
@@ -2126,6 +2233,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
             {
                 res = graph.build_kimivl();
             } break;
+        case PROJECTOR_TYPE_COGVLM:
+            {
+                res = graph.build_cogvlm();
+            } break;
         default:
             {
                 res = graph.build_llava();
@@ -2532,10 +2643,11 @@ struct clip_model_loader {
         model.layers.resize(hparams.n_layer);
         for (int il = 0; il < hparams.n_layer; ++il) {
             auto & layer = model.layers[il];
-            layer.k_w    = get_tensor(string_format(TN_ATTN_K,      prefix, il, "weight"));
-            layer.q_w    = get_tensor(string_format(TN_ATTN_Q,      prefix, il, "weight"));
-            layer.v_w    = get_tensor(string_format(TN_ATTN_V,      prefix, il, "weight"));
+            layer.k_w    = get_tensor(string_format(TN_ATTN_K,      prefix, il, "weight"), false);
+            layer.q_w    = get_tensor(string_format(TN_ATTN_Q,      prefix, il, "weight"), false);
+            layer.v_w    = get_tensor(string_format(TN_ATTN_V,      prefix, il, "weight"), false);
             layer.o_w    = get_tensor(string_format(TN_ATTN_OUTPUT, prefix, il, "weight"));
+            layer.qkv_w  = get_tensor(string_format(TN_ATTN_QKV,    prefix, il, "weight"), false);
             layer.k_norm = get_tensor(string_format(TN_ATTN_K_NORM, prefix, il, "weight"), false);
             layer.q_norm = get_tensor(string_format(TN_ATTN_Q_NORM, prefix, il, "weight"), false);
             layer.ln_1_w = get_tensor(string_format(TN_LN_1,        prefix, il, "weight"), false);
@@ -2547,6 +2659,7 @@ struct clip_model_loader {
             layer.q_b    = get_tensor(string_format(TN_ATTN_Q,      prefix, il, "bias"), false);
             layer.v_b    = get_tensor(string_format(TN_ATTN_V,      prefix, il, "bias"), false);
             layer.o_b    = get_tensor(string_format(TN_ATTN_OUTPUT, prefix, il, "bias"), false);
+            layer.qkv_b  = get_tensor(string_format(TN_ATTN_QKV,    prefix, il, "bias"), false);
             layer.ln_1_b = get_tensor(string_format(TN_LN_1,        prefix, il, "bias"), false);
             layer.ln_2_b = get_tensor(string_format(TN_LN_2,        prefix, il, "bias"), false);
 
@@ -2682,8 +2795,8 @@ struct clip_model_loader {
                     model.mm_model_mlp_1_w = get_tensor(string_format(TN_GLM_ADAPTER_D_H_2_4H, "weight"));
                     model.mm_model_mlp_2_w = get_tensor(string_format(TN_GLM_ADAPTER_GATE, "weight"));
                     model.mm_model_mlp_3_w = get_tensor(string_format(TN_GLM_ADAPTER_D_4H_2_H, "weight"));
-                    model.mm_glm_tok_boi = get_tensor(string_format(TN_TOK_GLM_BOI, "weight"));
-                    model.mm_glm_tok_eoi = get_tensor(string_format(TN_TOK_GLM_EOI, "weight"));
+                    model.mm_boi = get_tensor(string_format(TN_TOK_GLM_BOI, "weight"));
+                    model.mm_eoi = get_tensor(string_format(TN_TOK_GLM_EOI, "weight"));
                 } break;
             case PROJECTOR_TYPE_QWEN2VL:
             case PROJECTOR_TYPE_QWEN25VL:
@@ -2777,6 +2890,17 @@ struct clip_model_loader {
                     model.mm_model_mlp_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight"));
                     model.mm_model_mlp_2_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "weight"));
                 } break;
+            case PROJECTOR_TYPE_COGVLM:
+                {
+                    model.mm_model_proj     = get_tensor(TN_MM_PROJECTOR);
+                    model.mm_post_fc_norm_w = get_tensor(string_format(TN_MM_POST_FC_NORM, "weight"));
+                    model.mm_post_fc_norm_b = get_tensor(string_format(TN_MM_POST_FC_NORM, "bias"));
+                    model.mm_h_to_4h_w      = get_tensor(string_format(TN_MM_H_TO_4H,      "weight"));
+                    model.mm_gate_w         = get_tensor(string_format(TN_MM_GATE,         "weight"));
+                    model.mm_4h_to_h_w      = get_tensor(string_format(TN_MM_4H_TO_H,      "weight"));
+                    model.mm_boi            = get_tensor(TN_TOK_BOI);
+                    model.mm_eoi            = get_tensor(TN_TOK_EOI);
+                } break;
             default:
                 GGML_ASSERT(false && "unknown projector type");
         }
@@ -3825,7 +3949,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
         case PROJECTOR_TYPE_GLM_EDGE:
             {
                 n_patches /= 4;
-                if (ctx->model.mm_glm_tok_boi) {
+                if (ctx->model.mm_boi) {
                     n_patches += 2; // for BOI and EOI token embeddings
                 }
             } break;
@@ -3915,6 +4039,10 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
                     n_patches /= 2;
                 }
             } break;
+        case PROJECTOR_TYPE_COGVLM:
+            {
+                n_patches += 2; // for BOI and EOI token embeddings
+            } break;
         default:
             GGML_ABORT("unsupported projector type");
     }
@@ -4323,6 +4451,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
         case PROJECTOR_TYPE_ULTRAVOX:
         case PROJECTOR_TYPE_LFM2:
         case PROJECTOR_TYPE_VOXTRAL:
+        case PROJECTOR_TYPE_COGVLM:
             {
                 // do nothing
             } break;
@@ -4427,6 +4556,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
         case PROJECTOR_TYPE_LFM2:
         case PROJECTOR_TYPE_KIMIVL:
             return ctx->model.mm_2_w->ne[1];
+        case PROJECTOR_TYPE_COGVLM:
+            return ctx->model.mm_4h_to_h_w->ne[1];
         default:
             GGML_ABORT("Unknown projector type");
     }