]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
clip : fix confused naming ffn_up and ffn_down (#13290)
authorXuan-Son Nguyen <redacted>
Mon, 5 May 2025 10:54:44 +0000 (12:54 +0200)
committerGitHub <redacted>
Mon, 5 May 2025 10:54:44 +0000 (12:54 +0200)
* clip :  fix confused naming ffn_up and ffn_down

* rm ffn_i/o/g naming

* rename n_embd, n_ff

* small fix

* no check n_ff

convert_hf_to_gguf.py
gguf-py/gguf/tensor_mapping.py
tools/llava/clip.cpp
tools/llava/mtmd-cli.cpp

index 34bed7a08a1b28383cc2231ece8c03660d4b8229..a47d7df6fd3a3d2062d9c5e2e4eaddebe4844ecb 100755 (executable)
@@ -1778,6 +1778,12 @@ class LlamaModel(TextModel):
     model_arch = gguf.MODEL_ARCH.LLAMA
     undo_permute = True
 
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        # fix for SmolVLM2, missing `num_attention_heads` in config.json
+        if self.hf_arch == "VLlama3ForCausalLM":
+            self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
+
     def set_vocab(self):
         try:
             self._set_vocab_sentencepiece()
index 2b089f84a841ad9d064db3f86b4c1ba4205e4187..003b0172c77b07a76d4f9de04e11abce16d6fc79 100644 (file)
@@ -977,15 +977,12 @@ class TensorNameMap:
             "visual.blocks.{bid}.norm2", # qwen2vl
         ),
 
-        # some namings are messed up because the original llava code swapped fc1 and fc2
-        # we have no better way to fix it, just be careful
-        # new models like pixtral use the correct naming
         MODEL_TENSOR.V_ENC_FFN_UP: (
             "vision_tower.vision_model.encoder.layers.{bid}.mlp.fc1",
             "vpm.encoder.layers.{bid}.mlp.fc1",
-            "model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM, gemma3 (note: name is swapped)
+            "model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM, gemma3
             "vision_tower.transformer.layers.{bid}.feed_forward.up_proj", # pixtral
-            "visual.blocks.{bid}.mlp.fc2", # qwen2vl
+            "visual.blocks.{bid}.mlp.fc1", # qwen2vl
             "visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl
         ),
 
@@ -997,9 +994,9 @@ class TensorNameMap:
         MODEL_TENSOR.V_ENC_FFN_DOWN: (
             "vision_tower.vision_model.encoder.layers.{bid}.mlp.fc2",
             "vpm.encoder.layers.{bid}.mlp.fc2",
-            "model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM, gemma3 (note: name is swapped)
+            "model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM, gemma3
             "vision_tower.transformer.layers.{bid}.feed_forward.down_proj", # pixtral
-            "visual.blocks.{bid}.mlp.fc1", # qwen2vl
+            "visual.blocks.{bid}.mlp.fc2", # qwen2vl
             "visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl
         ),
 
index cc03bf884a3cdce182ec01cb14073f668005c8ca..1414f6ba4f5ef62a15a5cfc4963ca5530c5a20ad 100644 (file)
@@ -155,8 +155,8 @@ enum patch_merge_type {
 struct clip_hparams {
     int32_t image_size;
     int32_t patch_size;
-    int32_t hidden_size;
-    int32_t n_intermediate;
+    int32_t n_embd;
+    int32_t n_ff;
     int32_t projection_dim;
     int32_t n_head;
     int32_t n_layer;
@@ -191,12 +191,6 @@ struct clip_layer {
     struct ggml_tensor * ln_1_w = nullptr;
     struct ggml_tensor * ln_1_b = nullptr;
 
-    // ff
-    struct ggml_tensor * ff_i_w = nullptr; // legacy naming
-    struct ggml_tensor * ff_i_b = nullptr; // legacy naming
-    struct ggml_tensor * ff_o_w = nullptr; // legacy naming
-    struct ggml_tensor * ff_o_b = nullptr; // legacy naming
-
     struct ggml_tensor * ff_up_w = nullptr;
     struct ggml_tensor * ff_up_b = nullptr;
     struct ggml_tensor * ff_gate_w = nullptr;
@@ -204,9 +198,6 @@ struct clip_layer {
     struct ggml_tensor * ff_down_w = nullptr;
     struct ggml_tensor * ff_down_b = nullptr;
 
-    struct ggml_tensor * ff_g_w = NULL;
-    struct ggml_tensor * ff_g_b = NULL;
-
     // layernorm 2
     struct ggml_tensor * ln_2_w = nullptr;
     struct ggml_tensor * ln_2_b = nullptr;
@@ -388,9 +379,9 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im
 
     const int patch_size  = hparams.patch_size;
     const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
-    const int hidden_size = hparams.hidden_size;
+    const int n_embd      = hparams.n_embd;
     const int n_head      = hparams.n_head;
-    const int d_head      = hidden_size / n_head;
+    const int d_head      = n_embd / n_head;
     const int n_layer     = hparams.n_layer;
     const float eps       = hparams.eps;
 
@@ -411,7 +402,7 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im
     ggml_set_input(inp_raw);
 
     struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
-    inp = ggml_reshape_2d(ctx0, inp, num_patches, hidden_size);
+    inp = ggml_reshape_2d(ctx0, inp, num_patches, n_embd);
     inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
     inp = ggml_add(ctx0, inp, model.patch_bias);
 
@@ -456,7 +447,7 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im
             KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_patches, n_head);
             KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
 
-            cur = ggml_cont_2d(ctx0, KQV, hidden_size, num_patches);
+            cur = ggml_cont_2d(ctx0, KQV, n_embd, num_patches);
         }
 
         // attention output
@@ -473,14 +464,14 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im
             cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b);
         }
 
-        cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
-        cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
+        cur = ggml_mul_mat(ctx0, model.layers[il].ff_up_w, cur);
+        cur = ggml_add(ctx0, cur, model.layers[il].ff_up_b);
 
         // siglip uses gelu
         cur = ggml_gelu(ctx0, cur);
 
-        cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
-        cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b);
+        cur = ggml_mul_mat(ctx0, model.layers[il].ff_down_w, cur);
+        cur = ggml_add(ctx0, cur, model.layers[il].ff_down_b);
 
         // residual 2
         cur = ggml_add(ctx0, embeddings, cur);
@@ -504,11 +495,11 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im
         const int kernel_size = patches_per_image / tokens_per_side;
 
         embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings));
-        embeddings = ggml_reshape_4d(ctx0, embeddings, patches_per_image, patches_per_image, hidden_size, batch_size);
+        embeddings = ggml_reshape_4d(ctx0, embeddings, patches_per_image, patches_per_image, n_embd, batch_size);
 
         // doing a pool2d to reduce the number of output tokens to 256
         embeddings = ggml_pool_2d(ctx0, embeddings, GGML_OP_POOL_AVG, kernel_size, kernel_size, kernel_size, kernel_size, 0, 0);
-        embeddings = ggml_reshape_3d(ctx0, embeddings, embeddings->ne[0] * embeddings->ne[0], hidden_size, batch_size);
+        embeddings = ggml_reshape_3d(ctx0, embeddings, embeddings->ne[0] * embeddings->ne[0], n_embd, batch_size);
         embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings));
 
         // apply norm before projection
@@ -637,9 +628,9 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
     const int n_patches_x = image_size_width  / patch_size;
     const int n_patches_y = image_size_height / patch_size;
     const int num_patches = n_patches_x * n_patches_y;
-    const int hidden_size = hparams.hidden_size;
+    const int n_embd      = hparams.n_embd;
     const int n_head      = hparams.n_head;
-    const int d_head      = hidden_size / n_head;
+    const int d_head      = n_embd / n_head;
     const int n_layer     = hparams.n_layer;
     const float eps       = hparams.eps;
     const int n_merge     = hparams.spatial_merge_size;
@@ -669,7 +660,7 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
     ggml_set_input(pos_w);
 
     struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
-    inp = ggml_reshape_2d(ctx0, inp, num_patches, hidden_size);
+    inp = ggml_reshape_2d(ctx0, inp, num_patches, n_embd);
     inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
 
     struct ggml_tensor * embeddings = inp;
@@ -710,7 +701,7 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
             KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_patches, n_head);
             KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
 
-            cur = ggml_cont_2d(ctx0, KQV, hidden_size, num_patches);
+            cur = ggml_cont_2d(ctx0, KQV, n_embd, num_patches);
 
             cur = ggml_mul_mat(ctx0, model.layers[il].o_w, cur);
         }
@@ -753,8 +744,8 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
         cur = ggml_mul(ctx0, ggml_rms_norm(ctx0, cur, eps), model.mm_input_norm_w);
 
         // reshape image tokens to 2D grid
-        cur = ggml_reshape_3d(ctx0, cur, hidden_size, n_patches_x, n_patches_y);
-        cur = ggml_permute(ctx0, cur, 2, 0, 1, 3); // [x, y, hidden_size]
+        cur = ggml_reshape_3d(ctx0, cur, n_embd, n_patches_x, n_patches_y);
+        cur = ggml_permute(ctx0, cur, 2, 0, 1, 3); // [x, y, n_embd]
         cur = ggml_cont(ctx0, cur);
 
         // torch.nn.functional.unfold is just an im2col under the hood
@@ -762,7 +753,7 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
         ggml_tensor * kernel = ggml_view_3d(ctx0, cur, n_merge, n_merge, cur->ne[2], 0, 0, 0);
         cur = ggml_im2col(ctx0, kernel, cur, n_merge, n_merge, 0, 0, 1, 1, true, inp->type);
 
-        // project to hidden_size
+        // project to n_embd
         cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]);
         cur = ggml_mul_mat(ctx0, model.mm_patch_merger_w, cur);
         embeddings = cur;
@@ -785,9 +776,9 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
     // arrangement of the [IMG_BREAK] token
     {
         // not efficient, but works
-        // the trick is to view the embeddings as a 3D tensor with shape [hidden_size, n_patches_per_row, n_rows]
+        // the trick is to view the embeddings as a 3D tensor with shape [n_embd, n_patches_per_row, n_rows]
         // and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension
-        // after the concatenation, we have a tensor with shape [hidden_size, n_patches_per_row + 1, n_rows]
+        // after the concatenation, we have a tensor with shape [n_embd, n_patches_per_row + 1, n_rows]
 
         const int p_y             = n_merge > 0 ? n_patches_y / n_merge : n_patches_y;
         const int p_x             = n_merge > 0 ? n_patches_x / n_merge : n_patches_x;
@@ -827,9 +818,9 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_
     const int patches_h            = image_size_height / patch_size;
     const int num_positions        = num_patches + (model.class_embedding ? 1 : 0);
     const int num_position_ids     = num_positions * 4; // m-rope requires 4 dim per position
-    const int hidden_size          = hparams.hidden_size;
+    const int n_embd               = hparams.n_embd;
     const int n_head               = hparams.n_head;
-    const int d_head               = hidden_size / n_head;
+    const int d_head               = n_embd / n_head;
     const int n_layer              = hparams.n_layer;
     const float eps                = hparams.eps;
 
@@ -864,14 +855,14 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_
     inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3));  // [w, h, c, b] -> [c, w, h, b]
     inp = ggml_reshape_4d(
         ctx0, inp,
-        hidden_size * 2, patches_w / 2, patches_h, batch_size);
+        n_embd * 2, patches_w / 2, patches_h, batch_size);
     inp = ggml_reshape_4d(
         ctx0, inp,
-        hidden_size * 2, patches_w / 2, 2, batch_size * (patches_h / 2));
+        n_embd * 2, patches_w / 2, 2, batch_size * (patches_h / 2));
     inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 1, 3));
     inp = ggml_reshape_3d(
         ctx0, inp,
-        hidden_size, patches_w * patches_h, batch_size);
+        n_embd, patches_w * patches_h, batch_size);
 
     if (model.patch_bias) {
         // inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp));
@@ -904,11 +895,11 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_
         ggml_set_name(window_mask, "window_mask");
         ggml_set_input(window_mask);
 
-        // embeddings shape: [hidden_size, patches_w * patches_h, batch_size]
+        // embeddings shape: [n_embd, patches_w * patches_h, batch_size]
         GGML_ASSERT(batch_size == 1);
-        embeddings = ggml_reshape_2d(ctx0, embeddings, hidden_size * 4, patches_w * patches_h * batch_size / 4);
+        embeddings = ggml_reshape_2d(ctx0, embeddings, n_embd * 4, patches_w * patches_h * batch_size / 4);
         embeddings = ggml_get_rows(ctx0, embeddings, inv_window_idx);
-        embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size, patches_w * patches_h, batch_size);
+        embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd, patches_w * patches_h, batch_size);
     }
 
     // loop over layers
@@ -961,7 +952,7 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_
             KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size);
             KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
 
-            cur = ggml_cont_3d(ctx0, KQV, hidden_size, num_positions, batch_size);
+            cur = ggml_cont_3d(ctx0, KQV, n_embd, num_positions, batch_size);
         }
 
         // attention output
@@ -978,11 +969,11 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_
 
         // mlp
         // ffn_up
-        auto cur_up = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
-        cur_up = ggml_add(ctx0, cur_up, model.layers[il].ff_o_b);
+        auto cur_up = ggml_mul_mat(ctx0, model.layers[il].ff_up_w, cur);
+        cur_up = ggml_add(ctx0, cur_up, model.layers[il].ff_up_b);
 
-        auto cur_gate = ggml_mul_mat(ctx0, model.layers[il].ff_g_w, cur);
-        cur_gate = ggml_add(ctx0, cur_gate, model.layers[il].ff_g_b);
+        auto cur_gate = ggml_mul_mat(ctx0, model.layers[il].ff_gate_w, cur);
+        cur_gate = ggml_add(ctx0, cur_gate, model.layers[il].ff_gate_b);
         // TODO : only 2 of these 3 are actually used, should we remove one of them?
         if (ctx->use_gelu) {
             cur_gate = ggml_gelu_inplace(ctx0, cur_gate);
@@ -994,8 +985,8 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_
         cur = ggml_mul(ctx0, cur_gate, cur_up);
 
         // ffn_down
-        cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
-        cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
+        cur = ggml_mul_mat(ctx0, model.layers[il].ff_down_w, cur);
+        cur = ggml_add(ctx0, cur, model.layers[il].ff_down_b);
 
         // residual 2
         cur = ggml_add(ctx0, embeddings, cur);
@@ -1011,7 +1002,7 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_
         embeddings = ggml_mul(ctx0, embeddings, model.post_ln_w);
     }
 
-    embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size);
+    embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, num_positions / 4, batch_size);
 
     embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
     embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
@@ -1028,7 +1019,7 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_
         ggml_set_name(window_idx, "window_idx");
         ggml_set_input(window_idx);
 
-        // embeddings shape: [hidden_size, patches_w * patches_h, batch_size]
+        // embeddings shape: [n_embd, patches_w * patches_h, batch_size]
         GGML_ASSERT(batch_size == 1);
         embeddings = ggml_reshape_2d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4);
         embeddings = ggml_get_rows(ctx0, embeddings, window_idx);
@@ -1074,9 +1065,9 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
     const int patches_h            = image_size_height / patch_size;
     const int num_positions        = num_patches + (model.class_embedding ? 1 : 0);
     const int num_position_ids     = ctx->proj_type == PROJECTOR_TYPE_QWEN2VL ? num_positions * 4 : num_positions;
-    const int hidden_size          = hparams.hidden_size;
+    const int n_embd               = hparams.n_embd;
     const int n_head               = hparams.n_head;
-    const int d_head               = hidden_size / n_head;
+    const int d_head               = n_embd / n_head;
     const float eps                = hparams.eps;
     int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
 
@@ -1114,17 +1105,17 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
         inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3));  // [w, h, c, b] -> [c, w, h, b]
         inp = ggml_reshape_4d(
             ctx0, inp,
-            hidden_size * 2, patches_w / 2, patches_h, batch_size);
+            n_embd * 2, patches_w / 2, patches_h, batch_size);
         inp = ggml_reshape_4d(
             ctx0, inp,
-            hidden_size * 2, patches_w / 2, 2, batch_size * (patches_h / 2));
+            n_embd * 2, patches_w / 2, 2, batch_size * (patches_h / 2));
         inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 1, 3));
         inp = ggml_reshape_3d(
             ctx0, inp,
-            hidden_size, patches_w * patches_h, batch_size);
+            n_embd, patches_w * patches_h, batch_size);
     }
     else {
-        inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size);
+        inp = ggml_reshape_3d(ctx0, inp, num_patches, n_embd, batch_size);
         inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3));
     }
 
@@ -1137,7 +1128,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
 
     // concat class_embeddings and patch_embeddings
     if (model.class_embedding) {
-        embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);
+        embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd, num_positions, batch_size);
         embeddings = ggml_scale(ctx0, embeddings, 0.0f); // set to all zeros
         embeddings = ggml_acc(ctx0, embeddings, model.class_embedding,
                 embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0);
@@ -1234,7 +1225,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
             KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size);
             KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
 
-            cur = ggml_cont_3d(ctx0, KQV, hidden_size, num_positions, batch_size);
+            cur = ggml_cont_3d(ctx0, KQV, n_embd, num_positions, batch_size);
         }
 
         // attention output
@@ -1252,8 +1243,8 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
             cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b);
         }
 
-        cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
-        cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
+        cur = ggml_mul_mat(ctx0, model.layers[il].ff_up_w, cur);
+        cur = ggml_add(ctx0, cur, model.layers[il].ff_up_b);
 
         if (ctx->use_gelu) {
             cur = ggml_gelu_inplace(ctx0, cur);
@@ -1263,8 +1254,8 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
             cur = ggml_gelu_quick_inplace(ctx0, cur);
         }
 
-        cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
-        cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b);
+        cur = ggml_mul_mat(ctx0, model.layers[il].ff_down_w, cur);
+        cur = ggml_add(ctx0, cur, model.layers[il].ff_down_b);
 
         // residual 2
         cur = ggml_add(ctx0, embeddings, cur);
@@ -1496,9 +1487,9 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
         }
 
         { // attention
-            int hidden_size = clip_n_mmproj_embd(ctx);
+            int n_embd = clip_n_mmproj_embd(ctx);
             const int d_head = 128;
-            int n_head = hidden_size/d_head;
+            int n_head = n_embd/d_head;
             int num_query = 96;
             if (ctx->minicpmv_version == 2) {
                 num_query = 96;
@@ -1528,7 +1519,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
             struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
             KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_query, n_head, batch_size);
             KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
-            KQV = ggml_cont_3d(ctx0, KQV, hidden_size, num_query, batch_size);
+            KQV = ggml_cont_3d(ctx0, KQV, n_embd, num_query, batch_size);
 
             embeddings = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_o_w, KQV), model.mm_model_attn_o_b);
         }
@@ -1571,7 +1562,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
     }
 
     else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
-        embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size);
+        embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, num_positions / 4, batch_size);
 
         embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
         embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
@@ -1696,9 +1687,9 @@ struct clip_model_loader {
             get_bool(KEY_USE_GELU, ctx_clip.use_gelu, false);
             get_bool(KEY_USE_SILU, ctx_clip.use_silu, false);
 
-            get_u32(KEY_N_EMBD,         hparams.hidden_size);
+            get_u32(KEY_N_EMBD,         hparams.n_embd);
             get_u32(KEY_N_HEAD,         hparams.n_head);
-            get_u32(KEY_N_FF,           hparams.n_intermediate);
+            get_u32(KEY_N_FF,           hparams.n_ff);
             get_u32(KEY_N_BLOCK,        hparams.n_layer);
             get_u32(KEY_PROJ_DIM,       hparams.projection_dim);
             get_f32(KEY_LAYER_NORM_EPS, hparams.eps);
@@ -1807,6 +1798,7 @@ struct clip_model_loader {
     }
 
     void load_tensors() {
+        auto & hparams = ctx_clip.vision_model.hparams;
         std::map<std::string, size_t> tensor_offset;
         std::vector<ggml_tensor *> tensors_to_load;
 
@@ -1860,8 +1852,8 @@ struct clip_model_loader {
         vision_model.position_embeddings = get_tensor(string_format(TN_POS_EMBD, "v"), false);
 
         // layers
-        vision_model.layers.resize(vision_model.hparams.n_layer);
-        for (int il = 0; il < vision_model.hparams.n_layer; ++il) {
+        vision_model.layers.resize(hparams.n_layer);
+        for (int il = 0; il < hparams.n_layer; ++il) {
             auto & layer = vision_model.layers[il];
             layer.k_w    = get_tensor(string_format(TN_ATTN_K,      "v", il, "weight"));
             layer.q_w    = get_tensor(string_format(TN_ATTN_Q,      "v", il, "weight"));
@@ -1884,13 +1876,18 @@ struct clip_model_loader {
             layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, "v", il, "weight"));
             layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, "v", il, "bias"),   false);
 
-            // legacy naming (the in and out is reversed! don't ask me why)
-            layer.ff_i_w = layer.ff_down_w;
-            layer.ff_o_w = layer.ff_up_w;
-            layer.ff_g_w = layer.ff_gate_w;
-            layer.ff_i_b = layer.ff_down_b;
-            layer.ff_o_b = layer.ff_up_b;
-            layer.ff_g_b = layer.ff_gate_b;
+            // some models already exported with legacy (incorrect) naming which is quite messy, let's fix it here
+            // note: Qwen model converted from the old surgery script has n_ff = 0, so we cannot use n_ff to check!
+            if (layer.ff_up_w && layer.ff_down_w && layer.ff_down_w->ne[0] == hparams.n_embd) {
+                // swap up and down weights
+                ggml_tensor * tmp = layer.ff_up_w;
+                layer.ff_up_w = layer.ff_down_w;
+                layer.ff_down_w = tmp;
+                // swap up and down biases
+                tmp = layer.ff_up_b;
+                layer.ff_up_b = layer.ff_down_b;
+                layer.ff_down_b = tmp;
+            }
         }
 
         switch (ctx_clip.proj_type) {
@@ -2904,7 +2901,7 @@ int32_t clip_get_patch_size(const struct clip_ctx * ctx) {
 }
 
 int32_t clip_get_hidden_size(const struct clip_ctx * ctx) {
-    return ctx->vision_model.hparams.hidden_size;
+    return ctx->vision_model.hparams.n_embd;
 }
 
 const char * clip_patch_merge_type(const struct clip_ctx * ctx) {
index dd18e0fe6ed0dfceac705e321e53ef3f61f5e935..4977d5480bd1d085fafd7d17e1e331a0626cb8cc 100644 (file)
@@ -92,6 +92,10 @@ struct mtmd_cli_context {
         batch = llama_batch_init(params.n_batch, 0, 1);
         n_batch = params.n_batch;
 
+        if (!model || !lctx) {
+            exit(1);
+        }
+
         if (!llama_model_chat_template(model, nullptr) && params.chat_template.empty()) {
             LOG_ERR("Model does not have chat template.\n");
             LOG_ERR("  For old llava models, you may need to use '--chat-template vicuna'\n");