]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llava : support Minicpm-omni (#11289)
authortc-mb <redacted>
Wed, 22 Jan 2025 07:35:48 +0000 (15:35 +0800)
committerGitHub <redacted>
Wed, 22 Jan 2025 07:35:48 +0000 (09:35 +0200)
* init

* add readme

* update readme

* no use make

* update readme

* update fix code

* fix editorconfig-checker

* no change convert py

* use clip_image_u8_free

examples/llava/README-minicpmo2.6.md [new file with mode: 0644]
examples/llava/clip.cpp
examples/llava/llava.cpp
examples/llava/minicpmv-cli.cpp
examples/llava/minicpmv-convert-image-encoder-to-gguf.py
examples/llava/minicpmv-surgery.py

diff --git a/examples/llava/README-minicpmo2.6.md b/examples/llava/README-minicpmo2.6.md
new file mode 100644 (file)
index 0000000..8713a43
--- /dev/null
@@ -0,0 +1,46 @@
+## MiniCPM-o 2.6
+Currently, this readme only supports minicpm-omni's image capabilities, and we will update the full-mode support as soon as possible.
+
+### Prepare models and code
+
+Download [MiniCPM-o-2_6](https://huggingface.co/openbmb/MiniCPM-o-2_6) PyTorch model from huggingface to "MiniCPM-o-2_6" folder.
+
+Clone llama.cpp:
+```bash
+git clone git@github.com:OpenBMB/llama.cpp.git
+cd llama.cpp
+git checkout minicpm-omni
+```
+
+### Usage of MiniCPM-o 2.6
+
+Convert PyTorch model to gguf files (You can also download the converted [gguf](https://huggingface.co/openbmb/MiniCPM-o-2_6-gguf) by us)
+
+```bash
+python ./examples/llava/minicpmv-surgery.py -m ../MiniCPM-o-2_6
+python ./examples/llava/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-o-2_6 --minicpmv-projector ../MiniCPM-o-2_6/minicpmv.projector --output-dir ../MiniCPM-o-2_6/ --image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5 --minicpmv_version 4
+python ./convert_hf_to_gguf.py ../MiniCPM-o-2_6/model
+
+# quantize int4 version
+./llama-quantize ../MiniCPM-o-2_6/model/ggml-model-f16.gguf ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf Q4_K_M
+```
+
+Build llama.cpp using `CMake`:
+https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md
+
+```bash
+cmake -B build
+cmake --build build --config Release
+```
+
+Inference on Linux or Mac
+```
+# run f16 version
+./llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-f16.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
+
+# run quantized int4 version
+./llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg  -p "What is in the image?"
+
+# or run in interactive mode
+./llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -i
+```
index 7a8a3156bfdeffda5172b4859a47445cec2e4e1c..24073c5a9b15fd1c2748f7b7ada3baf665f5ef44 100644 (file)
@@ -718,6 +718,9 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
         else if (ctx->minicpmv_version == 3) {
             pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1);
         }
+        else if (ctx->minicpmv_version == 4) {
+            pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1);
+        }
         ggml_set_name(pos_embed, "pos_embed");
         ggml_set_input(pos_embed);
     }
@@ -1053,6 +1056,11 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
                     n_head = hidden_size/d_head;
                     num_query = 64;
                 }
+                else if (ctx->minicpmv_version == 4) {
+                    hidden_size = 3584;
+                    n_head = hidden_size/d_head;
+                    num_query = 64;
+                }
 
                 struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b);
                 Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
@@ -2041,6 +2049,7 @@ static std::vector<std::vector<clip_image_u8 *>> uhd_slice_image(const clip_imag
                 images[images.size()-1].push_back(patch);
             }
         }
+        clip_image_u8_free(refine_image);
     }
     return images;
 }
@@ -2079,6 +2088,13 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
                 clip_image_f32_free(res);
             }
         }
+        for (size_t i = 0; i < imgs.size(); ++i) {
+            for (size_t j = 0; j < imgs[i].size(); ++j) {
+                if (imgs[i][j] != nullptr) {
+                    clip_image_u8_free(imgs[i][j]);
+                }
+            }
+        }
         return true;
     }
     else if (ctx->has_qwen2vl_merger) {
@@ -2335,6 +2351,9 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i
         else if (ctx->minicpmv_version == 3) {
             n_patches = 64;
         }
+        else if (ctx->minicpmv_version == 4) {
+            n_patches = 64;
+        }
     } else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
         int patch_size = params.patch_size * 2;
         int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0);
@@ -2514,8 +2533,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
             //    -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316
             struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
             int* positions_data = (int*)malloc(ggml_nbytes(positions));
-            int bucket_coords_h[70];
-            int bucket_coords_w[70];
+            int bucket_coords_h[1024];
+            int bucket_coords_w[1024];
             for (int i = 0; i < pos_h; i++){
                 bucket_coords_h[i] = std::floor(70.0*i/pos_h);
             }
@@ -2543,6 +2562,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
             else if (ctx->minicpmv_version == 3) {
                 embed_dim = 3584;
             }
+            else if (ctx->minicpmv_version == 4) {
+                embed_dim = 3584;
+            }
             auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h));
 
             float * pos_embed_data = (float *)malloc(ggml_nbytes(pos_embed));
@@ -2786,6 +2808,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
         else if (ctx->minicpmv_version == 3) {
             return 3584;
         }
+        else if (ctx->minicpmv_version == 4) {
+            return 3584;
+        }
     }
     if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
         return ctx->vision_model.mm_1_b->ne[0];
index c598caf3dd1ebea024968457f431a1464db12667..2cac7933d2f2ae054f8648f3599ded4c2b240486 100644 (file)
@@ -216,7 +216,7 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *>
     return true;
 }
 
-static clip_image_f32 * only_v2_5_reshape_by_patch(clip_image_f32 * image, int patch_size) {
+static clip_image_f32 * reshape_by_patch(clip_image_f32 * image, int patch_size) {
     int width = image->nx;
     int height = image->ny;
     int num_patches = (height / patch_size) * (width / patch_size);
@@ -277,13 +277,7 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
                 encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]);
             }
             else {
-                int has_minicpmv_projector = clip_is_minicpmv(ctx_clip);
-                if (has_minicpmv_projector == 2) {
-                    encoded = clip_image_encode(ctx_clip, n_threads, only_v2_5_reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]);
-                }
-                else if (has_minicpmv_projector == 3) {
-                    encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]);
-                }
+                encoded = clip_image_encode(ctx_clip, n_threads, reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]);
             }
 
             if (!encoded) {
@@ -313,6 +307,9 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
         load_image_size->height = img->ny;
         clip_add_load_image_size(ctx_clip, load_image_size);
         LOG_INF("%s: load_image_size %d %d\n", __func__, load_image_size->width, load_image_size->height);
+        delete[] img_res_v.data;
+        img_res_v.size = 0;
+        img_res_v.data = nullptr;
     }
     else if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) {
         // flat / default llava-1.5 type embedding
index 38c44e130e15c01a29a8d09b4a76b26e2f4019f8..53d902d616e856b331cb0112c25b636f66e51b1d 100644 (file)
@@ -140,6 +140,9 @@ static void process_image(struct llava_context * ctx_llava, struct llava_image_e
     else if (has_minicpmv_projector == 3) {
         system_prompt = "<|im_start|>user\n";
     }
+    else if (has_minicpmv_projector == 4) {
+        system_prompt = "<|im_start|>user\n";
+    }
     LOG_INF("%s: image token past: %d\n", __func__, n_past);
     eval_string(ctx_llava->ctx_llama, (system_prompt+"<image>").c_str(), params->n_batch, &n_past, false);
     process_eval_image_embed(ctx_llava, embeds, params->n_batch, &n_past, idx++);
@@ -227,6 +230,9 @@ static struct common_sampler * llama_init(struct llava_context * ctx_llava, comm
         else if (has_minicpmv_projector == 3) {
             user_prompt = "<|im_start|>user\n" + prompt;
         }
+        else if (has_minicpmv_projector == 4) {
+            user_prompt = "<|im_start|>user\n" + prompt;
+        }
     }
 
     eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, false);
@@ -236,6 +242,9 @@ static struct common_sampler * llama_init(struct llava_context * ctx_llava, comm
     else if (has_minicpmv_projector == 3) {
         eval_string(ctx_llava->ctx_llama, "<|im_end|><|im_start|>assistant\n", params->n_batch, &n_past, false);
     }
+    else if (has_minicpmv_projector == 4) {
+        eval_string(ctx_llava->ctx_llama, "<|im_end|><|im_start|>assistant\n", params->n_batch, &n_past, false);
+    }
 
     // generate the response
 
@@ -308,7 +317,6 @@ int main(int argc, char ** argv) {
                     const auto * tmp = llama_loop(ctx_llava, smpl, n_past);
                     response += tmp;
                     if (strcmp(tmp, "</s>") == 0) break;
-                    if (strstr(tmp, "###")) break; // Yi-VL behavior
                     printf("%s", tmp);// mistral llava-1.6
                     if (strstr(response.c_str(), "<user>")) break; // minicpm-v
                     fflush(stdout);
index ea773742a832bb2c4549754f4e3a78b969313de5..9b196757f07c947310926c7a6d74356bd4d0c814 100644 (file)
@@ -501,7 +501,7 @@ default_image_mean = [0.48145466, 0.4578275, 0.40821073]
 default_image_std = [0.26862954, 0.26130258, 0.27577711]
 ap.add_argument('--image-mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None)
 ap.add_argument('--image-std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None)
-ap.add_argument('--minicpmv_version', type=int, help='minicpmv_version: MiniCPM-V-2 use 1; MiniCPM-V-2.5 use 2; MiniCPM-V-2.6 use 3', default=2)
+ap.add_argument('--minicpmv_version', type=int, help='minicpmv_version: MiniCPM-V-2 use 1; MiniCPM-V-2.5 use 2; MiniCPM-V-2.6 use 3; MiniCPM-o-2.6 use 4', default=2)
 
 # with proper
 args = ap.parse_args()
@@ -545,12 +545,19 @@ if args.use_f32:
 
 minicpmv_version = args.minicpmv_version
 emb_dim = 4096
+block_count = 26
 if minicpmv_version == 1:
     emb_dim = 2304
+    block_count = 26
 elif minicpmv_version == 2:
     emb_dim = 4096
+    block_count = 27
 elif minicpmv_version == 3:
     emb_dim = 3584
+    block_count = 27
+elif minicpmv_version == 4:
+    emb_dim = 3584
+    block_count = 27
 
 default_vision_config = {
         "hidden_size": 1152,
@@ -567,6 +574,9 @@ model = Idefics2VisionTransformer(vision_config)
 if minicpmv_version == 3:
     vision_config = SiglipVisionConfig(**default_vision_config)
     model = SiglipVisionTransformer(vision_config)
+elif minicpmv_version == 4:
+    vision_config = SiglipVisionConfig(**default_vision_config)
+    model = SiglipVisionTransformer(vision_config)
 
 processor = None
 # if model.attn_pool is not None:
@@ -587,7 +597,7 @@ elif args.minicpmv_projector is not None:
     fname_middle = "mmproj-"
     has_text_encoder = False
     has_minicpmv_projector = True
-    minicpmv_version = 3
+    minicpmv_version = 4
 elif args.vision_only:
     fname_middle = "vision-"
     has_text_encoder = False
@@ -625,7 +635,6 @@ if has_vision_encoder:
     fout.add_uint32("clip.vision.projection_dim", 0)
     fout.add_uint32(add_key_str(KEY_ATTENTION_HEAD_COUNT, VISION), 16)
     fout.add_float32(add_key_str(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6)
-    block_count = 26
     fout.add_uint32(add_key_str(KEY_BLOCK_COUNT, VISION), block_count)
 
     if processor is not None:
index 748ff5c57824e4a01e69335aee8c1f239dd38949..ba82116582b1fa4d530ee61ed6e98ca9471a6ce3 100644 (file)
@@ -8,7 +8,7 @@ ap.add_argument("-m", "--model", help="Path to MiniCPM-V model")
 args = ap.parse_args()
 
 # find the model part that includes the the multimodal projector weights
-model = AutoModel.from_pretrained(args.model, trust_remote_code=True, local_files_only=True)
+model = AutoModel.from_pretrained(args.model, trust_remote_code=True, local_files_only=True, torch_dtype=torch.bfloat16)
 checkpoint = model.state_dict()
 
 # get a list of mm tensor names