]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server tests : more pythonic process management; fix bare `except:` (#6146)
authorJared Van Bortel <redacted>
Wed, 20 Mar 2024 05:33:49 +0000 (01:33 -0400)
committerGitHub <redacted>
Wed, 20 Mar 2024 05:33:49 +0000 (06:33 +0100)
* server tests : remove seemingly redundant newlines in print()

* server tests : use built-in subprocess features, not os.kill and psutil

* server tests : do not catch e.g. SystemExit; use print_exc

* server tests: handle TimeoutExpired exception

* server tests: fix connect on dual-stack systems

* server: tests: add new tokens regex on windows generated following new repeat penalties default changed in (#6127)

* server: tests: remove the hack on windows since now we get the good socket family

* server: tests: add new tokens regex following new repeat penalties default changed in (#6127)

* server: tests: add new tokens regex following new repeat penalties default changed in (#6127)

---------

Co-authored-by: Pierrick HYMBERT <redacted>
examples/server/tests/features/environment.py
examples/server/tests/features/server.feature
examples/server/tests/features/steps/steps.py
examples/server/tests/requirements.txt

index 82104e9202e5e05e227e6f111025b0bf49077979..e7845dc2f51fc8cf24afceb7616cec83f498c65e 100644 (file)
@@ -5,15 +5,14 @@ import sys
 import time
 import traceback
 from contextlib import closing
-
-import psutil
+from subprocess import TimeoutExpired
 
 
 def before_scenario(context, scenario):
     context.debug = 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON'
     if context.debug:
-        print("DEBUG=ON\n")
-    print(f"\x1b[33;42mStarting new scenario: {scenario.name}!\x1b[0m\n")
+        print("DEBUG=ON")
+    print(f"\x1b[33;42mStarting new scenario: {scenario.name}!\x1b[0m")
     port = 8080
     if 'PORT' in os.environ:
         port = int(os.environ['PORT'])
@@ -27,60 +26,40 @@ def after_scenario(context, scenario):
             return
         if scenario.status == "failed":
             if 'GITHUB_ACTIONS' in os.environ:
-                print(f"\x1b[33;101mSCENARIO FAILED: {scenario.name} server logs:\x1b[0m\n\n")
+                print(f"\x1b[33;101mSCENARIO FAILED: {scenario.name} server logs:\x1b[0m\n")
                 if os.path.isfile('llama.log'):
                     with closing(open('llama.log', 'r')) as f:
                         for line in f:
                             print(line)
             if not is_server_listening(context.server_fqdn, context.server_port):
-                print("\x1b[33;101mERROR: Server stopped listening\x1b[0m\n")
+                print("\x1b[33;101mERROR: Server stopped listening\x1b[0m")
 
-        if not pid_exists(context.server_process.pid):
+        if context.server_process.poll() is not None:
             assert False, f"Server not running pid={context.server_process.pid} ..."
 
-        server_graceful_shutdown(context)
+        server_graceful_shutdown(context)  # SIGINT
 
-        # Wait few for socket to free up
-        time.sleep(0.05)
+        try:
+            context.server_process.wait(0.5)
+        except TimeoutExpired:
+            print(f"server still alive after 500ms, force-killing pid={context.server_process.pid} ...")
+            context.server_process.kill()  # SIGKILL
+            context.server_process.wait()
 
-        attempts = 0
-        while pid_exists(context.server_process.pid) or is_server_listening(context.server_fqdn, context.server_port):
-            server_kill(context)
+        while is_server_listening(context.server_fqdn, context.server_port):
             time.sleep(0.1)
-            attempts += 1
-            if attempts > 5:
-                server_kill_hard(context)
-    except:
-        exc = sys.exception()
-        print("error in after scenario: \n")
-        print(exc)
-        print("*** print_tb: \n")
-        traceback.print_tb(exc.__traceback__, file=sys.stdout)
+    except Exception:
+        print("ignoring error in after_scenario:")
+        traceback.print_exc(file=sys.stdout)
 
 
 def server_graceful_shutdown(context):
-    print(f"shutting down server pid={context.server_process.pid} ...\n")
+    print(f"shutting down server pid={context.server_process.pid} ...")
     if os.name == 'nt':
-        os.kill(context.server_process.pid, signal.CTRL_C_EVENT)
+        interrupt = signal.CTRL_C_EVENT
     else:
-        os.kill(context.server_process.pid, signal.SIGINT)
-
-
-def server_kill(context):
-    print(f"killing server pid={context.server_process.pid} ...\n")
-    context.server_process.kill()
-
-
-def server_kill_hard(context):
-    pid = context.server_process.pid
-    path = context.server_path
-
-    print(f"Server dangling exits, hard killing force {pid}={path}...\n")
-    try:
-        psutil.Process(pid).kill()
-    except psutil.NoSuchProcess:
-        return False
-    return True
+        interrupt = signal.SIGINT
+    context.server_process.send_signal(interrupt)
 
 
 def is_server_listening(server_fqdn, server_port):
@@ -88,14 +67,5 @@ def is_server_listening(server_fqdn, server_port):
         result = sock.connect_ex((server_fqdn, server_port))
         _is_server_listening = result == 0
         if _is_server_listening:
-            print(f"server is listening on {server_fqdn}:{server_port}...\n")
+            print(f"server is listening on {server_fqdn}:{server_port}...")
         return _is_server_listening
-
-
-def pid_exists(pid):
-    try:
-        psutil.Process(pid)
-    except psutil.NoSuchProcess:
-        return False
-    return True
-
index 7448986e75a496110147e6fce045210d4cdd15b1..45a988db67f5fdfa914d7f1b8695e7463c2c827b 100644 (file)
@@ -35,9 +35,9 @@ Feature: llama.cpp server
     And   metric llamacpp:tokens_predicted is <n_predicted>
 
     Examples: Prompts
-      | prompt                                                                    | n_predict | re_content                    | n_prompt | n_predicted | truncated |
-      | I believe the meaning of life is                                          | 8         | (read\|going)+                | 18       | 8           | not       |
-      | Write a joke about AI from a very long prompt which will not be truncated | 256       | (princesses\|everyone\|kids)+ | 46       | 64          | not       |
+      | prompt                                                                    | n_predict | re_content                                  | n_prompt | n_predicted | truncated |
+      | I believe the meaning of life is                                          | 8         | (read\|going)+                              | 18       | 8           | not       |
+      | Write a joke about AI from a very long prompt which will not be truncated | 256       | (princesses\|everyone\|kids\|Anna\|forest)+ | 46       | 64          | not       |
 
   Scenario: Completion prompt truncated
     Given a prompt:
@@ -48,7 +48,7 @@ Feature: llama.cpp server
     Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
     """
     And   a completion request with no api error
-    Then  64 tokens are predicted matching fun|Annaks|popcorns|pictry
+    Then  64 tokens are predicted matching fun|Annaks|popcorns|pictry|bowl
     And   the completion is  truncated
     And   109 prompt tokens are processed
 
@@ -65,9 +65,9 @@ Feature: llama.cpp server
     And   the completion is <truncated> truncated
 
     Examples: Prompts
-      | model        | system_prompt               | user_prompt                          | max_tokens | re_content             | n_prompt | n_predicted | enable_streaming | truncated |
-      | llama-2      | Book                        | What is the best book                | 8          | (Here\|what)+          | 77       | 8           | disabled         | not       |
-      | codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 128        | (thanks\|happy\|bird)+ | -1       | 64          | enabled          |           |
+      | model        | system_prompt               | user_prompt                          | max_tokens | re_content                        | n_prompt | n_predicted | enable_streaming | truncated |
+      | llama-2      | Book                        | What is the best book                | 8          | (Here\|what)+                     | 77       | 8           | disabled         | not       |
+      | codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 128        | (thanks\|happy\|bird\|Annabyear)+ | -1       | 64          | enabled          |           |
 
 
   Scenario: Tokenize / Detokenize
index 9e348d5fc4c37f294ea1d50d0ca776dd688a78cb..40c97001ac5d766b944613ee7c586f1b217a3516 100644 (file)
@@ -66,7 +66,7 @@ def step_server_config(context, server_fqdn, server_port):
 def step_download_hf_model(context, hf_file, hf_repo):
     context.model_file = hf_hub_download(repo_id=hf_repo, filename=hf_file)
     if context.debug:
-        print(f"model file: {context.model_file}\n")
+        print(f"model file: {context.model_file}")
 
 
 @step('a model file {model_file}')
@@ -137,9 +137,12 @@ def step_start_server(context):
     if 'GITHUB_ACTIONS' in os.environ:
         max_attempts *= 2
 
+    addrs = socket.getaddrinfo(context.server_fqdn, context.server_port, type=socket.SOCK_STREAM)
+    family, typ, proto, _, sockaddr = addrs[0]
+
     while True:
-        with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
-            result = sock.connect_ex((context.server_fqdn, context.server_port))
+        with closing(socket.socket(family, typ, proto)) as sock:
+            result = sock.connect_ex(sockaddr)
             if result == 0:
                 print("\x1b[33;46mserver started!\x1b[0m")
                 return
@@ -209,7 +212,7 @@ async def step_request_completion(context, api_error):
                                           user_api_key=context.user_api_key)
     context.tasks_result.append(completion)
     if context.debug:
-        print(f"Completion response: {completion}\n")
+        print(f"Completion response: {completion}")
     if expect_api_error:
         assert completion == 401, f"completion must be an 401 status code: {completion}"
 
@@ -354,7 +357,7 @@ def step_prompt_passkey(context, passkey, i_pos):
         prompt += context.prompt_junk_suffix
     if context.debug:
         passkey_highlight = "\x1b[33m" + passkey + "\x1b[0m"
-        print(f"Passkey challenge:\n```{prompt.replace(passkey, passkey_highlight)}```\n")
+        print(f"Passkey challenge:\n```{prompt.replace(passkey, passkey_highlight)}```")
     context.prompts.append(context.prompt_prefix + prompt + context.prompt_suffix)
     context.n_prompts = len(context.prompts)
 
@@ -363,7 +366,7 @@ def step_prompt_passkey(context, passkey, i_pos):
 @async_run_until_complete
 async def step_oai_chat_completions(context, api_error):
     if context.debug:
-        print(f"Submitting OAI compatible completions request...\n")
+        print(f"Submitting OAI compatible completions request...")
     expect_api_error = api_error == 'raised'
     completion = await oai_chat_completions(context.prompts.pop(),
                                             context.system_prompt,
@@ -508,12 +511,12 @@ async def step_all_embeddings_are_the_same(context):
             embedding1 = np.array(embeddings[i])
             embedding2 = np.array(embeddings[j])
             if context.debug:
-                print(f"embedding1: {embedding1[-8:]}\n")
-                print(f"embedding2: {embedding2[-8:]}\n")
+                print(f"embedding1: {embedding1[-8:]}")
+                print(f"embedding2: {embedding2[-8:]}")
             similarity = np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2))
             msg = f"Similarity between {i} and {j}: {similarity:.10f}"
             if context.debug:
-                print(f"{msg}\n")
+                print(f"{msg}")
             assert np.isclose(similarity, 1.0, rtol=1e-05, atol=1e-08, equal_nan=False), msg
 
 
@@ -630,7 +633,7 @@ async def step_prometheus_metrics_exported(context):
             metrics_raw = await metrics_response.text()
             metric_exported = False
             if context.debug:
-                print(f"/metrics answer:\n{metrics_raw}\n")
+                print(f"/metrics answer:\n{metrics_raw}")
             context.metrics = {}
             for metric in parser.text_string_to_metric_families(metrics_raw):
                 match metric.name:
@@ -932,7 +935,7 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re
             last_match = end
         highlighted += content[last_match:]
         if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
-          print(f"Checking completion response: {highlighted}\n")
+          print(f"Checking completion response: {highlighted}")
         assert last_match > 0, f'/{re_content}/ must match ```{highlighted}```'
     if expected_predicted_n and expected_predicted_n > 0:
         assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:'
@@ -942,7 +945,7 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re
 async def gather_tasks_results(context):
     n_tasks = len(context.concurrent_tasks)
     if context.debug:
-        print(f"Waiting for all {n_tasks} tasks results...\n")
+        print(f"Waiting for all {n_tasks} tasks results...")
     for task_no in range(n_tasks):
         context.tasks_result.append(await context.concurrent_tasks.pop())
     n_completions = len(context.tasks_result)
@@ -959,7 +962,7 @@ async def wait_for_health_status(context,
                                  slots_processing=None,
                                  expected_slots=None):
     if context.debug:
-        print(f"Starting checking for health for expected_health_status={expected_health_status}\n")
+        print(f"Starting checking for health for expected_health_status={expected_health_status}")
     interval = 0.5
     counter = 0
     if 'GITHUB_ACTIONS' in os.environ:
@@ -1048,8 +1051,6 @@ def start_server_background(context):
     if 'LLAMA_SERVER_BIN_PATH' in os.environ:
         context.server_path = os.environ['LLAMA_SERVER_BIN_PATH']
     server_listen_addr = context.server_fqdn
-    if os.name == 'nt':
-        server_listen_addr = '0.0.0.0'
     server_args = [
         '--host', server_listen_addr,
         '--port', context.server_port,
@@ -1088,7 +1089,7 @@ def start_server_background(context):
         server_args.append('--verbose')
     if 'SERVER_LOG_FORMAT_JSON' not in os.environ:
         server_args.extend(['--log-format', "text"])
-    print(f"starting server with: {context.server_path} {server_args}\n")
+    print(f"starting server with: {context.server_path} {server_args}")
     flags = 0
     if 'nt' == os.name:
         flags |= subprocess.DETACHED_PROCESS
index c2c960102b52346fe704a382ae2c8428b3966c7f..2e4f42ad28c233fd2e3cec00f2a33aecb3925f5a 100644 (file)
@@ -3,5 +3,4 @@ behave~=1.2.6
 huggingface_hub~=0.20.3
 numpy~=1.24.4
 openai~=0.25.0
-psutil~=5.9.8
 prometheus-client~=0.20.0