import asyncio
+import collections
import json
import os
import re
@step(u'concurrent completion requests')
@async_run_until_complete()
async def step_concurrent_completion_requests(context):
- await concurrent_completion_requests(context,
- request_completion,
- # prompt is inserted automatically
- context.base_url,
- debug=context.debug,
- n_predict=context.n_predict if hasattr(context, 'n_predict') else None,
- server_seed=context.server_seed if hasattr(context, 'server_seed') else None,
- user_api_key=context.user_api_key if hasattr(context,
- 'user_api_key') else None)
+ await concurrent_requests(context,
+ request_completion,
+ # prompt is inserted automatically
+ context.base_url,
+ debug=context.debug,
+ n_predict=context.n_predict if hasattr(context, 'n_predict') else None,
+ server_seed=context.server_seed if hasattr(context, 'server_seed') else None,
+ user_api_key=context.user_api_key if hasattr(context,
+ 'user_api_key') else None)
@step(u'concurrent OAI completions requests')
@async_run_until_complete
async def step_oai_chat_completions(context):
- await concurrent_completion_requests(context, oai_chat_completions,
- # user_prompt is inserted automatically
- context.system_prompt,
- context.base_url,
- True, # async_client
- model=context.model
- if hasattr(context, 'model') else None,
- n_predict=context.n_predict
- if hasattr(context, 'n_predict') else None,
- enable_streaming=context.enable_streaming
- if hasattr(context, 'enable_streaming') else None,
- server_seed=context.server_seed
- if hasattr(context, 'server_seed') else None,
- user_api_key=context.user_api_key
- if hasattr(context, 'user_api_key') else None)
+ await concurrent_requests(context, oai_chat_completions,
+ # user_prompt is inserted automatically
+ context.system_prompt,
+ context.base_url,
+ True, # async_client
+ model=context.model
+ if hasattr(context, 'model') else None,
+ n_predict=context.n_predict
+ if hasattr(context, 'n_predict') else None,
+ enable_streaming=context.enable_streaming
+ if hasattr(context, 'enable_streaming') else None,
+ server_seed=context.server_seed
+ if hasattr(context, 'server_seed') else None,
+ user_api_key=context.user_api_key
+ if hasattr(context, 'user_api_key') else None)
@step(u'all prompts are predicted')
@step(u'embeddings are computed for')
@async_run_until_complete
async def step_compute_embedding(context):
- content = context.text
- base_url = context.base_url
- context.embeddings = await request_embedding(content, base_url)
+ context.embeddings = await request_embedding(context.text, base_url=context.base_url)
@step(u'embeddings are generated')
def step_assert_embeddings(context):
- assert_embeddings(context.embeddings)
+ if len(context.prompts) == 0:
+ assert_embeddings(context.embeddings)
+ else:
+ assert len(context.embeddings) == len(context.prompts), (f"unexpected response:\n"
+ f"context.prompts={context.prompts}\n"
+ f"context.embeddings={context.embeddings}")
+ for embedding in context.embeddings:
+ context.prompts.pop()
+ assert_embeddings(embedding)
@step(u'an OAI compatible embeddings computation request for')
-def step_oai_compute_embedding(context):
- openai.api_key = 'nope' # openai client always expects an api_keu
- if context.user_api_key is not None:
- openai.api_key = context.user_api_key
- openai.api_base = f'{context.base_url}/v1'
- embeddings = openai.Embedding.create(
- model=context.model,
- input=context.text,
- )
- context.embeddings = embeddings
+@async_run_until_complete
+async def step_oai_compute_embeddings(context):
+ context.embeddings = await request_oai_embeddings(context.text,
+ base_url=context.base_url,
+ user_api_key=context.user_api_key,
+ model=context.model)
+
+
+@step(u'an OAI compatible embeddings computation request for multiple inputs')
+@async_run_until_complete
+async def step_oai_compute_embeddings_multiple_inputs(context):
+ context.embeddings = await request_oai_embeddings(context.prompts,
+ base_url=context.base_url,
+ user_api_key=context.user_api_key,
+ model=context.model)
@step(u'concurrent embedding requests')
@async_run_until_complete()
async def step_concurrent_embedding_requests(context):
- await concurrent_completion_requests(context,
- request_embedding,
- # prompt is inserted automatically
- context.base_url)
+ await concurrent_requests(context,
+ request_embedding,
+ # prompt is inserted automatically
+ base_url=context.base_url)
+
+
+@step(u'concurrent OAI embedding requests')
+@async_run_until_complete()
+async def step_concurrent_oai_embedding_requests(context):
+ await concurrent_requests(context,
+ request_oai_embeddings,
+ # prompt is inserted automatically
+ base_url=context.base_url,
+ async_client=True,
+ model=context.model)
@step(u'all embeddings are generated')
assert context.options_response.headers[cors_header] == cors_header_value
-async def concurrent_completion_requests(context, f_completion, *args, **kwargs):
+async def concurrent_requests(context, f_completion, *args, **kwargs):
n_prompts = len(context.prompts)
if context.debug:
print(f"starting {n_prompts} concurrent completion requests...")
return completion_response
-async def request_embedding(content, base_url):
+async def request_embedding(content, base_url=None):
async with aiohttp.ClientSession() as session:
async with session.post(f'{base_url}/embedding',
json={
return response_json['embedding']
+async def request_oai_embeddings(input,
+ base_url=None, user_api_key=None,
+ model=None, async_client=False):
+ # openai client always expects an api_key
+ user_api_key = user_api_key if user_api_key is not None else 'nope'
+ if async_client:
+ origin = 'llama.cpp'
+ if user_api_key is not None:
+ headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin}
+ async with aiohttp.ClientSession() as session:
+ async with session.post(f'{base_url}/v1/embeddings',
+ json={
+ "input": input,
+ "model": model,
+ },
+ headers=headers) as response:
+ assert response.status == 200, f"received status code not expected: {response.status}"
+ assert response.headers['Access-Control-Allow-Origin'] == origin
+ assert response.headers['Content-Type'] == "application/json; charset=utf-8"
+ response_json = await response.json()
+ assert response_json['model'] == model, f"invalid model received: {response_json['model']}"
+ assert response_json['object'] == 'list'
+ return response_json['data']
+ else:
+ openai.api_key = user_api_key
+ openai.api_base = f'{base_url}/v1'
+ oai_embeddings = openai.Embedding.create(
+ model=model,
+ input=input,
+ )
+
+ if isinstance(input, collections.abc.Sequence):
+ embeddings = []
+ for an_oai_embeddings in oai_embeddings.data:
+ embeddings.append(an_oai_embeddings.embedding)
+ else:
+ embeddings = oai_embeddings.data.embedding
+ return embeddings
+
+
def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None):
content = completion_response['content']
n_predicted = completion_response['timings']['predicted_n']