ps.tiktoken
ps.torchWithoutCuda
ps.transformers
+
+ # server bench
+ ps.matplotlib
+
+ # server tests
+ ps.openai
+ ps.behave
+ ps.prometheus-client
+
+ # for examples/pydantic-models-to-grammar-examples.py
+ ps.docstring-parser
+ ps.pydantic
+
+ # for scripts/compare-llama-bench.py
+ ps.gitpython
+ ps.tabulate
]
);
--- /dev/null
+name: Python Type-Check
+
+on:
+ push:
+ paths:
+ - '.github/workflows/python-type-check.yml'
+ - '**.py'
+ - '**/requirements*.txt'
+ pull_request:
+ paths:
+ - '.github/workflows/python-type-check.yml'
+ - '**.py'
+ - '**/requirements*.txt'
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
+ cancel-in-progress: true
+
+jobs:
+ python-type-check:
+ runs-on: ubuntu-latest
+ name: pyright type-check
+ steps:
+ - name: Check out source repository
+ uses: actions/checkout@v4
+ - name: Set up Python environment
+ uses: actions/setup-python@v5
+ with:
+ python-version: "3.11"
+ - name: Install Python dependencies
+ # TODO: use a venv
+ run: pip install -r requirements/requirements-all.txt
+ - name: Type-check with Pyright
+ uses: jakebailey/pyright-action@v2
+ with:
+ version: 1.1.370
+ level: warning
+ warnings: true
break
for new_name, data in ((n, d.squeeze().numpy()) for n, d in self.modify_tensors(data_torch, name, bid)):
- data: np.ndarray = data # type hint
+ data: np.ndarray # type hint
n_dims = len(data.shape)
data_dtype = data.dtype
data_qtype: gguf.GGMLQuantizationType | None = None
tokenizer_path = self.dir_model / 'tokenizer.model'
- tokens: list[bytes] = []
- scores: list[float] = []
- toktypes: list[int] = []
-
if not tokenizer_path.is_file():
raise FileNotFoundError(f"File not found: {tokenizer_path}")
logger.error(f'Error: Missing {tokenizer_path}')
sys.exit(1)
- sentencepiece_model = model.ModelProto()
+ sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue]
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
if not tokenizer_path.is_file():
raise FileNotFoundError(f"File not found: {tokenizer_path}")
- sentencepiece_model = model.ModelProto()
+ sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue]
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
# some models like Pile-T5 family use BPE tokenizer instead of Unigram
- if sentencepiece_model.trainer_spec.model_type == 2: # BPE
+ if sentencepiece_model.trainer_spec.model_type == 2: # BPE
# assure the tokenizer model file name is correct
assert tokenizer_path.name == 'tokenizer.model'
return self._set_vocab_sentencepiece()
else:
- assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM
+ assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM
add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces
# but Jais's PyTorch model simply precalculates the slope values and places them
# in relative_pes.slopes
n_head_closest_log2 = 2 ** math.floor(math.log2(self.hparams["n_head"]))
- first_val = float(data_torch._data[0])
+ first_val = float(data_torch[0].item())
self.max_alibi_bias = -round(math.log2(first_val) * n_head_closest_log2)
return tensors
def set_vocab_chatglm3(self):
dir_model = self.dir_model
hparams = self.hparams
- tokens: list[bytearray] = []
+ tokens: list[bytes] = []
toktypes: list[int] = []
scores: list[float] = []
special_vocab.add_to_gguf(self.gguf_writer)
def set_gguf_parameters(self):
- self.gguf_writer.add_name(self.hparams.get("_name_or_path").split("/")[1]) # THUDM/glm4-9b-chat or THUDM/chatglm3-6b
+ self.gguf_writer.add_name(self.hparams["_name_or_path"].split("/")[1]) # THUDM/glm4-9b-chat or THUDM/chatglm3-6b
n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
n_head_kv = self.hparams.get("multi_query_group_num", n_head)
def handle_metadata(cfg, hp):
- import convert
+ import examples.convert_legacy_llama as convert
+
assert cfg.model_metadata_dir.is_dir(), 'Metadata dir is not a directory'
hf_config_path = cfg.model_metadata_dir / "config.json"
orig_config_path = cfg.model_metadata_dir / "params.json"
version: Optional[str] = None
url: Optional[str] = None
description: Optional[str] = None
- licence: Optional[str] = None
+ license: Optional[str] = None
source_url: Optional[str] = None
source_hf_repo: Optional[str] = None
LazyModel: TypeAlias = 'dict[str, LazyTensor]'
+ModelFormat: TypeAlias = Literal['ggml', 'torch', 'safetensors', 'none']
@dataclass
class ModelPlus:
model: LazyModel
paths: list[Path] # Where this was read from.
- format: Literal['ggml', 'torch', 'safetensors', 'none']
+ format: ModelFormat
vocab: BaseVocab | None # For GGML models (which have vocab built in), the vocab.
def merge_multifile_models(models_plus: list[ModelPlus]) -> ModelPlus:
- formats = set(mp.format for mp in models_plus)
+ formats: set[ModelFormat] = set(mp.format for mp in models_plus)
assert len(formats) == 1, "different formats?"
format = formats.pop()
paths = [path for mp in models_plus for path in mp.paths]
else:
model = merge_sharded([mp.model for mp in models_plus])
- return ModelPlus(model, paths, format, vocab) # pytype: disable=wrong-arg-types
+ return ModelPlus(model, paths, format, vocab)
def permute_lazy(lazy_tensor: LazyTensor, n_head: int, n_head_kv: int) -> LazyTensor:
def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE):
self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess)
- def add_meta_model(self, params: Params, metadata: Metadata) -> None:
+ def add_meta_model(self, params: Params, metadata: Metadata | None) -> None:
# Metadata About The Model And Its Provenence
name = "LLaMA"
if metadata is not None and metadata.name is not None:
self.gguf.add_url(metadata.url)
if metadata.description is not None:
self.gguf.add_description(metadata.description)
- if metadata.licence is not None:
- self.gguf.add_licence(metadata.licence)
+ if metadata.license is not None:
+ self.gguf.add_licence(metadata.license)
if metadata.source_url is not None:
self.gguf.add_source_url(metadata.source_url)
if metadata.source_hf_repo is not None:
@staticmethod
def write_vocab_only(
fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab,
- endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, metadata: Metadata = None,
+ endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, metadata: Metadata | None = None,
) -> None:
check_vocab_size(params, vocab, pad_vocab=pad_vocab)
fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: BaseVocab, svocab: gguf.SpecialVocab,
concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
pad_vocab: bool = False,
- metadata: Metadata = None,
+ metadata: Metadata | None = None,
) -> None:
check_vocab_size(params, vocab, pad_vocab=pad_vocab)
if model_plus.vocab is not None and args.vocab_dir is None and not args.no_vocab:
vocab = model_plus.vocab
+ assert params is not None
+
logger.info(f"Vocab info: {vocab}")
logger.info(f"Special vocab info: {special_vocab}")
model = model_plus.model
if len(self.ne) == 0:
self.nbytes = 0
else:
- self.nbytes = int(np.product(self.ne)) * 4
+ self.nbytes = int(np.prod(self.ne)) * 4
else:
raise ValueError(f"Unhandled data type '{self.dtype}'")
#! pip install pydantic
#! python json_schema_pydantic_example.py
-from pydantic import BaseModel, Extra, TypeAdapter
+from pydantic import BaseModel, Field, TypeAdapter
from annotated_types import MinLen
from typing import Annotated, List, Optional
import json, requests
The response_model param takes a type (+ supports Pydantic) and behaves just as w/ Instructor (see below)
'''
+ response_format = None
+ type_adapter = None
+
if response_model:
type_adapter = TypeAdapter(response_model)
schema = type_adapter.json_schema()
#!/usr/bin/env python3
+from __future__ import annotations
+
import argparse
import itertools
import json
raise RuntimeError("At least one of min_value or max_value must be set")
class BuiltinRule:
- def __init__(self, content: str, deps: list = None):
+ def __init__(self, content: str, deps: list | None = None):
self.content = content
self.deps = deps or []
def _format_literal(self, literal):
escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
- lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), literal
+ lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)) or m.group(0), literal
)
return f'"{escaped}"'
i = 0
length = len(pattern)
- def to_rule(s: Tuple[str, bool]) -> str:
+ def to_rule(s: tuple[str, bool]) -> str:
(txt, is_literal) = s
return "\"" + txt + "\"" if is_literal else txt
- def transform() -> Tuple[str, bool]:
+ def transform() -> tuple[str, bool]:
'''
Parse a unit at index i (advancing it), and return its string representation + whether it's a literal.
'''
# We only need a flat structure here to apply repetition operators to the last item, and
# to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially
# (GBNF's syntax is luckily very close to regular expressions!)
- seq: list[Tuple[str, bool]] = []
+ seq: list[tuple[str, bool]] = []
def get_dot():
if self._dotall:
fout.add_description("two-tower CLIP model")
if has_text_encoder:
+ assert t_hparams is not None
+ assert tokens is not None
# text_model hparams
fout.add_uint32(k(KEY_CONTEXT_LENGTH, TEXT), t_hparams["max_position_embeddings"])
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, TEXT), t_hparams["hidden_size"])
if processor is not None:
- image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean
- image_std = processor.image_processor.image_std if args.image_std is None or args.image_std == default_image_std else args.image_std
+ image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean # pyright: ignore[reportAttributeAccessIssue]
+ image_std = processor.image_processor.image_std if args.image_std is None or args.image_std == default_image_std else args.image_std # pyright: ignore[reportAttributeAccessIssue]
else:
image_mean = args.image_mean if args.image_mean is not None else default_image_mean
image_std = args.image_std if args.image_std is not None else default_image_std
if has_llava_projector:
- model.vision_model.encoder.layers.pop(-1)
+ model.vision_model.encoder.layers.pop(-1) # pyright: ignore[reportAttributeAccessIssue]
projector = torch.load(args.llava_projector)
for name, data in projector.items():
name = get_tensor_name(name)
print("Projector tensors added\n")
-state_dict = model.state_dict()
+state_dict = model.state_dict() # pyright: ignore[reportAttributeAccessIssue]
for name, data in state_dict.items():
if should_skip_tensor(name, has_text_encoder, has_vision_encoder, has_llava_projector):
# we don't need this
import glob
import os
import torch
-from safetensors.torch import load as safe_load, save as safe_save, safe_open, save_file
+from safetensors import safe_open
+from safetensors.torch import save_file
+from typing import Any, ContextManager, cast
# Function to determine if file is a SafeTensor file
def is_safetensor_file(file_path):
def load_model(file_path):
if is_safetensor_file(file_path):
tensors = {}
- with safe_open(file_path, framework="pt", device="cpu") as f:
+ with cast(ContextManager[Any], safe_open(file_path, framework="pt", device="cpu")) as f:
for key in f.keys():
tensors[key] = f.get_tensor(key).clone()
# output shape
if last_checkpoint is not None:
for k, v in last_checkpoint.items():
print(k)
- print(f"Found {len(mm_tensors)} tensors to extract out of {len(last_checkpoint)} tensors.")
+ print(f"Found {len(mm_tensors)} tensors to extract out of {len(last_checkpoint) if last_checkpoint is not None else 0} tensors.")
print("No tensors found. Is this a LLaVA model?")
exit()
# projector = {name: checkpoint.[name].float() for name in mm_tensors}
projector = {}
for name in mm_tensors:
+ assert last_checkpoint is not None
projector[name] = last_checkpoint[name].float()
for name in first_mm_tensors:
+ assert first_checkpoint is not None
projector[name] = first_checkpoint[name].float()
if len(projector) > 0:
from copy import copy
from enum import Enum
from inspect import getdoc, isclass
-from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, get_args, get_origin, get_type_hints
+from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, get_args, get_origin
from docstring_parser import parse
-from pydantic import BaseModel, Field, create_model
+from pydantic import BaseModel, create_model
if TYPE_CHECKING:
from types import GenericAlias
# python 3.8 compat
from typing import _GenericAlias as GenericAlias
+# TODO: fix this
+# pyright: reportAttributeAccessIssue=information
+
class PydanticDataType(Enum):
"""
# Define the integer part rule
integer_part_rule = (
- "integer-part" + (f"-max{max_digit}" if max_digit is not None else "") + (
- f"-min{min_digit}" if min_digit is not None else "")
+ "integer-part"
+ + (f"-max{max_digit}" if max_digit is not None else "")
+ + (f"-min{min_digit}" if min_digit is not None else "")
)
# Define the fractional part rule based on precision constraints
if not issubclass(model, BaseModel):
# For non-Pydantic classes, generate model_fields from __annotations__ or __init__
if hasattr(model, "__annotations__") and model.__annotations__:
- model_fields = {name: (typ, ...) for name, typ in model.__annotations__.items()}
+ model_fields = {name: (typ, ...) for name, typ in model.__annotations__.items()} # pyright: ignore[reportGeneralTypeIssues]
else:
init_signature = inspect.signature(model.__init__)
parameters = init_signature.parameters
str: Generated text documentation.
"""
documentation = ""
- pyd_models = [(model, True) for model in pydantic_models]
+ pyd_models: list[tuple[type[BaseModel], bool]] = [(model, True) for model in pydantic_models]
for model, add_prefix in pyd_models:
if add_prefix:
documentation += f"{model_prefix}: {model.__name__}\n"
# Indenting the fields section
documentation += f" {fields_prefix}:\n"
else:
- documentation += f" Fields:\n"
+ documentation += f" Fields:\n" # noqa: F541
if isclass(model) and issubclass(model, BaseModel):
for name, field_type in model.__annotations__.items():
# if name == "markdown_code_block":
return field_text
if field_description != "":
- field_text += f" Description: " + field_description + "\n"
+ field_text += f" Description: {field_description}\n"
# Check for and include field-specific examples if available
if hasattr(model, "Config") and hasattr(model.Config,
str: Generated text documentation.
"""
documentation = ""
- pyd_models = [(model, True) for model in pydantic_models]
+ pyd_models: list[tuple[type[BaseModel], bool]] = [(model, True) for model in pydantic_models]
for model, add_prefix in pyd_models:
if add_prefix:
documentation += f"{model_prefix}: {model.__name__}\n"
dynamic_fields[param.name] = (
param.annotation if param.annotation != inspect.Parameter.empty else str, default_value)
# Creating the dynamic model
- dynamic_model = create_model(f"{func.__name__}", **dynamic_fields) # type: ignore[call-overload]
+ dynamic_model = create_model(f"{func.__name__}", **dynamic_fields)
for name, param_doc in param_docs:
dynamic_model.model_fields[name].description = param_doc.description
return output
-from enum import Enum
-
-
def json_schema_to_python_types(schema):
type_map = {
"any": Any,
if items != {}:
array = {"properties": items}
array_type = convert_dictionary_to_pydantic_model(array, f"{model_name}_{field_name}_items")
- fields[field_name] = (List[array_type], ...) # type: ignore[valid-type]
+ fields[field_name] = (List[array_type], ...)
else:
fields[field_name] = (list, ...)
elif field_type == "object":
required = field_data.get("enum", [])
for key, field in fields.items():
if key not in required:
- fields[key] = (Optional[fields[key][0]], ...)
+ optional_type = fields[key][0]
+ fields[key] = (Optional[optional_type], ...)
else:
field_type = json_schema_to_python_types(field_type)
fields[field_name] = (field_type, ...)
required = dictionary.get("required", [])
for key, field in fields.items():
if key not in required:
- fields[key] = (Optional[fields[key][0]], ...)
+ optional_type = fields[key][0]
+ fields[key] = (Optional[optional_type], ...)
custom_model = create_model(model_name, **fields)
return custom_model
# Function calling example using pydantic models.
+from __future__ import annotations
+
import datetime
-import importlib
import json
from enum import Enum
from typing import Optional, Union
if call["function"] == "Calculator":
print(Calculator(**call["params"]).run())
elif call["function"] == "get_current_datetime":
- print(current_datetime_model(**call["params"]).run())
+ print(current_datetime_model(**call["params"]).run()) # pyright: ignore[reportAttributeAccessIssue]
elif call["function"] == "get_current_weather":
- print(current_weather_tool_model(**call["params"]).run())
+ print(current_weather_tool_model(**call["params"]).run()) # pyright: ignore[reportAttributeAccessIssue]
# Should output something like this:
# 2024-01-14 13:36:06
# {"location": "London", "temperature": "42", "unit": "celsius"}
+from __future__ import annotations
+
import argparse
import json
import os
sys.exit(1)
# start the benchmark
+ iterations = 0
+ data = {}
try:
start_benchmark(args)
- iterations = 0
with open("results.github.env", 'w') as github_env:
# parse output
with open('k6-results.json', 'r') as bench_results:
timestamps, metric_values = zip(*values)
metric_values = [float(value) for value in metric_values]
prometheus_metrics[metric] = metric_values
- timestamps_dt = [datetime.fromtimestamp(int(ts)) for ts in timestamps]
+ timestamps_dt = [str(datetime.fromtimestamp(int(ts))) for ts in timestamps]
plt.figure(figsize=(16, 10), dpi=80)
plt.plot(timestamps_dt, metric_values, label=metric)
plt.xticks(rotation=0, fontsize=14, horizontalalignment='center', alpha=.7)
plt.close()
# Mermaid format in case images upload failed
- with (open(f"{metric}.mermaid", 'w') as mermaid_f):
+ with open(f"{metric}.mermaid", 'w') as mermaid_f:
mermaid = (
f"""---
config:
}
server_process = subprocess.Popen(
args,
- **pkwargs)
+ **pkwargs) # pyright: ignore[reportArgumentType, reportCallIssue]
def server_log(in_stream, out_stream):
for line in iter(in_stream.readline, b''):
import asyncio
-import collections
import json
import os
import re
import sys
import threading
import time
+from collections.abc import Sequence
from contextlib import closing
from re import RegexFlag
+from typing import Any, Literal, cast
import aiohttp
import numpy as np
import openai
-from behave import step
+from openai.types.chat import ChatCompletionChunk
+from behave import step # pyright: ignore[reportAttributeAccessIssue]
from behave.api.async_step import async_run_until_complete
from prometheus_client import parser
+# pyright: reportRedeclaration=false
@step("a server listening on {server_fqdn}:{server_port}")
-def step_server_config(context, server_fqdn, server_port):
+def step_server_config(context, server_fqdn: str, server_port: str):
context.server_fqdn = server_fqdn
context.server_port = int(server_port)
context.n_threads = None
@step('a model file {hf_file} from HF repo {hf_repo}')
-def step_download_hf_model(context, hf_file, hf_repo):
+def step_download_hf_model(context, hf_file: str, hf_repo: str):
context.model_hf_repo = hf_repo
context.model_hf_file = hf_file
context.model_file = os.path.basename(hf_file)
@step('a model file {model_file}')
-def step_model_file(context, model_file):
+def step_model_file(context, model_file: str):
context.model_file = model_file
@step('a model url {model_url}')
-def step_model_url(context, model_url):
+def step_model_url(context, model_url: str):
context.model_url = model_url
@step('a model alias {model_alias}')
-def step_model_alias(context, model_alias):
+def step_model_alias(context, model_alias: str):
context.model_alias = model_alias
@step('{seed:d} as server seed')
-def step_seed(context, seed):
+def step_seed(context, seed: int):
context.server_seed = seed
@step('{ngl:d} GPU offloaded layers')
-def step_n_gpu_layer(context, ngl):
+def step_n_gpu_layer(context, ngl: int):
if 'N_GPU_LAYERS' in os.environ:
new_ngl = int(os.environ['N_GPU_LAYERS'])
if context.debug:
@step('{n_threads:d} threads')
-def step_n_threads(context, n_threads):
+def step_n_threads(context, n_threads: int):
context.n_thread = n_threads
@step('{draft:d} as draft')
-def step_draft(context, draft):
+def step_draft(context, draft: int):
context.draft = draft
@step('{n_ctx:d} KV cache size')
-def step_n_ctx(context, n_ctx):
+def step_n_ctx(context, n_ctx: int):
context.n_ctx = n_ctx
@step('{n_slots:d} slots')
-def step_n_slots(context, n_slots):
+def step_n_slots(context, n_slots: int):
context.n_slots = n_slots
@step('{n_predict:d} server max tokens to predict')
-def step_server_n_predict(context, n_predict):
+def step_server_n_predict(context, n_predict: int):
context.n_server_predict = n_predict
@step('{slot_save_path} as slot save path')
-def step_slot_save_path(context, slot_save_path):
+def step_slot_save_path(context, slot_save_path: str):
context.slot_save_path = slot_save_path
@step('using slot id {id_slot:d}')
-def step_id_slot(context, id_slot):
+def step_id_slot(context, id_slot: int):
context.id_slot = id_slot
@step("the server is {expecting_status}")
@async_run_until_complete
-async def step_wait_for_the_server_to_be_started(context, expecting_status):
+async def step_wait_for_the_server_to_be_started(context, expecting_status: Literal['healthy', 'ready', 'idle', 'busy'] | str):
match expecting_status:
case 'healthy':
await wait_for_health_status(context, context.base_url, 200, 'ok',
@step('all slots are {expected_slot_status_string}')
@async_run_until_complete
-async def step_all_slots_status(context, expected_slot_status_string):
+async def step_all_slots_status(context, expected_slot_status_string: Literal['idle', 'busy'] | str):
match expected_slot_status_string:
case 'idle':
expected_slot_status = 0
@step('a completion request with {api_error} api error')
@async_run_until_complete
-async def step_request_completion(context, api_error):
+async def step_request_completion(context, api_error: Literal['raised'] | str):
expect_api_error = api_error == 'raised'
seeds = await completions_seed(context, num_seeds=1)
completion = await request_completion(context.prompts.pop(),
def step_available_models(context):
# openai client always expects an api_key
openai.api_key = context.user_api_key if context.user_api_key is not None else 'nope'
- openai.api_base = f'{context.base_url}/v1'
- context.models = openai.Model.list().data
+ openai.base_url = f'{context.base_url}/v1/'
+ context.models = openai.models.list().data
@step('{n_model:d} models are supported')
@step('model {i_model:d} is {param} {preposition} {param_value}')
-def step_supported_models(context, i_model, param, preposition, param_value):
+def step_supported_models(context, i_model: int, param: Literal['identified', 'trained'] | str, preposition: str, param_value: str):
assert i_model < len(context.models)
model = context.models[i_model]
case 'identified':
value = model.id
case 'trained':
- value = str(model.meta.n_ctx_train)
+ value = str(model.meta["n_ctx_train"])
case _:
assert False, "param {param} not supported"
assert param_value == value, f"model param {param} {value} != {param_value}"
print(f"starting {context.n_prompts} concurrent completion requests...")
assert context.n_prompts > 0
seeds = await completions_seed(context)
+ assert seeds is not None
for prompt_no in range(context.n_prompts):
shifted_args = [context.prompts.pop(), seeds[prompt_no], *args]
context.concurrent_tasks.append(asyncio.create_task(f_completion(*shifted_args, **kwargs)))
id_slot=None,
expect_api_error=None,
user_api_key=None,
- temperature=None):
+ temperature=None) -> int | dict[str, Any]:
if debug:
print(f"Sending completion request: {prompt}")
origin = "my.super.domain"
async def oai_chat_completions(user_prompt,
seed,
system_prompt,
- base_url,
- base_path,
+ base_url: str,
+ base_path: str,
async_client,
debug=False,
temperature=None,
enable_streaming=None,
response_format=None,
user_api_key=None,
- expect_api_error=None):
+ expect_api_error=None) -> int | dict[str, Any]:
if debug:
print(f"Sending OAI Chat completions request: {user_prompt}")
# openai client always expects an api key
else:
try:
openai.api_key = user_api_key
- openai.api_base = f'{base_url}{base_path}'
- chat_completion = openai.Completion.create(
+ openai.base_url = f'{base_url}{base_path.removesuffix("chat")}'
+ assert model is not None
+ chat_completion = openai.chat.completions.create(
messages=payload['messages'],
model=model,
max_tokens=n_predict,
stream=enable_streaming,
- response_format=payload.get('response_format'),
+ response_format=payload.get('response_format') or openai.NOT_GIVEN,
seed=seed,
temperature=payload['temperature']
)
- except openai.error.AuthenticationError as e:
+ except openai.AuthenticationError as e:
if expect_api_error is not None and expect_api_error:
return 401
else:
assert False, f'error raised: {e}'
if enable_streaming:
+ chat_completion = cast(openai.Stream[ChatCompletionChunk], chat_completion)
for chunk in chat_completion:
assert len(chunk.choices) == 1
delta = chunk.choices[0].delta
- if 'content' in delta:
- completion_response['content'] += delta['content']
+ if delta.content is not None:
+ completion_response['content'] += delta.content
completion_response['timings']['predicted_n'] += 1
completion_response['truncated'] = chunk.choices[0].finish_reason != 'stop'
else:
assert len(chat_completion.choices) == 1
+ assert chat_completion.usage is not None
completion_response = {
'content': chat_completion.choices[0].message.content,
'timings': {
return completion_response
-async def request_embedding(content, seed, base_url=None):
+async def request_embedding(content, seed, base_url=None) -> list[list[float]]:
async with aiohttp.ClientSession() as session:
async with session.post(f'{base_url}/embedding',
json={
async def request_oai_embeddings(input, seed,
base_url=None, user_api_key=None,
- model=None, async_client=False):
+ model=None, async_client=False) -> list[list[float]]:
# openai client always expects an api_key
user_api_key = user_api_key if user_api_key is not None else 'nope'
if async_client:
response_json = await response.json()
assert response_json['model'] == model, f"invalid model received: {response_json['model']}"
assert response_json['object'] == 'list'
- if isinstance(input, collections.abc.Sequence):
+ if isinstance(input, Sequence):
embeddings = []
for an_oai_embeddings in response_json['data']:
embeddings.append(an_oai_embeddings['embedding'])
return embeddings
else:
openai.api_key = user_api_key
- openai.api_base = f'{base_url}/v1'
- oai_embeddings = openai.Embedding.create(
+ openai.base_url = f'{base_url}/v1/'
+ assert model is not None
+ oai_embeddings = openai.embeddings.create(
model=model,
input=input,
)
- if isinstance(input, collections.abc.Sequence):
- embeddings = []
- for an_oai_embeddings in oai_embeddings.data:
- embeddings.append(an_oai_embeddings.embedding)
- else:
- embeddings = [oai_embeddings.data.embedding]
- return embeddings
+ return [e.embedding for e in oai_embeddings.data]
def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None):
if i == j:
continue
content_j = response_j['content']
- assert content_i == content_j, "contents not equal"
+ assert content_i == content_j, "contents not equal"
def assert_all_predictions_different(completion_responses):
if i == j:
continue
content_j = response_j['content']
- assert content_i != content_j, "contents not different"
+ assert content_i != content_j, "contents not different"
def assert_all_token_probabilities_equal(completion_responses):
if i == j:
continue
probs_j = response_j['completion_probabilities'][pos]['probs']
- assert probs_i == probs_j, "contents not equal"
+ assert probs_i == probs_j, "contents not equal"
async def gather_tasks_results(context):
}
context.server_process = subprocess.Popen(
[str(arg) for arg in [context.server_path, *server_args]],
- **pkwargs)
+ **pkwargs) # pyright: ignore[reportArgumentType, reportCallIssue]
def server_log(in_stream, out_stream):
for line in iter(in_stream.readline, b''):
aiohttp~=3.9.3
behave~=1.2.6
huggingface_hub~=0.20.3
-numpy~=1.24.4
-openai~=0.25.0
+numpy~=1.26.4
+openai~=1.30.3
prometheus-client~=0.20.0
import asyncio
+import asyncio.threads
import requests
import numpy as np
+
n = 8
result = []
async def requests_post_async(*args, **kwargs):
- return await asyncio.to_thread(requests.post, *args, **kwargs)
+ return await asyncio.threads.to_thread(requests.post, *args, **kwargs)
async def main():
model_url = "http://127.0.0.1:6900"
if len(self.ne) == 0:
self.nbytes = 0
else:
- self.nbytes = int(np.product(self.ne)) * 4
+ self.nbytes = int(np.prod(self.ne)) * 4
else:
raise ValueError(f"Unhandled data type '{self.dtype}'")
tasks = []
+ base_dict = {"FLOAT_TYPE": "float"}
+
for fp16 in (False, True):
# MUL_MAT
matmul_shaders(tasks, fp16, False)
matmul_shaders(tasks, fp16, True)
for tname in type_names:
- base_dict = {"FLOAT_TYPE": "float"}
-
# mul mat vec
data_a_key = f"DATA_A_{tname.upper()}"
shader = f"mul_mat_vec_{tname}.comp" if tname.endswith("_k") else "mul_mat_vec.comp"
class GGUFReader:
# I - same as host, S - swapped
- byte_order: Literal['I'] | Literal['S'] = 'I'
+ byte_order: Literal['I', 'S'] = 'I'
alignment: int = GGUF_DEFAULT_ALIGNMENT
data_offset: int
GGUFValueType.BOOL: np.bool_,
}
- def __init__(self, path: os.PathLike[str] | str, mode: Literal['r'] | Literal['r+'] | Literal['c'] = 'r'):
+ def __init__(self, path: os.PathLike[str] | str, mode: Literal['r', 'r+', 'c'] = 'r'):
self.data = np.memmap(path, mode = mode)
offs = 0
return self.tensors[idx]
def _get(
- self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I'] | Literal['S'] | Literal['<'] = None,
+ self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I', 'S', '<'] = None,
) -> npt.NDArray[Any]:
count = int(count)
itemsize = int(np.empty([], dtype = dtype).itemsize)
class LazyMeta(ABCMeta):
def __new__(cls, name: str, bases: tuple[type, ...], namespace: dict[str, Any], **kwargs):
- def __getattr__(self, __name: str) -> Any:
- meta_attr = getattr(self._meta, __name)
+ def __getattr__(self, name: str) -> Any:
+ meta_attr = getattr(self._meta, name)
if callable(meta_attr):
return type(self)._wrap_fn(
- (lambda s, *args, **kwargs: getattr(s, __name)(*args, **kwargs)),
+ (lambda s, *args, **kwargs: getattr(s, name)(*args, **kwargs)),
use_self=self,
)
elif isinstance(meta_attr, self._tensor_type):
# e.g. self.T with torch.Tensor should still be wrapped
- return type(self)._wrap_fn(lambda s: getattr(s, __name))(self)
+ return type(self)._wrap_fn(lambda s: getattr(s, name))(self)
else:
# no need to wrap non-tensor properties,
# and they likely don't depend on the actual contents of the tensor
res = cls.meta_with_dtype_and_shape(meta_noop, res.shape)
if isinstance(res, cls._tensor_type):
- def collect_replace(t: LazyBase):
- if collect_replace.shared_lazy is None:
- collect_replace.shared_lazy = t._lazy
- else:
- collect_replace.shared_lazy.extend(t._lazy)
- t._lazy = collect_replace.shared_lazy
+ class CollectSharedLazy:
+ # emulating a static variable
+ shared_lazy: None | deque[LazyBase] = None
- # emulating a static variable
- collect_replace.shared_lazy = None
+ @staticmethod
+ def collect_replace(t: LazyBase):
+ if CollectSharedLazy.shared_lazy is None:
+ CollectSharedLazy.shared_lazy = t._lazy
+ else:
+ CollectSharedLazy.shared_lazy.extend(t._lazy)
+ t._lazy = CollectSharedLazy.shared_lazy
- LazyBase._recurse_apply(args, collect_replace)
+ LazyBase._recurse_apply(args, CollectSharedLazy.collect_replace)
- shared_lazy = collect_replace.shared_lazy
+ shared_lazy = CollectSharedLazy.shared_lazy
return cls(meta=cls.eager_to_meta(res), lazy=shared_lazy, args=args, func=lambda a: fn(*a, **kwargs))
else:
lt._args = cls._recurse_apply(lt._args, already_eager_to_eager)
lt._data = lt._func(lt._args)
# sanity check
+ assert lt._data is not None
assert lt._data.dtype == lt._meta.dtype
assert lt._data.shape == lt._meta.shape
+# pyright: reportUnusedImport=false
+
from .gguf_convert_endian import main as gguf_convert_endian_entrypoint
from .gguf_dump import main as gguf_dump_entrypoint
from .gguf_set_metadata import main as gguf_set_metadata_entrypoint
bar.update(sum_weights_in_tensor)
sha1_layer = hashlib.sha1()
- sha1_layer.update(tensor.data)
- sha1.update(tensor.data)
- uuidv5_sha1.update(tensor.data)
+ sha1_layer.update(tensor.data.data)
+ sha1.update(tensor.data.data)
+ uuidv5_sha1.update(tensor.data.data)
print("sha1 {0} {1}:{2}".format(sha1_layer.hexdigest(), filename, tensor.name)) # noqa: NP100
# Flush Hash Progress Bar
#!/usr/bin/env python3
+from __future__ import annotations
+
import logging
import argparse
import os
-import gguf # noqa: F401
+import gguf # noqa: F401 # pyright: ignore[reportUnusedImport]
# TODO: add tests
{
"extraPaths": ["gguf-py"],
-}
+ "pythonVersion": "3.9",
+ "pythonPlatform": "All",
+ "reportUnusedImport": "warning",
+ "reportDuplicateImport": "error",
+ "reportDeprecated": "warning",
+ "reportUnnecessaryTypeIgnoreComment": "warning",
+ "executionEnvironments": [
+ {
+ // TODO: make this version override work correctly
+ "root": "gguf-py",
+ "pythonVersion": "3.8",
+ },
+ {
+ // uses match expressions in steps.py
+ "root": "examples/server/tests",
+ "pythonVersion": "3.10",
+ },
+ ],
+ }
--- /dev/null
+-r ../examples/llava/requirements.txt
+-r ../examples/server/bench/requirements.txt
+-r ../examples/server/tests/requirements.txt
+
+-r ./requirements-compare-llama-bench.txt
+-r ./requirements-pydantic.txt
+-r ./requirements-test-tokenizer-random.txt
+
+-r ./requirements-convert_hf_to_gguf.txt
+-r ./requirements-convert_hf_to_gguf_update.txt
+-r ./requirements-convert_legacy_llama.txt
+-r ./requirements-convert_llama_ggml_to_gguf.txt
--- /dev/null
+tabulate~=0.9.0
+GitPython~=3.1.43
--- /dev/null
+docstring_parser~=0.15
+pydantic~=2.6.3
--- /dev/null
+cffi~=1.16.0
fatal "$py missing requirements. Expected: $reqs"
fi
+ # Check that all sub-requirements are added to top-level requirements.txt
+ if ! grep -qF "$reqs" requirements.txt; then
+ fatal "$reqs needs to be added to requirements.txt"
+ fi
+
local venv="$workdir/$pyname-venv"
python3 -m venv "$venv"
readonly ignore_eq_eq='check_requirements: ignore "=="'
-for req in "$reqs_dir"/*; do
- # Check that all sub-requirements are added to top-level requirements.txt
- if ! grep -qF "$req" requirements.txt; then
- fatal "$req needs to be added to requirements.txt"
- fi
-
+for req in */**/requirements*.txt; do
# Make sure exact release versions aren't being pinned in the requirements
# Filters out the ignore string
if grep -vF "$ignore_eq_eq" "$req" | grep -q '=='; then
try:
repo = git.Repo(".", search_parent_directories=True)
-except git.exc.InvalidGitRepositoryError:
+except git.InvalidGitRepositoryError:
repo = None
-def find_parent_in_data(commit):
+def find_parent_in_data(commit: git.Commit):
"""Helper function to find the most recent parent measured in number of commits for which there is data."""
- heap = [(0, commit)]
+ heap: list[tuple[int, git.Commit]] = [(0, commit)]
seen_hexsha8 = set()
while heap:
depth, current_commit = heapq.heappop(heap)
return None
-def get_all_parent_hexsha8s(commit):
+def get_all_parent_hexsha8s(commit: git.Commit):
"""Helper function to recursively get hexsha8 values for all parents of a commit."""
unvisited = [commit]
visited = []
+from __future__ import annotations
+
import array
import unicodedata
import requests
# group ranges with same flags
-ranges_flags = [(0, codepoint_flags[0])] # start, flags
+ranges_flags: list[tuple[int, int]] = [(0, codepoint_flags[0])] # start, flags
for codepoint, flags in enumerate(codepoint_flags):
if flags != ranges_flags[-1][1]:
ranges_flags.append((codepoint, flags))
# group ranges with same nfd
-ranges_nfd = [(0, 0, 0)] # start, last, nfd
+ranges_nfd: list[tuple[int, int, int]] = [(0, 0, 0)] # start, last, nfd
for codepoint, norm in table_nfd:
start = ranges_nfd[-1][0]
if ranges_nfd[-1] != (start, codepoint - 1, norm):
- ranges_nfd.append(None)
+ ranges_nfd.append(None) # type: ignore[arg-type] # dummy, will be replaced below
start = codepoint
ranges_nfd[-1] = (start, codepoint, norm)
out("};\n")
out("const std::unordered_map<uint32_t, uint32_t> unicode_map_lowercase = {")
-for tuple in table_lowercase:
- out("{0x%06X, 0x%06X}," % tuple)
+for tuple_lw in table_lowercase:
+ out("{0x%06X, 0x%06X}," % tuple_lw)
out("};\n")
out("const std::unordered_map<uint32_t, uint32_t> unicode_map_uppercase = {")
-for tuple in table_uppercase:
- out("{0x%06X, 0x%06X}," % tuple)
+for tuple_up in table_uppercase:
+ out("{0x%06X, 0x%06X}," % tuple_up)
out("};\n")
out("const std::vector<range_nfd> unicode_ranges_nfd = { // start, last, nfd")
# python3 tests/test-tokenizer-random.py ./models/ggml-vocab-llama-bpe.gguf ./models/tokenizers/llama-bpe
#
+from __future__ import annotations
+
import time
import logging
import argparse
import random
import unicodedata
-from typing import Iterator
+from pathlib import Path
+from typing import Any, Iterator, cast
+from typing_extensions import Buffer
import cffi
from transformers import AutoTokenizer
DEFAULT_PATH_INCLUDES = ["./ggml/include/", "./include/"]
DEFAULT_PATH_LIBLLAMA = "./build/src/libllama.so" # CMakeLists.txt: BUILD_SHARED_LIBS ON
- def __init__(self, path_llama_h: str = None, path_includes: list[str] = [], path_libllama: str = None):
+ def __init__(self, path_llama_h: str | None = None, path_includes: list[str] = [], path_libllama: str | None = None):
path_llama_h = path_llama_h or self.DEFAULT_PATH_LLAMA_H
path_includes = path_includes or self.DEFAULT_PATH_INCLUDES
path_libllama = path_libllama or self.DEFAULT_PATH_LIBLLAMA
(self.ffi, self.lib) = self._load_libllama_cffi(path_llama_h, path_includes, path_libllama)
self.lib.llama_backend_init()
- def _load_libllama_cffi(self, path_llama_h: str, path_includes: list[str], path_libllama: str):
- cmd = ["gcc", "-E", "-P", "-D__restrict=", "-D__attribute__(x)=", "-D__asm__(x)="]
+ def _load_libllama_cffi(self, path_llama_h: str, path_includes: list[str], path_libllama: str) -> tuple[cffi.FFI, Any]:
+ cmd = ["gcc", "-O0", "-E", "-P", "-D__restrict=", "-D__attribute__(x)=", "-D__asm__(x)="]
cmd += ["-I" + path for path in path_includes] + [path_llama_h]
res = subprocess.run(cmd, stdout=subprocess.PIPE)
assert (res.returncode == 0)
class LibLlamaModel:
def __init__(self, libllama: LibLlama, path_model: str, mparams={}, cparams={}):
- self.lib = libllama.lib
+ self.lib: Any = libllama.lib
self.ffi = libllama.ffi
if isinstance(mparams, dict):
mparams = libllama.model_default_params(**mparams)
self.lib = None
def tokenize(self, text: str, add_special: bool = False, parse_special: bool = False) -> list[int]:
- text = text.encode("utf-8")
- num = self.lib.llama_tokenize(self.model, text, len(text), self.token_ids, len(self.token_ids), add_special, parse_special)
+ encoded_text: bytes = text.encode("utf-8")
+ num = self.lib.llama_tokenize(self.model, encoded_text, len(encoded_text), self.token_ids, len(self.token_ids), add_special, parse_special)
while num < 0 and len(self.token_ids) < (16 << 20):
self.token_ids = self.ffi.new("llama_token[]", -2 * num)
- num = self.lib.llama_tokenize(self.model, text, len(text), self.token_ids, len(self.token_ids), add_special, parse_special)
+ num = self.lib.llama_tokenize(self.model, encoded_text, len(encoded_text), self.token_ids, len(self.token_ids), add_special, parse_special)
return list(self.token_ids[0:num])
def detokenize(self, ids: list[int], remove_special: bool = False, unparse_special: bool = False) -> str:
while num < 0 and len(self.text_buff) < (16 << 20):
self.text_buff = self.ffi.new("uint8_t[]", -2 * num)
num = self.lib.llama_detokenize(self.model, self.token_ids, len(ids), self.text_buff, len(self.text_buff), remove_special, unparse_special)
- return str(self.ffi.buffer(self.text_buff, num), encoding="utf-8", errors="replace") # replace errors with '\uFFFD'
+ return str(cast(Buffer, self.ffi.buffer(self.text_buff, num)), encoding="utf-8", errors="replace") # replace errors with '\uFFFD'
class Tokenizer:
class TokenizerLlamaCpp (Tokenizer):
- libllama: LibLlama = None
+ libllama: LibLlama | None = None
def __init__(self, vocab_file: str):
if not self.libllama:
def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp, generator: Iterator[str]):
- def find_first_mismatch(ids1: list[int], ids2: list[int]):
+ def find_first_mismatch(ids1: list[int] | str, ids2: list[int] | str):
for i, (a, b) in enumerate(zip(ids1, ids2)):
if a != b:
return i
decode_errors = 0
MAX_ERRORS = 10
- logger.info("%s: %s" % (generator.__name__, "ini"))
+ logger.info("%s: %s" % (generator.__qualname__, "ini"))
for text in generator:
# print(repr(text), text.encode())
# print(repr(text), hex(ord(text[0])), text.encode())
break
t_total = time.perf_counter() - t_start
- logger.info(f"{generator.__name__}: end, {t_encode1=:.3f} {t_encode2=:.3f} {t_decode1=:.3f} {t_decode2=:.3f} {t_total=:.3f}")
+ logger.info(f"{generator.__qualname__}: end, {t_encode1=:.3f} {t_encode2=:.3f} {t_decode1=:.3f} {t_decode2=:.3f} {t_total=:.3f}")
-def main(argv: list[str] = None):
+def main(argv: list[str] | None = None):
parser = argparse.ArgumentParser()
- parser.add_argument("vocab_file", help="path to vocab 'gguf' file")
- parser.add_argument("dir_tokenizer", help="directory containing 'tokenizer.model' file")
+ parser.add_argument("vocab_file", type=str, help="path to vocab 'gguf' file")
+ parser.add_argument("dir_tokenizer", type=str, help="directory containing 'tokenizer.model' file")
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
args = parser.parse_args(argv)
format = "%(levelname)s %(message)s",
)
- path_tokenizers = "./models/tokenizers/"
+ path_tokenizers = Path("./models/tokenizers/")
path_vocab_format = "./models/ggml-vocab-%s.gguf"
tokenizers = [
for tokenizer in tokenizers:
logger.info("-" * 50)
logger.info(f"TOKENIZER: '{tokenizer}'")
- vocab_file = path_vocab_format % tokenizer
- dir_tokenizer = path_tokenizers + "/" + tokenizer
- main([vocab_file, dir_tokenizer, "--verbose"])
+ vocab_file = Path(path_vocab_format % tokenizer)
+ dir_tokenizer = path_tokenizers / tokenizer
+ main([str(vocab_file), str(dir_tokenizer), "--verbose"])