]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llava-cli: fix base64 prompt (#7248)
authork.h.lai <redacted>
Mon, 13 May 2024 14:02:36 +0000 (22:02 +0800)
committerGitHub <redacted>
Mon, 13 May 2024 14:02:36 +0000 (00:02 +1000)
examples/llava/llava-cli.cpp

index da60ddf2f057dba3a5de2679707faf7b9e1cd840..a6d67e5d72cd2882e8aa30efaa6b6211973675a4 100644 (file)
@@ -300,14 +300,10 @@ int main(int argc, char ** argv) {
         return 1;
     }
 
-    for (auto & image : params.image) {
+    if (prompt_contains_image(params.prompt)) {
         auto ctx_llava = llava_init_context(&params, model);
 
-        auto image_embed = load_image(ctx_llava, &params, image);
-        if (!image_embed) {
-            std::cerr << "error: failed to load image " << image << ". Terminating\n\n";
-            return 1;
-        }
+        auto image_embed = load_image(ctx_llava, &params, "");
 
         // process the prompt
         process_prompt(ctx_llava, image_embed, &params, params.prompt);
@@ -316,7 +312,26 @@ int main(int argc, char ** argv) {
         llava_image_embed_free(image_embed);
         ctx_llava->model = NULL;
         llava_free(ctx_llava);
+    } else {
+        for (auto & image : params.image) {
+            auto ctx_llava = llava_init_context(&params, model);
+
+            auto image_embed = load_image(ctx_llava, &params, image);
+            if (!image_embed) {
+                std::cerr << "error: failed to load image " << image << ". Terminating\n\n";
+                return 1;
+            }
+
+            // process the prompt
+            process_prompt(ctx_llava, image_embed, &params, params.prompt);
+
+            llama_print_timings(ctx_llava->ctx_llama);
+            llava_image_embed_free(image_embed);
+            ctx_llava->model = NULL;
+            llava_free(ctx_llava);
+        }
     }
+
     llama_free_model(model);
 
     return 0;