]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
py : fix oai proxy (#3972)
authorrhjdvsgsgks <redacted>
Thu, 30 Nov 2023 20:50:40 +0000 (20:50 +0000)
committerGitHub <redacted>
Thu, 30 Nov 2023 20:50:40 +0000 (22:50 +0200)
* fix oai proxy

fix generation not stoped while bot stop talking in chat mode

fix possible `slot_id` not exist

response for cors (and pre flight)

* oai proxy: workaround for some client (such as Chatbox)

* use stop as separator to replace hardcoded `\n`

examples/server/api_like_OAI.py

index 313e1a9652d14a87ce2b0b4be9522092c0b46531..830c056d4acfc6d2e864a184727a1f8dbd831b8e 100755 (executable)
@@ -11,10 +11,10 @@ app = Flask(__name__)
 slot_id = -1
 
 parser = argparse.ArgumentParser(description="An example of using server.cpp with a similar API to OAI. It must be used together with server.cpp.")
-parser.add_argument("--chat-prompt", type=str, help="the top prompt in chat completions(default: 'A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n')", default='A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n')
-parser.add_argument("--user-name", type=str, help="USER name in chat completions(default: '\\nUSER: ')", default="\\nUSER: ")
-parser.add_argument("--ai-name", type=str, help="ASSISTANT name in chat completions(default: '\\nASSISTANT: ')", default="\\nASSISTANT: ")
-parser.add_argument("--system-name", type=str, help="SYSTEM name in chat completions(default: '\\nASSISTANT's RULE: ')", default="\\nASSISTANT's RULE: ")
+parser.add_argument("--chat-prompt", type=str, help="the top prompt in chat completions(default: 'A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.')", default='A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.')
+parser.add_argument("--user-name", type=str, help="USER name in chat completions(default: 'USER: ')", default="USER: ")
+parser.add_argument("--ai-name", type=str, help="ASSISTANT name in chat completions(default: 'ASSISTANT: ')", default="ASSISTANT: ")
+parser.add_argument("--system-name", type=str, help="SYSTEM name in chat completions(default: 'ASSISTANT's RULE: ')", default="ASSISTANT's RULE: ")
 parser.add_argument("--stop", type=str, help="the end of response in chat completions(default: '</s>')", default="</s>")
 parser.add_argument("--llama-api", type=str, help="Set the address of server.cpp in llama.cpp(default: http://127.0.0.1:8080)", default='http://127.0.0.1:8080')
 parser.add_argument("--api-key", type=str, help="Set the api key to allow only few user(default: NULL)", default="")
@@ -34,19 +34,19 @@ def is_present(json, key):
 
 #convert chat to prompt
 def convert_chat(messages):
-    prompt = "" + args.chat_prompt.replace("\\n", "\n")
 
-    system_n = args.system_name.replace("\\n", "\n")
-    user_n = args.user_name.replace("\\n", "\n")
-    ai_n = args.ai_name.replace("\\n", "\n")
-    stop = args.stop.replace("\\n", "\n")
+    system_n = args.system_name
+    user_n = args.user_name
+    ai_n = args.ai_name
+    stop = args.stop
 
+    prompt = "" + args.chat_prompt + stop
 
     for line in messages:
         if (line["role"] == "system"):
-            prompt += f"{system_n}{line['content']}"
+            prompt += f"{system_n}{line['content']}{stop}"
         if (line["role"] == "user"):
-            prompt += f"{user_n}{line['content']}"
+            prompt += f"{user_n}{line['content']}{stop}"
         if (line["role"] == "assistant"):
             prompt += f"{ai_n}{line['content']}{stop}"
     prompt += ai_n.rstrip()
@@ -130,7 +130,7 @@ def make_resData_stream(data, chat=False, time_now = 0, start=False):
             }
         ]
     }
-    slot_id = data["slot_id"]
+    slot_id = data.get("slot_id")
     if (chat):
         if (start):
             resData["choices"][0]["delta"] =  {
@@ -150,11 +150,13 @@ def make_resData_stream(data, chat=False, time_now = 0, start=False):
     return resData
 
 
-@app.route('/chat/completions', methods=['POST'])
-@app.route('/v1/chat/completions', methods=['POST'])
+@app.route('/chat/completions', methods=['POST', 'OPTIONS'])
+@app.route('/v1/chat/completions', methods=['POST', 'OPTIONS'])
 def chat_completions():
     if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key):
         return Response(status=403)
+    if request.method == 'OPTIONS':
+        return Response(headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"})
     body = request.get_json()
     stream = False
     tokenize = False
@@ -177,20 +179,22 @@ def chat_completions():
             data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData), stream=True)
             time_now = int(time.time())
             resData = make_resData_stream({}, chat=True, time_now=time_now, start=True)
-            yield 'data: {}\n'.format(json.dumps(resData))
+            yield 'data: {}\n\n'.format(json.dumps(resData))
             for line in data.iter_lines():
                 if line:
                     decoded_line = line.decode('utf-8')
                     resData = make_resData_stream(json.loads(decoded_line[6:]), chat=True, time_now=time_now)
-                    yield 'data: {}\n'.format(json.dumps(resData))
-        return Response(generate(), mimetype='text/event-stream')
+                    yield 'data: {}\n\n'.format(json.dumps(resData))
+        return Response(generate(), mimetype='text/event-stream', headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"})
 
 
-@app.route('/completions', methods=['POST'])
-@app.route('/v1/completions', methods=['POST'])
+@app.route('/completions', methods=['POST', 'OPTIONS'])
+@app.route('/v1/completions', methods=['POST', 'OPTIONS'])
 def completion():
     if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key):
         return Response(status=403)
+    if request.method == 'OPTIONS':
+        return Response(headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"})
     body = request.get_json()
     stream = False
     tokenize = False
@@ -216,8 +220,8 @@ def completion():
                 if line:
                     decoded_line = line.decode('utf-8')
                     resData = make_resData_stream(json.loads(decoded_line[6:]), chat=False, time_now=time_now)
-                    yield 'data: {}\n'.format(json.dumps(resData))
-        return Response(generate(), mimetype='text/event-stream')
+                    yield 'data: {}\n\n'.format(json.dumps(resData))
+        return Response(generate(), mimetype='text/event-stream', headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"})
 
 if __name__ == '__main__':
     app.run(args.host, port=args.port)