]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
[Fix] Reenable server embedding endpoint (#1937)
authorHenri Vasserman <redacted>
Mon, 19 Jun 2023 22:12:39 +0000 (01:12 +0300)
committerGitHub <redacted>
Mon, 19 Jun 2023 22:12:39 +0000 (01:12 +0300)
* Add back embedding feature

* Update README

examples/server/README.md
examples/server/server.cpp

index 474a28b20018ff5308b4a95551d69fc203e0a5ff..fa95c00441bc220c1363c5e649fe42fb2c634145 100644 (file)
@@ -21,6 +21,7 @@ Command line options:
 -   `-to N`, `--timeout N`: Server read/write timeout in seconds. Default `600`.
 -   `--host`: Set the hostname or ip address to listen. Default `127.0.0.1`.
 -   `--port`: Set the port to listen. Default: `8080`.
+-   `--embedding`: Enable embedding extraction, Default: disabled.
 
 ## Build
 
@@ -119,14 +120,14 @@ node .
 
     `top_p`: Limit the next token selection to a subset of tokens with a cumulative probability above a threshold P (default: 0.9).
 
-    `n_predict`: Set the number of tokens to predict when generating text. **Note:** May exceed the set limit slightly if the last token is a partial multibyte character.  (default: 128, -1 = infinity).
+    `n_predict`: Set the number of tokens to predict when generating text. **Note:** May exceed the set limit slightly if the last token is a partial multibyte character. When 0, no tokens will be generated but the prompt is evaluated into the cache. (default: 128, -1 = infinity).
 
     `n_keep`: Specify the number of tokens from the initial prompt to retain when the model resets its internal context.
     By default, this value is set to 0 (meaning no tokens are kept). Use `-1` to retain all tokens from the initial prompt.
 
     `stream`: It allows receiving each predicted token in real-time instead of waiting for the completion to finish. To enable this, set to `true`.
 
-    `prompt`: Provide a prompt. Internally, the prompt is compared, and it detects if a part has already been evaluated, and the remaining part will be evaluate.
+    `prompt`: Provide a prompt. Internally, the prompt is compared, and it detects if a part has already been evaluated, and the remaining part will be evaluate. A space is inserted in the front like main.cpp does.
 
     `stop`: Specify a JSON array of stopping strings.
     These words will not be included in the completion, so make sure to add them to the prompt for the next iteration (default: []).
@@ -163,6 +164,14 @@ node .
 
     `content`: Set the text to tokenize.
 
+    Note that the special `BOS` token is not added in fron of the text and also a space character is not inserted automatically as it is for `/completion`.
+
+-   **POST** `/embedding`: Generate embedding of a given text just as [the embedding example](../embedding) does.
+
+    *Options:*
+
+    `content`: Set the text to process.
+
 ## More examples
 
 ### Interactive mode
index 12d4e2fa4b4f55dbe3c74db2c5050401f569429e..c0984aadb92ba733dcc47170bba255c7238e86c3 100644 (file)
@@ -254,6 +254,11 @@ struct llama_server_context {
             n_past += n_eval;
         }
 
+        if (params.n_predict == 0) {
+            has_next_token = false;
+            return llama_token_eos();
+        }
+
         // out of user input, sample next token
         const float temp = params.temp;
         const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
@@ -419,6 +424,19 @@ struct llama_server_context {
 
         return token_text;
     }
+
+    std::vector<float> getEmbedding() {
+        static const int n_embd = llama_n_embd(ctx);
+        if (!params.embedding) {
+            LOG_WARNING("embedding disabled", {
+                { "params.embedding", params.embedding },
+            });
+            return std::vector<float>(n_embd, 0.0f);
+        }
+        const float * data = llama_get_embeddings(ctx);
+        std::vector<float> embedding(data, data + n_embd);
+        return embedding;
+    }
 };
 
 static void server_print_usage(const char * argv0, const gpt_params & params,
@@ -457,6 +475,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params,
     fprintf(stderr, "  --host                ip address to listen (default  (default: %s)\n", sparams.hostname.c_str());
     fprintf(stderr, "  --port PORT           port to listen (default  (default: %d)\n", sparams.port);
     fprintf(stderr, "  -to N, --timeout N    server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
+    fprintf(stderr, "  --embedding           enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
     fprintf(stderr, "\n");
 }
 
@@ -603,6 +622,8 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
             params.use_mlock = true;
         } else if (arg == "--no-mmap") {
             params.use_mmap = false;
+        } else if (arg == "--embedding") {
+            params.embedding = true;
         } else {
             fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
             server_print_usage(argv[0], default_params, default_sparams);
@@ -646,6 +667,12 @@ static json format_generation_settings(llama_server_context & llama) {
     };
 }
 
+static json format_embedding_response(llama_server_context & llama) {
+    return json {
+        { "embedding", llama.getEmbedding() },
+    };
+}
+
 static json format_final_response(llama_server_context & llama, const std::string & content) {
     return json {
         { "content", content },
@@ -881,12 +908,27 @@ int main(int argc, char ** argv) {
 
     svr.Post("/tokenize", [&llama](const Request & req, Response & res) {
         const json body = json::parse(req.body);
-        const std::string content = body["content"].get<std::string>();
+        const std::string content = body.value("content", "");
         const std::vector<llama_token> tokens = llama_tokenize(llama.ctx, content, false);
         const json data = format_tokenizer_response(tokens);
         return res.set_content(data.dump(), "application/json");
     });
 
+    svr.Post("/embedding", [&llama](const Request & req, Response & res) {
+        const json body = json::parse(req.body);
+
+        llama.rewind();
+        llama_reset_timings(llama.ctx);
+        llama.params.prompt = body.value("content", "");
+        llama.params.n_predict = 0;
+        llama.loadPrompt();
+        llama.beginCompletion();
+        llama.doCompletion();
+
+        const json data = format_embedding_response(llama);
+        return res.set_content(data.dump(), "application/json");
+    });
+
     svr.set_logger(log_server_request);
 
     svr.set_exception_handler([](const Request &, Response & res, std::exception_ptr ep) {