]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : ensure batches are either all embed or all completion (#8420)
authorDouglas Hanley <redacted>
Fri, 12 Jul 2024 08:14:12 +0000 (03:14 -0500)
committerGitHub <redacted>
Fri, 12 Jul 2024 08:14:12 +0000 (11:14 +0300)
* make sure batches are all embed or all non-embed

* non-embedding batch for sampled tokens; fix unused params warning

examples/server/server.cpp

index efd426289a54cb4a51a0b7d72f24dab60755023e..badeb9121324facf357d426b3ae274ff23b837fd 100644 (file)
@@ -2005,6 +2005,11 @@ struct server_context {
         int32_t n_batch  = llama_n_batch(ctx);
         int32_t n_ubatch = llama_n_ubatch(ctx);
 
+        // track if this is an embedding or non-embedding batch
+        // if we've added sampled tokens above, we are in non-embedding mode
+        // -1: none, 0: non-embedding, 1: embedding
+        int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
+
         // next, batch any pending prompts without exceeding n_batch
         if (params.cont_batching || batch.n_tokens == 0) {
             for (auto & slot : slots) {
@@ -2175,6 +2180,14 @@ struct server_context {
                         }
                     }
 
+                    // check that we are in the right batch_type, if not defer the slot
+                    bool slot_type = slot.embedding ? 1 : 0;
+                    if (batch_type == -1) {
+                        batch_type = slot_type;
+                    } else if (batch_type != slot_type) {
+                        continue;
+                    }
+
                     // keep only the common part
                     int p0 = (int) system_tokens.size() + slot.n_past;
                     if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
@@ -2276,6 +2289,9 @@ struct server_context {
             {"n_tokens", batch.n_tokens},
         });
 
+        // make sure we're in the right embedding mode
+        llama_set_embeddings(ctx, batch_type == 1);
+
         // process the created batch of tokens
         for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
             const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
@@ -2990,6 +3006,11 @@ int main(int argc, char ** argv) {
     };
 
     const auto handle_completions = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
+        if (ctx_server.params.embedding) {
+            res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
+            return;
+        }
+
         res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
 
         json data = json::parse(req.body);
@@ -3085,6 +3106,11 @@ int main(int argc, char ** argv) {
     };
 
     const auto handle_chat_completions = [&ctx_server, &params, &res_error](const httplib::Request & req, httplib::Response & res) {
+        if (ctx_server.params.embedding) {
+            res_error(res, format_error_response("This server does not support chat completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
+            return;
+        }
+
         res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
         json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
 
@@ -3157,6 +3183,11 @@ int main(int argc, char ** argv) {
     };
 
     const auto handle_infill = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
+        if (ctx_server.params.embedding) {
+            res_error(res, format_error_response("This server does not support infill. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
+            return;
+        }
+
         res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
 
         json data = json::parse(req.body);
@@ -3243,13 +3274,8 @@ int main(int argc, char ** argv) {
         return res.set_content(data.dump(), "application/json; charset=utf-8");
     };
 
-    const auto handle_embeddings = [&params, &ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
+    const auto handle_embeddings = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
         res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
-        if (!params.embedding) {
-            res.status = 501;
-            res.set_content("This server does not support embeddings. Start it with `--embeddings`", "text/plain; charset=utf-8");
-            return;
-        }
 
         const json body = json::parse(req.body);
         bool is_openai = false;