]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Add an API example using server.cpp similar to OAI. (#2009)
authorjwj7140 <redacted>
Tue, 4 Jul 2023 18:06:12 +0000 (03:06 +0900)
committerGitHub <redacted>
Tue, 4 Jul 2023 18:06:12 +0000 (21:06 +0300)
* add api_like_OAI.py
* add evaluated token count to server
* add /v1/ endpoints binding

examples/server/README.md
examples/server/api_like_OAI.py [new file with mode: 0755]
examples/server/server.cpp

index ba4b2fec9d1df08f6d2e1b43f244f83ed2224f08..4ed226e048218085a4d907e9d18199845263bc33 100644 (file)
@@ -190,3 +190,19 @@ Run with bash:
 ```sh
 bash chat.sh
 ```
+
+### API like OAI
+
+API example using Python Flask: [api_like_OAI.py](api_like_OAI.py)
+This example must be used with server.cpp
+
+```sh
+python api_like_OAI.py
+```
+
+After running the API server, you can use it in Python by setting the API base URL.
+```python
+openai.api_base = "http://<Your api-server IP>:port"
+```
+
+Then you can utilize llama.cpp as an OpenAI's **chat.completion** or **text_completion** API
diff --git a/examples/server/api_like_OAI.py b/examples/server/api_like_OAI.py
new file mode 100755 (executable)
index 0000000..aa325a0
--- /dev/null
@@ -0,0 +1,219 @@
+import argparse
+from flask import Flask, jsonify, request, Response
+import urllib.parse
+import requests
+import time
+import json
+
+
+app = Flask(__name__)
+
+parser = argparse.ArgumentParser(description="An example of using server.cpp with a similar API to OAI. It must be used together with server.cpp.")
+parser.add_argument("--chat-prompt", type=str, help="the top prompt in chat completions(default: 'A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n')", default='A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n')
+parser.add_argument("--user-name", type=str, help="USER name in chat completions(default: '\\nUSER: ')", default="\\nUSER: ")
+parser.add_argument("--ai-name", type=str, help="ASSISTANT name in chat completions(default: '\\nASSISTANT: ')", default="\\nASSISTANT: ")
+parser.add_argument("--system-name", type=str, help="SYSTEM name in chat completions(default: '\\nASSISTANT's RULE: ')", default="\\nASSISTANT's RULE: ")
+parser.add_argument("--stop", type=str, help="the end of response in chat completions(default: '</s>')", default="</s>")
+parser.add_argument("--llama-api", type=str, help="Set the address of server.cpp in llama.cpp(default: http://127.0.0.1:8080)", default='http://127.0.0.1:8080')
+parser.add_argument("--api-key", type=str, help="Set the api key to allow only few user(default: NULL)", default="")
+parser.add_argument("--host", type=str, help="Set the ip address to listen.(default: 127.0.0.1)", default='127.0.0.1')
+parser.add_argument("--port", type=int, help="Set the port to listen.(default: 8081)", default=8081)
+
+args = parser.parse_args()
+
+def is_present(json, key):
+    try:
+        buf = json[key]
+    except KeyError:
+        return False
+    return True
+
+
+
+#convert chat to prompt
+def convert_chat(messages):
+    prompt = "" + args.chat_prompt.replace("\\n", "\n")
+
+    system_n = args.system_name.replace("\\n", "\n")
+    user_n = args.user_name.replace("\\n", "\n")
+    ai_n = args.ai_name.replace("\\n", "\n")
+    stop = args.stop.replace("\\n", "\n")
+
+
+    for line in messages:
+        if (line["role"] == "system"):
+            prompt += f"{system_n}{line['content']}"
+        if (line["role"] == "user"):
+            prompt += f"{user_n}{line['content']}"
+        if (line["role"] == "assistant"):
+            prompt += f"{ai_n}{line['content']}{stop}"
+    prompt += ai_n.rstrip()
+
+    return prompt
+
+def make_postData(body, chat=False, stream=False):
+    postData = {}
+    if (chat):
+        postData["prompt"] = convert_chat(body["messages"])
+    else:
+        postData["prompt"] = body["prompt"]
+    if(is_present(body, "temperature")): postData["temperature"] = body["temperature"]
+    if(is_present(body, "top_k")): postData["top_k"] = body["top_k"]
+    if(is_present(body, "top_p")): postData["top_p"] = body["top_p"]
+    if(is_present(body, "max_tokens")): postData["n_predict"] = body["max_tokens"]
+    if(is_present(body, "presence_penalty")): postData["presence_penalty"] = body["presence_penalty"]
+    if(is_present(body, "frequency_penalty")): postData["frequency_penalty"] = body["frequency_penalty"]
+    if(is_present(body, "repeat_penalty")): postData["repeat_penalty"] = body["repeat_penalty"]
+    if(is_present(body, "mirostat")): postData["mirostat"] = body["mirostat"]
+    if(is_present(body, "mirostat_tau")): postData["mirostat_tau"] = body["mirostat_tau"]
+    if(is_present(body, "mirostat_eta")): postData["mirostat_eta"] = body["mirostat_eta"]
+    if(is_present(body, "seed")): postData["seed"] = body["seed"]
+    if(is_present(body, "logit_bias")): postData["logit_bias"] = [[int(token), body["logit_bias"][token]] for token in body["logit_bias"].keys()]
+    if (args.stop != ""):
+        postData["stop"] = [args.stop]
+    else:
+        postData["stop"] = []
+    if(is_present(body, "stop")): postData["stop"] += body["stop"]
+    postData["n_keep"] = -1
+    postData["stream"] = stream
+
+    return postData
+
+def make_resData(data, chat=False, promptToken=[]):
+    resData = {
+        "id": "chatcmpl" if (chat) else "cmpl",
+        "object": "chat.completion" if (chat) else "text_completion",
+        "created": int(time.time()),
+        "truncated": data["truncated"],
+        "model": "LLaMA_CPP",
+        "usage": {
+            "prompt_tokens": data["tokens_evaluated"],
+            "completion_tokens": data["tokens_predicted"],
+            "total_tokens": data["tokens_evaluated"] + data["tokens_predicted"]
+        }
+    }
+    if (len(promptToken) != 0):
+        resData["promptToken"] = promptToken
+    if (chat):
+        #only one choice is supported
+        resData["choices"] = [{
+            "index": 0,
+            "message": {
+                "role": "assistant",
+                "content": data["content"],
+            },
+            "finish_reason": "stop" if (data["stopped_eos"] or data["stopped_word"]) else "length"
+        }]
+    else:
+        #only one choice is supported
+        resData["choices"] = [{
+            "text": data["content"],
+            "index": 0,
+            "logprobs": None,
+            "finish_reason": "stop" if (data["stopped_eos"] or data["stopped_word"]) else "length"
+        }]
+    return resData
+
+def make_resData_stream(data, chat=False, time_now = 0, start=False):
+    resData = {
+        "id": "chatcmpl" if (chat) else "cmpl",
+        "object": "chat.completion.chunk" if (chat) else "text_completion.chunk",
+        "created": time_now,
+        "model": "LLaMA_CPP",
+        "choices": [
+            {
+                "finish_reason": None,
+                "index": 0
+            }
+        ]
+    }
+    if (chat):
+        if (start):
+            resData["choices"][0]["delta"] =  {
+                "role": "assistant"
+            }
+        else:
+            resData["choices"][0]["delta"] =  {
+                "content": data["content"]
+            }
+            if (data["stop"]):
+                resData["choices"][0]["finish_reason"] = "stop" if (data["stopped_eos"] or data["stopped_word"]) else "length"
+    else:
+        resData["choices"][0]["text"] = data["content"]
+        if (data["stop"]):
+            resData["choices"][0]["finish_reason"] = "stop" if (data["stopped_eos"] or data["stopped_word"]) else "length"
+
+    return resData
+
+
+@app.route('/chat/completions', methods=['POST'])
+@app.route('/v1/chat/completions', methods=['POST'])
+def chat_completions():
+    if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key):
+        return Response(status=403)
+    body = request.get_json()
+    stream = False
+    tokenize = False
+    if(is_present(body, "stream")): stream = body["stream"]
+    if(is_present(body, "tokenize")): tokenize = body["tokenize"]
+    postData = make_postData(body, chat=True, stream=stream)
+
+    promptToken = []
+    if (tokenize):
+        tokenData = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/tokenize"), data=json.dumps({"content": postData["prompt"]})).json()
+        promptToken = tokenData["tokens"]
+
+    if (not stream):
+        data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData))
+        print(data.json())
+        resData = make_resData(data.json(), chat=True, promptToken=promptToken)
+        return jsonify(resData)
+    else:
+        def generate():
+            data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData), stream=True)
+            time_now = int(time.time())
+            resData = make_resData_stream({}, chat=True, time_now=time_now, start=True)
+            yield 'data: {}\n'.format(json.dumps(resData))
+            for line in data.iter_lines():
+                if line:
+                    decoded_line = line.decode('utf-8')
+                    resData = make_resData_stream(json.loads(decoded_line[6:]), chat=True, time_now=time_now)
+                    yield 'data: {}\n'.format(json.dumps(resData))
+        return Response(generate(), mimetype='text/event-stream')
+
+
+@app.route('/completions', methods=['POST'])
+@app.route('/v1/completions', methods=['POST'])
+def completion():
+    if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key):
+        return Response(status=403)
+    body = request.get_json()
+    stream = False
+    tokenize = False
+    if(is_present(body, "stream")): stream = body["stream"]
+    if(is_present(body, "tokenize")): tokenize = body["tokenize"]
+    postData = make_postData(body, chat=False, stream=stream)
+
+    promptToken = []
+    if (tokenize):
+        tokenData = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/tokenize"), data=json.dumps({"content": postData["prompt"]})).json()
+        promptToken = tokenData["tokens"]
+
+    if (not stream):
+        data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData))
+        print(data.json())
+        resData = make_resData(data.json(), chat=False, promptToken=promptToken)
+        return jsonify(resData)
+    else:
+        def generate():
+            data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData), stream=True)
+            time_now = int(time.time())
+            for line in data.iter_lines():
+                if line:
+                    decoded_line = line.decode('utf-8')
+                    resData = make_resData_stream(json.loads(decoded_line[6:]), chat=False, time_now=time_now)
+                    yield 'data: {}\n'.format(json.dumps(resData))
+        return Response(generate(), mimetype='text/event-stream')
+
+if __name__ == '__main__':
+    app.run(args.host, port=args.port)
index 043e49750707b58ee11d906e12be6727683efda5..a835c398885f6b22be2acbae27f9147625dbd8af 100644 (file)
@@ -158,6 +158,7 @@ struct llama_server_context {
     std::string generated_text;
     std::vector<completion_token_output> generated_token_probs;
 
+    size_t num_prompt_tokens = 0;
     size_t num_tokens_predicted = 0;
     size_t n_past = 0;
     size_t n_remain = 0;
@@ -195,6 +196,7 @@ struct llama_server_context {
 
     void rewind() {
         params.antiprompt.clear();
+        num_prompt_tokens = 0;
         num_tokens_predicted = 0;
         generated_text = "";
         generated_text.reserve(params.n_ctx);
@@ -226,17 +228,18 @@ struct llama_server_context {
     void loadPrompt() {
         params.prompt.insert(0, 1, ' '); // always add a first space
         std::vector<llama_token> prompt_tokens = ::llama_tokenize(ctx, params.prompt, true);
+        num_prompt_tokens = prompt_tokens.size();
 
         if (params.n_keep < 0) {
-            params.n_keep = (int)prompt_tokens.size();
+            params.n_keep = (int)num_prompt_tokens;
         }
         params.n_keep = std::min(params.n_ctx - 4, params.n_keep);
 
         // if input prompt is too big, truncate like normal
-        if (prompt_tokens.size() >= (size_t)params.n_ctx) {
+        if (num_prompt_tokens>= (size_t)params.n_ctx) {
             const int n_left = (params.n_ctx - params.n_keep) / 2;
             std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
-            const int erased_blocks = (prompt_tokens.size() - params.n_keep - n_left - 1) / n_left;
+            const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
             new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
             std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin());
 
@@ -250,7 +253,7 @@ struct llama_server_context {
             truncated = true;
             prompt_tokens = new_tokens;
         } else {
-            const size_t ps = prompt_tokens.size();
+            const size_t ps = num_prompt_tokens;
             std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0);
             std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps);
         }
@@ -258,7 +261,7 @@ struct llama_server_context {
         // compare the evaluated prompt with the new prompt
         n_past = common_part(embd, prompt_tokens);
         embd = prompt_tokens;
-        if (n_past == prompt_tokens.size()) {
+        if (n_past == num_prompt_tokens) {
             // we have to evaluate at least 1 token to generate logits.
             n_past--;
         }
@@ -763,6 +766,7 @@ static json format_final_response(llama_server_context & llama, const std::strin
         { "stop", true },
         { "model", llama.params.model_alias },
         { "tokens_predicted", llama.num_tokens_predicted },
+        { "tokens_evaluated", llama.num_prompt_tokens },
         { "generation_settings", format_generation_settings(llama) },
         { "prompt", llama.params.prompt },
         { "truncated", llama.truncated },