File: utils.py

package info (click to toggle)
python-openai 1.99.9-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 6,784 kB
  • sloc: python: 57,274; sh: 140; makefile: 7
file content (66 lines) | stat: -rw-r--r-- 1,993 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from __future__ import annotations

import inspect
from typing import Any, Iterable
from typing_extensions import TypeAlias

import pytest
import pydantic

from ..utils import rich_print_str

ReprArgs: TypeAlias = "Iterable[tuple[str | None, Any]]"


def print_obj(obj: object, monkeypatch: pytest.MonkeyPatch) -> str:
    """Pretty print an object to a string"""

    # monkeypatch pydantic model printing so that model fields
    # are always printed in the same order so we can reliably
    # use this for snapshot tests
    original_repr = pydantic.BaseModel.__repr_args__

    def __repr_args__(self: pydantic.BaseModel) -> ReprArgs:
        return sorted(original_repr(self), key=lambda arg: arg[0] or arg)

    with monkeypatch.context() as m:
        m.setattr(pydantic.BaseModel, "__repr_args__", __repr_args__)

        string = rich_print_str(obj)

        # we remove all `fn_name.<locals>.` occurrences
        # so that we can share the same snapshots between
        # pydantic v1 and pydantic v2 as their output for
        # generic models differs, e.g.
        #
        # v2: `ParsedChatCompletion[test_parse_pydantic_model.<locals>.Location]`
        # v1: `ParsedChatCompletion[Location]`
        return clear_locals(string, stacklevel=2)


def get_caller_name(*, stacklevel: int = 1) -> str:
    frame = inspect.currentframe()
    assert frame is not None

    for i in range(stacklevel):
        frame = frame.f_back
        assert frame is not None, f"no {i}th frame"

    return frame.f_code.co_name


def clear_locals(string: str, *, stacklevel: int) -> str:
    caller = get_caller_name(stacklevel=stacklevel + 1)
    return string.replace(f"{caller}.<locals>.", "")


def get_snapshot_value(snapshot: Any) -> Any:
    if not hasattr(snapshot, "_old_value"):
        return snapshot

    old = snapshot._old_value
    if not hasattr(old, "value"):
        return old

    loader = getattr(old.value, "_load_value", None)
    return loader() if loader else old.value