]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
clip : use smart pointer (⚠️ breaking change) (#12869)
authorXuan-Son Nguyen <redacted>
Fri, 11 Apr 2025 10:09:39 +0000 (12:09 +0200)
committerGitHub <redacted>
Fri, 11 Apr 2025 10:09:39 +0000 (12:09 +0200)
* clip : use smart pointers

* fix warmup

* add forward declaration

* misisng include

* fix include (2)

* composite

* simplify batch ptr

* fix conflict

examples/llava/clip-impl.h
examples/llava/clip.cpp
examples/llava/clip.h
examples/llava/llava.cpp
examples/llava/mtmd.cpp

index 4c035298749245e8a04aa87b31c06dfdb22a7eb6..4d7340a56bd0caf2372b1efd5af86f6c007928c2 100644 (file)
@@ -1,5 +1,6 @@
 #include "ggml.h"
 #include "gguf.h"
+#include "clip.h"
 
 #include "clip.h"
 
@@ -202,23 +203,31 @@ static void clip_log_internal(enum ggml_log_level level, const char * format, ..
 // cpp wrappers
 //
 
+// wrapper for clip_image_size
+struct clip_image_size_deleter {
+    void operator()(clip_image_size * val) { clip_image_size_free(val); }
+};
+typedef std::unique_ptr<clip_image_size, clip_image_size_deleter> clip_image_size_ptr;
+
+// wrapper for clip_image_u8
 struct clip_image_u8_deleter {
     void operator()(clip_image_u8 * val) { clip_image_u8_free(val); }
 };
+typedef std::unique_ptr<clip_image_u8, clip_image_u8_deleter> clip_image_u8_ptr;
 
+// wrapper for clip_image_f32
 struct clip_image_f32_deleter {
     void operator()(clip_image_f32 * val) { clip_image_f32_free(val); }
 };
+typedef std::unique_ptr<clip_image_f32, clip_image_f32_deleter> clip_image_f32_ptr;
 
-struct clip_image_f32_batch_deleter {
-    void operator()(clip_image_f32_batch * val) { clip_image_f32_batch_free(val); }
+struct clip_image_u8_batch {
+    std::vector<clip_image_u8_ptr> entries;
 };
 
-typedef std::unique_ptr<clip_image_u8, clip_image_u8_deleter> clip_image_u8_ptr;
-typedef std::unique_ptr<clip_image_f32, clip_image_f32_deleter> clip_image_f32_ptr;
-typedef std::unique_ptr<clip_image_f32_batch, clip_image_f32_batch_deleter> clip_image_f32_batch_ptr;
-
-// TODO @ngxson : we're currently having a naming clash between struct clip_image_size and function clip_image_size()
+struct clip_image_f32_batch {
+    std::vector<clip_image_f32_ptr> entries;
+};
 
 //
 // common utils
index 710309edaecd6a31cb9608444a13939b49ea6aae..a55b3f3835184cd148a2524e76efb6083e74d7da 100644 (file)
@@ -315,58 +315,47 @@ struct clip_ctx {
     bool use_gelu = false;
     bool use_silu = false;
 
-    struct gguf_context * ctx_gguf = nullptr;
-    struct ggml_context * ctx_data = nullptr;
+    gguf_context_ptr ctx_gguf;
+    ggml_context_ptr ctx_data;
 
     std::vector<uint8_t> buf_compute_meta;
 
     std::vector<ggml_backend_t> backend_ptrs;
     std::vector<ggml_backend_buffer_type_t> backend_buft;
 
-    ggml_backend_t backend     = nullptr;
-    ggml_backend_t backend_cpu = nullptr;
-    ggml_backend_buffer_t buf  = nullptr;
+    ggml_backend_ptr backend;
+    ggml_backend_ptr backend_cpu;
+    ggml_backend_buffer_ptr buf;
 
     ggml_backend_sched_ptr sched;
 
-    struct clip_image_size * load_image_size = nullptr;
+    clip_image_size load_image_size;
 
     clip_ctx(clip_context_params & ctx_params) {
-        backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
-        backend     = ctx_params.use_gpu
+        backend_cpu = ggml_backend_ptr(ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr));
+        backend     = ggml_backend_ptr(ctx_params.use_gpu
                         ? ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_GPU, nullptr)
-                        : nullptr;
+                        : nullptr);
 
         if (backend) {
-            LOG_INF("%s: CLIP using %s backend\n", __func__, ggml_backend_name(backend));
-            backend_ptrs.push_back(backend);
-            backend_buft.push_back(ggml_backend_get_default_buffer_type(backend));
+            LOG_INF("%s: CLIP using %s backend\n", __func__, ggml_backend_name(backend.get()));
+            backend_ptrs.push_back(backend.get());
+            backend_buft.push_back(ggml_backend_get_default_buffer_type(backend.get()));
         } else {
-            backend = backend_cpu;
+            backend = std::move(backend_cpu);
             LOG_INF("%s: CLIP using CPU backend\n", __func__);
         }
 
-        backend_ptrs.push_back(backend_cpu);
-        backend_buft.push_back(ggml_backend_get_default_buffer_type(backend_cpu));
+        backend_ptrs.push_back(backend_cpu.get());
+        backend_buft.push_back(ggml_backend_get_default_buffer_type(backend_cpu.get()));
 
         sched.reset(
             ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), 8192, false)
         );
     }
-
-    ~clip_ctx() {
-        ggml_free(ctx_data);
-        gguf_free(ctx_gguf);
-        ggml_backend_buffer_free(buf);
-        ggml_backend_free(backend);
-        if (backend_cpu != backend) {
-            ggml_backend_free(backend_cpu);
-        }
-        clip_image_size_free(load_image_size);
-    }
 };
 
-static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_image_f32_batch * imgs) {
+static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_image_f32_batch & imgs) {
     const auto & model = ctx->vision_model;
     const auto & hparams = model.hparams;
 
@@ -382,7 +371,7 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im
     const int n_layer              = hparams.n_layer;
     const float eps                = hparams.eps;
 
-    GGML_ASSERT(imgs->size == 1); // batch_size == 1
+    GGML_ASSERT(imgs.entries.size() == 1); // batch_size == 1
 
     struct ggml_init_params params = {
         /*.mem_size   =*/ ctx->buf_compute_meta.size(),
@@ -390,7 +379,9 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im
         /*.no_alloc   =*/ true,
     };
 
-    struct ggml_context * ctx0 = ggml_init(params);
+    ggml_context_ptr ctx0_ptr(ggml_init(params));
+    auto ctx0 = ctx0_ptr.get();
+
     struct ggml_cgraph * gf = ggml_new_graph(ctx0);
 
     // input raw
@@ -512,12 +503,10 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im
     // build the graph
     ggml_build_forward_expand(gf, embeddings);
 
-    ggml_free(ctx0);
-
     return gf;
 }
 
-static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false) {
+static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_image_f32_batch & imgs, struct clip_image_size load_image_size, bool is_inf = false) {
     if (!ctx->has_vision_encoder) {
         LOG_ERR("This gguf file seems to have no vision encoder\n");
         return nullptr;
@@ -530,23 +519,20 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
     int image_size_width  = image_size;
     int image_size_height = image_size;
     if (ctx->has_minicpmv_projector) {
-        if (load_image_size == nullptr) {
-            load_image_size = clip_image_size_init();
-        }
-        LOG_DBG("%s: %d %d\n", __func__, load_image_size->width, load_image_size->height);
-        image_size_width  = load_image_size->width;
-        image_size_height = load_image_size->height;
+        LOG_DBG("%s: %d %d\n", __func__, load_image_size.width, load_image_size.height);
+        image_size_width  = load_image_size.width;
+        image_size_height = load_image_size.height;
         if (is_inf) {
-            image_size_width  = imgs->data->nx;
-            image_size_height = imgs->data->ny;
+            image_size_width  = imgs.entries[0]->nx;
+            image_size_height = imgs.entries[0]->ny;
         }
     }
     else if (ctx->has_qwen2vl_merger) {
         // use the image's native resolution when image is avaible
         if (is_inf) {
         // if (imgs->data->nx && imgs->data->ny) {
-            image_size_width  = imgs->data->nx;
-            image_size_height = imgs->data->ny;
+            image_size_width  = imgs.entries[0]->nx;
+            image_size_height = imgs.entries[0]->ny;
         }
     }
     const int patch_size           = hparams.patch_size;
@@ -561,7 +547,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
     const float eps                = hparams.eps;
     int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
 
-    const int batch_size = imgs->size;
+    const int batch_size = imgs.entries.size();
 
     if (ctx->has_llava_projector || ctx->has_minicpmv_projector || ctx->has_glm_projector) {
         GGML_ASSERT(batch_size == 1);
@@ -573,7 +559,9 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
         /*.no_alloc   =*/ true,
     };
 
-    struct ggml_context * ctx0 = ggml_init(params);
+    ggml_context_ptr ctx0_ptr(ggml_init(params));
+    auto ctx0 = ctx0_ptr.get();
+
     struct ggml_cgraph * gf = ggml_new_graph(ctx0);
 
     struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3, batch_size);
@@ -1061,7 +1049,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
                 embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings);
             }
         } else {
-            GGML_ABORT("fatel error");
+            GGML_ABORT("fatal error");
         }
     }
     else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
@@ -1081,12 +1069,10 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
     // build the graph
     ggml_build_forward_expand(gf, embeddings);
 
-    ggml_free(ctx0);
-
     return gf;
 }
 
-static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false) {
+static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch & imgs, struct clip_image_size load_image_size, bool is_inf = false) {
     if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
         return clip_image_build_graph_siglip(ctx, imgs);
     } else {
@@ -1257,7 +1243,7 @@ struct clip_model_loader {
             /*.mem_buffer =*/ NULL,
             /*.no_alloc =*/ true,
         };
-        ctx_clip.ctx_data = ggml_init(params);
+        ctx_clip.ctx_data.reset(ggml_init(params));
         if (!ctx_clip.ctx_data) {
             throw std::runtime_error(string_format("%s: failed to init ggml context\n", __func__));
         }
@@ -1271,7 +1257,7 @@ struct clip_model_loader {
             if (cur) {
                 tensors_to_load.push_back(cur);
                 // add tensors to context
-                struct ggml_tensor * data_tensor = ggml_dup_tensor(ctx_clip.ctx_data, cur);
+                struct ggml_tensor * data_tensor = ggml_dup_tensor(ctx_clip.ctx_data.get(), cur);
                 ggml_set_name(data_tensor, cur->name);
                 cur = data_tensor;
             }
@@ -1442,11 +1428,11 @@ struct clip_model_loader {
             }
 
             // alloc memory and offload data
-            ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(ctx_clip.backend);
-            ctx_clip.buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx_clip.ctx_data, buft);
-            ggml_backend_buffer_set_usage(ctx_clip.buf, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
+            ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(ctx_clip.backend.get());
+            ctx_clip.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(ctx_clip.ctx_data.get(), buft));
+            ggml_backend_buffer_set_usage(ctx_clip.buf.get(), GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
             for (auto & t : tensors_to_load) {
-                struct ggml_tensor * cur = ggml_get_tensor(ctx_clip.ctx_data, t->name);
+                struct ggml_tensor * cur = ggml_get_tensor(ctx_clip.ctx_data.get(), t->name);
                 const size_t offset = tensor_offset[t->name];
                 fin.seekg(offset, std::ios::beg);
                 if (!fin) {
@@ -1471,10 +1457,20 @@ struct clip_model_loader {
 
     void alloc_compute_meta() {
         ctx_clip.buf_compute_meta.resize(GGML_DEFAULT_GRAPH_SIZE * ggml_tensor_overhead() + ggml_graph_overhead());
+
+        // create a fake batch
         clip_image_f32_batch batch;
-        batch.size = 1;
-        batch.data = nullptr;
-        ggml_cgraph * gf = clip_image_build_graph(&ctx_clip, &batch, nullptr, false);
+        clip_image_f32_ptr img(clip_image_f32_init());
+        clip_image_size image_size;
+        image_size.width  = clip_get_image_size(&ctx_clip);
+        image_size.height = clip_get_image_size(&ctx_clip);
+        int n_patches = clip_get_image_size(&ctx_clip) / image_size.width;
+        img->nx = n_patches;
+        img->ny = n_patches;
+        img->buf.resize(n_patches * image_size.width * image_size.height * 3);
+        batch.entries.push_back(std::move(img));
+
+        ggml_cgraph * gf = clip_image_build_graph(&ctx_clip, batch, image_size, false);
         ggml_backend_sched_reserve(ctx_clip.sched.get(), gf);
         for (size_t i = 0; i < ctx_clip.backend_ptrs.size(); ++i) {
             ggml_backend_t backend = ctx_clip.backend_ptrs[i];
@@ -1575,11 +1571,11 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p
 }
 
 void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size) {
-    ctx_clip->load_image_size = load_image_size;
+    ctx_clip->load_image_size = *load_image_size; // copy
 }
 
 struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip) {
-    return ctx_clip->load_image_size;
+    return &ctx_clip->load_image_size;
 }
 
 struct clip_image_size * clip_image_size_init() {
@@ -1597,6 +1593,10 @@ struct clip_image_f32 * clip_image_f32_init() {
     return new clip_image_f32();
 }
 
+struct clip_image_f32_batch * clip_image_f32_batch_init() {
+    return new clip_image_f32_batch();
+}
+
 unsigned char * clip_image_u8_get_data(struct clip_image_u8 * img, uint32_t * nx, uint32_t * ny) {
     if (nx) *nx = img->nx;
     if (ny) *ny = img->ny;
@@ -1609,19 +1609,37 @@ void clip_image_size_free(struct clip_image_size * load_image_size) {
     }
     delete load_image_size;
 }
-void clip_image_u8_free(struct clip_image_u8  * img) { delete img; }
-void clip_image_f32_free(struct clip_image_f32 * img) { delete img; }
-void clip_image_u8_batch_free(struct clip_image_u8_batch  * batch) {
-    if (batch->size > 0) {
-        delete[] batch->data;
-        batch->size = 0;
+void clip_image_u8_free(struct clip_image_u8  * img) { if (img) delete img; }
+void clip_image_f32_free(struct clip_image_f32 * img) { if (img) delete img; }
+void clip_image_u8_batch_free(struct clip_image_u8_batch * batch) { if (batch) delete batch; }
+void clip_image_f32_batch_free(struct clip_image_f32_batch * batch) { if (batch) delete batch; }
+
+size_t clip_image_f32_batch_n_images(const struct clip_image_f32_batch * batch) {
+    return batch->entries.size();
+}
+
+size_t clip_image_f32_batch_nx(const struct clip_image_f32_batch * batch, int idx) {
+    if (idx < 0 || idx >= (int)batch->entries.size()) {
+        LOG_ERR("%s: invalid index %d\n", __func__, idx);
+        return 0;
+    }
+    return batch->entries[idx]->nx;
+}
+
+size_t clip_image_f32_batch_ny(const struct clip_image_f32_batch * batch, int idx) {
+    if (idx < 0 || idx >= (int)batch->entries.size()) {
+        LOG_ERR("%s: invalid index %d\n", __func__, idx);
+        return 0;
     }
+    return batch->entries[idx]->ny;
 }
-void clip_image_f32_batch_free(struct clip_image_f32_batch  * batch) {
-    if (batch->size > 0) {
-        delete[] batch->data;
-        batch->size = 0;
+
+clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch * batch, int idx) {
+    if (idx < 0 || idx >= (int)batch->entries.size()) {
+        LOG_ERR("%s: invalid index %d\n", __func__, idx);
+        return nullptr;
     }
+    return batch->entries[idx].get();
 }
 
 void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, clip_image_u8 * img) {
@@ -1695,14 +1713,15 @@ static void bilinear_resize(const clip_image_u8& src, clip_image_u8& dst, int ta
 }
 
 // Normalize image to float32 - careful with pytorch .to(model.device, dtype=torch.float16) - this sometimes reduces precision (32>16>32), sometimes not
-static void normalize_image_u8_to_f32(const clip_image_u8* src, clip_image_f32* dst, const float mean[3], const float std[3]) {
-    dst->nx = src->nx;
-    dst->ny = src->ny;
-    dst->buf.resize(src->buf.size());
+static void normalize_image_u8_to_f32(const clip_image_u8 & src, clip_image_f32 & dst, const float mean[3], const float std[3]) {
+    dst.nx = src.nx;
+    dst.ny = src.ny;
+    dst.buf.resize(src.buf.size());
 
-    for (size_t i = 0; i < src->buf.size(); ++i) {
+    // TODO @ngxson : seems like this could be done more efficiently on cgraph
+    for (size_t i = 0; i < src.buf.size(); ++i) {
         int c = i % 3; // rgb
-        dst->buf[i] = (static_cast<float>(src->buf[i]) / 255.0f - mean[c]) / std[c];
+        dst.buf[i] = (static_cast<float>(src.buf[i]) / 255.0f - mean[c]) / std[c];
     }
 }
 
@@ -1710,7 +1729,7 @@ inline int clip(int x, int lower, int upper) {
     return std::max(lower, std::min(x, upper));
 }
 
-static bool bicubic_resize(const clip_image_u8 &img, clip_image_u8 &dst, int target_width, int target_height) {
+static bool bicubic_resize(const clip_image_u8 & img, clip_image_u8 & dst, int target_width, int target_height) {
     const int nx = img.nx;
     const int ny = img.ny;
 
@@ -1848,13 +1867,13 @@ static std::pair<int, int> select_best_resolution(const std::pair<int, int> & or
     return best_fit;
 }
 
-static std::vector<clip_image_u8*> divide_to_patches_u8(const clip_image_u8 & image, int patch_size) {
-    std::vector<clip_image_u8*> patches;
+static std::vector<clip_image_u8_ptr> divide_to_patches_u8(const clip_image_u8 & image, int patch_size) {
+    std::vector<clip_image_u8_ptr> patches;
     int width = image.nx;
     int height = image.ny;
     for (int i = 0; i < height; i += patch_size) {
         for (int j = 0; j < width; j += patch_size) {
-            clip_image_u8 *patch = clip_image_u8_init();
+            clip_image_u8_ptr patch(clip_image_u8_init());
             patch->nx = std::min(patch_size, width - j);
             patch->ny = std::min(patch_size, height - i);
             patch->buf.resize(3 * patch->nx * patch->ny);
@@ -1865,7 +1884,7 @@ static std::vector<clip_image_u8*> divide_to_patches_u8(const clip_image_u8 & im
                     }
                 }
             }
-            patches.push_back(patch);
+            patches.push_back(std::move(patch));
         }
     }
     return patches;
@@ -1946,7 +1965,7 @@ static std::pair<int, int> uhd_best_grid(const int max_slice_nums, const int mul
 //    -> https://arxiv.org/pdf/2403.11703
 //    -> https://github.com/thunlp/LLaVA-UHD
 //    -> https://github.com/thunlp/LLaVA-UHD/blob/302301bc2175f7e717fb8548516188e89f649753/llava_uhd/train/llava-uhd/slice_logic.py#L118
-static std::vector<std::vector<clip_image_u8 *>> uhd_slice_image(const clip_image_u8 * img, const int max_slice_nums=9, const int scale_resolution=448, const int patch_size=14) {
+static std::vector<std::vector<clip_image_u8_ptr>> uhd_slice_image(const clip_image_u8 * img, const int max_slice_nums=9, const int scale_resolution=448, const int patch_size=14) {
     const std::pair<int, int> original_size={img->nx,img->ny};
     const int original_width = img->nx;
     const int original_height = img->ny;
@@ -1954,30 +1973,30 @@ static std::vector<std::vector<clip_image_u8 *>> uhd_slice_image(const clip_imag
     const float ratio = 1.0 * original_width * original_height/ (scale_resolution * scale_resolution);
     const int multiple = fmin(ceil(ratio), max_slice_nums);
 
-    std::vector<std::vector<clip_image_u8 *>> images;
+    std::vector<std::vector<clip_image_u8_ptr>> images;
     LOG_DBG("%s: multiple %d\n", __func__, multiple);
-    images.push_back(std::vector<clip_image_u8 *>());
+    images.push_back(std::vector<clip_image_u8_ptr>());
 
     if (multiple <= 1) {
         auto best_size = uhd_find_best_resize(original_size, scale_resolution, patch_size, true);
-        clip_image_u8 * source_image = clip_image_u8_init();
+        clip_image_u8_ptr source_image(clip_image_u8_init());
         bicubic_resize(*img, *source_image, best_size.first, best_size.second);
         // source_image = image.resize(best_size, Image.Resampling.BICUBIC)
-        images[images.size()-1].push_back(source_image);
+        images.back().push_back(std::move(source_image));
     }
     else if (multiple > 1) {
         auto best_size = uhd_find_best_resize(original_size, scale_resolution, patch_size);
-        clip_image_u8 * source_image = clip_image_u8_init();
+        clip_image_u8_ptr source_image(clip_image_u8_init());
         bicubic_resize(*img, *source_image, best_size.first, best_size.second);
         // source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC)
         LOG_DBG("%s: image_size: %d %d; source_image size: %d %d\n", __func__, img->nx, img->ny, best_size.first, best_size.second);
-        images[images.size()-1].push_back(source_image);
+        images.back().push_back(std::move(source_image));
 
         std::pair<int, int> best_grid = uhd_best_grid(max_slice_nums, multiple, log_ratio);
         LOG_DBG("%s: image_size: %d %d; best_grid: %d %d\n", __func__, img->nx, img->ny, best_grid.first, best_grid.second);
 
         auto refine_size = uhd_get_refine_size(original_size, best_grid, scale_resolution, patch_size, true);
-        clip_image_u8 * refine_image = clip_image_u8_init();
+        clip_image_u8_ptr refine_image(clip_image_u8_init());
         bicubic_resize(*img, *refine_image, refine_size.first, refine_size.second);
 
         LOG_DBG("%s: refine_image_size: %d %d; refine_size: %d %d\n", __func__, refine_image->nx, refine_image->ny, refine_size.first, refine_size.second);
@@ -1988,9 +2007,9 @@ static std::vector<std::vector<clip_image_u8 *>> uhd_slice_image(const clip_imag
         int grid_x = int(width / best_grid.first);
         int grid_y = int(height / best_grid.second);
         for (int patches_i = 0, ic = 0; patches_i < height && ic < best_grid.second; patches_i += grid_y, ic += 1){
-            images.push_back(std::vector<clip_image_u8 *>());
+            images.push_back(std::vector<clip_image_u8_ptr>());
             for(int patches_j = 0, jc = 0; patches_j < width && jc < best_grid.first; patches_j += grid_x, jc += 1){
-                clip_image_u8 * patch = clip_image_u8_init();
+                clip_image_u8_ptr patch(clip_image_u8_init());
                 patch->nx = grid_x;
                 patch->ny = grid_y;
                 patch->buf.resize(3 * patch->nx * patch->ny);
@@ -2003,10 +2022,9 @@ static std::vector<std::vector<clip_image_u8 *>> uhd_slice_image(const clip_imag
                         patch->buf[j+2] = refine_image->buf[i+2];
                     }
                 }
-                images[images.size()-1].push_back(patch);
+                images.back().push_back(std::move(patch));
             }
         }
-        clip_image_u8_free(refine_image);
     }
     return images;
 }
@@ -2014,8 +2032,8 @@ static std::vector<std::vector<clip_image_u8 *>> uhd_slice_image(const clip_imag
 int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip) {
     const int max_slice_nums=9;
     const int scale_resolution=448;
-    const int original_width = ctx_clip->load_image_size->width;
-    const int original_height = ctx_clip->load_image_size->height;
+    const int original_width = ctx_clip->load_image_size.width;
+    const int original_height = ctx_clip->load_image_size.height;
     const float log_ratio = log(1.0*original_width/original_height);
     const float ratio = 1.0 * original_width * original_height/ (scale_resolution * scale_resolution);
     const int multiple = fmin(ceil(ratio), max_slice_nums);
@@ -2025,64 +2043,44 @@ int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip) {
 
 // returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector
 // res_imgs memory is being allocated here, previous allocations will be freed if found
-bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, clip_image_f32_batch * res_imgs) {
+bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, struct clip_image_f32_batch * res_imgs) {
 
-    if(clip_is_minicpmv(ctx)){
+    if (clip_is_minicpmv(ctx)) {
         int max_slice_nums = 9;
-        std::vector<std::vector<clip_image_u8 *>> imgs = uhd_slice_image(img, max_slice_nums);
-        res_imgs->size = 0;
-        for (size_t i = 0; i < imgs.size(); ++i){
-            res_imgs->size += imgs[i].size();
-        }
-        res_imgs->data = new clip_image_f32[res_imgs->size];
-        int idx = 0;
+        std::vector<std::vector<clip_image_u8_ptr>> imgs = uhd_slice_image(img, max_slice_nums);
         for (size_t i = 0; i < imgs.size(); ++i) {
             for (size_t j = 0; j < imgs[i].size(); ++j) {
                 LOG_DBG("%s: %d %d\n", __func__,imgs[i][j]->nx,imgs[i][j]->ny);
-                clip_image_f32 * res = clip_image_f32_init();
-                normalize_image_u8_to_f32(imgs[i][j], res, ctx->image_mean, ctx->image_std);
-                res_imgs->data[idx++] = *res;
-                clip_image_f32_free(res);
-            }
-        }
-        for (size_t i = 0; i < imgs.size(); ++i) {
-            for (size_t j = 0; j < imgs[i].size(); ++j) {
-                if (imgs[i][j] != nullptr) {
-                    clip_image_u8_free(imgs[i][j]);
-                }
+                clip_image_f32_ptr res(clip_image_f32_init());
+                normalize_image_u8_to_f32(*imgs[i][j], *res, ctx->image_mean, ctx->image_std);
+                res_imgs->entries.push_back(std::move(res));
             }
         }
         return true;
     }
     else if (ctx->has_qwen2vl_merger) {
-        clip_image_u8 * resized = clip_image_u8_init();
-        auto patch_size = clip_patch_size(ctx) * 2;
+        clip_image_u8 resized;
+        auto patch_size = clip_get_patch_size(ctx) * 2;
         int nx = ceil((float)img->nx / patch_size) * patch_size;
         int ny = ceil((float)img->ny / patch_size) * patch_size;
-        bicubic_resize(*img, *resized, nx, ny);
+        bicubic_resize(*img, resized, nx, ny);
 
-        res_imgs->data = new clip_image_f32[1];
-        // clip_image_f32 * res = clip_image_f32_init();
-        normalize_image_u8_to_f32(resized, res_imgs->data, ctx->image_mean, ctx->image_std);
+        clip_image_f32_ptr img_f32(clip_image_f32_init());
+        // clip_image_f32_ptr res(clip_image_f32_init());
+        normalize_image_u8_to_f32(resized, *img_f32, ctx->image_mean, ctx->image_std);
         // res_imgs->data[0] = *res;
-        res_imgs->size = 1;
-
-        // clip_image_f32_free(res);
-        clip_image_u8_free(resized);
+        res_imgs->entries.push_back(std::move(img_f32));
         return true;
     }
 
     if (ctx->has_glm_projector || ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
-        res_imgs->size = 1;
-        res_imgs->data = new clip_image_f32[res_imgs->size];
         clip_image_u8 resized_image;
         int32_t sz=ctx->vision_model.hparams.image_size;
         bicubic_resize(*img, resized_image,sz,sz);
-        clip_image_f32 * res = clip_image_f32_init();
+        clip_image_f32_ptr img_f32(clip_image_f32_init());
         //clip_image_save_to_bmp(resized_image, "resized.bmp");
-        normalize_image_u8_to_f32(&resized_image, res, ctx->image_mean, ctx->image_std);
-        res_imgs->data[0] = *res;
-        clip_image_f32_free(res);
+        normalize_image_u8_to_f32(resized_image, *img_f32, ctx->image_mean, ctx->image_std);
+        res_imgs->entries.push_back(std::move(img_f32));
         return true;
     }
 
@@ -2097,16 +2095,12 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
         pad_to_square = false;
     }
     // free the previous res_imgs if any set
-    if (res_imgs->size > 0) {
-        clip_image_f32_batch_free(res_imgs);
-    }
-    res_imgs->data = nullptr;
-    res_imgs->size = 0;
+    res_imgs->entries.clear();
 
     // the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104)
     // see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156
 
-    clip_image_u8 * temp = clip_image_u8_init(); // we will keep the input image data here temporarily
+    clip_image_u8_ptr temp(clip_image_u8_init()); // we will keep the input image data here temporarily
     if (pad_to_square && img->nx != img->ny) {
         int longer_side = std::max(img->nx, img->ny);
         temp->nx = longer_side;
@@ -2149,28 +2143,18 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
             //     clip_image_u8_free(temp2);
             // }
 
-            std::vector<clip_image_u8 *> patches = divide_to_patches_u8(*temp, params.image_size); // prepare spatial sorted main patches of image_size each (336 in llava-1.6)
+            std::vector<clip_image_u8_ptr> patches = divide_to_patches_u8(*temp, params.image_size); // prepare spatial sorted main patches of image_size each (336 in llava-1.6)
 
-            clip_image_u8 *image_original_resize = clip_image_u8_init();
+            clip_image_u8_ptr image_original_resize(clip_image_u8_init());
             // bilinear_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square
             bicubic_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square
-            patches.insert(patches.begin(), image_original_resize);
-            // clip_image_f32_batch_init(patches.size());
-            res_imgs->size = patches.size();
-            res_imgs->data = new clip_image_f32[res_imgs->size];
-            int num=0;
-            for (auto& patch : patches) {
-                normalize_image_u8_to_f32(patch, &res_imgs->data[num], ctx->image_mean, ctx->image_std);
-                num++;
-            }
-
-            for (size_t i = 0; i < patches.size(); i++) {
-                // LOG_DBG("patch %d: %d %d\n", i, patches[i]->nx, patches[i]->ny);
-                clip_image_u8_free(patches[i]);
+            patches.insert(patches.begin(), std::move(image_original_resize));
+            for (auto & patch : patches) {
+                clip_image_f32_ptr res(clip_image_f32_init());
+                normalize_image_u8_to_f32(*patch, *res, ctx->image_mean, ctx->image_std);
+                res_imgs->entries.push_back(std::move(res));
             }
 
-            clip_image_u8_free(temp);
-
             return true;
         } else {
             temp->nx = img->nx;
@@ -2186,7 +2170,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
 
     const int nx2 = ctx->vision_model.hparams.image_size;
     const int ny2 = ctx->vision_model.hparams.image_size;
-    clip_image_f32 * res = clip_image_f32_init();
+    clip_image_f32_ptr res(clip_image_f32_init());
     res->nx = nx2;
     res->ny = ny2;
     res->buf.resize(3 * nx2 * ny2);
@@ -2238,7 +2222,6 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
             }
         }
     }
-    clip_image_u8_free(temp);
 
     // {
     //     clip_image_u8 * temp2 = clip_image_u8_init();
@@ -2248,10 +2231,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
     // }
     // res_imgs.push_back(res);
 
-    res_imgs->size = 1;
-    res_imgs->data = new clip_image_f32[res_imgs->size];
-    res_imgs->data[0] = *res;
-    clip_image_f32_free(res);
+    res_imgs->entries.push_back(std::move(res));
 
     return true;
 }
@@ -2279,15 +2259,15 @@ size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w
     return clip_n_patches_by_img(ctx, &img) * clip_n_mmproj_embd(ctx) * sizeof(float);
 }
 
-int32_t clip_image_size(const struct clip_ctx * ctx) {
+int32_t clip_get_image_size(const struct clip_ctx * ctx) {
     return ctx->vision_model.hparams.image_size;
 }
 
-int32_t clip_patch_size(const struct clip_ctx * ctx) {
+int32_t clip_get_patch_size(const struct clip_ctx * ctx) {
     return ctx->vision_model.hparams.patch_size;
 }
 
-int32_t clip_hidden_size(const struct clip_ctx * ctx) {
+int32_t clip_get_hidden_size(const struct clip_ctx * ctx) {
     return ctx->vision_model.hparams.hidden_size;
 }
 
@@ -2434,19 +2414,23 @@ bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f3
         return false;
     }
 
-    clip_image_f32_batch imgs{};
-    imgs.size = 1;
-    imgs.data = img;
+    clip_image_f32_batch imgs;
+    clip_image_f32_ptr img_copy(clip_image_f32_init());
+    *img_copy = *img;
+    imgs.entries.push_back(std::move(img_copy));
+
     return clip_image_batch_encode(ctx, n_threads, &imgs, vec);
 }
 
-bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs, float * vec) {
+bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs_c_ptr, float * vec) {
+    const clip_image_f32_batch & imgs = *imgs_c_ptr;
+
     if (!ctx->has_vision_encoder) {
         LOG_ERR("%s: This gguf file seems to have no vision encoder\n", __func__);
         return false;
     }
 
-    int batch_size = imgs->size;
+    int batch_size = imgs.entries.size();
     if (ctx->has_llava_projector) {
         GGML_ASSERT(batch_size == 1); // TODO: support multiple images
     }
@@ -2473,25 +2457,22 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
     int image_size_width  = image_size;
     int image_size_height = image_size;
     if (ctx->has_minicpmv_projector | ctx->has_qwen2vl_merger) {
-        image_size_width  = imgs->data[0].nx;
-        image_size_height = imgs->data[0].ny;
+        image_size_width  = imgs.entries[0]->nx;
+        image_size_height = imgs.entries[0]->ny;
     }
     const int patch_size    = hparams.patch_size;
     const int num_patches   = ((image_size_width / patch_size) * (image_size_height / patch_size));
     const int num_positions = num_patches + (model.class_embedding ? 1 : 0);
-    if(ctx->load_image_size==nullptr){
-        ctx->load_image_size= clip_image_size_init();
-    }
-    const int pos_w = ctx->load_image_size->width/patch_size;
-    const int pos_h = ctx->load_image_size->height/patch_size;
+    const int pos_w = ctx->load_image_size.width / patch_size;
+    const int pos_h = ctx->load_image_size.height / patch_size;
 
     {
         struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw");
         float * data = (float *)malloc(ggml_nbytes(inp_raw));
 
-        for (size_t i = 0; i < imgs->size; i++) {
-            const int nx = imgs->data[i].nx;
-            const int ny = imgs->data[i].ny;
+        for (size_t i = 0; i < imgs.entries.size(); i++) {
+            const int nx = imgs.entries[i]->nx;
+            const int ny = imgs.entries[i]->ny;
             if (!(ctx->has_minicpmv_projector | ctx->has_qwen2vl_merger)) {
                 GGML_ASSERT(nx == image_size && ny == image_size);
             }
@@ -2502,7 +2483,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
                 for (int k = 0; k < 3; k++) {
                     for (int y = 0; y < ny; y++) {
                         for (int x = 0; x < nx; x++) {
-                            data[(b * 3 * n) + k * n + y * nx + x] = imgs->data[b].buf[3 * (y * nx + x) + k];
+                            data[(b * 3 * n) + k * n + y * nx + x] = imgs.entries[b]->buf[3 * (y * nx + x) + k];
                         }
                     }
                 }
@@ -2629,7 +2610,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
         }
     }
 
-    ggml_backend_cpu_set_n_threads(ctx->backend_cpu, n_threads);
+    ggml_backend_cpu_set_n_threads(ctx->backend_cpu.get(), n_threads);
 
     auto status = ggml_backend_sched_graph_compute(ctx->sched.get(), gf);
     if (status != GGML_STATUS_SUCCESS) {
@@ -2662,8 +2643,8 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i
         /* verbosity */ GGML_LOG_LEVEL_ERROR,
     });
 
-    const auto & ctx_src = ctx_clip->ctx_gguf;
-    const auto & ctx_data = ctx_clip->ctx_data;
+    const auto & ctx_src = ctx_clip->ctx_gguf.get();
+    const auto & ctx_data = ctx_clip->ctx_data.get();
 
     auto * ctx_out = gguf_init_empty();
     gguf_set_kv(ctx_out, ctx_src);
index f61e0c0b2b3a73d26b3e5d66249c95567000bdc2..cc133a58de3e8f4b05c68098633009855b83eb7b 100644 (file)
@@ -30,15 +30,8 @@ struct clip_image_size {
     int height;
 };
 
-struct clip_image_u8_batch {
-    struct clip_image_u8 * data;
-    size_t size;
-};
-
-struct clip_image_f32_batch {
-    struct clip_image_f32 * data;
-    size_t size;
-};
+struct clip_image_u8_batch;
+struct clip_image_f32_batch;
 
 struct clip_context_params {
     bool use_gpu;
@@ -55,9 +48,9 @@ CLIP_API void clip_free(struct clip_ctx * ctx);
 CLIP_API size_t clip_embd_nbytes(const struct clip_ctx * ctx);
 CLIP_API size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w);
 
-CLIP_API int32_t clip_image_size (const struct clip_ctx * ctx);
-CLIP_API int32_t clip_patch_size (const struct clip_ctx * ctx);
-CLIP_API int32_t clip_hidden_size(const struct clip_ctx * ctx);
+CLIP_API int32_t clip_get_image_size (const struct clip_ctx * ctx);
+CLIP_API int32_t clip_get_patch_size (const struct clip_ctx * ctx);
+CLIP_API int32_t clip_get_hidden_size(const struct clip_ctx * ctx);
 
 // TODO: should be enum, not string
 CLIP_API const char * clip_patch_merge_type(const struct clip_ctx * ctx);
@@ -73,9 +66,10 @@ CLIP_API int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip);
 CLIP_API void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size);
 CLIP_API struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip);
 
-CLIP_API struct clip_image_size * clip_image_size_init();
-CLIP_API struct clip_image_u8  * clip_image_u8_init ();
-CLIP_API struct clip_image_f32 * clip_image_f32_init();
+CLIP_API struct clip_image_size      * clip_image_size_init();
+CLIP_API struct clip_image_u8        * clip_image_u8_init ();
+CLIP_API struct clip_image_f32       * clip_image_f32_init();
+CLIP_API struct clip_image_f32_batch * clip_image_f32_batch_init(); // only used by libllava
 
 // nx, ny are the output image dimensions
 CLIP_API unsigned char * clip_image_u8_get_data(struct clip_image_u8 * img, uint32_t * nx, uint32_t * ny);
@@ -86,6 +80,12 @@ CLIP_API void clip_image_f32_free(struct clip_image_f32 * img);
 CLIP_API void clip_image_u8_batch_free (struct clip_image_u8_batch  * batch);
 CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch * batch);
 
+// use for accessing underlay data of clip_image_f32_batch
+CLIP_API size_t clip_image_f32_batch_n_images(const struct clip_image_f32_batch * batch); // equivalent to batch->size()
+CLIP_API size_t clip_image_f32_batch_nx(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->nx
+CLIP_API size_t clip_image_f32_batch_ny(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->ny
+CLIP_API clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->data
+
 /**
  * Build image from pixels decoded by other libraries instead of stb_image.h for better performance.
  * The memory layout is RGBRGBRGB..., input buffer length must be 3*nx*ny bytes
index 518aad3f1f70b68c43ed750bcb1f26634e7a50b2..03a22cbb4c20541cbf5e0bc562759cc7adfead4c 100644 (file)
@@ -10,6 +10,7 @@
 #include <cstring>
 #include <limits>
 #include <vector>
+#include <memory>
 
 #if defined(LLAVA_LOG_OFF)
 #   define LOG_INF(...)
@@ -45,6 +46,17 @@ struct clip_image_grid_shape {
     int second;
 };
 
+// convenience cpp wrapper
+struct clip_image_f32_batch_deleter {
+    void operator()(clip_image_f32_batch * val) { clip_image_f32_batch_free(val); }
+};
+typedef std::unique_ptr<clip_image_f32_batch, clip_image_f32_batch_deleter> clip_image_f32_batch_ptr;
+
+struct clip_image_size_deleter {
+    void operator()(clip_image_f32_batch * val) { clip_image_f32_batch_free(val); }
+};
+typedef std::unique_ptr<clip_image_size, clip_image_size_deleter> clip_image_size_ptr;
+
 /**
  * Selects the best resolution from a list of possible resolutions based on the original size.
  *
@@ -105,8 +117,8 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *>
         struct ggml_context * ctx;
     } model;
 
-    const int32_t image_size = clip_image_size(ctx_clip);
-    const int32_t patch_size = clip_patch_size(ctx_clip);
+    const int32_t image_size = clip_get_image_size(ctx_clip);
+    const int32_t patch_size = clip_get_patch_size(ctx_clip);
 
     int32_t num_patches_per_side = image_size / patch_size; // 336 / 14 = 24 - used for embedding-patching boxes (24*24 = 576 patches)
 
@@ -246,12 +258,9 @@ static clip_image_f32 * reshape_by_patch(clip_image_f32 * image, int patch_size)
 
 static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float * image_embd, int * n_img_pos) {
     // std::vector<clip_image_f32*> img_res_v; // format VectN x H x W x RGB (N x 336 x 336 x 3), so interleaved RGB - different to the python implementation which is N x 3 x 336 x 336
-    clip_image_f32_batch img_res_v;
-    img_res_v.size = 0;
-    img_res_v.data = nullptr;
-    if (!clip_image_preprocess(ctx_clip, img, &img_res_v)) {
+    clip_image_f32_batch_ptr img_res_v(clip_image_f32_batch_init());
+    if (!clip_image_preprocess(ctx_clip, img, img_res_v.get())) {
         LOG_ERR("%s: unable to preprocess image\n", __func__);
-        delete[] img_res_v.data;
         return false;
     }
 
@@ -259,66 +268,72 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
 
     const char * mm_patch_merge_type = clip_patch_merge_type(ctx_clip);
 
+    const size_t n_imgs = clip_image_f32_batch_n_images(img_res_v.get());
+
     if (clip_is_minicpmv(ctx_clip) || clip_is_qwen2vl(ctx_clip)) {
         std::vector<float *> image_embd_v;
-        image_embd_v.resize(img_res_v.size);
-        struct clip_image_size * load_image_size = clip_image_size_init();
+        image_embd_v.resize(n_imgs);
+        clip_image_size load_image_size;
 
-        for (size_t i = 0; i < img_res_v.size; i++) {
+        for (size_t i = 0; i < n_imgs; i++) {
             const int64_t t_img_enc_step_start_us = ggml_time_us();
-            image_embd_v[i] = (float *)malloc(clip_embd_nbytes_by_img(ctx_clip, img_res_v.data[i].nx, img_res_v.data[i].ny));
-            int patch_size=14;
-            load_image_size->width = img_res_v.data[i].nx;
-            load_image_size->height = img_res_v.data[i].ny;
-            clip_add_load_image_size(ctx_clip, load_image_size);
+            int nx = clip_image_f32_batch_nx(img_res_v.get(), i);
+            int ny = clip_image_f32_batch_ny(img_res_v.get(), i);
+            image_embd_v[i] = (float *)malloc(clip_embd_nbytes_by_img(ctx_clip, nx, ny));
+            int patch_size = 14;
+            load_image_size.width = nx;
+            load_image_size.height = ny;
+            clip_add_load_image_size(ctx_clip, &load_image_size);
 
             bool encoded = false;
+            clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), i);
             if (clip_is_qwen2vl(ctx_clip)) {
-                encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]);
+                encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd_v[i]);
             }
             else {
-                encoded = clip_image_encode(ctx_clip, n_threads, reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]);
+                encoded = clip_image_encode(ctx_clip, n_threads, reshape_by_patch(img_res, patch_size), image_embd_v[i]);
             }
 
             if (!encoded) {
-                LOG_ERR("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) img_res_v.size);
+                LOG_ERR("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) n_imgs);
                 return false;
             }
             const int64_t t_img_enc_steop_batch_us = ggml_time_us();
-            LOG_INF("%s: step %d of %d encoded in %8.2f ms\n", __func__, (int)i+1, (int)img_res_v.size, (t_img_enc_steop_batch_us - t_img_enc_step_start_us) / 1000.0);
+            LOG_INF("%s: step %d of %d encoded in %8.2f ms\n", __func__, (int)i+1, (int)n_imgs, (t_img_enc_steop_batch_us - t_img_enc_step_start_us) / 1000.0);
         }
         const int64_t t_img_enc_batch_us = ggml_time_us();
-        LOG_INF("%s: all %d segments encoded in %8.2f ms\n", __func__, (int)img_res_v.size, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0);
+        LOG_INF("%s: all %d segments encoded in %8.2f ms\n", __func__, (int)n_imgs, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0);
 
         int n_img_pos_out = 0;
         for (size_t i = 0; i < image_embd_v.size(); i++) {
+            int nx = clip_image_f32_batch_nx(img_res_v.get(), i);
+            int ny = clip_image_f32_batch_ny(img_res_v.get(), i);
+            clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), i);
             std::memcpy(
                 image_embd + n_img_pos_out * clip_n_mmproj_embd(ctx_clip),
                 image_embd_v[i],
-                clip_embd_nbytes_by_img(ctx_clip, img_res_v.data[i].nx, img_res_v.data[i].ny));
-            n_img_pos_out += clip_n_patches_by_img(ctx_clip, &img_res_v.data[i]);
+                clip_embd_nbytes_by_img(ctx_clip, nx, ny));
+            n_img_pos_out += clip_n_patches_by_img(ctx_clip, img_res);
         }
         *n_img_pos = n_img_pos_out;
         for (size_t i = 0; i < image_embd_v.size(); i++) {
             free(image_embd_v[i]);
         }
         image_embd_v.clear();
-        load_image_size->width = img->nx;
-        load_image_size->height = img->ny;
-        clip_add_load_image_size(ctx_clip, load_image_size);
-        LOG_INF("%s: load_image_size %d %d\n", __func__, load_image_size->width, load_image_size->height);
-        delete[] img_res_v.data;
-        img_res_v.size = 0;
-        img_res_v.data = nullptr;
+        load_image_size.width = img->nx;
+        load_image_size.height = img->ny;
+        clip_add_load_image_size(ctx_clip, &load_image_size);
+        LOG_INF("%s: load_image_size %d %d\n", __func__, load_image_size.width, load_image_size.height);
     }
     else if (clip_is_glm(ctx_clip)){
         struct clip_image_size * load_image_size = clip_image_size_init();
-        load_image_size->width = img_res_v.data[0].nx;
-        load_image_size->height = img_res_v.data[0].ny;
+        load_image_size->width  = clip_image_f32_batch_nx(img_res_v.get(), 0);
+        load_image_size->height = clip_image_f32_batch_ny(img_res_v.get(), 0);
         clip_add_load_image_size(ctx_clip, load_image_size);
 
-        bool encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[0], image_embd);
-        int pos = int(load_image_size->width/clip_patch_size(ctx_clip)/2);
+        clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), 0);
+        bool encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd);
+        int pos = int(load_image_size->width/clip_get_patch_size(ctx_clip)/2);
         *n_img_pos = (pos * pos + 2);
         if (!encoded){
             LOG_ERR("Unable to encode image \n");
@@ -328,8 +343,8 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
     else if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) {
         // flat / default llava-1.5 type embedding
         *n_img_pos = clip_n_patches(ctx_clip);
-        bool encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[0], image_embd); // image_embd shape is 576 x 4096
-        delete[] img_res_v.data;
+        clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), 0);
+        bool encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd); // image_embd shape is 576 x 4096
         if (!encoded) {
             LOG_ERR("Unable to encode image\n");
 
@@ -340,17 +355,18 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
         // spatial_unpad llava-1.6 type embedding
         // TODO: CLIP needs batching support - in HF the llm projection is separate after encoding, which might be a solution to quickly get batching working
         std::vector<float *> image_embd_v;
-        image_embd_v.resize(img_res_v.size);
-        for (size_t i = 0; i < img_res_v.size; i++) {
+        image_embd_v.resize(n_imgs);
+        for (size_t i = 0; i < n_imgs; i++) {
+            clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), i);
             image_embd_v[i] = (float *)malloc(clip_embd_nbytes(ctx_clip)); // 576 patches * 4096 embeddings * 4 bytes = 9437184
-            const bool encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]); // image data is in 3x336x336 format and will be converted to 336x336x3 inside
+            const bool encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd_v[i]); // image data is in 3x336x336 format and will be converted to 336x336x3 inside
             if (!encoded) {
-                LOG_ERR("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) img_res_v.size);
+                LOG_ERR("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) n_imgs);
                 return false;
             }
         }
         const int64_t t_img_enc_batch_us = ggml_time_us();
-        LOG_INF("%s: %d segments encoded in %8.2f ms\n", __func__, (int)img_res_v.size, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0);
+        LOG_INF("%s: %d segments encoded in %8.2f ms\n", __func__, (int)n_imgs, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0);
 
         const int32_t * image_grid = clip_image_grid(ctx_clip);
         const size_t num_gridpoints = get_clip_image_grid_size(ctx_clip);
@@ -360,12 +376,7 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
             grid_pinpoints.push_back({image_grid[i], image_grid[i+1]});
         }
 
-        // free all img_res_v - not needed anymore
-        delete[] img_res_v.data;
-        img_res_v.size = 0;
-        img_res_v.data = nullptr;
-
-        const int32_t image_size = clip_image_size(ctx_clip);
+        const int32_t image_size = clip_get_image_size(ctx_clip);
 
         struct clip_image_grid_shape grid_shape = get_anyres_image_grid_shape({img->nx,img->ny}, grid_pinpoints, image_size);
 
index 58503d0b22c33ec28850dbed80fd749516f744f0..114c274bc1250b5c3801c7d07a0f17628e4beddc 100644 (file)
@@ -41,14 +41,14 @@ struct mtmd_context {
 };
 
 struct mtmd_image_tokens_data {
-    clip_image_f32_batch_ptr batch_f32; // preprocessed image patches
+    clip_image_f32_batch batch_f32; // preprocessed image patches
 };
 
 struct mtmd_image_tokens {
     uint32_t nx; // number of tokens in x direction
     uint32_t ny; // number of tokens in y direction
     uint32_t n_tokens() const { return nx * ny; }
-    clip_image_f32_batch_ptr batch_f32; // preprocessed image patches
+    clip_image_f32_batch batch_f32; // preprocessed image patches
 };
 
 mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
@@ -141,8 +141,8 @@ mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
             std::memcpy(img_u8->buf.data(), bitmaps[i_img].data.data(), img_u8->nx * img_u8->ny * 3);
 
             // preprocess image
-            clip_image_f32_batch_ptr batch_f32(new clip_image_f32_batch);
-            bool ok = clip_image_preprocess(ctx->ctx_clip, img_u8.get(), batch_f32.get());
+            clip_image_f32_batch batch_f32;
+            bool ok = clip_image_preprocess(ctx->ctx_clip, img_u8.get(), &batch_f32);
             if (!ok) {
                 LOG_ERR("Unable to preprocess image\n");
                 return nullptr;
@@ -181,7 +181,7 @@ int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens)
     bool ok = clip_image_batch_encode(
         ctx->ctx_clip,
         ctx->n_threads,
-        image_tokens->batch_f32.get(),
+        &image_tokens->batch_f32,
         ctx->image_embd_v.data());
     return ok ? 0 : 1;
 }