1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
|
from __future__ import annotations
import io
import os
import inspect
import traceback
import contextlib
from typing import Any, TypeVar, Iterator, ForwardRef, cast
from datetime import date, datetime
from typing_extensions import Literal, get_args, get_origin, assert_type
import rich
from openai._types import Omit, NoneType
from openai._utils import (
is_dict,
is_list,
is_list_type,
is_union_type,
extract_type_arg,
is_annotated_type,
is_type_alias_type,
)
from openai._compat import PYDANTIC_V2, field_outer_type, get_model_fields
from openai._models import BaseModel
BaseModelT = TypeVar("BaseModelT", bound=BaseModel)
def evaluate_forwardref(forwardref: ForwardRef, globalns: dict[str, Any]) -> type:
return eval(str(forwardref), globalns) # type: ignore[no-any-return]
def assert_matches_model(model: type[BaseModelT], value: BaseModelT, *, path: list[str]) -> bool:
for name, field in get_model_fields(model).items():
field_value = getattr(value, name)
if PYDANTIC_V2:
allow_none = False
else:
# in v1 nullability was structured differently
# https://docs.pydantic.dev/2.0/migration/#required-optional-and-nullable-fields
allow_none = getattr(field, "allow_none", False)
assert_matches_type(
field_outer_type(field),
field_value,
path=[*path, name],
allow_none=allow_none,
)
return True
# Note: the `path` argument is only used to improve error messages when `--showlocals` is used
def assert_matches_type(
type_: Any,
value: object,
*,
path: list[str],
allow_none: bool = False,
) -> None:
if is_type_alias_type(type_):
type_ = type_.__value__
# unwrap `Annotated[T, ...]` -> `T`
if is_annotated_type(type_):
type_ = extract_type_arg(type_, 0)
if allow_none and value is None:
return
if type_ is None or type_ is NoneType:
assert value is None
return
origin = get_origin(type_) or type_
if is_list_type(type_):
return _assert_list_type(type_, value)
if origin == str:
assert isinstance(value, str)
elif origin == int:
assert isinstance(value, int)
elif origin == bool:
assert isinstance(value, bool)
elif origin == float:
assert isinstance(value, float)
elif origin == bytes:
assert isinstance(value, bytes)
elif origin == datetime:
assert isinstance(value, datetime)
elif origin == date:
assert isinstance(value, date)
elif origin == object:
# nothing to do here, the expected type is unknown
pass
elif origin == Literal:
assert value in get_args(type_)
elif origin == dict:
assert is_dict(value)
args = get_args(type_)
key_type = args[0]
items_type = args[1]
for key, item in value.items():
assert_matches_type(key_type, key, path=[*path, "<dict key>"])
assert_matches_type(items_type, item, path=[*path, "<dict item>"])
elif is_union_type(type_):
variants = get_args(type_)
try:
none_index = variants.index(type(None))
except ValueError:
pass
else:
# special case Optional[T] for better error messages
if len(variants) == 2:
if value is None:
# valid
return
return assert_matches_type(type_=variants[not none_index], value=value, path=path)
for i, variant in enumerate(variants):
try:
assert_matches_type(variant, value, path=[*path, f"variant {i}"])
return
except AssertionError:
traceback.print_exc()
continue
raise AssertionError("Did not match any variants")
elif issubclass(origin, BaseModel):
assert isinstance(value, type_)
assert assert_matches_model(type_, cast(Any, value), path=path)
elif inspect.isclass(origin) and origin.__name__ == "HttpxBinaryResponseContent":
assert value.__class__.__name__ == "HttpxBinaryResponseContent"
else:
assert None, f"Unhandled field type: {type_}"
def _assert_list_type(type_: type[object], value: object) -> None:
assert is_list(value)
inner_type = get_args(type_)[0]
for entry in value:
assert_type(inner_type, entry) # type: ignore
def rich_print_str(obj: object) -> str:
"""Like `rich.print()` but returns the string instead"""
buf = io.StringIO()
console = rich.console.Console(file=buf, width=120)
console.print(obj)
return buf.getvalue()
@contextlib.contextmanager
def update_env(**new_env: str | Omit) -> Iterator[None]:
old = os.environ.copy()
try:
for name, value in new_env.items():
if isinstance(value, Omit):
os.environ.pop(name, None)
else:
os.environ[name] = value
yield None
finally:
os.environ.clear()
os.environ.update(old)
|