File: operator.py

package info (click to toggle)
pytorch 2.6.0%2Bdfsg-8
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 161,672 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (109 lines) | stat: -rw-r--r-- 2,940 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""
Python polyfills for operator
"""

from __future__ import annotations

import operator
from typing import Any, Callable, overload, TypeVar
from typing_extensions import TypeVarTuple, Unpack

from ..decorators import substitute_in_graph


# Most unary and binary operators are handled by BuiltinVariable (e.g., `pos`, `add`)
__all__ = ["attrgetter", "itemgetter", "methodcaller"]


_T = TypeVar("_T")
_T1 = TypeVar("_T1")
_T2 = TypeVar("_T2")
_Ts = TypeVarTuple("_Ts")
_U = TypeVar("_U")
_U1 = TypeVar("_U1")
_U2 = TypeVar("_U2")
_Us = TypeVarTuple("_Us")


@overload
def attrgetter(attr: str, /) -> Callable[[Any], _U]:
    ...


@overload
def attrgetter(
    attr1: str, attr2: str, /, *attrs: str
) -> Callable[[Any], tuple[_U1, _U2, Unpack[_Us]]]:
    ...


# Reference: https://docs.python.org/3/library/operator.html#operator.attrgetter
@substitute_in_graph(operator.attrgetter, is_embedded_type=True)  # type: ignore[arg-type,misc]
def attrgetter(*attrs: str) -> Callable[[Any], Any | tuple[Any, ...]]:
    if len(attrs) == 0:
        raise TypeError("attrgetter expected 1 argument, got 0")

    if any(not isinstance(attr, str) for attr in attrs):
        raise TypeError("attribute name must be a string")

    def resolve_attr(obj: Any, attr: str) -> Any:
        for name in attr.split("."):
            obj = getattr(obj, name)
        return obj

    if len(attrs) == 1:
        attr = attrs[0]

        def getter(obj: Any) -> Any:
            return resolve_attr(obj, attr)

    else:

        def getter(obj: Any) -> tuple[Any, ...]:  # type: ignore[misc]
            return tuple(resolve_attr(obj, attr) for attr in attrs)

    return getter


@overload
def itemgetter(item: _T, /) -> Callable[[Any], _U]:
    ...


@overload
def itemgetter(
    item1: _T1, item2: _T2, /, *items: Unpack[_Ts]
) -> Callable[[Any], tuple[_U1, _U2, Unpack[_Us]]]:
    ...


# Reference: https://docs.python.org/3/library/operator.html#operator.itemgetter
@substitute_in_graph(operator.itemgetter, is_embedded_type=True)  # type: ignore[arg-type,misc]
def itemgetter(*items: Any) -> Callable[[Any], Any | tuple[Any, ...]]:
    if len(items) == 0:
        raise TypeError("itemgetter expected 1 argument, got 0")

    if len(items) == 1:
        item = items[0]

        def getter(obj: Any) -> Any:
            return obj[item]

    else:

        def getter(obj: Any) -> tuple[Any, ...]:  # type: ignore[misc]
            return tuple(obj[item] for item in items)

    return getter


# Reference: https://docs.python.org/3/library/operator.html#operator.methodcaller
@substitute_in_graph(operator.methodcaller, is_embedded_type=True)  # type: ignore[arg-type]
def methodcaller(name: str, /, *args: Any, **kwargs: Any) -> Callable[[Any], Any]:
    if not isinstance(name, str):
        raise TypeError("method name must be a string")

    def caller(obj: Any) -> Any:
        return getattr(obj, name)(*args, **kwargs)

    return caller