]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
clip : fix pixtral on some GPU backends (#13097)
authorXuan-Son Nguyen <redacted>
Fri, 25 Apr 2025 12:31:42 +0000 (14:31 +0200)
committerGitHub <redacted>
Fri, 25 Apr 2025 12:31:42 +0000 (14:31 +0200)
* clip : fix pixtral on some GPU backends

* refactor inp_raw set

* rm outdated comment

* fix dynamic size

* add TODO

examples/llava/clip.cpp
tests/test-backend-ops.cpp

index 9a5ab7c819585c3b1109f4f0099115f35b0c05b9..da8a590f0e5638169078142b264ac340d793af78 100644 (file)
@@ -554,15 +554,15 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im
 }
 
 // implementation of the 2D RoPE without adding a new op in ggml
+// this is not efficient (use double the memory), but works on all backends
+// TODO: there was a more efficient which relies on ggml_view and ggml_rope_ext_inplace, but the rope inplace does not work well with non-contiguous tensors ; we should fix that and revert back to the original implementation in https://github.com/ggml-org/llama.cpp/pull/13065
 static ggml_tensor * build_rope_2d(
-    ggml_cgraph * gf,
     ggml_context * ctx0,
     ggml_tensor * cur,
     ggml_tensor * pos_h,
     ggml_tensor * pos_w,
     const float freq_base
 ) {
-    ggml_tensor * tmp;
     const int64_t n_dim  = cur->ne[0];
     const int64_t n_head = cur->ne[1];
     const int64_t n_pos  = cur->ne[2];
@@ -571,18 +571,23 @@ static ggml_tensor * build_rope_2d(
     // we will have a list of 4 inv_freq: 1e-0, 1e-1, 1e-2, 1e-3
     // first half of cur will use 1e-0, 1e-2 (even)
     // second half of cur will use 1e-1, 1e-3 (odd)
-    //
-    // for the first half, the trick here is to rotate n_dim/2, so inv_freq will be even
+    // the trick here is to rotate just half of n_dim, so inv_freq will automatically be even
     //  ^ don't ask me why, it's math! -2(2i) / n_dim == -2i / (n_dim/2)
     // then for the second half, we use freq_scale to shift the inv_freq
     //  ^ why? replace (2i) with (2i+1) in the above equation
     const float freq_scale_odd = std::pow(freq_base, (float)-2/n_dim);
 
     // first half
+    ggml_tensor * first;
     {
-        cur = ggml_rope_ext_inplace(
+        first = ggml_view_3d(ctx0, cur,
+            n_dim/2, n_head, n_pos,
+            ggml_row_size(cur->type, n_dim),
+            ggml_row_size(cur->type, n_dim*n_head),
+            0);
+        first = ggml_rope_ext(
             ctx0,
-            cur,
+            first,
             pos_h,      // positions
             nullptr,    // freq factors
             n_dim/2,    // n_dims
@@ -592,15 +597,17 @@ static ggml_tensor * build_rope_2d(
     }
 
     // second half
+    ggml_tensor * second;
     {
-        tmp = ggml_view_3d(ctx0, cur,
+        second = ggml_view_3d(ctx0, cur,
             n_dim/2, n_head, n_pos,
             ggml_row_size(cur->type, n_dim),
             ggml_row_size(cur->type, n_dim*n_head),
             n_dim/2 * ggml_element_size(cur));
-        tmp = ggml_rope_ext_inplace(
+        second = ggml_cont(ctx0, second); // copy, because ggml_rope don't play well with non-contiguous tensors
+        second = ggml_rope_ext(
             ctx0,
-            tmp,
+            second,
             pos_w,      // positions
             nullptr,    // freq factors
             n_dim/2,    // n_dims
@@ -608,10 +615,9 @@ static ggml_tensor * build_rope_2d(
             freq_scale_odd,
             0.0f, 1.0f, 0.0f, 0.0f
         );
-        // calculate inplace (modify cur directly)
-        ggml_build_forward_expand(gf, tmp);
     }
 
+    cur = ggml_concat(ctx0, first, second, 0);
     return cur;
 }
 
@@ -680,13 +686,13 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
             struct ggml_tensor * Q = ggml_mul_mat(ctx0, model.layers[il].q_w, cur);
 
             Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_patches);
-            Q = build_rope_2d(gf, ctx0, Q, pos_h, pos_w, hparams.rope_theta);
+            Q = build_rope_2d(ctx0, Q, pos_h, pos_w, hparams.rope_theta);
             Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
 
             struct ggml_tensor * K = ggml_mul_mat(ctx0, model.layers[il].k_w, cur);
 
             K = ggml_reshape_3d(ctx0, K, d_head, n_head, num_patches);
-            K = build_rope_2d(gf, ctx0, K, pos_h, pos_w, hparams.rope_theta);
+            K = build_rope_2d(ctx0, K, pos_h, pos_w, hparams.rope_theta);
             K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
 
             struct ggml_tensor * V = ggml_mul_mat(ctx0, model.layers[il].v_w, cur);
@@ -2796,10 +2802,15 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
     const auto & model = ctx->vision_model;
     const auto & hparams = model.hparams;
 
+    // TODO @ngxson : this is ugly, need to refactor later
+    bool support_dynamic_size = ctx->has_minicpmv_projector
+        || ctx->has_qwen2vl_merger
+        || ctx->proj_type == PROJECTOR_TYPE_PIXTRAL;
+
     const int image_size = hparams.image_size;
     int image_size_width  = image_size;
     int image_size_height = image_size;
-    if (ctx->has_minicpmv_projector | ctx->has_qwen2vl_merger) {
+    if (support_dynamic_size) {
         image_size_width  = imgs.entries[0]->nx;
         image_size_height = imgs.entries[0]->ny;
     }
@@ -2811,9 +2822,20 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
 
     {
         struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw");
-        float * data = (float *)malloc(ggml_nbytes(inp_raw));
+        std::vector<float> inp_data(ggml_nelements(inp_raw));
+        float * data = inp_data.data();
+
+        // layout of data (note: the channel dim is unrolled to better visualize the layout):
+        //
+        // ┌──W──┐
+        // │     H │  channel = R
+        // ├─────┤ │
+        // │     H │  channel = G
+        // ├─────┤ │
+        // │     H │  channel = B
+        // └─────┘ │
+        //   ──────┘ x B
 
-        // TODO @ngxson : this whole code block is ugly, will need to be refactored
         for (size_t i = 0; i < imgs.entries.size(); i++) {
             const int nx = imgs.entries[i]->nx;
             const int ny = imgs.entries[i]->ny;
@@ -2828,17 +2850,19 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
             const int n = nx * ny;
 
             for (int b = 0; b < batch_size; b++) {
-                for (int k = 0; k < 3; k++) {
-                    for (int y = 0; y < ny; y++) {
-                        for (int x = 0; x < nx; x++) {
-                            data[(b * 3 * n) + k * n + y * nx + x] = imgs.entries[b]->buf[3 * (y * nx + x) + k];
-                        }
+                float * batch_entry = data + b * (3*n);
+                for (int y = 0; y < ny; y++) {
+                    for (int x = 0; x < nx; x++) {
+                        size_t base_src = 3*(y * nx + x); // idx of the first channel
+                        size_t base_dst =    y * nx + x;  // idx of the first channel
+                        batch_entry[      base_dst] = imgs.entries[b]->buf[base_src    ];
+                        batch_entry[1*n + base_dst] = imgs.entries[b]->buf[base_src + 1];
+                        batch_entry[2*n + base_dst] = imgs.entries[b]->buf[base_src + 2];
                     }
                 }
             }
         }
         ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw));
-        free(data);
     }
     if (ctx->has_minicpmv_projector) {
         {
index 61751755b317b435c0e04d3e0e765ad590ef482d..d70acb77194352ac366634917d94471ffd3999a9 100644 (file)
@@ -2606,6 +2606,8 @@ struct test_rope : public test_case {
             } else {
                 out = ggml_rope_ext_back(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
             }
+
+            // TODO: add test with a non-contiguous view as input ; this case is needed for build_rope_2d in clip.cpp
         }
         ggml_set_name(out, "out");