1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
|
from __future__ import annotations
import inspect
from typing import Any, Iterable
from typing_extensions import TypeAlias
import pytest
import pydantic
from ..utils import rich_print_str
ReprArgs: TypeAlias = "Iterable[tuple[str | None, Any]]"
def print_obj(obj: object, monkeypatch: pytest.MonkeyPatch) -> str:
"""Pretty print an object to a string"""
# monkeypatch pydantic model printing so that model fields
# are always printed in the same order so we can reliably
# use this for snapshot tests
original_repr = pydantic.BaseModel.__repr_args__
def __repr_args__(self: pydantic.BaseModel) -> ReprArgs:
return sorted(original_repr(self), key=lambda arg: arg[0] or arg)
with monkeypatch.context() as m:
m.setattr(pydantic.BaseModel, "__repr_args__", __repr_args__)
string = rich_print_str(obj)
# we remove all `fn_name.<locals>.` occurrences
# so that we can share the same snapshots between
# pydantic v1 and pydantic v2 as their output for
# generic models differs, e.g.
#
# v2: `ParsedChatCompletion[test_parse_pydantic_model.<locals>.Location]`
# v1: `ParsedChatCompletion[Location]`
return clear_locals(string, stacklevel=2)
def get_caller_name(*, stacklevel: int = 1) -> str:
frame = inspect.currentframe()
assert frame is not None
for i in range(stacklevel):
frame = frame.f_back
assert frame is not None, f"no {i}th frame"
return frame.f_code.co_name
def clear_locals(string: str, *, stacklevel: int) -> str:
caller = get_caller_name(stacklevel=stacklevel + 1)
return string.replace(f"{caller}.<locals>.", "")
def get_snapshot_value(snapshot: Any) -> Any:
if not hasattr(snapshot, "_old_value"):
return snapshot
old = snapshot._old_value
if not hasattr(old, "value"):
return old
loader = getattr(old.value, "_load_value", None)
return loader() if loader else old.value
|