]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
examples : make pydantic scripts pass mypy and support py3.8 (#5099)
authorJared Van Bortel <redacted>
Thu, 25 Jan 2024 19:51:24 +0000 (14:51 -0500)
committerGitHub <redacted>
Thu, 25 Jan 2024 19:51:24 +0000 (14:51 -0500)
examples/pydantic-models-to-grammar-examples.py
examples/pydantic_models_to_grammar.py

index cbf3766526ff7584086b3a269ee871e8e4e8f7ff..160966649b05dc07aba65b6762321e1140d2f017 100644 (file)
@@ -1,14 +1,14 @@
 # Function calling example using pydantic models.
 import datetime
+import importlib
 import json
 from enum import Enum
-from typing import Union, Optional
+from typing import Optional, Union
 
 import requests
 from pydantic import BaseModel, Field
-
-import importlib
-from pydantic_models_to_grammar import generate_gbnf_grammar_and_documentation, convert_dictionary_to_pydantic_model, add_run_method_to_dynamic_model, create_dynamic_model_from_function
+from pydantic_models_to_grammar import (add_run_method_to_dynamic_model, convert_dictionary_to_pydantic_model,
+                                        create_dynamic_model_from_function, generate_gbnf_grammar_and_documentation)
 
 
 # Function to get completion on the llama.cpp server with grammar.
@@ -35,7 +35,7 @@ class SendMessageToUser(BaseModel):
         print(self.message)
 
 
-# Enum for the calculator function.
+# Enum for the calculator tool.
 class MathOperation(Enum):
     ADD = "add"
     SUBTRACT = "subtract"
@@ -43,7 +43,7 @@ class MathOperation(Enum):
     DIVIDE = "divide"
 
 
-# Very simple calculator tool for the agent.
+# Simple pydantic calculator tool for the agent that can add, subtract, multiply, and divide. Docstring and description of fields will be used in system prompt.
 class Calculator(BaseModel):
     """
     Perform a math operation on two numbers.
@@ -148,37 +148,6 @@ def get_current_datetime(output_format: Optional[str] = None):
     return datetime.datetime.now().strftime(output_format)
 
 
-# Enum for the calculator tool.
-class MathOperation(Enum):
-    ADD = "add"
-    SUBTRACT = "subtract"
-    MULTIPLY = "multiply"
-    DIVIDE = "divide"
-
-
-
-# Simple pydantic calculator tool for the agent that can add, subtract, multiply, and divide. Docstring and description of fields will be used in system prompt.
-class Calculator(BaseModel):
-    """
-    Perform a math operation on two numbers.
-    """
-    number_one: Union[int, float] = Field(..., description="First number.")
-    operation: MathOperation = Field(..., description="Math operation to perform.")
-    number_two: Union[int, float] = Field(..., description="Second number.")
-
-    def run(self):
-        if self.operation == MathOperation.ADD:
-            return self.number_one + self.number_two
-        elif self.operation == MathOperation.SUBTRACT:
-            return self.number_one - self.number_two
-        elif self.operation == MathOperation.MULTIPLY:
-            return self.number_one * self.number_two
-        elif self.operation == MathOperation.DIVIDE:
-            return self.number_one / self.number_two
-        else:
-            raise ValueError("Unknown operation.")
-
-
 # Example function to get the weather
 def get_current_weather(location, unit):
     """Get the current weather in a given location"""
index 848c1c367d7017f2681f9963d11ee74939bca3f1..9acc7cc6dcd859f5cf66f7bf04d3c8ece12355ab 100644 (file)
@@ -1,15 +1,21 @@
+from __future__ import annotations
+
 import inspect
 import json
+import re
 from copy import copy
-from inspect import isclass, getdoc
-from types import NoneType
+from enum import Enum
+from inspect import getdoc, isclass
+from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, get_args, get_origin, get_type_hints
 
 from docstring_parser import parse
-from pydantic import BaseModel, create_model, Field
-from typing import Any, Type, List, get_args, get_origin, Tuple, Union, Optional, _GenericAlias
-from enum import Enum
-from typing import get_type_hints, Callable
-import re
+from pydantic import BaseModel, Field, create_model
+
+if TYPE_CHECKING:
+    from types import GenericAlias
+else:
+    # python 3.8 compat
+    from typing import _GenericAlias as GenericAlias
 
 
 class PydanticDataType(Enum):
@@ -43,7 +49,7 @@ class PydanticDataType(Enum):
     SET = "set"
 
 
-def map_pydantic_type_to_gbnf(pydantic_type: Type[Any]) -> str:
+def map_pydantic_type_to_gbnf(pydantic_type: type[Any]) -> str:
     if isclass(pydantic_type) and issubclass(pydantic_type, str):
         return PydanticDataType.STRING.value
     elif isclass(pydantic_type) and issubclass(pydantic_type, bool):
@@ -57,22 +63,22 @@ def map_pydantic_type_to_gbnf(pydantic_type: Type[Any]) -> str:
 
     elif isclass(pydantic_type) and issubclass(pydantic_type, BaseModel):
         return format_model_and_field_name(pydantic_type.__name__)
-    elif get_origin(pydantic_type) == list:
+    elif get_origin(pydantic_type) is list:
         element_type = get_args(pydantic_type)[0]
         return f"{map_pydantic_type_to_gbnf(element_type)}-list"
-    elif get_origin(pydantic_type) == set:
+    elif get_origin(pydantic_type) is set:
         element_type = get_args(pydantic_type)[0]
         return f"{map_pydantic_type_to_gbnf(element_type)}-set"
-    elif get_origin(pydantic_type) == Union:
+    elif get_origin(pydantic_type) is Union:
         union_types = get_args(pydantic_type)
         union_rules = [map_pydantic_type_to_gbnf(ut) for ut in union_types]
         return f"union-{'-or-'.join(union_rules)}"
-    elif get_origin(pydantic_type) == Optional:
+    elif get_origin(pydantic_type) is Optional:
         element_type = get_args(pydantic_type)[0]
         return f"optional-{map_pydantic_type_to_gbnf(element_type)}"
     elif isclass(pydantic_type):
         return f"{PydanticDataType.CUSTOM_CLASS.value}-{format_model_and_field_name(pydantic_type.__name__)}"
-    elif get_origin(pydantic_type) == dict:
+    elif get_origin(pydantic_type) is dict:
         key_type, value_type = get_args(pydantic_type)
         return f"custom-dict-key-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(key_type))}-value-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(value_type))}"
     else:
@@ -106,7 +112,6 @@ def get_members_structure(cls, rule_name):
         return f"{cls.__name__.lower()} ::= " + " | ".join(members)
     if cls.__annotations__ and cls.__annotations__ != {}:
         result = f'{rule_name} ::= "{{"'
-        type_list_rules = []
         # Modify this comprehension
         members = [
             f'  "\\"{name}\\"" ":"  {map_pydantic_type_to_gbnf(param_type)}'
@@ -116,27 +121,25 @@ def get_members_structure(cls, rule_name):
 
         result += '"," '.join(members)
         result += '  "}"'
-        return result, type_list_rules
-    elif rule_name == "custom-class-any":
+        return result
+    if rule_name == "custom-class-any":
         result = f"{rule_name} ::= "
         result += "value"
-        type_list_rules = []
-        return result, type_list_rules
-    else:
-        init_signature = inspect.signature(cls.__init__)
-        parameters = init_signature.parameters
-        result = f'{rule_name} ::=  "{{"'
-        type_list_rules = []
-        # Modify this comprehension too
-        members = [
-            f'  "\\"{name}\\"" ":"  {map_pydantic_type_to_gbnf(param.annotation)}'
-            for name, param in parameters.items()
-            if name != "self" and param.annotation != inspect.Parameter.empty
-        ]
+        return result
 
-        result += '", "'.join(members)
-        result += '  "}"'
-        return result, type_list_rules
+    init_signature = inspect.signature(cls.__init__)
+    parameters = init_signature.parameters
+    result = f'{rule_name} ::=  "{{"'
+    # Modify this comprehension too
+    members = [
+        f'  "\\"{name}\\"" ":"  {map_pydantic_type_to_gbnf(param.annotation)}'
+        for name, param in parameters.items()
+        if name != "self" and param.annotation != inspect.Parameter.empty
+    ]
+
+    result += '", "'.join(members)
+    result += '  "}"'
+    return result
 
 
 def regex_to_gbnf(regex_pattern: str) -> str:
@@ -269,7 +272,7 @@ def generate_gbnf_float_rules(max_digit=None, min_digit=None, max_precision=None
 
 def generate_gbnf_rule_for_type(
     model_name, field_name, field_type, is_optional, processed_models, created_rules, field_info=None
-) -> Tuple[str, list]:
+) -> tuple[str, list[str]]:
     """
     Generate GBNF rule for a given field type.
 
@@ -283,7 +286,7 @@ def generate_gbnf_rule_for_type(
     :param field_info: Additional information about the field (optional).
 
     :return: Tuple containing the GBNF type and a list of additional rules.
-    :rtype: Tuple[str, list]
+    :rtype: tuple[str, list]
     """
     rules = []
 
@@ -321,8 +324,7 @@ def generate_gbnf_rule_for_type(
         gbnf_type, rules = model_name + "-" + field_name, rules
 
     elif gbnf_type.startswith("custom-class-"):
-        nested_model_rules, field_types = get_members_structure(field_type, gbnf_type)
-        rules.append(nested_model_rules)
+        rules.append(get_members_structure(field_type, gbnf_type))
     elif gbnf_type.startswith("custom-dict-"):
         key_type, value_type = get_args(field_type)
 
@@ -341,14 +343,14 @@ def generate_gbnf_rule_for_type(
         union_rules = []
 
         for union_type in union_types:
-            if isinstance(union_type, _GenericAlias):
+            if isinstance(union_type, GenericAlias):
                 union_gbnf_type, union_rules_list = generate_gbnf_rule_for_type(
                     model_name, field_name, union_type, False, processed_models, created_rules
                 )
                 union_rules.append(union_gbnf_type)
                 rules.extend(union_rules_list)
 
-            elif not issubclass(union_type, NoneType):
+            elif not issubclass(union_type, type(None)):
                 union_gbnf_type, union_rules_list = generate_gbnf_rule_for_type(
                     model_name, field_name, union_type, False, processed_models, created_rules
                 )
@@ -424,14 +426,10 @@ def generate_gbnf_rule_for_type(
     else:
         gbnf_type, rules = gbnf_type, []
 
-    if gbnf_type not in created_rules:
-        return gbnf_type, rules
-    else:
-        if gbnf_type in created_rules:
-            return gbnf_type, rules
+    return gbnf_type, rules
 
 
-def generate_gbnf_grammar(model: Type[BaseModel], processed_models: set, created_rules: dict) -> (list, bool, bool):
+def generate_gbnf_grammar(model: type[BaseModel], processed_models: set[type[BaseModel]], created_rules: dict[str, list[str]]) -> tuple[list[str], bool]:
     """
 
     Generate GBnF Grammar
@@ -452,7 +450,7 @@ def generate_gbnf_grammar(model: Type[BaseModel], processed_models: set, created
     ```
     """
     if model in processed_models:
-        return []
+        return [], False
 
     processed_models.add(model)
     model_name = format_model_and_field_name(model.__name__)
@@ -518,7 +516,7 @@ def generate_gbnf_grammar(model: Type[BaseModel], processed_models: set, created
 
 
 def generate_gbnf_grammar_from_pydantic_models(
-    models: List[Type[BaseModel]], outer_object_name: str = None, outer_object_content: str = None,
+    models: list[type[BaseModel]], outer_object_name: str | None = None, outer_object_content: str | None = None,
     list_of_outputs: bool = False
 ) -> str:
     """
@@ -528,7 +526,7 @@ def generate_gbnf_grammar_from_pydantic_models(
     * grammar.
 
     Args:
-        models (List[Type[BaseModel]]): A list of Pydantic models to generate the grammar from.
+        models (list[type[BaseModel]]): A list of Pydantic models to generate the grammar from.
         outer_object_name (str): Outer object name for the GBNF grammar. If None, no outer object will be generated. Eg. "function" for function calling.
         outer_object_content (str): Content for the outer rule in the GBNF grammar. Eg. "function_parameters" or "params" for function calling.
         list_of_outputs (str, optional): Allows a list of output objects
@@ -543,9 +541,9 @@ def generate_gbnf_grammar_from_pydantic_models(
         # root ::= UserModel | PostModel
         # ...
     """
-    processed_models = set()
+    processed_models: set[type[BaseModel]] = set()
     all_rules = []
-    created_rules = {}
+    created_rules: dict[str, list[str]] = {}
     if outer_object_name is None:
         for model in models:
             model_rules, _ = generate_gbnf_grammar(model, processed_models, created_rules)
@@ -608,7 +606,7 @@ def get_primitive_grammar(grammar):
     Returns:
         str: GBNF primitive grammar string.
     """
-    type_list = []
+    type_list: list[type[object]] = []
     if "string-list" in grammar:
         type_list.append(str)
     if "boolean-list" in grammar:
@@ -666,14 +664,14 @@ triple-quotes ::= "'''" """
 
 
 def generate_markdown_documentation(
-    pydantic_models: List[Type[BaseModel]], model_prefix="Model", fields_prefix="Fields",
+    pydantic_models: list[type[BaseModel]], model_prefix="Model", fields_prefix="Fields",
     documentation_with_field_description=True
 ) -> str:
     """
     Generate markdown documentation for a list of Pydantic models.
 
     Args:
-        pydantic_models (List[Type[BaseModel]]): List of Pydantic model classes.
+        pydantic_models (list[type[BaseModel]]): list of Pydantic model classes.
         model_prefix (str): Prefix for the model section.
         fields_prefix (str): Prefix for the fields section.
         documentation_with_field_description (bool): Include field descriptions in the documentation.
@@ -731,7 +729,7 @@ def generate_markdown_documentation(
 
 
 def generate_field_markdown(
-    field_name: str, field_type: Type[Any], model: Type[BaseModel], depth=1,
+    field_name: str, field_type: type[Any], model: type[BaseModel], depth=1,
     documentation_with_field_description=True
 ) -> str:
     """
@@ -739,8 +737,8 @@ def generate_field_markdown(
 
     Args:
         field_name (str): Name of the field.
-        field_type (Type[Any]): Type of the field.
-        model (Type[BaseModel]): Pydantic model class.
+        field_type (type[Any]): Type of the field.
+        model (type[BaseModel]): Pydantic model class.
         depth (int): Indentation depth in the documentation.
         documentation_with_field_description (bool): Include field descriptions in the documentation.
 
@@ -798,7 +796,7 @@ def generate_field_markdown(
     return field_text
 
 
-def format_json_example(example: dict, depth: int) -> str:
+def format_json_example(example: dict[str, Any], depth: int) -> str:
     """
     Format a JSON example into a readable string with indentation.
 
@@ -819,14 +817,14 @@ def format_json_example(example: dict, depth: int) -> str:
 
 
 def generate_text_documentation(
-    pydantic_models: List[Type[BaseModel]], model_prefix="Model", fields_prefix="Fields",
+    pydantic_models: list[type[BaseModel]], model_prefix="Model", fields_prefix="Fields",
     documentation_with_field_description=True
 ) -> str:
     """
     Generate text documentation for a list of Pydantic models.
 
     Args:
-        pydantic_models (List[Type[BaseModel]]): List of Pydantic model classes.
+        pydantic_models (list[type[BaseModel]]): List of Pydantic model classes.
         model_prefix (str): Prefix for the model section.
         fields_prefix (str): Prefix for the fields section.
         documentation_with_field_description (bool): Include field descriptions in the documentation.
@@ -885,7 +883,7 @@ def generate_text_documentation(
 
 
 def generate_field_text(
-    field_name: str, field_type: Type[Any], model: Type[BaseModel], depth=1,
+    field_name: str, field_type: type[Any], model: type[BaseModel], depth=1,
     documentation_with_field_description=True
 ) -> str:
     """
@@ -893,8 +891,8 @@ def generate_field_text(
 
     Args:
         field_name (str): Name of the field.
-        field_type (Type[Any]): Type of the field.
-        model (Type[BaseModel]): Pydantic model class.
+        field_type (type[Any]): Type of the field.
+        model (type[BaseModel]): Pydantic model class.
         depth (int): Indentation depth in the documentation.
         documentation_with_field_description (bool): Include field descriptions in the documentation.
 
@@ -1017,8 +1015,8 @@ def generate_and_save_gbnf_grammar_and_documentation(
     pydantic_model_list,
     grammar_file_path="./generated_grammar.gbnf",
     documentation_file_path="./generated_grammar_documentation.md",
-    outer_object_name: str = None,
-    outer_object_content: str = None,
+    outer_object_name: str | None = None,
+    outer_object_content: str | None = None,
     model_prefix: str = "Output Model",
     fields_prefix: str = "Output Fields",
     list_of_outputs: bool = False,
@@ -1053,8 +1051,8 @@ def generate_and_save_gbnf_grammar_and_documentation(
 
 def generate_gbnf_grammar_and_documentation(
     pydantic_model_list,
-    outer_object_name: str = None,
-    outer_object_content: str = None,
+    outer_object_name: str | None = None,
+    outer_object_content: str | None = None,
     model_prefix: str = "Output Model",
     fields_prefix: str = "Output Fields",
     list_of_outputs: bool = False,
@@ -1086,9 +1084,9 @@ def generate_gbnf_grammar_and_documentation(
 
 
 def generate_gbnf_grammar_and_documentation_from_dictionaries(
-    dictionaries: List[dict],
-    outer_object_name: str = None,
-    outer_object_content: str = None,
+    dictionaries: list[dict[str, Any]],
+    outer_object_name: str | None = None,
+    outer_object_content: str | None = None,
     model_prefix: str = "Output Model",
     fields_prefix: str = "Output Fields",
     list_of_outputs: bool = False,
@@ -1098,7 +1096,7 @@ def generate_gbnf_grammar_and_documentation_from_dictionaries(
     Generate GBNF grammar and documentation from a list of dictionaries.
 
     Args:
-        dictionaries (List[dict]): List of dictionaries representing Pydantic models.
+        dictionaries (list[dict]): List of dictionaries representing Pydantic models.
         outer_object_name (str): Outer object name for the GBNF grammar. If None, no outer object will be generated. Eg. "function" for function calling.
         outer_object_content (str): Content for the outer rule in the GBNF grammar. Eg. "function_parameters" or "params" for function calling.
         model_prefix (str): Prefix for the model section in the documentation.
@@ -1120,7 +1118,7 @@ def generate_gbnf_grammar_and_documentation_from_dictionaries(
     return grammar, documentation
 
 
-def create_dynamic_model_from_function(func: Callable):
+def create_dynamic_model_from_function(func: Callable[..., Any]):
     """
     Creates a dynamic Pydantic model from a given function's type hints and adds the function as a 'run' method.
 
@@ -1135,6 +1133,7 @@ def create_dynamic_model_from_function(func: Callable):
     sig = inspect.signature(func)
 
     # Parse the docstring
+    assert func.__doc__ is not None
     docstring = parse(func.__doc__)
 
     dynamic_fields = {}
@@ -1157,7 +1156,6 @@ def create_dynamic_model_from_function(func: Callable):
                 f"Parameter '{param.name}' in function '{func.__name__}' lacks a description in the docstring")
 
         # Add parameter details to the schema
-        param_doc = next((d for d in docstring.params if d.arg_name == param.name), None)
         param_docs.append((param.name, param_doc))
         if param.default == inspect.Parameter.empty:
             default_value = ...
@@ -1166,10 +1164,10 @@ def create_dynamic_model_from_function(func: Callable):
         dynamic_fields[param.name] = (
             param.annotation if param.annotation != inspect.Parameter.empty else str, default_value)
     # Creating the dynamic model
-    dynamic_model = create_model(f"{func.__name__}", **dynamic_fields)
+    dynamic_model = create_model(f"{func.__name__}", **dynamic_fields)  # type: ignore[call-overload]
 
-    for param_doc in param_docs:
-        dynamic_model.model_fields[param_doc[0]].description = param_doc[1].description
+    for name, param_doc in param_docs:
+        dynamic_model.model_fields[name].description = param_doc.description
 
     dynamic_model.__doc__ = docstring.short_description
 
@@ -1182,16 +1180,16 @@ def create_dynamic_model_from_function(func: Callable):
     return dynamic_model
 
 
-def add_run_method_to_dynamic_model(model: Type[BaseModel], func: Callable):
+def add_run_method_to_dynamic_model(model: type[BaseModel], func: Callable[..., Any]):
     """
     Add a 'run' method to a dynamic Pydantic model, using the provided function.
 
     Args:
-        model (Type[BaseModel]): Dynamic Pydantic model class.
+        model (type[BaseModel]): Dynamic Pydantic model class.
         func (Callable): Function to be added as a 'run' method to the model.
 
     Returns:
-        Type[BaseModel]: Pydantic model class with the added 'run' method.
+        type[BaseModel]: Pydantic model class with the added 'run' method.
     """
 
     def run_method_wrapper(self):
@@ -1204,15 +1202,15 @@ def add_run_method_to_dynamic_model(model: Type[BaseModel], func: Callable):
     return model
 
 
-def create_dynamic_models_from_dictionaries(dictionaries: List[dict]):
+def create_dynamic_models_from_dictionaries(dictionaries: list[dict[str, Any]]):
     """
     Create a list of dynamic Pydantic model classes from a list of dictionaries.
 
     Args:
-        dictionaries (List[dict]): List of dictionaries representing model structures.
+        dictionaries (list[dict]): List of dictionaries representing model structures.
 
     Returns:
-        List[Type[BaseModel]]: List of generated dynamic Pydantic model classes.
+        list[type[BaseModel]]: List of generated dynamic Pydantic model classes.
     """
     dynamic_models = []
     for func in dictionaries:
@@ -1249,7 +1247,7 @@ def list_to_enum(enum_name, values):
     return Enum(enum_name, {value: value for value in values})
 
 
-def convert_dictionary_to_pydantic_model(dictionary: dict, model_name: str = "CustomModel") -> Type[BaseModel]:
+def convert_dictionary_to_pydantic_model(dictionary: dict[str, Any], model_name: str = "CustomModel") -> type[Any]:
     """
     Convert a dictionary to a Pydantic model class.
 
@@ -1258,9 +1256,9 @@ def convert_dictionary_to_pydantic_model(dictionary: dict, model_name: str = "Cu
         model_name (str): Name of the generated Pydantic model.
 
     Returns:
-        Type[BaseModel]: Generated Pydantic model class.
+        type[BaseModel]: Generated Pydantic model class.
     """
-    fields = {}
+    fields: dict[str, Any] = {}
 
     if "properties" in dictionary:
         for field_name, field_data in dictionary.get("properties", {}).items():
@@ -1277,7 +1275,7 @@ def convert_dictionary_to_pydantic_model(dictionary: dict, model_name: str = "Cu
                     if items != {}:
                         array = {"properties": items}
                         array_type = convert_dictionary_to_pydantic_model(array, f"{model_name}_{field_name}_items")
-                        fields[field_name] = (List[array_type], ...)
+                        fields[field_name] = (List[array_type], ...)  # type: ignore[valid-type]
                     else:
                         fields[field_name] = (list, ...)
                 elif field_type == "object":