]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
pydantic : replace uses of __annotations__ with get_type_hints (#8474)
authorcompilade <redacted>
Sun, 14 Jul 2024 23:51:21 +0000 (19:51 -0400)
committerGitHub <redacted>
Sun, 14 Jul 2024 23:51:21 +0000 (19:51 -0400)
* pydantic : replace uses of __annotations__ with get_type_hints

* pydantic : fix Python 3.9 and 3.10 support

examples/pydantic_models_to_grammar.py
examples/pydantic_models_to_grammar_examples.py
requirements/requirements-pydantic.txt

index d8145710ce250e9ea2378d651898f293141951bb..93e5dcb6c3855a1f0d33e9b333364828a0af8668 100644 (file)
@@ -6,7 +6,7 @@ import re
 from copy import copy
 from enum import Enum
 from inspect import getdoc, isclass
-from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, get_args, get_origin
+from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, get_args, get_origin, get_type_hints
 
 from docstring_parser import parse
 from pydantic import BaseModel, create_model
@@ -53,35 +53,38 @@ class PydanticDataType(Enum):
 
 
 def map_pydantic_type_to_gbnf(pydantic_type: type[Any]) -> str:
-    if isclass(pydantic_type) and issubclass(pydantic_type, str):
+    origin_type = get_origin(pydantic_type)
+    origin_type = pydantic_type if origin_type is None else origin_type
+
+    if isclass(origin_type) and issubclass(origin_type, str):
         return PydanticDataType.STRING.value
-    elif isclass(pydantic_type) and issubclass(pydantic_type, bool):
+    elif isclass(origin_type) and issubclass(origin_type, bool):
         return PydanticDataType.BOOLEAN.value
-    elif isclass(pydantic_type) and issubclass(pydantic_type, int):
+    elif isclass(origin_type) and issubclass(origin_type, int):
         return PydanticDataType.INTEGER.value
-    elif isclass(pydantic_type) and issubclass(pydantic_type, float):
+    elif isclass(origin_type) and issubclass(origin_type, float):
         return PydanticDataType.FLOAT.value
-    elif isclass(pydantic_type) and issubclass(pydantic_type, Enum):
+    elif isclass(origin_type) and issubclass(origin_type, Enum):
         return PydanticDataType.ENUM.value
 
-    elif isclass(pydantic_type) and issubclass(pydantic_type, BaseModel):
-        return format_model_and_field_name(pydantic_type.__name__)
-    elif get_origin(pydantic_type) is list:
+    elif isclass(origin_type) and issubclass(origin_type, BaseModel):
+        return format_model_and_field_name(origin_type.__name__)
+    elif origin_type is list:
         element_type = get_args(pydantic_type)[0]
         return f"{map_pydantic_type_to_gbnf(element_type)}-list"
-    elif get_origin(pydantic_type) is set:
+    elif origin_type is set:
         element_type = get_args(pydantic_type)[0]
         return f"{map_pydantic_type_to_gbnf(element_type)}-set"
-    elif get_origin(pydantic_type) is Union:
+    elif origin_type is Union:
         union_types = get_args(pydantic_type)
         union_rules = [map_pydantic_type_to_gbnf(ut) for ut in union_types]
         return f"union-{'-or-'.join(union_rules)}"
-    elif get_origin(pydantic_type) is Optional:
+    elif origin_type is Optional:
         element_type = get_args(pydantic_type)[0]
         return f"optional-{map_pydantic_type_to_gbnf(element_type)}"
-    elif isclass(pydantic_type):
-        return f"{PydanticDataType.CUSTOM_CLASS.value}-{format_model_and_field_name(pydantic_type.__name__)}"
-    elif get_origin(pydantic_type) is dict:
+    elif isclass(origin_type):
+        return f"{PydanticDataType.CUSTOM_CLASS.value}-{format_model_and_field_name(origin_type.__name__)}"
+    elif origin_type is dict:
         key_type, value_type = get_args(pydantic_type)
         return f"custom-dict-key-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(key_type))}-value-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(value_type))}"
     else:
@@ -118,7 +121,7 @@ def get_members_structure(cls, rule_name):
         # Modify this comprehension
         members = [
             f'  "\\"{name}\\"" ":"  {map_pydantic_type_to_gbnf(param_type)}'
-            for name, param_type in cls.__annotations__.items()
+            for name, param_type in get_type_hints(cls).items()
             if name != "self"
         ]
 
@@ -297,17 +300,20 @@ def generate_gbnf_rule_for_type(
     field_name = format_model_and_field_name(field_name)
     gbnf_type = map_pydantic_type_to_gbnf(field_type)
 
-    if isclass(field_type) and issubclass(field_type, BaseModel):
+    origin_type = get_origin(field_type)
+    origin_type = field_type if origin_type is None else origin_type
+
+    if isclass(origin_type) and issubclass(origin_type, BaseModel):
         nested_model_name = format_model_and_field_name(field_type.__name__)
         nested_model_rules, _ = generate_gbnf_grammar(field_type, processed_models, created_rules)
         rules.extend(nested_model_rules)
         gbnf_type, rules = nested_model_name, rules
-    elif isclass(field_type) and issubclass(field_type, Enum):
+    elif isclass(origin_type) and issubclass(origin_type, Enum):
         enum_values = [f'"\\"{e.value}\\""' for e in field_type]  # Adding escaped quotes
         enum_rule = f"{model_name}-{field_name} ::= {' | '.join(enum_values)}"
         rules.append(enum_rule)
         gbnf_type, rules = model_name + "-" + field_name, rules
-    elif get_origin(field_type) == list:  # Array
+    elif origin_type is list:  # Array
         element_type = get_args(field_type)[0]
         element_rule_name, additional_rules = generate_gbnf_rule_for_type(
             model_name, f"{field_name}-element", element_type, is_optional, processed_models, created_rules
@@ -317,7 +323,7 @@ def generate_gbnf_rule_for_type(
         rules.append(array_rule)
         gbnf_type, rules = model_name + "-" + field_name, rules
 
-    elif get_origin(field_type) == set or field_type == set:  # Array
+    elif origin_type is set:  # Array
         element_type = get_args(field_type)[0]
         element_rule_name, additional_rules = generate_gbnf_rule_for_type(
             model_name, f"{field_name}-element", element_type, is_optional, processed_models, created_rules
@@ -371,7 +377,7 @@ def generate_gbnf_rule_for_type(
             gbnf_type = f"{model_name}-{field_name}-optional"
         else:
             gbnf_type = f"{model_name}-{field_name}-union"
-    elif isclass(field_type) and issubclass(field_type, str):
+    elif isclass(origin_type) and issubclass(origin_type, str):
         if field_info and hasattr(field_info, "json_schema_extra") and field_info.json_schema_extra is not None:
             triple_quoted_string = field_info.json_schema_extra.get("triple_quoted_string", False)
             markdown_string = field_info.json_schema_extra.get("markdown_code_block", False)
@@ -387,8 +393,8 @@ def generate_gbnf_rule_for_type(
             gbnf_type = PydanticDataType.STRING.value
 
     elif (
-        isclass(field_type)
-        and issubclass(field_type, float)
+        isclass(origin_type)
+        and issubclass(origin_type, float)
         and field_info
         and hasattr(field_info, "json_schema_extra")
         and field_info.json_schema_extra is not None
@@ -413,8 +419,8 @@ def generate_gbnf_rule_for_type(
         )
 
     elif (
-        isclass(field_type)
-        and issubclass(field_type, int)
+        isclass(origin_type)
+        and issubclass(origin_type, int)
         and field_info
         and hasattr(field_info, "json_schema_extra")
         and field_info.json_schema_extra is not None
@@ -462,7 +468,7 @@ def generate_gbnf_grammar(model: type[BaseModel], processed_models: set[type[Bas
     if not issubclass(model, BaseModel):
         # For non-Pydantic classes, generate model_fields from __annotations__ or __init__
         if hasattr(model, "__annotations__") and model.__annotations__:
-            model_fields = {name: (typ, ...) for name, typ in model.__annotations__.items()}  # pyright: ignore[reportGeneralTypeIssues]
+            model_fields = {name: (typ, ...) for name, typ in get_type_hints(model).items()}
         else:
             init_signature = inspect.signature(model.__init__)
             parameters = init_signature.parameters
@@ -470,7 +476,7 @@ def generate_gbnf_grammar(model: type[BaseModel], processed_models: set[type[Bas
                             name != "self"}
     else:
         # For Pydantic models, use model_fields and check for ellipsis (required fields)
-        model_fields = model.__annotations__
+        model_fields = get_type_hints(model)
 
     model_rule_parts = []
     nested_rules = []
@@ -706,7 +712,7 @@ def generate_markdown_documentation(
         else:
             documentation += f"  Fields:\n"  # noqa: F541
         if isclass(model) and issubclass(model, BaseModel):
-            for name, field_type in model.__annotations__.items():
+            for name, field_type in get_type_hints(model).items():
                 # if name == "markdown_code_block":
                 #    continue
                 if get_origin(field_type) == list:
@@ -754,14 +760,17 @@ def generate_field_markdown(
     field_info = model.model_fields.get(field_name)
     field_description = field_info.description if field_info and field_info.description else ""
 
-    if get_origin(field_type) == list:
+    origin_type = get_origin(field_type)
+    origin_type = field_type if origin_type is None else origin_type
+
+    if origin_type == list:
         element_type = get_args(field_type)[0]
         field_text = f"{indent}{field_name} ({format_model_and_field_name(field_type.__name__)} of {format_model_and_field_name(element_type.__name__)})"
         if field_description != "":
             field_text += ":\n"
         else:
             field_text += "\n"
-    elif get_origin(field_type) == Union:
+    elif origin_type == Union:
         element_types = get_args(field_type)
         types = []
         for element_type in element_types:
@@ -792,9 +801,9 @@ def generate_field_markdown(
             example_text = f"'{field_example}'" if isinstance(field_example, str) else field_example
             field_text += f"{indent}  Example: {example_text}\n"
 
-    if isclass(field_type) and issubclass(field_type, BaseModel):
+    if isclass(origin_type) and issubclass(origin_type, BaseModel):
         field_text += f"{indent}  Details:\n"
-        for name, type_ in field_type.__annotations__.items():
+        for name, type_ in get_type_hints(field_type).items():
             field_text += generate_field_markdown(name, type_, field_type, depth + 2)
 
     return field_text
@@ -855,7 +864,7 @@ def generate_text_documentation(
 
         if isclass(model) and issubclass(model, BaseModel):
             documentation_fields = ""
-            for name, field_type in model.__annotations__.items():
+            for name, field_type in get_type_hints(model).items():
                 # if name == "markdown_code_block":
                 #    continue
                 if get_origin(field_type) == list:
@@ -948,7 +957,7 @@ def generate_field_text(
 
     if isclass(field_type) and issubclass(field_type, BaseModel):
         field_text += f"{indent}  Details:\n"
-        for name, type_ in field_type.__annotations__.items():
+        for name, type_ in get_type_hints(field_type).items():
             field_text += generate_field_text(name, type_, field_type, depth + 2)
 
     return field_text
index 8e7f46cf99a43b3eec9d8ab4ebdd07bd5c7ee87b..504ed98df64f29999384fee2b30318ac70777496 100644 (file)
@@ -20,6 +20,8 @@ def create_completion(prompt, grammar):
     response = requests.post("http://127.0.0.1:8080/completion", headers=headers, json=data)
     data = response.json()
 
+    assert data.get("error") is None, data
+
     print(data["content"])
     return data["content"]
 
index 2f9455b147d68e11f96113a69f87b15d2215b0eb..bdd423e07ea36b0d559ebbb9e67608e9aadc25e8 100644 (file)
@@ -1,2 +1,3 @@
 docstring_parser~=0.15
 pydantic~=2.6.3
+requests