]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llava-cli : multiple images (#6969)
authorcpumaxx <redacted>
Mon, 29 Apr 2024 14:34:24 +0000 (07:34 -0700)
committerGitHub <redacted>
Mon, 29 Apr 2024 14:34:24 +0000 (17:34 +0300)
Co-authored-by: root <redacted>
common/common.cpp
common/common.h
examples/llava/llava-cli.cpp

index aa494291dd52bea1c60e26afd047da2c76636801..fe84039f76e551c2aff8d7f01c6528e140c5cd8d 100644 (file)
@@ -893,7 +893,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
             invalid_param = true;
             return true;
         }
-        params.image = argv[i];
+        params.image.emplace_back(argv[i]);
         return true;
     }
     if (arg == "-i" || arg == "--interactive") {
@@ -1495,7 +1495,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     printf("  -ps N, --p-split N    speculative decoding split probability (default: %.1f)\n", (double)params.p_split);
     printf("  -cb, --cont-batching  enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
     printf("  --mmproj MMPROJ_FILE  path to a multimodal projector file for LLaVA. see examples/llava/README.md\n");
-    printf("  --image IMAGE_FILE    path to an image file. use with multimodal models\n");
+    printf("  --image IMAGE_FILE    path to an image file. use with multimodal models. Specify multiple times for batching\n");
     if (llama_supports_mlock()) {
         printf("  --mlock               force system to keep model in RAM rather than swapping or compressing\n");
     }
index eea63a1142a4dd15d66c370f54c2bc00e6e136e9..3233d90e69eb5fc2abe79caea15613e1fc62b748 100644 (file)
@@ -167,8 +167,8 @@ struct gpt_params {
     std::string cache_type_v = "f16"; // KV cache data type for the V
 
     // multimodal models (see examples/llava)
-    std::string mmproj = ""; // path to multimodal projector
-    std::string image  = ""; // path to an image file
+    std::string mmproj = "";        // path to multimodal projector
+    std::vector<std::string> image; // path to image file(s)
 };
 
 bool parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
index a44c6cd7632c788164a9f7ea794e155cc2f0c9c0..157a680b5ecdb04d3ed5868b28fd52db829bc36b 100644 (file)
@@ -113,11 +113,11 @@ struct llava_context {
 };
 
 static void show_additional_info(int /*argc*/, char ** argv) {
-    LOG_TEE("\n example usage: %s -m <llava-v1.5-7b/ggml-model-q5_k.gguf> --mmproj <llava-v1.5-7b/mmproj-model-f16.gguf> --image <path/to/an/image.jpg> [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]);
+    LOG_TEE("\n example usage: %s -m <llava-v1.5-7b/ggml-model-q5_k.gguf> --mmproj <llava-v1.5-7b/mmproj-model-f16.gguf> --image <path/to/an/image.jpg> --image <path/to/another/image.jpg> [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]);
     LOG_TEE("  note: a lower temperature value like 0.1 is recommended for better quality.\n");
 }
 
-static struct llava_image_embed * load_image(llava_context * ctx_llava, gpt_params * params) {
+static struct llava_image_embed * load_image(llava_context * ctx_llava, gpt_params * params, const std::string & fname) {
 
     // load and preprocess the image
     llava_image_embed * embed = NULL;
@@ -133,9 +133,9 @@ static struct llava_image_embed * load_image(llava_context * ctx_llava, gpt_para
         }
         params->prompt = remove_image_from_prompt(prompt);
     } else {
-        embed = llava_image_embed_make_with_filename(ctx_llava->ctx_clip, params->n_threads, params->image.c_str());
+        embed = llava_image_embed_make_with_filename(ctx_llava->ctx_clip, params->n_threads, fname.c_str());
         if (!embed) {
-            LOG_TEE("%s: is %s really an image file?\n", __func__, params->image.c_str());
+            fprintf(stderr, "%s: is %s really an image file?\n", __func__, fname.c_str());
             return NULL;
         }
     }
@@ -207,17 +207,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
     printf("\n");
 }
 
-
-static struct llava_context * llava_init(gpt_params * params) {
-    const char * clip_path = params->mmproj.c_str();
-
-    auto prompt = params->prompt;
-    if (prompt.empty()) {
-        prompt = "describe the image in detail.";
-    }
-
-    auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1);
-
+static struct llama_model * llava_init(gpt_params * params) {
     llama_backend_init();
     llama_numa_init(params->numa);
 
@@ -228,6 +218,19 @@ static struct llava_context * llava_init(gpt_params * params) {
         LOG_TEE("%s: error: unable to load model\n" , __func__);
         return NULL;
     }
+    return model;
+}
+
+static struct llava_context * llava_init_context(gpt_params * params, llama_model * model) {
+    const char * clip_path = params->mmproj.c_str();
+
+    auto prompt = params->prompt;
+    if (prompt.empty()) {
+        prompt = "describe the image in detail.";
+    }
+
+    auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1);
+
 
     llama_context_params ctx_params = llama_context_params_from_gpt_params(*params);
     ctx_params.n_ctx           = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings
@@ -286,24 +289,30 @@ int main(int argc, char ** argv) {
         show_additional_info(argc, argv);
         return 1;
     }
-
-    auto ctx_llava = llava_init(&params);
-    if (ctx_llava == NULL) {
-        LOG_TEE("%s: error: failed to init llava\n", __func__);
+    auto model = llava_init(&params);
+    if (model == NULL) {
+        fprintf(stderr, "%s: error: failed to init llava model\n", __func__);
         return 1;
     }
 
-    auto image_embed = load_image(ctx_llava, &params);
-    if (!image_embed) {
-        return 1;
-    }
+    for (auto & image : params.image) {
+        auto ctx_llava = llava_init_context(&params, model);
 
-    // process the prompt
-    process_prompt(ctx_llava, image_embed, &params, params.prompt);
+        auto image_embed = load_image(ctx_llava, &params, image);
+        if (!image_embed) {
+            std::cerr << "error: failed to load image " << image << ". Terminating\n\n";
+            return 1;
+        }
+
+        // process the prompt
+        process_prompt(ctx_llava, image_embed, &params, params.prompt);
 
-    llama_print_timings(ctx_llava->ctx_llama);
+        llama_print_timings(ctx_llava->ctx_llama);
+        llava_image_embed_free(image_embed);
+        ctx_llava->model = NULL;
+        llava_free(ctx_llava);
+    }
+    llama_free_model(model);
 
-    llava_image_embed_free(image_embed);
-    llava_free(ctx_llava);
     return 0;
 }