]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : add support for "encoding_format": "base64" to the */embeddings endpoints...
authorReza Kakhki <redacted>
Tue, 24 Dec 2024 20:33:04 +0000 (21:33 +0100)
committerGitHub <redacted>
Tue, 24 Dec 2024 20:33:04 +0000 (21:33 +0100)
* add support for base64

* fix base64 test

* improve test

---------

Co-authored-by: Xuan Son Nguyen <redacted>
examples/server/CMakeLists.txt
examples/server/server.cpp
examples/server/tests/unit/test_embedding.py
examples/server/utils.hpp

index a27597cbc294a0b3972e3667b1069fff03a7f41e..1b7cc8c1328e463b1c332085439fa250c9f384e8 100644 (file)
@@ -34,6 +34,7 @@ endforeach()
 add_executable(${TARGET} ${TARGET_SRCS})
 install(TARGETS ${TARGET} RUNTIME)
 
+target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR})
 target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT})
 
 if (LLAMA_SERVER_SSL)
index 3fbfb13c49b729f7a673d4c2d6ae4e1aef3133f1..30ff3b14957dc76b15831a5f483a935b9e0e890f 100644 (file)
@@ -3790,6 +3790,17 @@ int main(int argc, char ** argv) {
             return;
         }
 
+        bool use_base64 = false;
+        if (body.count("encoding_format") != 0) {
+            const std::string& format = body.at("encoding_format");
+            if (format == "base64") {
+                use_base64 = true;
+            } else if (format != "float") {
+                res_error(res, format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST));
+                return;
+            }
+        }
+
         std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, true, true);
         for (const auto & tokens : tokenized_prompts) {
             // this check is necessary for models that do not add BOS token to the input
@@ -3841,7 +3852,7 @@ int main(int argc, char ** argv) {
         }
 
         // write JSON response
-        json root = oaicompat ? format_embeddings_response_oaicompat(body, responses) : json(responses);
+        json root = oaicompat ? format_embeddings_response_oaicompat(body, responses, use_base64) : json(responses);
         res_ok(res, root);
     };
 
index 43e372fc70d71a3c23d8bc99eb9b4a3ea192fc73..8b0eb42b0926ff596327ecdbdf921b914a0e3026 100644 (file)
@@ -1,3 +1,5 @@
+import base64
+import struct
 import pytest
 from openai import OpenAI
 from utils import *
@@ -194,3 +196,42 @@ def test_embedding_usage_multiple():
     assert res.status_code == 200
     assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
     assert res.body['usage']['prompt_tokens'] == 2 * 9
+
+
+def test_embedding_openai_library_base64():
+    server.start()
+    test_input = "Test base64 embedding output"
+
+    # get embedding in default format
+    res = server.make_request("POST", "/v1/embeddings", data={
+        "input": test_input
+    })
+    assert res.status_code == 200
+    vec0 = res.body["data"][0]["embedding"]
+
+    # get embedding in base64 format
+    res = server.make_request("POST", "/v1/embeddings", data={
+        "input": test_input,
+        "encoding_format": "base64"
+    })
+
+    assert res.status_code == 200
+    assert "data" in res.body
+    assert len(res.body["data"]) == 1
+
+    embedding_data = res.body["data"][0]
+    assert "embedding" in embedding_data
+    assert isinstance(embedding_data["embedding"], str)
+
+    # Verify embedding is valid base64
+    decoded = base64.b64decode(embedding_data["embedding"])
+    # Verify decoded data can be converted back to float array
+    float_count = len(decoded) // 4  # 4 bytes per float
+    floats = struct.unpack(f'{float_count}f', decoded)
+    assert len(floats) > 0
+    assert all(isinstance(x, float) for x in floats)
+    assert len(floats) == len(vec0)
+
+    # make sure the decoded data is the same as the original
+    for x, y in zip(floats, vec0):
+        assert abs(x - y) < EPSILON
index 043d8b52897db01b6d5c596214dcbc1e0a28bf3e..334f2f19207ef4c7869c03ae261689bc8a6956f7 100644 (file)
@@ -3,6 +3,7 @@
 #include "common.h"
 #include "log.h"
 #include "llama.h"
+#include "common/base64.hpp"
 
 #ifndef NDEBUG
 // crash the server in debug mode, otherwise send an http 500 error
@@ -613,16 +614,31 @@ static json oaicompat_completion_params_parse(
     return llama_params;
 }
 
-static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) {
+static json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false) {
     json data = json::array();
     int32_t n_tokens = 0;
     int i = 0;
     for (const auto & elem : embeddings) {
-        data.push_back(json{
-            {"embedding", json_value(elem, "embedding", json::array())},
-            {"index",     i++},
-            {"object",    "embedding"}
-        });
+        json embedding_obj;
+
+        if (use_base64) {
+            const auto& vec = json_value(elem, "embedding", json::array()).get<std::vector<float>>();
+            const char* data_ptr = reinterpret_cast<const char*>(vec.data());
+            size_t data_size = vec.size() * sizeof(float);
+            embedding_obj = {
+                {"embedding", base64::encode(data_ptr, data_size)},
+                {"index", i++},
+                {"object", "embedding"},
+                {"encoding_format", "base64"}
+            };
+        } else {
+            embedding_obj = {
+                {"embedding", json_value(elem, "embedding", json::array())},
+                {"index", i++},
+                {"object", "embedding"}
+            };
+        }
+        data.push_back(embedding_obj);
 
         n_tokens += json_value(elem, "tokens_evaluated", 0);
     }