]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
embedding : more cli arguments (#7458)
authorYann Follet <redacted>
Mon, 24 Jun 2024 05:30:24 +0000 (13:30 +0800)
committerGitHub <redacted>
Mon, 24 Jun 2024 05:30:24 +0000 (08:30 +0300)
* add parameters for embeddings
--embd-normalize
--embd-output-format
--embd-separator
description in the README.md

* Update README.md

fix tipo

* Trailing whitespace

* fix json generation, use " not '

* fix merge master

* fix code formating
group of parameters // embedding
print usage for embedding parameters

---------

Co-authored-by: Brian <redacted>
common/common.cpp
common/common.h
examples/embedding/README.md
examples/embedding/embedding.cpp

index cfdedcbae0cd99a921e928914dff5b064cf30d78..1dc53265134a7acae55874bbc7373ffd9cd60a0e 100644 (file)
@@ -273,26 +273,22 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
     return true;
 }
 
+#define CHECK_ARG if (++i >= argc) { invalid_param = true; return true; }
+
 bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) {
     const char split_delim = ',';
 
     llama_sampling_params & sparams = params.sparams;
 
     if (arg == "-s" || arg == "--seed") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         // TODO: this is temporary, in the future the sampling state will be moved fully to llama_sampling_context.
         params.seed = std::stoul(argv[i]);
         sparams.seed = std::stoul(argv[i]);
         return true;
     }
     if (arg == "-t" || arg == "--threads") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.n_threads = std::stoi(argv[i]);
         if (params.n_threads <= 0) {
             params.n_threads = std::thread::hardware_concurrency();
@@ -300,10 +296,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         return true;
     }
     if (arg == "-tb" || arg == "--threads-batch") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.n_threads_batch = std::stoi(argv[i]);
         if (params.n_threads_batch <= 0) {
             params.n_threads_batch = std::thread::hardware_concurrency();
@@ -311,10 +304,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         return true;
     }
     if (arg == "-td" || arg == "--threads-draft") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.n_threads_draft = std::stoi(argv[i]);
         if (params.n_threads_draft <= 0) {
             params.n_threads_draft = std::thread::hardware_concurrency();
@@ -322,10 +312,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         return true;
     }
     if (arg == "-tbd" || arg == "--threads-batch-draft") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.n_threads_batch_draft = std::stoi(argv[i]);
         if (params.n_threads_batch_draft <= 0) {
             params.n_threads_batch_draft = std::thread::hardware_concurrency();
@@ -333,10 +320,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         return true;
     }
     if (arg == "-p" || arg == "--prompt") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.prompt = argv[i];
         return true;
     }
@@ -349,10 +333,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         return true;
     }
     if (arg == "--prompt-cache") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.path_prompt_cache = argv[i];
         return true;
     }
@@ -365,10 +346,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         return true;
     }
     if (arg == "-bf" || arg == "--binary-file") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         std::ifstream file(argv[i], std::ios::binary);
         if (!file) {
             fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
@@ -384,10 +362,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         return true;
     }
     if (arg == "-f" || arg == "--file") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         std::ifstream file(argv[i]);
         if (!file) {
             fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
@@ -403,10 +378,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         return true;
     }
     if (arg == "--in-file") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         std::ifstream file(argv[i]);
         if (!file) {
             fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
@@ -417,66 +389,42 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         return true;
     }
     if (arg == "-n" || arg == "--predict" || arg == "--n-predict") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.n_predict = std::stoi(argv[i]);
         return true;
     }
     if (arg == "--top-k") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         sparams.top_k = std::stoi(argv[i]);
         return true;
     }
     if (arg == "-c" || arg == "--ctx-size") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.n_ctx = std::stoi(argv[i]);
         return true;
     }
     if (arg == "--grp-attn-n" || arg == "-gan") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.grp_attn_n = std::stoi(argv[i]);
         return true;
     }
     if (arg == "--grp-attn-w" || arg == "-gaw") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.grp_attn_w = std::stoi(argv[i]);
         return true;
     }
     if (arg == "--rope-freq-base") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.rope_freq_base = std::stof(argv[i]);
         return true;
     }
     if (arg == "--rope-freq-scale") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.rope_freq_scale = std::stof(argv[i]);
         return true;
     }
     if (arg == "--rope-scaling") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         std::string value(argv[i]);
         /**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_NONE; }
         else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_LINEAR; }
@@ -485,58 +433,37 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         return true;
     }
     if (arg == "--rope-scale") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.rope_freq_scale = 1.0f / std::stof(argv[i]);
         return true;
     }
     if (arg == "--yarn-orig-ctx") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.yarn_orig_ctx = std::stoi(argv[i]);
         return true;
     }
     if (arg == "--yarn-ext-factor") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.yarn_ext_factor = std::stof(argv[i]);
         return true;
     }
     if (arg == "--yarn-attn-factor") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.yarn_attn_factor = std::stof(argv[i]);
         return true;
     }
     if (arg == "--yarn-beta-fast") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.yarn_beta_fast = std::stof(argv[i]);
         return true;
     }
     if (arg == "--yarn-beta-slow") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.yarn_beta_slow = std::stof(argv[i]);
         return true;
     }
     if (arg == "--pooling") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         std::string value(argv[i]);
         /**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
         else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
@@ -546,157 +473,100 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         return true;
     }
     if (arg == "--defrag-thold" || arg == "-dt") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.defrag_thold = std::stof(argv[i]);
         return true;
     }
     if (arg == "--samplers") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         const auto sampler_names = string_split(argv[i], ';');
         sparams.samplers_sequence = llama_sampling_types_from_names(sampler_names, true);
         return true;
     }
     if (arg == "--sampling-seq") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         sparams.samplers_sequence = llama_sampling_types_from_chars(argv[i]);
         return true;
     }
     if (arg == "--top-p") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         sparams.top_p = std::stof(argv[i]);
         return true;
     }
     if (arg == "--min-p") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         sparams.min_p = std::stof(argv[i]);
         return true;
     }
     if (arg == "--temp") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         sparams.temp = std::stof(argv[i]);
         sparams.temp = std::max(sparams.temp, 0.0f);
         return true;
     }
     if (arg == "--tfs") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         sparams.tfs_z = std::stof(argv[i]);
         return true;
     }
     if (arg == "--typical") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         sparams.typical_p = std::stof(argv[i]);
         return true;
     }
     if (arg == "--repeat-last-n") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         sparams.penalty_last_n = std::stoi(argv[i]);
         sparams.n_prev = std::max(sparams.n_prev, sparams.penalty_last_n);
         return true;
     }
     if (arg == "--repeat-penalty") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         sparams.penalty_repeat = std::stof(argv[i]);
         return true;
     }
     if (arg == "--frequency-penalty") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         sparams.penalty_freq = std::stof(argv[i]);
         return true;
     }
     if (arg == "--presence-penalty") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         sparams.penalty_present = std::stof(argv[i]);
         return true;
     }
     if (arg == "--dynatemp-range") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         sparams.dynatemp_range = std::stof(argv[i]);
         return true;
     }
     if (arg == "--dynatemp-exp") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         sparams.dynatemp_exponent = std::stof(argv[i]);
         return true;
     }
     if (arg == "--mirostat") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         sparams.mirostat = std::stoi(argv[i]);
         return true;
     }
     if (arg == "--mirostat-lr") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         sparams.mirostat_eta = std::stof(argv[i]);
         return true;
     }
     if (arg == "--mirostat-ent") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         sparams.mirostat_tau = std::stof(argv[i]);
         return true;
     }
     if (arg == "--cfg-negative-prompt") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         sparams.cfg_negative_prompt = argv[i];
         return true;
     }
     if (arg == "--cfg-negative-prompt-file") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         std::ifstream file(argv[i]);
         if (!file) {
             fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
@@ -710,203 +580,125 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         return true;
     }
     if (arg == "--cfg-scale") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         sparams.cfg_scale = std::stof(argv[i]);
         return true;
     }
     if (arg == "-b" || arg == "--batch-size") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.n_batch = std::stoi(argv[i]);
         return true;
     }
     if (arg == "-ub" || arg == "--ubatch-size") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.n_ubatch = std::stoi(argv[i]);
         return true;
     }
     if (arg == "--keep") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.n_keep = std::stoi(argv[i]);
         return true;
     }
     if (arg == "--draft") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.n_draft = std::stoi(argv[i]);
         return true;
     }
     if (arg == "--chunks") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.n_chunks = std::stoi(argv[i]);
         return true;
     }
     if (arg == "-np" || arg == "--parallel") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.n_parallel = std::stoi(argv[i]);
         return true;
     }
     if (arg == "-ns" || arg == "--sequences") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.n_sequences = std::stoi(argv[i]);
         return true;
     }
     if (arg == "--p-split" || arg == "-ps") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.p_split = std::stof(argv[i]);
         return true;
     }
     if (arg == "-m" || arg == "--model") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.model = argv[i];
         return true;
     }
     if (arg == "-md" || arg == "--model-draft") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.model_draft = argv[i];
         return true;
     }
     if (arg == "-a" || arg == "--alias") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.model_alias = argv[i];
         return true;
     }
     if (arg == "-mu" || arg == "--model-url") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.model_url = argv[i];
         return true;
     }
     if (arg == "-hfr" || arg == "--hf-repo") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.hf_repo = argv[i];
         return true;
     }
     if (arg == "-hff" || arg == "--hf-file") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.hf_file = argv[i];
         return true;
     }
     if (arg == "--lora") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.lora_adapter.emplace_back(argv[i], 1.0f);
         params.use_mmap = false;
         return true;
     }
     if (arg == "--lora-scaled") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         const char* lora_adapter = argv[i];
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.lora_adapter.emplace_back(lora_adapter, std::stof(argv[i]));
         params.use_mmap = false;
         return true;
     }
     if (arg == "--lora-base") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.lora_base = argv[i];
         return true;
     }
     if (arg == "--control-vector") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.control_vectors.push_back({ 1.0f, argv[i], });
         return true;
     }
     if (arg == "--control-vector-scaled") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         const char* fname = argv[i];
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.control_vectors.push_back({ std::stof(argv[i]), fname, });
         return true;
     }
     if (arg == "--control-vector-layer-range") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.control_vector_layer_start = std::stoi(argv[i]);
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.control_vector_layer_end = std::stoi(argv[i]);
         return true;
     }
     if (arg == "--mmproj") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.mmproj = argv[i];
         return true;
     }
     if (arg == "--image") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.image.emplace_back(argv[i]);
         return true;
     }
@@ -922,6 +714,21 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         params.embedding = true;
         return true;
     }
+    if (arg == "--embd-normalize") {
+        CHECK_ARG
+        params.embd_normalize = std::stoi(argv[i]);
+        return true;
+    }
+    if (arg == "--embd-output-format") {
+        CHECK_ARG
+        params.embd_out = argv[i];
+        return true;
+    }
+    if (arg == "--embd-separator") {
+        CHECK_ARG
+        params.embd_sep = argv[i];
+        return true;
+    }
     if (arg == "-if" || arg == "--interactive-first") {
         params.interactive_first = true;
         return true;
@@ -975,10 +782,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         return true;
     }
     if (arg == "-ngl" || arg == "--gpu-layers" || arg == "--n-gpu-layers") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.n_gpu_layers = std::stoi(argv[i]);
         if (!llama_supports_gpu_offload()) {
             fprintf(stderr, "warning: not compiled with GPU offload support, --gpu-layers option will be ignored\n");
@@ -987,10 +791,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         return true;
     }
     if (arg == "-ngld" || arg == "--gpu-layers-draft" || arg == "--gpu-layers-draft") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.n_gpu_layers_draft = std::stoi(argv[i]);
         if (!llama_supports_gpu_offload()) {
             fprintf(stderr, "warning: not compiled with GPU offload support, --gpu-layers-draft option will be ignored\n");
@@ -999,10 +800,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         return true;
     }
     if (arg == "--main-gpu" || arg == "-mg") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.main_gpu = std::stoi(argv[i]);
 #ifndef GGML_USE_CUDA_SYCL_VULKAN
         fprintf(stderr, "warning: llama.cpp was compiled without CUDA/SYCL/Vulkan. Setting the main GPU has no effect.\n");
@@ -1010,10 +808,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         return true;
     }
     if (arg == "--split-mode" || arg == "-sm") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         std::string arg_next = argv[i];
         if (arg_next == "none") {
             params.split_mode = LLAMA_SPLIT_MODE_NONE;
@@ -1038,10 +833,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         return true;
     }
     if (arg == "--tensor-split" || arg == "-ts") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         std::string arg_next = argv[i];
 
         // split string by , and /
@@ -1066,10 +858,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         return true;
     }
     if (arg == "--rpc") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.rpc_servers = argv[i];
         return true;
     }
@@ -1078,10 +867,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         return true;
     }
     if (arg == "--numa") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         std::string value(argv[i]);
         /**/ if (value == "distribute" || value == "") { params.numa = GGML_NUMA_STRATEGY_DISTRIBUTE; }
         else if (value == "isolate") { params.numa = GGML_NUMA_STRATEGY_ISOLATE; }
@@ -1094,10 +880,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         return true;
     }
     if (arg == "--verbosity") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.verbosity = std::stoi(argv[i]);
         return true;
     }
@@ -1110,18 +893,12 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         return true;
     }
     if (arg == "-r" || arg == "--reverse-prompt") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.antiprompt.emplace_back(argv[i]);
         return true;
     }
     if (arg == "-ld" || arg == "--logdir") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.logdir = argv[i];
 
         if (params.logdir.back() != DIRECTORY_SEPARATOR) {
@@ -1130,26 +907,17 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         return true;
     }
     if (arg == "-lcs" || arg == "--lookup-cache-static") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.lookup_cache_static = argv[i];
         return true;
     }
     if (arg == "-lcd" || arg == "--lookup-cache-dynamic") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.lookup_cache_dynamic = argv[i];
         return true;
     }
     if (arg == "--save-all-logits" || arg == "--kl-divergence-base") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.logits_file = argv[i];
         return true;
     }
@@ -1158,26 +926,17 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         return true;
     }
     if (arg == "--ppl-stride") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.ppl_stride = std::stoi(argv[i]);
         return true;
     }
     if (arg == "--ppl-output-type") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.ppl_output_type = std::stoi(argv[i]);
         return true;
     }
     if (arg == "-ptc" || arg == "--print-token-count") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.n_print = std::stoi(argv[i]);
         return true;
     }
@@ -1190,10 +949,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         return true;
     }
     if (arg == "--hellaswag-tasks") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.hellaswag_tasks = std::stoi(argv[i]);
         return true;
     }
@@ -1202,10 +958,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         return true;
     }
     if (arg == "--winogrande-tasks") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.winogrande_tasks = std::stoi(argv[i]);
         return true;
     }
@@ -1214,10 +967,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         return true;
     }
     if (arg == "--multiple-choice-tasks") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.multiple_choice_tasks = std::stoi(argv[i]);
         return true;
     }
@@ -1234,10 +984,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         return true;
     }
     if (arg == "-l" || arg == "--logit-bias") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         std::stringstream ss(argv[i]);
         llama_token key;
         char sign;
@@ -1270,34 +1017,22 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         return true;
     }
     if (arg == "--in-prefix") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.input_prefix = argv[i];
         return true;
     }
     if (arg == "--in-suffix") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.input_suffix = argv[i];
         return true;
     }
     if (arg == "--grammar") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         sparams.grammar = argv[i];
         return true;
     }
     if (arg == "--grammar-file") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         std::ifstream file(argv[i]);
         if (!file) {
             fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
@@ -1312,18 +1047,12 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         return true;
     }
     if (arg == "-j" || arg == "--json-schema") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         sparams.grammar = json_schema_to_grammar(json::parse(argv[i]));
         return true;
     }
     if (arg == "--override-kv") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         if (!string_parse_kv_override(argv[i], params.kv_overrides)) {
             fprintf(stderr, "error: Invalid type for KV override: %s\n", argv[i]);
             invalid_param = true;
@@ -1332,42 +1061,27 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         return true;
     }
     if (arg == "--host") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.hostname = argv[i];
         return true;
     }
     if (arg == "--port") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.port = std::stoi(argv[i]);
         return true;
     }
     if (arg == "--path") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.public_path = argv[i];
         return true;
     }
     if (arg == "--api-key") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.api_keys.push_back(argv[i]);
         return true;
     }
     if (arg == "--api-key-file") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         std::ifstream key_file(argv[i]);
         if (!key_file) {
             fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
@@ -1384,43 +1098,28 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         return true;
     }
     if (arg == "--ssl-key-file") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.ssl_file_key = argv[i];
         return true;
     }
     if (arg == "--ssl-cert-file") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.ssl_file_cert = argv[i];
         return true;
     }
     if (arg == "--timeout" || arg == "-to") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.timeout_read  = std::stoi(argv[i]);
         params.timeout_write = std::stoi(argv[i]);
         return true;
     }
     if (arg == "--threads-http") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.n_threads_http = std::stoi(argv[i]);
         return true;
     }
     if (arg == "-spf" || arg == "--system-prompt-file") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         std::ifstream file(argv[i]);
         if (!file) {
             fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
@@ -1437,10 +1136,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         return true;
     }
     if (arg == "--log-format") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         if (std::strcmp(argv[i], "json") == 0) {
             params.log_json = true;
         } else if (std::strcmp(argv[i], "text") == 0) {
@@ -1460,10 +1156,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         return true;
     }
     if (arg == "--slot-save-path") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.slot_save_path = argv[i];
         // if doesn't end with DIRECTORY_SEPARATOR, add it
         if (!params.slot_save_path.empty() && params.slot_save_path[params.slot_save_path.size() - 1] != DIRECTORY_SEPARATOR) {
@@ -1472,10 +1165,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         return true;
     }
     if (arg == "--chat-template") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         if (!llama_chat_verify_template(argv[i])) {
             fprintf(stderr, "error: the supplied chat template is not supported: %s\n", argv[i]);
             fprintf(stderr, "note: llama.cpp does not use jinja parser, we only support commonly used templates\n");
@@ -1486,10 +1176,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         return true;
     }
     if (arg == "--slot-prompt-similarity" || arg == "-sps") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.slot_prompt_similarity = std::stof(argv[i]);
         return true;
     }
@@ -1498,37 +1185,25 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         return true;
     }
     if (arg == "-npp") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         auto p = string_split<int>(argv[i], split_delim);
         params.n_pp.insert(params.n_pp.end(), p.begin(), p.end());
         return true;
     }
     if (arg == "-ntg") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         auto p = string_split<int>(argv[i], split_delim);
         params.n_tg.insert(params.n_tg.end(), p.begin(), p.end());
         return true;
     }
     if (arg == "-npl") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         auto p = string_split<int>(argv[i], split_delim);
         params.n_pl.insert(params.n_pl.end(), p.begin(), p.end());
         return true;
     }
     if (arg == "--context-file") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         std::ifstream file(argv[i], std::ios::binary);
         if (!file) {
             fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
@@ -1539,59 +1214,38 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         return true;
     }
     if (arg == "--chunk-size") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.chunk_size = std::stoi(argv[i]);
         return true;
     }
     if (arg == "--chunk-separator") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.chunk_separator = argv[i];
         return true;
     }
     if (arg == "--junk") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.n_junk = std::stoi(argv[i]);
         return true;
     }
     if (arg == "--pos") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.i_pos = std::stoi(argv[i]);
         return true;
     }
     if (arg == "-o" || arg == "--output" || arg == "--output-file") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.out_file = argv[i];
         params.cvector_outfile = argv[i];
         return true;
     }
     if (arg == "-ofreq" || arg == "--output-frequency") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.n_out_freq = std::stoi(argv[i]);
         return true;
     }
     if (arg == "--save-frequency") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.n_save_freq = std::stoi(argv[i]);
         return true;
     }
@@ -1604,59 +1258,38 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         return true;
     }
     if (arg == "--chunk" || arg == "--from-chunk") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.i_chunk = std::stoi(argv[i]);
         return true;
     }
     // cvector params
     if (arg == "--completions-file") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.cvector_completions_file = argv[i];
         return true;
     }
     if (arg == "--positive-file") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.cvector_positive_file = argv[i];
         return true;
     }
     if (arg == "--negative-file") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.cvector_negative_file = argv[i];
         return true;
     }
     if (arg == "--completions") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.n_completions = std::stoi(argv[i]);
         return true;
     }
     if (arg == "--pca-batch") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.n_pca_batch = std::stoi(argv[i]);
         return true;
     }
     if (arg == "--pca-iter") {
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         params.n_pca_iterations = std::stoi(argv[i]);
         return true;
     }
@@ -1671,10 +1304,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         // We have a matching known parameter requiring an argument,
         //  now we need to check if there is anything after this argv
         //  and flag invalid_param or parse it.
-        if (++i >= argc) {
-            invalid_param = true;
-            return true;
-        }
+        CHECK_ARG
         if (!log_param_pair_parse( /*check_but_dont_parse*/ false, argv[i - 1], argv[i])) {
             invalid_param = true;
             return true;
@@ -1944,6 +1574,11 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
     options.push_back({ "bench",       "-ntg n0,n1,...",                "number of text generation tokens" });
     options.push_back({ "bench",       "-npl n0,n1,...",                "number of parallel prompts" });
 
+    options.push_back({ "embedding" });
+    options.push_back({ "embedding",   "       --embd-normalize",       "normalisation for embendings (default: %d) (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)", params.embd_normalize });
+    options.push_back({ "embedding",   "       --embd-output-format",   "empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix" });
+    options.push_back({ "embedding",   "       --embd-separator",       "separator of embendings (default \\n) for example \"<#sep#>\"" });
+
     options.push_back({ "server" });
     options.push_back({ "server",      "       --host HOST",            "ip address to listen (default: %s)", params.hostname.c_str() });
     options.push_back({ "server",      "       --port PORT",            "port to listen (default: %d)", params.port });
@@ -3052,14 +2687,34 @@ void llama_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_siz
 // Embedding utils
 //
 
-void llama_embd_normalize(const float * inp, float * out, int n) {
+void llama_embd_normalize(const float * inp, float * out, int n, int embd_norm) {
     double sum = 0.0;
-    for (int i = 0; i < n; i++) {
-        sum += inp[i] * inp[i];
+
+    switch (embd_norm) {
+        case -1: // no normalisation
+            sum = 1.0;
+            break;
+        case 0: // max absolute
+            for (int i = 0; i < n; i++) {
+                if (sum < std::abs(inp[i])) sum = std::abs(inp[i]);
+            }
+            sum /= 32760.0; // make an int16 range
+            break;
+        case 2: // euclidean
+            for (int i = 0; i < n; i++) {
+                sum += inp[i] * inp[i];
+            }
+            sum = std::sqrt(sum);
+            break;
+        default: // p-norm (euclidean is p-norm p=2)
+            for (int i = 0; i < n; i++) {
+                sum += std::pow(std::abs(inp[i]), embd_norm);
+            }
+            sum = std::pow(sum, 1.0 / embd_norm);
+            break;
     }
-    sum = sqrt(sum);
 
-    const float norm = sum > 0.0 ? 1.0f / sum : 0.0f;
+    const float norm = sum > 0.0 ? 1.0 / sum : 0.0f;
 
     for (int i = 0; i < n; i++) {
         out[i] = inp[i] * norm;
@@ -3077,6 +2732,14 @@ float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n)
         sum2 += embd2[i] * embd2[i];
     }
 
+    // Handle the case where one or both vectors are zero vectors
+    if (sum1 == 0.0 || sum2 == 0.0) {
+        if (sum1 == 0.0 && sum2 == 0.0) {
+            return 1.0f; // two zero vectors are similar
+        }
+        return 0.0f;
+    }
+
     return sum / (sqrt(sum1) * sqrt(sum2));
 }
 
index 9a1dc4a2fe4c1b49993e02cd8e2c67d112525c48..a5c738f8b643f87787f9ea4ec2dafcd01cef7887 100644 (file)
@@ -152,7 +152,6 @@ struct gpt_params {
     bool prompt_cache_all  = false; // save user input and generations to prompt cache
     bool prompt_cache_ro   = false; // open the prompt cache read-only and do not update it
 
-    bool embedding         = false; // get only sentence embedding
     bool escape            = true;  // escape "\n", "\r", "\t", "\'", "\"", and "\\"
     bool multiline_input   = false; // reverse the usage of `\`
     bool simple_io         = false; // improves compatibility with subprocesses and limited consoles
@@ -179,6 +178,12 @@ struct gpt_params {
     std::string mmproj = "";        // path to multimodal projector
     std::vector<std::string> image; // path to image file(s)
 
+    // embedding
+    bool embedding         = false; // get only sentence embedding
+    int32_t embd_normalize = 2;     // normalisation for embendings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
+    std::string embd_out   = "";    // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
+    std::string embd_sep   = "\n";  // separator of embendings
+
     // server params
     int32_t port           = 8080;         // server listens on this network port
     int32_t timeout_read   = 600;          // http read timeout in seconds
@@ -377,7 +382,7 @@ void llama_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_siz
 // Embedding utils
 //
 
-void llama_embd_normalize(const float * inp, float * out, int n);
+void llama_embd_normalize(const float * inp, float * out, int n, int embd_norm = 2);
 
 float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n);
 
index 2298ec3e7fcb34c63533572da379104cab20d33c..86df1895878465c11482effa9bed005ea5528f10 100644 (file)
@@ -19,3 +19,43 @@ llama-embedding.exe -m ./path/to/model --log-disable -p "Hello World!" 2>$null
 ```
 
 The above command will output space-separated float values.
+
+## extra parameters
+### --embd-normalize $integer$
+| $integer$ | description         | formula |
+|-----------|---------------------|---------|
+| $-1$      | none                |
+| $0$       | max absolute int16  | $\Large{{32760 * x_i} \over\max \lvert x_i\rvert}$
+| $1$       | taxicab             | $\Large{x_i \over\sum \lvert x_i\rvert}$
+| $2$       | euclidean (default) | $\Large{x_i \over\sqrt{\sum x_i^2}}$
+| $>2$      | p-norm              | $\Large{x_i \over\sqrt[p]{\sum \lvert x_i\rvert^p}}$
+
+### --embd-output-format $'string'$
+| $'string'$ | description                  |  |
+|------------|------------------------------|--|
+| ''         | same as before               | (default)
+| 'array'    | single embeddings            | $[[x_1,...,x_n]]$
+|            | multiple embeddings          | $[[x_1,...,x_n],[x_1,...,x_n],...,[x_1,...,x_n]]$
+| 'json'     | openai style                 |
+| 'json+'    | add cosine similarity matrix |
+
+### --embd-separator $"string"$
+| $"string"$   | |
+|--------------|-|
+| "\n"         | (default)
+| "<#embSep#>" | for exemple
+| "<#sep#>"    | other exemple
+
+## examples
+### Unix-based systems (Linux, macOS, etc.):
+
+```bash
+./embedding -p 'Castle<#sep#>Stronghold<#sep#>Dog<#sep#>Cat' --embd-separator '<#sep#>' --embd-normalize 2  --embd-output-format '' -m './path/to/model.gguf' --n-gpu-layers 99 --log-disable 2>/dev/null
+```
+
+### Windows:
+
+```powershell
+embedding.exe -p 'Castle<#sep#>Stronghold<#sep#>Dog<#sep#>Cat' --embd-separator '<#sep#>' --embd-normalize 2  --embd-output-format '' -m './path/to/model.gguf' --n-gpu-layers 99 --log-disable 2>/dev/null
+```
+
index b4b73c0175cda0beea53a24375ee080f67da6d2a..1466e5b2bc512213008819f41c5005ac1ac123da 100644 (file)
@@ -7,13 +7,19 @@
 #pragma warning(disable: 4244 4267) // possible loss of data
 #endif
 
-static std::vector<std::string> split_lines(const std::string & s) {
-    std::string line;
+static std::vector<std::string> split_lines(const std::string & s, const std::string & separator = "\n") {
     std::vector<std::string> lines;
-    std::stringstream ss(s);
-    while (std::getline(ss, line)) {
-        lines.push_back(line);
+    size_t start = 0;
+    size_t end = s.find(separator);
+
+    while (end != std::string::npos) {
+        lines.push_back(s.substr(start, end - start));
+        start = end + separator.length();
+        end = s.find(separator, start);
     }
+
+    lines.push_back(s.substr(start)); // Add the last part
+
     return lines;
 }
 
@@ -24,7 +30,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
     }
 }
 
-static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
+static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) {
     // clear previous kv_cache values (irrelevant for embeddings)
     llama_kv_cache_clear(ctx);
 
@@ -44,13 +50,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
         GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
 
         float * out = output + batch.seq_id[i][0] * n_embd;
-        //TODO: I would also add a parameter here to enable normalization or not.
-        /*fprintf(stdout, "unnormalized_embedding:");
-        for (int hh = 0; hh < n_embd; hh++) {
-            fprintf(stdout, "%9.6f ", embd[hh]);
-        }
-        fprintf(stdout, "\n");*/
-        llama_embd_normalize(embd, out, n_embd);
+        llama_embd_normalize(embd, out, n_embd, embd_norm);
     }
 }
 
@@ -110,7 +110,7 @@ int main(int argc, char ** argv) {
     }
 
     // split the prompt into lines
-    std::vector<std::string> prompts = split_lines(params.prompt);
+    std::vector<std::string> prompts = split_lines(params.prompt, params.embd_sep);
 
     // max batch size
     const uint64_t n_batch = params.n_batch;
@@ -170,7 +170,7 @@ int main(int argc, char ** argv) {
         // encode if at capacity
         if (batch.n_tokens + n_toks > n_batch) {
             float * out = emb + p * n_embd;
-            batch_decode(ctx, batch, out, s, n_embd);
+            batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
             llama_batch_clear(batch);
             p += s;
             s = 0;
@@ -183,29 +183,78 @@ int main(int argc, char ** argv) {
 
     // final batch
     float * out = emb + p * n_embd;
-    batch_decode(ctx, batch, out, s, n_embd);
-
-    // print the first part of the embeddings or for a single prompt, the full embedding
-    fprintf(stdout, "\n");
-    for (int j = 0; j < n_prompts; j++) {
-        fprintf(stdout, "embedding %d: ", j);
-        for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) {
-            fprintf(stdout, "%9.6f ", emb[j * n_embd + i]);
-        }
-        fprintf(stdout, "\n");
-    }
+    batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
 
-    // print cosine similarity matrix
-    if (n_prompts > 1) {
+    if (params.embd_out.empty()) {
+        // print the first part of the embeddings or for a single prompt, the full embedding
         fprintf(stdout, "\n");
-        printf("cosine similarity matrix:\n\n");
-        for (int i = 0; i < n_prompts; i++) {
-            for (int j = 0; j < n_prompts; j++) {
-                float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
-                fprintf(stdout, "%6.2f ", sim);
+        for (int j = 0; j < n_prompts; j++) {
+            fprintf(stdout, "embedding %d: ", j);
+            for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) {
+                if (params.embd_normalize == 0) {
+                    fprintf(stdout, "%6.0f ", emb[j * n_embd + i]);
+                } else {
+                    fprintf(stdout, "%9.6f ", emb[j * n_embd + i]);
+                }
+            }
+            fprintf(stdout, "\n");
+        }
+
+        // print cosine similarity matrix
+        if (n_prompts > 1) {
+            fprintf(stdout, "\n");
+            printf("cosine similarity matrix:\n\n");
+            for (int i = 0; i < n_prompts; i++) {
+                fprintf(stdout, "%6.6s ", prompts[i].c_str());
             }
             fprintf(stdout, "\n");
+            for (int i = 0; i < n_prompts; i++) {
+                for (int j = 0; j < n_prompts; j++) {
+                    float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
+                    fprintf(stdout, "%6.2f ", sim);
+                }
+                fprintf(stdout, "%1.10s", prompts[i].c_str());
+                fprintf(stdout, "\n");
+            }
+        }
+    }
+
+    if (params.embd_out == "json" || params.embd_out == "json+" || params.embd_out == "array") {
+        const bool notArray = params.embd_out != "array";
+
+        fprintf(stdout, notArray ? "{\n  \"object\": \"list\",\n  \"data\": [\n" : "[");
+        for (int j = 0;;) { // at least one iteration (one prompt)
+            if (notArray) fprintf(stdout, "    {\n      \"object\": \"embedding\",\n      \"index\": %d,\n      \"embedding\": ",j);
+            fprintf(stdout, "[");
+            for (int i = 0;;) { // at least one iteration (n_embd > 0)
+                fprintf(stdout, params.embd_normalize == 0 ? "%1.0f" : "%1.7f", emb[j * n_embd + i]);
+                i++;
+                if (i < n_embd) fprintf(stdout, ","); else break;
+            }
+            fprintf(stdout, notArray ? "]\n    }" : "]");
+            j++;
+            if (j < n_prompts) fprintf(stdout, notArray ? ",\n" : ","); else break;
         }
+        fprintf(stdout, notArray ? "\n  ]" : "]\n");
+
+        if (params.embd_out == "json+" && n_prompts > 1) {
+            fprintf(stdout, ",\n  \"cosineSimilarity\": [\n");
+            for (int i = 0;;) { // at least two iteration (n_prompts > 1)
+                fprintf(stdout, "    [");
+                for (int j = 0;;) { // at least two iteration (n_prompts > 1)
+                    float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
+                    fprintf(stdout, "%6.2f", sim);
+                    j++;
+                    if (j < n_prompts) fprintf(stdout, ", "); else break;
+                }
+                fprintf(stdout, " ]");
+                i++;
+                if (i < n_prompts) fprintf(stdout, ",\n"); else break;
+            }
+            fprintf(stdout, "\n  ]");
+        }
+
+        if (notArray) fprintf(stdout, "\n}\n");
     }
 
     // clean up