]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
clip : enable gpu backend (#4205)
authorSteward Garcia <redacted>
Fri, 29 Dec 2023 16:52:15 +0000 (11:52 -0500)
committerGitHub <redacted>
Fri, 29 Dec 2023 16:52:15 +0000 (18:52 +0200)
* clip: enable CUDA backend

* add missing kernels

* add enough padding for alignment

* remove ggml_repeat of clip.cpp

* add metal backend

* llava : fixes

- avoid ggml_repeat
- use GGML_USE_ instead of CLIP_USE_ macros
- remove unused vars

---------

Co-authored-by: Georgi Gerganov <redacted>
examples/llava/CMakeLists.txt
examples/llava/clip.cpp

index 48dae1506e81e752ee0a8799b82e3001037aaef5..2985caff8379a1684eff84b40452985dc6dac7b3 100644 (file)
@@ -24,7 +24,8 @@ endif()
 
 if (NOT MSVC)
     target_compile_options(llava PRIVATE -Wno-cast-qual) # stb_image.h
-    endif()
+endif()
+
 if(TARGET BUILD_INFO)
     add_dependencies(llava BUILD_INFO)
 endif()
index f06ec400d9a6e5bcffb20752df3e518f0fc26729..f9326a5cc3dff4c321f7f2b88e6ca46a66fe800d 100644 (file)
 #include "clip.h"
 #include "ggml.h"
 #include "ggml-alloc.h"
+#include "ggml-backend.h"
+
+#ifdef GGML_USE_CUBLAS
+#include "ggml-cuda.h"
+#endif
+
+#ifdef GGML_USE_METAL
+#include "ggml-metal.h"
+#endif
 
 #define STB_IMAGE_IMPLEMENTATION
 #include "stb_image.h"
 
-#define CLIP_DEBUG
-
 static std::string format(const char * fmt, ...) {
     va_list ap;
     va_list ap2;
@@ -196,20 +203,6 @@ struct clip_vision_model {
     struct ggml_tensor * mm_2_b;
 };
 
-// Replacement for std::vector<uint8_t> that doesn't require zero-initialization.
-struct clip_buffer {
-    uint8_t * data = NULL;
-    size_t size = 0;
-
-    void resize(size_t size) {
-        delete[] data;
-        data = new uint8_t[size];
-        this->size = size;
-    }
-
-    ~clip_buffer() { delete[] data; }
-};
-
 struct clip_ctx {
     bool has_text_encoder = false;
     bool has_vision_encoder = false;
@@ -223,9 +216,10 @@ struct clip_ctx {
     struct gguf_context * ctx_gguf;
 
     // memory buffers to evaluate the model
-    clip_buffer buf_compute;
-    clip_buffer buf_alloc;
-    ggml_allocr * alloc = NULL;
+    ggml_backend_buffer_t params_buffer = NULL;
+    ggml_backend_buffer_t compute_buffer = NULL;
+    ggml_backend_t backend = NULL;
+    ggml_allocr * compute_alloc = NULL;
 };
 
 static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_image_f32_batch * imgs) {
@@ -252,25 +246,20 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
     if(ctx->has_llava_projector) {
         GGML_ASSERT(batch_size == 1);
     }
-
-    const auto & buf_compute = ctx->buf_compute;
-
     struct ggml_init_params params = {
-        /*.mem_size =*/ buf_compute.size,
-        /*.mem_buffer =*/ buf_compute.data,
-        /*.no_alloc =*/ false,
+        /*.mem_size =*/ GGML_DEFAULT_GRAPH_SIZE * ggml_tensor_overhead() + ggml_graph_overhead(),
+        /*.mem_buffer =*/ NULL,
+        /*.no_alloc =*/ true,
     };
 
-    params.no_alloc = true;
-
     struct ggml_context * ctx0 = ggml_init(params);
     struct ggml_cgraph * gf = ggml_new_graph(ctx0);
 
     struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size, image_size, 3, batch_size);
-    ggml_allocr_alloc(ctx->alloc, inp_raw);
+    ggml_allocr_alloc(ctx->compute_alloc, inp_raw);
 
-    if (!ggml_allocr_is_measure(ctx->alloc)) {
-        float * data = (float *)ggml_get_data(inp_raw);
+    if (!ggml_allocr_is_measure(ctx->compute_alloc)) {
+        float * data = (float *)malloc(ggml_nbytes(inp_raw));
 
         for (size_t i = 0; i < imgs->size; i++) {
             const int nx = imgs->data[i].nx;
@@ -289,6 +278,8 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
                 }
             }
         }
+        ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw));
+        free(data);
     }
 
     struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
@@ -298,36 +289,39 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
 
     // concat class_embeddings and patch_embeddings
     struct ggml_tensor * embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);
-    ggml_allocr_alloc(ctx->alloc, embeddings);
-    if (!ggml_allocr_is_measure(ctx->alloc)) {
-        ggml_set_zero(embeddings);
+    ggml_allocr_alloc(ctx->compute_alloc, embeddings);
+    if (!ggml_allocr_is_measure(ctx->compute_alloc)) {
+        void* zero_mem = malloc(ggml_nbytes(embeddings));
+        memset(zero_mem, 0, ggml_nbytes(embeddings));
+        ggml_backend_tensor_set(embeddings, zero_mem, 0, ggml_nbytes(embeddings));
+        free(zero_mem);
     }
 
-    struct ggml_tensor * temp = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, 1, batch_size);
-    ggml_allocr_alloc(ctx->alloc, temp);
+    embeddings = ggml_acc(ctx0, embeddings, model.class_embedding,
+            embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0);
 
-    embeddings = ggml_acc(ctx0, embeddings, ggml_repeat(ctx0, model.class_embedding, temp), embeddings->nb[1],
-                          embeddings->nb[2], embeddings->nb[3], 0);
-    embeddings =
-        ggml_acc(ctx0, embeddings, inp, embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]);
+    embeddings = ggml_acc(ctx0, embeddings, inp,
+            embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]);
 
     struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions);
-    ggml_allocr_alloc(ctx->alloc, positions);
-    if (!ggml_allocr_is_measure(ctx->alloc)) {
+    ggml_allocr_alloc(ctx->compute_alloc, positions);
+    if (!ggml_allocr_is_measure(ctx->compute_alloc)) {
+        int* positions_data = (int*)malloc(ggml_nbytes(positions));
         for (int i = 0; i < num_positions; i++) {
-            ggml_set_i32_1d(positions, i, i);
+            positions_data[i] = i;
         }
+        ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
+        free(positions_data);
     }
 
     embeddings =
-        ggml_add(ctx0, embeddings, ggml_repeat(ctx0, ggml_get_rows(ctx0, model.position_embeddings, positions), embeddings));
+        ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions));
 
     // pre-layernorm
     {
         embeddings = ggml_norm(ctx0, embeddings, eps);
 
-        embeddings = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, model.pre_ln_w, embeddings), embeddings),
-                              ggml_repeat(ctx0, model.pre_ln_b, embeddings));
+        embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_ln_w), model.pre_ln_b);
     }
 
     // loop over layers
@@ -340,15 +334,15 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
         {
             cur = ggml_norm(ctx0, cur, eps);
 
-            cur = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, model.layers[il].ln_1_w, cur), cur),
-                           ggml_repeat(ctx0, model.layers[il].ln_1_b, cur));
+            cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_1_w),
+                           model.layers[il].ln_1_b);
         }
 
         // self-attention
         {
 
             struct ggml_tensor * Q =
-                ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].q_b, cur), ggml_mul_mat(ctx0, model.layers[il].q_w, cur));
+                ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b);
 
             Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
             Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size);
@@ -356,14 +350,14 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
             Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size);
 
             struct ggml_tensor * K =
-                ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].k_b, cur), ggml_mul_mat(ctx0, model.layers[il].k_w, cur));
+                ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b);
 
             K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
             K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
             K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size);
 
             struct ggml_tensor * V =
-                ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].v_b, cur), ggml_mul_mat(ctx0, model.layers[il].v_w, cur));
+                ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].v_w, cur), model.layers[il].v_b);
 
             V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size);
             V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
@@ -379,7 +373,7 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
         }
 
         // attention output
-        cur = ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].o_b, cur), ggml_mul_mat(ctx0, model.layers[il].o_w, cur));
+        cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].o_w, cur), model.layers[il].o_b);
 
         // re-add the layer input, e.g., residual
         cur = ggml_add(ctx0, cur, embeddings);
@@ -390,12 +384,11 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
         {
             cur = ggml_norm(ctx0, cur, eps);
 
-            cur = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, model.layers[il].ln_2_w, cur), cur),
-                           ggml_repeat(ctx0, model.layers[il].ln_2_b, cur));
+            cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b);
         }
 
         cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
-        cur = ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].ff_i_b, cur), cur);
+        cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
 
         if (ctx->use_gelu) {
             cur = ggml_gelu_inplace(ctx0, cur);
@@ -404,7 +397,7 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
         }
 
         cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
-        cur = ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].ff_o_b, cur), cur);
+        cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b);
 
         // residual 2
         cur = ggml_add(ctx0, embeddings, cur);
@@ -417,23 +410,26 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
         embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);
 
         struct ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches);
-        ggml_allocr_alloc(ctx->alloc, patches);
-        if (!ggml_allocr_is_measure(ctx->alloc)) {
-            for (int i = 0; i < num_patches; ++i) {
-                ggml_set_i32_1d(patches, i, i+1);
+        ggml_allocr_alloc(ctx->compute_alloc, patches);
+        if (!ggml_allocr_is_measure(ctx->compute_alloc)) {
+            int* patches_data = (int*)malloc(ggml_nbytes(patches));
+            for (int i = 0; i < num_positions; i++) {
+                patches_data[i] = i + 1;
             }
+            ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches));
+            free(patches_data);
         }
 
         embeddings = ggml_get_rows(ctx0, embeddings, patches);
 
         // mm projection 0
         embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
-        embeddings = ggml_add(ctx0, ggml_repeat(ctx0, model.mm_0_b, embeddings), embeddings);
+        embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
 
         embeddings = ggml_gelu(ctx0, embeddings);
 
         embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
-        embeddings = ggml_add(ctx0, ggml_repeat(ctx0, model.mm_2_b, embeddings), embeddings);
+        embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
     }
 
     // build the graph
@@ -446,7 +442,6 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
 
 // read and create ggml_context containing the tensors and their data
 struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
-
     struct ggml_context * meta = NULL;
 
     struct gguf_init_params params = {
@@ -479,7 +474,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
         printf("%s: ftype:        %s\n", __func__, ftype_str.c_str());
         printf("\n");
     }
-
+    const int n_tensors = gguf_get_n_tensors(ctx);
     // kv
     if (verbosity >= 3) {
         const int n_kv = gguf_get_n_kv(ctx);
@@ -493,27 +488,38 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
     }
 
     // data
-    size_t ctx_size = 0;
+    size_t buffer_size = 0;
     {
-        const int n_tensors = gguf_get_n_tensors(ctx);
-
         for (int i = 0; i < n_tensors; ++i) {
             const char * name = gguf_get_tensor_name(ctx, i);
             const size_t offset = gguf_get_tensor_offset(ctx, i);
-
             struct ggml_tensor * cur = ggml_get_tensor(meta, name);
-            ctx_size += sizeof(struct ggml_tensor) + GGML_OBJECT_SIZE;
             size_t tensor_size = ggml_nbytes(cur);
-            size_t padded_size = ggml_nbytes_pad(cur);
-            ctx_size += padded_size;
+            buffer_size += tensor_size;
             if (verbosity >= 3) {
-                printf("%s: tensor[%d]: n_dims = %d, name = %s, tensor_size=%zu, padded_size=%zu, offset=%zu\n", __func__, i,
-                       ggml_n_dims(cur), cur->name, tensor_size, padded_size, offset);
+                printf("%s: tensor[%d]: n_dims = %d, name = %s, tensor_size=%zu, offset=%zu\n", __func__, i,
+                       ggml_n_dims(cur), cur->name, tensor_size, offset);
             }
         }
     }
 
+    buffer_size += n_tensors * 128 /* CLIP PADDING */;
+
     clip_ctx * new_clip = new clip_ctx;
+#ifdef GGML_USE_CUBLAS
+    new_clip->backend = ggml_backend_cuda_init(0);
+    printf("%s: CLIP using CUDA backend\n", __func__);
+#endif
+
+#ifdef GGML_USE_METAL
+    new_clip->backend = ggml_backend_metal_init();
+    printf("%s: CLIP using Metal backend\n", __func__);
+#endif
+
+    if (!new_clip->backend) {
+        new_clip->backend = ggml_backend_cpu_init();
+        printf("%s: CLIP using CPU backend\n", __func__);
+    }
 
     // model size and capabilities
     {
@@ -539,17 +545,20 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
             printf("%s: text_encoder:   %d\n", __func__, new_clip->has_text_encoder);
             printf("%s: vision_encoder: %d\n", __func__, new_clip->has_vision_encoder);
             printf("%s: llava_projector:  %d\n", __func__, new_clip->has_llava_projector);
-            printf("%s: model size:     %.2f MB\n", __func__, (ctx_size / 1024.0 / 1024.0));
+            printf("%s: model size:     %.2f MB\n", __func__, buffer_size / 1024.0 / 1024.0);
             printf("%s: metadata size:  %.2f MB\n", __func__, ggml_get_mem_size(meta) / 1024.0 / 1024.0);
         }
     }
 
+    printf("%s: params backend buffer size = % 6.2f MB (%i tensors)\n", __func__, buffer_size / (1024.0 * 1024.0), n_tensors);
+
     // load tensors
     {
+        std::vector<uint8_t> read_buf;
         struct ggml_init_params params = {
-            /*.mem_size =*/ ctx_size,
+            /*.mem_size =*/ (n_tensors + 1) * ggml_tensor_overhead(),
             /*.mem_buffer =*/ NULL,
-            /*.no_alloc =*/ false,
+            /*.no_alloc =*/ true,
         };
 
         new_clip->ctx = ggml_init(params);
@@ -566,13 +575,21 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
             return nullptr;
         }
 
-        const int n_tensors = gguf_get_n_tensors(ctx);
+        // add tensors to context
         for (int i = 0; i < n_tensors; ++i) {
             const char * name = gguf_get_tensor_name(ctx, i);
             struct ggml_tensor * t = ggml_get_tensor(meta, name);
             struct ggml_tensor * cur = ggml_dup_tensor(new_clip->ctx, t);
             ggml_set_name(cur, name);
+        }
 
+        // alloc memory and offload data
+        new_clip->params_buffer = ggml_backend_alloc_buffer(new_clip->backend, buffer_size);
+        ggml_allocr* alloc = ggml_allocr_new_from_buffer(new_clip->params_buffer);
+        for (int i = 0; i < n_tensors; ++i) {
+            const char * name = gguf_get_tensor_name(ctx, i);
+            struct ggml_tensor * cur = ggml_get_tensor(new_clip->ctx, name);
+            ggml_allocr_alloc(alloc, cur);
             const size_t offset = gguf_get_data_offset(ctx) + gguf_get_tensor_offset(ctx, i);
             fin.seekg(offset, std::ios::beg);
             if (!fin) {
@@ -580,10 +597,22 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
                 clip_free(new_clip);
                 return nullptr;
             }
-
-            fin.read(reinterpret_cast<char *>(cur->data), ggml_nbytes(t));
+            int num_bytes = ggml_nbytes(cur);
+            if (ggml_backend_is_cpu(new_clip->backend)
+#ifdef GGML_USE_METAL
+            || ggml_backend_is_metal(new_clip->backend)
+#endif
+            ) {
+                // for the CPU and Metal backend, we can read directly into the tensor
+                fin.read(reinterpret_cast<char *>(cur->data), num_bytes);
+            } else {
+                // read into a temporary buffer first, then copy to device memory
+                read_buf.resize(num_bytes);
+                fin.read(reinterpret_cast<char *>(read_buf.data()), num_bytes);
+                ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes);
+            }
         }
-
+        ggml_allocr_free(alloc);
         fin.close();
     }
 
@@ -657,18 +686,16 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
 
 // measure mem requirement and allocate
     {
-        static const size_t tensor_alignment = 32;
-        new_clip->buf_compute.resize(ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead());
-        new_clip->alloc = ggml_allocr_new_measure(tensor_alignment);
+        new_clip->compute_alloc = ggml_allocr_new_measure_from_backend(new_clip->backend);
         clip_image_f32_batch batch;
         batch.size = 1;
         ggml_cgraph * gf = clip_image_build_graph(new_clip, &batch);
-        size_t alloc_size = ggml_allocr_alloc_graph(new_clip->alloc, gf) + tensor_alignment;
-        ggml_allocr_free(new_clip->alloc);
-        new_clip->buf_alloc.resize(alloc_size);
-        new_clip->alloc = ggml_allocr_new(new_clip->buf_alloc.data, new_clip->buf_alloc.size, tensor_alignment);
+        size_t compute_memory_buffer_size = ggml_allocr_alloc_graph(new_clip->compute_alloc, gf);
+        ggml_allocr_free(new_clip->compute_alloc);
+        new_clip->compute_buffer = ggml_backend_alloc_buffer(new_clip->backend, compute_memory_buffer_size);
+        new_clip->compute_alloc = ggml_allocr_new_from_buffer(new_clip->compute_buffer);
 
-        printf("%s: total allocated memory: %.2f MB\n", __func__, (new_clip->buf_compute.size + alloc_size)/1024.0/1024.0);
+        printf("%s: compute allocated memory: %.2f MB\n", __func__, compute_memory_buffer_size /1024.0/1024.0);
     }
 
     return new_clip;
@@ -852,29 +879,29 @@ bool clip_image_batch_encode(const clip_ctx * ctx, const int n_threads, const cl
     }
 
     // reset alloc buffer to clean the memory from previous invocations
-    ggml_allocr_reset(ctx->alloc);
+    ggml_allocr_reset(ctx->compute_alloc);
 
     // build the inference graph
     ggml_cgraph * gf = clip_image_build_graph(ctx, imgs);
-    ggml_allocr_alloc_graph(ctx->alloc, gf);
+    ggml_allocr_alloc_graph(ctx->compute_alloc, gf);
+
+    if (ggml_backend_is_cpu(ctx->backend)) {
+        ggml_backend_cpu_set_n_threads(ctx->backend, n_threads);
+    }
 
-    struct ggml_cplan plan = ggml_graph_plan(gf, n_threads);
-    if (plan.work_size > 0) {
-        plan.work_data = (uint8_t *)malloc(plan.work_size);
+#ifdef GGML_USE_METAL
+    if (ggml_backend_is_metal(ctx->backend)) {
+        ggml_backend_metal_set_n_cb(ctx->backend, n_threads);
     }
+#endif
 
-    ggml_graph_compute(gf, &plan);
+    ggml_backend_graph_compute(ctx->backend, gf);
 
     // the last node is the embedding tensor
-struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 1];
+    struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 1];
 
     // copy the embeddings to the location passed by the user
-    memcpy(vec, ggml_get_data_f32(embeddings), ggml_nbytes(embeddings));
-
-    if (plan.work_size > 0) {
-        free(plan.work_data);
-    }
-
+    ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings));
     return true;
 }
 
@@ -1045,8 +1072,8 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i
     gguf_free(ctx_out);
 
     {
-        printf("%s: original size  = %8.2f MB\n", __func__, total_size_org / 1024.0 / 1024.0);
-        printf("%s: quantized size  = %8.2f MB\n", __func__, total_size_new / 1024.0 / 1024.0);
+        printf("%s: original  size = %8.2f MB\n", __func__, total_size_org / 1024.0 / 1024.0);
+        printf("%s: quantized size = %8.2f MB\n", __func__, total_size_new / 1024.0 / 1024.0);
 
         int64_t sum_all = 0;
         for (size_t i = 0; i < hist_all.size(); ++i) {