]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
mtmd: add clip_graph::build_mm() (#20751)
authorXuan-Son Nguyen <redacted>
Thu, 19 Mar 2026 12:11:39 +0000 (13:11 +0100)
committerGitHub <redacted>
Thu, 19 Mar 2026 12:11:39 +0000 (13:11 +0100)
* clip: add build_mm()

* apply to all models

* add TODO for bias overload

15 files changed:
tools/mtmd/clip-graph.h
tools/mtmd/clip.cpp
tools/mtmd/models/cogvlm.cpp
tools/mtmd/models/conformer.cpp
tools/mtmd/models/glm4v.cpp
tools/mtmd/models/llama4.cpp
tools/mtmd/models/llava.cpp
tools/mtmd/models/minicpmv.cpp
tools/mtmd/models/mobilenetv5.cpp
tools/mtmd/models/pixtral.cpp
tools/mtmd/models/qwen2vl.cpp
tools/mtmd/models/qwen3vl.cpp
tools/mtmd/models/siglip.cpp
tools/mtmd/models/whisper-enc.cpp
tools/mtmd/models/youtuvl.cpp

index 4c7f7504cfcd054be3d1e69f7d7c83f6c3a99a6d..3604bf77e8dc0b1b616a7396f93e8da58ea7bc5c 100644 (file)
@@ -41,6 +41,11 @@ struct clip_graph {
     virtual ~clip_graph() = default;
     virtual ggml_cgraph * build() = 0;
 
+    // wrapper around ggml_mul_mat, allow hooking (e.g. LoRA, clamping) depending on the model
+    // tensor w should be the weight matrix, and tensor x should be the input
+    virtual ggml_tensor * build_mm(ggml_tensor * w, ggml_tensor * x) const;
+    // TODO: build_mm(w, b, x) to support bias
+
     //
     // utility functions
     //
index 3d6cf6fd8493bc839213224073dea1e0eb2e7360..44a19189ea91e63a811370e9f2244168520a8b83 100644 (file)
@@ -255,6 +255,10 @@ clip_graph::clip_graph(clip_ctx * ctx, const clip_image_f32 & img) :
     gf = ggml_new_graph_custom(ctx0, ctx->max_nodes, false);
 }
 
+ggml_tensor * clip_graph::build_mm(ggml_tensor * w, ggml_tensor * x) const {
+    return ggml_mul_mat(ctx0, w, x);
+}
+
 void clip_graph::cb(ggml_tensor * cur, const char * name, int il) const {
     if (il >= 0) {
         ggml_format_name(cur, "%s-%d", name, il);
@@ -326,7 +330,7 @@ ggml_tensor * clip_graph::build_vit(
             ggml_tensor * Vcur = nullptr;
             if (layer.qkv_w != nullptr) {
                 // fused qkv
-                cur = ggml_mul_mat(ctx0, layer.qkv_w, cur);
+                cur = build_mm(layer.qkv_w, cur);
                 if (layer.qkv_b != nullptr) {
                     cur = ggml_add(ctx0, cur, layer.qkv_b);
                 }
@@ -360,17 +364,17 @@ ggml_tensor * clip_graph::build_vit(
 
             } else {
                 // separate q, k, v
-                Qcur = ggml_mul_mat(ctx0, layer.q_w, cur);
+                Qcur = build_mm(layer.q_w, cur);
                 if (layer.q_b) {
                     Qcur = ggml_add(ctx0, Qcur, layer.q_b);
                 }
 
-                Kcur = ggml_mul_mat(ctx0, layer.k_w, cur);
+                Kcur = build_mm(layer.k_w, cur);
                 if (layer.k_b) {
                     Kcur = ggml_add(ctx0, Kcur, layer.k_b);
                 }
 
-                Vcur = ggml_mul_mat(ctx0, layer.v_w, cur);
+                Vcur = build_mm(layer.v_w, cur);
                 if (layer.v_b) {
                     Vcur = ggml_add(ctx0, Vcur, layer.v_b);
                 }
@@ -517,7 +521,7 @@ ggml_tensor * clip_graph::build_ffn(
         ffn_op_type type_op,
         int il) const {
 
-    ggml_tensor * tmp = up ? ggml_mul_mat(ctx0, up, cur) : cur;
+    ggml_tensor * tmp = up ? build_mm(up, cur) : cur;
     cb(tmp, "ffn_up", il);
 
     if (up_b) {
@@ -526,7 +530,7 @@ ggml_tensor * clip_graph::build_ffn(
     }
 
     if (gate) {
-        cur = ggml_mul_mat(ctx0, gate, cur);
+        cur = build_mm(gate, cur);
         cb(cur, "ffn_gate", il);
 
         if (gate_b) {
@@ -580,7 +584,7 @@ ggml_tensor * clip_graph::build_ffn(
     }
 
     if (down) {
-        cur = ggml_mul_mat(ctx0, down, cur);
+        cur = build_mm(down, cur);
     }
 
     if (down_b) {
@@ -646,7 +650,7 @@ ggml_tensor * clip_graph::build_attn(
     cb(cur, "kqv_out", il);
 
     if (wo) {
-        cur = ggml_mul_mat(ctx0, wo, cur);
+        cur = build_mm(wo, cur);
     }
 
     if (wo_b) {
index d5b739c6873b41a46c95dcad07dc839f01536e29..44bc884421df49cfdcf3884030385c94eabb5b99 100644 (file)
@@ -19,7 +19,7 @@ ggml_cgraph * clip_graph_cogvlm::build() {
         auto & layer = model.layers[il];
         ggml_tensor * cur = inpL;
 
-        cur = ggml_mul_mat(ctx0, layer.qkv_w, cur);
+        cur = build_mm(layer.qkv_w, cur);
 
         cur = ggml_add(ctx0, cur, layer.qkv_b);
 
@@ -67,7 +67,7 @@ ggml_cgraph * clip_graph_cogvlm::build() {
         ggml_row_size(inpL->type, n_embd), 0);
 
     // Multiply with mm_model_proj
-    cur = ggml_mul_mat(ctx0, model.mm_model_proj, cur);
+    cur = build_mm(model.mm_model_proj, cur);
 
     // Apply layernorm, weight, bias
     cur = build_norm(cur, model.mm_post_fc_norm_w, model.mm_post_fc_norm_b, NORM_TYPE_NORMAL, 1e-5, -1);
@@ -76,16 +76,16 @@ ggml_cgraph * clip_graph_cogvlm::build() {
     cur = ggml_gelu_inplace(ctx0, cur);
 
     // Branch 1: multiply with mm_h_to_4h_w
-    ggml_tensor * h_to_4h = ggml_mul_mat(ctx0, model.mm_h_to_4h_w, cur);
+    ggml_tensor * h_to_4h = build_mm(model.mm_h_to_4h_w, cur);
 
     // Branch 2: multiply with mm_gate_w
-    ggml_tensor * gate = ggml_mul_mat(ctx0, model.mm_gate_w, cur);
+    ggml_tensor * gate = build_mm(model.mm_gate_w, cur);
 
     // Apply silu
     gate = ggml_swiglu_split(ctx0, gate, h_to_4h);
 
     // Apply mm_4h_to_h_w
-    cur = ggml_mul_mat(ctx0, model.mm_4h_to_h_w, gate);
+    cur = build_mm(model.mm_4h_to_h_w, gate);
 
     // Concatenate with boi and eoi
     cur = ggml_concat(ctx0, model.mm_boi, cur, 1);
index 9b1fab487397c41a2d3e89019d6033b08d02912b..f58c5048f594f5135efb9a540c6efaff71838c51 100644 (file)
@@ -56,7 +56,7 @@ ggml_cgraph * clip_graph_conformer::build() {
         cur = ggml_reshape_2d(ctx0, cur, cur->ne[0] * cur->ne[1], cur->ne[2]);
 
         // calculate out
-        cur = ggml_mul_mat(ctx0, model.pre_encode_out_w, cur);
+        cur = build_mm(model.pre_encode_out_w, cur);
         cur = ggml_add(ctx0, cur, model.pre_encode_out_b);
         cb(cur, "conformer.pre_encode.out", -1);
     }
@@ -87,7 +87,7 @@ ggml_cgraph * clip_graph_conformer::build() {
             cur = build_norm(residual, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_NORMAL, 1e-5, il);
             cb(cur, "conformer.layers.{}.norm_self_att", il);
 
-            ggml_tensor * Qcur     = ggml_mul_mat(ctx0, layer.q_w, cur);
+            ggml_tensor * Qcur     = build_mm(layer.q_w, cur);
             Qcur                   = ggml_add(ctx0, Qcur, layer.q_b);
             Qcur                   = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, Qcur->ne[1]);
             ggml_tensor * Q_bias_u = ggml_add(ctx0, Qcur, layer.pos_bias_u);
@@ -96,12 +96,12 @@ ggml_cgraph * clip_graph_conformer::build() {
             Q_bias_v               = ggml_permute(ctx0, Q_bias_v, 0, 2, 1, 3);
 
             // TODO @ngxson : some cont can/should be removed when ggml_mul_mat support these cases
-            ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.k_w, cur);
+            ggml_tensor * Kcur = build_mm(layer.k_w, cur);
             Kcur               = ggml_add(ctx0, Kcur, layer.k_b);
             Kcur               = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, Kcur->ne[1]);
             Kcur               = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
 
-            ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.v_w, cur);
+            ggml_tensor * Vcur = build_mm(layer.v_w, cur);
             Vcur               = ggml_add(ctx0, Vcur, layer.v_b);
             Vcur               = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, Vcur->ne[1]);
             Vcur               = ggml_cont(ctx0, ggml_permute(ctx0, Vcur, 1, 2, 0, 3));
@@ -111,7 +111,7 @@ ggml_cgraph * clip_graph_conformer::build() {
             matrix_ac               = ggml_cont(ctx0, ggml_permute(ctx0, matrix_ac, 1, 0, 2, 3));
             cb(matrix_ac, "conformer.layers.{}.self_attn.id3", il);
 
-            auto * p = ggml_mul_mat(ctx0, layer.linear_pos_w, pos_emb);
+            auto * p = build_mm(layer.linear_pos_w, pos_emb);
             cb(p, "conformer.layers.{}.self_attn.linear_pos", il);
             p = ggml_reshape_3d(ctx0, p, d_head, n_head, p->ne[1]);
             p = ggml_permute(ctx0, p, 0, 2, 1, 3);
@@ -143,7 +143,7 @@ ggml_cgraph * clip_graph_conformer::build() {
             x                  = ggml_permute(ctx0, x, 2, 0, 1, 3);
             x                  = ggml_cont_2d(ctx0, x, x->ne[0] * x->ne[1], x->ne[2]);
 
-            ggml_tensor * out = ggml_mul_mat(ctx0, layer.o_w, x);
+            ggml_tensor * out = build_mm(layer.o_w, x);
             out               = ggml_add(ctx0, out, layer.o_b);
             cb(out, "conformer.layers.{}.self_attn.linear_out", il);
 
@@ -157,7 +157,7 @@ ggml_cgraph * clip_graph_conformer::build() {
         // conv
         {
             auto * x = cur;
-            x = ggml_mul_mat(ctx0, layer.conv_pw1_w, x);
+            x = build_mm(layer.conv_pw1_w, x);
             x = ggml_add(ctx0, x, layer.conv_pw1_b);
             cb(x, "conformer.layers.{}.conv.pointwise_conv1", il);
 
@@ -181,7 +181,7 @@ ggml_cgraph * clip_graph_conformer::build() {
             x = ggml_silu(ctx0, x);
 
             // pointwise_conv2
-            x = ggml_mul_mat(ctx0, layer.conv_pw2_w, x);
+            x = build_mm(layer.conv_pw2_w, x);
             x = ggml_add(ctx0, x, layer.conv_pw2_b);
 
             cur = x;
index 6f52df41ab0197d143f188f2e0eb07bed1f5273c..9dbb162c5912aafd2f3b5e9bc9b29f415e167709 100644 (file)
@@ -97,7 +97,7 @@ ggml_cgraph * clip_graph_glm4v::build() {
 
     // FC projector
     {
-        cur = ggml_mul_mat(ctx0, model.projection, cur);
+        cur = build_mm(model.projection, cur);
         // default LayerNorm (post_projection_norm)
         cur = build_norm(cur, model.mm_post_norm_w, model.mm_post_norm_b, NORM_TYPE_NORMAL, 1e-5, -1);
         cur = ggml_gelu_erf(ctx0, cur);
index 30d1df5bcdd65f1211d9f942e62dbc5885780d93..01af54bbab7e0d8d4b002e0234467fa35104d0cc 100644 (file)
@@ -22,7 +22,7 @@ ggml_cgraph * clip_graph_llama4::build() {
         ggml_tensor * kernel = ggml_reshape_4d(ctx0, model.patch_embeddings_0,
                                                 patch_size, patch_size, 3, n_embd);
         inp = ggml_im2col(ctx0, kernel, inp, patch_size, patch_size, 0, 0, 1, 1, true, inp->type);
-        inp = ggml_mul_mat(ctx0, model.patch_embeddings_0, inp);
+        inp = build_mm(model.patch_embeddings_0, inp);
         inp = ggml_reshape_2d(ctx0, inp, n_embd, n_patches);
         cb(inp, "patch_conv", -1);
     }
@@ -78,15 +78,15 @@ ggml_cgraph * clip_graph_llama4::build() {
 
     // based on Llama4VisionMLP2 (always uses GELU activation, no bias)
     {
-        cur = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w, cur);
+        cur = build_mm(model.mm_model_mlp_1_w, cur);
         cur = ggml_gelu(ctx0, cur);
-        cur = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, cur);
+        cur = build_mm(model.mm_model_mlp_2_w, cur);
         cur = ggml_gelu(ctx0, cur);
         cb(cur, "adapter_mlp", -1);
     }
 
     // Llama4MultiModalProjector
-    cur = ggml_mul_mat(ctx0, model.mm_model_proj, cur);
+    cur = build_mm(model.mm_model_proj, cur);
     cb(cur, "projected", -1);
 
     // build the graph
index 0bfb5f05f66d7e83dd2b6df2d6b225eb08ea4c70..4af17ccfe853cb858fbbcd44ec206376bcb78d20 100644 (file)
@@ -70,17 +70,17 @@ ggml_cgraph * clip_graph_llava::build() {
 
         // self-attention
         {
-            ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.q_w, cur);
+            ggml_tensor * Qcur = build_mm(layer.q_w, cur);
             if (layer.q_b) {
                 Qcur = ggml_add(ctx0, Qcur, layer.q_b);
             }
 
-            ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.k_w, cur);
+            ggml_tensor * Kcur = build_mm(layer.k_w, cur);
             if (layer.k_b) {
                 Kcur = ggml_add(ctx0, Kcur, layer.k_b);
             }
 
-            ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.v_w, cur);
+            ggml_tensor * Vcur = build_mm(layer.v_w, cur);
             if (layer.v_b) {
                 Vcur = ggml_add(ctx0, Vcur, layer.v_b);
             }
@@ -164,17 +164,17 @@ ggml_cgraph * clip_graph_llava::build() {
 
         // llava projector
         if (proj_type == PROJECTOR_TYPE_MLP) {
-            embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
+            embeddings = build_mm(model.mm_0_w, embeddings);
             embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
 
             embeddings = ggml_gelu(ctx0, embeddings);
             if (model.mm_2_w) {
-                embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
+                embeddings = build_mm(model.mm_2_w, embeddings);
                 embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
             }
         }
         else if (proj_type == PROJECTOR_TYPE_MLP_NORM) {
-            embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
+            embeddings = build_mm(model.mm_0_w, embeddings);
             embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
             // ggml_tensor_printf(embeddings, "mm_0_w",0,true,false);
             // First LayerNorm
@@ -186,7 +186,7 @@ ggml_cgraph * clip_graph_llava::build() {
             embeddings = ggml_gelu(ctx0, embeddings);
 
             // Second linear layer
-            embeddings = ggml_mul_mat(ctx0, model.mm_3_w, embeddings);
+            embeddings = build_mm(model.mm_3_w, embeddings);
             embeddings = ggml_add(ctx0, embeddings, model.mm_3_b);
 
             // Second LayerNorm
@@ -197,10 +197,10 @@ ggml_cgraph * clip_graph_llava::build() {
         else if (proj_type == PROJECTOR_TYPE_LDP) {
             // MobileVLM projector
             int n_patch = 24;
-            ggml_tensor * mlp_1 = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w, embeddings);
+            ggml_tensor * mlp_1 = build_mm(model.mm_model_mlp_1_w, embeddings);
             mlp_1 = ggml_add(ctx0, mlp_1, model.mm_model_mlp_1_b);
             mlp_1 = ggml_gelu(ctx0, mlp_1);
-            ggml_tensor * mlp_3 = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, mlp_1);
+            ggml_tensor * mlp_3 = build_mm(model.mm_model_mlp_3_w, mlp_1);
             mlp_3 = ggml_add(ctx0, mlp_3, model.mm_model_mlp_3_b);
             // mlp_3 shape = [1, 576, 2048], ne = [2048, 576, 1, 1]
 
@@ -229,10 +229,10 @@ ggml_cgraph * clip_graph_llava::build() {
                 // block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1]
                 // pointwise conv
                 block_1 = ggml_reshape_2d(ctx0, block_1, block_1->ne[0]*block_1->ne[1]*block_1->ne[2], block_1->ne[3]);
-                block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_1_fc1_w, block_1);
+                block_1 = build_mm(model.mm_model_block_1_block_1_fc1_w, block_1);
                 block_1 = ggml_add(ctx0, block_1, model.mm_model_block_1_block_1_fc1_b);
                 block_1 = ggml_relu(ctx0, block_1);
-                block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_1_fc2_w, block_1);
+                block_1 = build_mm(model.mm_model_block_1_block_1_fc2_w, block_1);
                 block_1 = ggml_add(ctx0, block_1, model.mm_model_block_1_block_1_fc2_b);
                 block_1 = ggml_hardsigmoid(ctx0, block_1);
                 // block_1_hw shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1], block_1 shape = [1, 2048], ne = [2048, 1, 1, 1]
@@ -244,7 +244,7 @@ ggml_cgraph * clip_graph_llava::build() {
                 block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 0, 2, 3));
 
                 // block_1 shape = [1, 24*24, 2048], ne = [24*24, 2048, 1]
-                block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_2_0_w, block_1);
+                block_1 = build_mm(model.mm_model_block_1_block_2_0_w, block_1);
                 block_1 = ggml_reshape_4d(ctx0, block_1, block_1->ne[0], w, h, block_1->ne[3]);
 
                 // block_1 shape = [1, 24, 24, 2048], ne = [2048, 24, 24, 1]
@@ -277,10 +277,10 @@ ggml_cgraph * clip_graph_llava::build() {
                 // block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1]
                 // pointwise conv
                 block_1 = ggml_reshape_2d(ctx0, block_1, block_1->ne[0]*block_1->ne[1]*block_1->ne[2], block_1->ne[3]);
-                block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_1_fc1_w, block_1);
+                block_1 = build_mm(model.mm_model_block_2_block_1_fc1_w, block_1);
                 block_1 = ggml_add(ctx0, block_1, model.mm_model_block_2_block_1_fc1_b);
                 block_1 = ggml_relu(ctx0, block_1);
-                block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_1_fc2_w, block_1);
+                block_1 = build_mm(model.mm_model_block_2_block_1_fc2_w, block_1);
                 block_1 = ggml_add(ctx0, block_1, model.mm_model_block_2_block_1_fc2_b);
                 block_1 = ggml_hardsigmoid(ctx0, block_1);
 
@@ -292,7 +292,7 @@ ggml_cgraph * clip_graph_llava::build() {
                 block_1 = ggml_reshape_3d(ctx0, block_1, w*h, block_1->ne[2], block_1->ne[3]);
                 block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 0, 2, 3));
                 // block_1 shape = [1, 24*24, 2048], ne = [24*24, 2048, 1]
-                block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_2_0_w, block_1);
+                block_1 = build_mm(model.mm_model_block_2_block_2_0_w, block_1);
                 block_1 = ggml_reshape_4d(ctx0, block_1, block_1->ne[0], w, h, block_1->ne[3]);
 
 
@@ -307,10 +307,10 @@ ggml_cgraph * clip_graph_llava::build() {
         else if (proj_type == PROJECTOR_TYPE_LDPV2)
         {
             int n_patch = 24;
-            ggml_tensor * mlp_0 = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings);
+            ggml_tensor * mlp_0 = build_mm(model.mm_model_mlp_0_w, embeddings);
             mlp_0 = ggml_add(ctx0, mlp_0, model.mm_model_mlp_0_b);
             mlp_0 = ggml_gelu(ctx0, mlp_0);
-            ggml_tensor * mlp_2 = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, mlp_0);
+            ggml_tensor * mlp_2 = build_mm(model.mm_model_mlp_2_w, mlp_0);
             mlp_2 = ggml_add(ctx0, mlp_2, model.mm_model_mlp_2_b);
             // mlp_2 ne = [2048, 576, 1, 1]
             // // AVG Pool Layer 2*2, strides = 2
@@ -344,15 +344,15 @@ ggml_cgraph * clip_graph_llava::build() {
         embeddings = ggml_add(ctx0, embeddings, model.mm_model_adapter_conv_b);
         // GLU
         {
-            embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings);
+            embeddings = build_mm(model.mm_model_mlp_0_w, embeddings);
             embeddings = ggml_norm(ctx0, embeddings, eps);
             embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_q_w), model.mm_model_ln_q_b);
             embeddings = ggml_gelu_inplace(ctx0, embeddings);
             ggml_tensor * x = embeddings;
-            embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, embeddings);
-            x = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w,x);
+            embeddings = build_mm(model.mm_model_mlp_2_w, embeddings);
+            x = build_mm(model.mm_model_mlp_1_w,x);
             embeddings = ggml_swiglu_split(ctx0, embeddings, x);
-            embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings);
+            embeddings = build_mm(model.mm_model_mlp_3_w, embeddings);
         }
         // arrangement of BOI/EOI token embeddings
         // note: these embeddings are not present in text model, hence we cannot process them as text tokens
index 3594ea29fa946dfb2c64e4541fb5416f85bc035a..924117ab2a1dd0a6b708db0e4a97392f38697e8e 100644 (file)
@@ -38,7 +38,7 @@ ggml_cgraph * clip_graph_minicpmv::build() {
     // resampler projector (it is just another transformer)
 
     ggml_tensor * q = model.mm_model_query;
-    ggml_tensor * v = ggml_mul_mat(ctx0, model.mm_model_kv_proj, embeddings);
+    ggml_tensor * v = build_mm(model.mm_model_kv_proj, embeddings);
 
     // norm
     q = build_norm(q, model.mm_model_ln_q_w,  model.mm_model_ln_q_b,  NORM_TYPE_NORMAL, eps, -1);
@@ -77,13 +77,13 @@ ggml_cgraph * clip_graph_minicpmv::build() {
         // Use actual config value if available, otherwise fall back to hardcoded values
         int num_query = hparams.minicpmv_query_num;
         ggml_tensor * Q = ggml_add(ctx0,
-            ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q),
+            build_mm(model.mm_model_attn_q_w, q),
             model.mm_model_attn_q_b);
         ggml_tensor * K = ggml_add(ctx0,
-            ggml_mul_mat(ctx0, model.mm_model_attn_k_w, k),
+            build_mm(model.mm_model_attn_k_w, k),
             model.mm_model_attn_k_b);
         ggml_tensor * V = ggml_add(ctx0,
-            ggml_mul_mat(ctx0, model.mm_model_attn_v_w, v),
+            build_mm(model.mm_model_attn_v_w, v),
             model.mm_model_attn_v_b);
 
         Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_query);
@@ -105,7 +105,7 @@ ggml_cgraph * clip_graph_minicpmv::build() {
     embeddings = build_norm(embeddings, model.mm_model_ln_post_w, model.mm_model_ln_post_b, NORM_TYPE_NORMAL, eps, -1);
 
     // projection
-    embeddings = ggml_mul_mat(ctx0, model.mm_model_proj, embeddings);
+    embeddings = build_mm(model.mm_model_proj, embeddings);
 
     // build the graph
     ggml_build_forward_expand(gf, embeddings);
index 593afa1ddce8a0cf5ce2c50c86d10543e16f5f1a..1c42218d2affc3329a41fc498695de52ee7e4098 100644 (file)
@@ -429,7 +429,7 @@ ggml_cgraph * clip_graph_mobilenetv5::build() {
     // PyTorch: embedding_projection = nn.Linear(vision_hidden, text_hidden, bias=False)
     // Weight stored as [out_features, in_features] = [text_hidden_size, vision_hidden_size]
     if (model.mm_input_proj_w) {
-        cur = ggml_mul_mat(ctx0, model.mm_input_proj_w, cur);
+        cur = build_mm(model.mm_input_proj_w, cur);
     }
 
     // 5. POST PROJECTION NORM
index a849210b53d5ad52d7bb2d69e57925122009e88c..d6d037b6941256c71227a449efedefd1049ccb8a 100644 (file)
@@ -43,7 +43,7 @@ ggml_cgraph * clip_graph_pixtral::build() {
 
         // project to n_embd
         cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]);
-        cur = ggml_mul_mat(ctx0, model.mm_patch_merger_w, cur);
+        cur = build_mm(model.mm_patch_merger_w, cur);
     }
 
     // LlavaMultiModalProjector (always using GELU activation)
index 85f158bb1c00b88c91d06df3c6b63279cdb24548..ebf10757376b0dcbaf5382091ac32d04d42d52bb 100644 (file)
@@ -90,11 +90,11 @@ ggml_cgraph * clip_graph_qwen2vl::build() {
         // self-attention
         {
             ggml_tensor * Qcur = ggml_add(ctx0,
-                ggml_mul_mat(ctx0, layer.q_w, cur), layer.q_b);
+                build_mm(layer.q_w, cur), layer.q_b);
             ggml_tensor * Kcur = ggml_add(ctx0,
-                ggml_mul_mat(ctx0, layer.k_w, cur), layer.k_b);
+                build_mm(layer.k_w, cur), layer.k_b);
             ggml_tensor * Vcur = ggml_add(ctx0,
-                ggml_mul_mat(ctx0, layer.v_w, cur), layer.v_b);
+                build_mm(layer.v_w, cur), layer.v_b);
 
             Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_patches);
             Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_patches);
index 5ecb10fe4382732c1acb1a4421d57af0947bde50..fa1100dda8d5d08df6553008136f506659c57572 100644 (file)
@@ -85,7 +85,7 @@ ggml_cgraph * clip_graph_qwen3vl::build() {
 
         // self-attention
         {
-            cur = ggml_mul_mat(ctx0, layer.qkv_w, cur);
+            cur = build_mm(layer.qkv_w, cur);
             cur = ggml_add(ctx0, cur, layer.qkv_b);
 
             ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos,
index 75f9b4db448855d7b341e9ea995d1102f9268231..9dafa35ea8162512f0d4aa9457381f9d1dbaacac 100644 (file)
@@ -43,7 +43,7 @@ ggml_cgraph * clip_graph_siglip::build() {
         // https://github.com/huggingface/transformers/blob/0a950e0bbe1ed58d5401a6b547af19f15f0c195e/src/transformers/models/idefics3/modeling_idefics3.py#L578
         const int scale_factor = model.hparams.n_merge;
         cur = build_patch_merge_permute(cur, scale_factor);
-        cur = ggml_mul_mat(ctx0, model.projection, cur);
+        cur = build_mm(model.projection, cur);
 
     } else if (proj_type == PROJECTOR_TYPE_LFM2) {
         // pixel unshuffle block
index 2f2b12775510ef592d50d49593816ac76eb3b8cc..ed61bb05bad2bdec5f52ec61beee8ef64e98ae36 100644 (file)
@@ -59,7 +59,7 @@ ggml_cgraph * clip_graph_whisper_enc::build() {
         cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w);
 
         // ffn in
-        cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
+        cur = build_mm(model.mm_1_w, cur);
 
         // swiglu
         // see SwiGLU in ultravox_model.py, the second half passed through is silu, not the first half
@@ -70,11 +70,11 @@ ggml_cgraph * clip_graph_whisper_enc::build() {
         cur = ggml_mul(ctx0, cur, model.mm_norm_mid_w);
 
         // ffn out
-        cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
+        cur = build_mm(model.mm_2_w, cur);
 
     } else if (proj_type == PROJECTOR_TYPE_QWEN2A) {
         // projector
-        cur = ggml_mul_mat(ctx0, model.mm_fc_w, cur);
+        cur = build_mm(model.mm_fc_w, cur);
         cur = ggml_add(ctx0, cur, model.mm_fc_b);
 
     } else if (proj_type == PROJECTOR_TYPE_VOXTRAL) {
index ffbf2be5547758271c2953031ebfebc74a4b47f0..cd8f6d446f0895673255e4d21498e8610155154d 100644 (file)
@@ -43,7 +43,7 @@ ggml_cgraph * clip_graph_youtuvl::build() {
             ctx0, inp,
             3*patch_size* patch_size,  Hm * Wm * m * m, 1);
     }
-    inp = ggml_mul_mat(ctx0, model.patch_embeddings_0, inp);
+    inp = build_mm(model.patch_embeddings_0, inp);
 
     if (model.patch_bias) {
         inp = ggml_add(ctx0, inp, model.patch_bias);
@@ -97,11 +97,11 @@ ggml_cgraph * clip_graph_youtuvl::build() {
         // self-attention
         {
             ggml_tensor * Qcur = ggml_add(ctx0,
-                ggml_mul_mat(ctx0, layer.q_w, cur), layer.q_b);
+                build_mm(layer.q_w, cur), layer.q_b);
             ggml_tensor * Kcur = ggml_add(ctx0,
-                ggml_mul_mat(ctx0, layer.k_w, cur), layer.k_b);
+                build_mm(layer.k_w, cur), layer.k_b);
             ggml_tensor * Vcur = ggml_add(ctx0,
-                ggml_mul_mat(ctx0, layer.v_w, cur), layer.v_b);
+                build_mm(layer.v_w, cur), layer.v_b);
 
             Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_patches);
             Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_patches);