File: _dataclass_impls.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (152 lines) | stat: -rw-r--r-- 6,374 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# Functions for synthesizing magic methods for JIT-compiled dataclasses
import os
from functools import partial
from torch._jit_internal import is_optional, FAKE_FILENAME_PREFIX
from torch._sources import ParsedDef, SourceContext
from typing import Callable, Dict, List
import ast
import dataclasses
import inspect
import sys

def _get_fake_filename(cls, method_name):
    return os.path.join(FAKE_FILENAME_PREFIX, cls.__name__, method_name)


def compose_fn(cls, name: str, body_lines: List[str], signature: str) -> ParsedDef:
    body = '\n'.join(f'  {b}' for b in body_lines)
    decl = f'def {name}{signature}:\n{body}'

    # Parse the function declaration
    try:
        py_ast = ast.parse(decl)
    except SyntaxError:
        # This should only happen if there's some unforeseeable change
        # in the dataclasses module that makes our synthesized code fail
        raise RuntimeError(
            f"TorchScript failed to synthesize dataclass method '{name}' for class '{cls.__name__}'. "
            "Please file a bug report at <https://github.com/pytorch/pytorch/issues>"
        )
    fake_filename = _get_fake_filename(cls, name)
    # Parse the function
    return ParsedDef(
        py_ast,
        ctx=SourceContext(
            source=decl,
            filename=fake_filename,
            file_lineno=0,
            leading_whitespace_len=0
        ),
        source=decl,
        filename=fake_filename,
        file_lineno=0
    )


def synthesize__init__(cls) -> ParsedDef:
    # Supporting default factories in the way that people expect would sort of require us to
    # allow compiling lambda functions, which is not currently supported.
    if any(field.default_factory is not dataclasses.MISSING for field in dataclasses.fields(cls)):
        raise NotImplementedError("Default factory initializers are not supported in TorchScript dataclasses")

    # Simply read off the generated __init__ signature from CPython's implementation. It'll be
    # almost correct except for InitVar annotations, which we need to handle specially.
    signature = inspect.signature(cls.__init__)

    # Handle InitVars if needed (only works on Python 3.8+, when a `type` attribute was added to InitVar);
    # see CPython commit here https://github.com/python/cpython/commit/01ee12ba35a333e8a6a25c4153c4a21838e9585c
    init_vars: List[str] = []
    if sys.version_info >= (3, 8):
        params = []
        for name, param in signature.parameters.items():
            ann = param.annotation

            if isinstance(ann, dataclasses.InitVar):
                # The TorchScript interpreter can't handle InitVar annotations, so we unwrap the underlying type here
                init_vars.append(name)
                params.append(param.replace(annotation=ann.type))   # type: ignore[attr-defined]
            else:
                params.append(param)

        signature = signature.replace(parameters=params)

    body = [
        # Assign all attributes to self
        f'self.{field.name} = {field.name}'
        for field in dataclasses.fields(cls)
        if field.init and field.name not in init_vars
    ]
    # Call user's impl of __post_init__ if it exists
    if hasattr(cls, '__post_init__'):
        body.append('self.__post_init__(' + ', '.join(init_vars) + ')')

    return compose_fn(cls, '__init__', body or ['pass'], signature=str(signature))

# This is a placeholder at the moment since the TorchScript interpreter doesn't call __repr__
def synthesize__repr__(cls) -> ParsedDef:
    return compose_fn(
        cls, '__repr__',
        [f"return '{cls.__name__}(" + ", ".join([
            f"{field.name}=self.{field.name}"
            for field in dataclasses.fields(cls) if field.repr
        ]) + ")'"],
        signature='(self) -> str'
    )

def synthesize__hash__(cls) -> ParsedDef:
    return compose_fn(
        cls, '__hash__',
        [
            # This is just a placeholder to prevent compilation from failing; this won't even get called at
            # all right now because the TorchScript interpreter doesn't call custom __hash__ implementations
            "raise NotImplementedError('__hash__ is not supported for dataclasses in TorchScript')"
        ],
        signature='(self) -> int'
    )

# Implementation for __eq__ and __ne__
def synthesize_equality(cls, name: str, converse: str) -> ParsedDef:
    return synthesize_comparison(cls, name, allow_eq=True, raise_on_none=False, inner=[
        f"if val1 {converse} val2: return False"
    ])

def synthesize_inequality(cls, name: str, op: str, allow_eq: bool) -> ParsedDef:
    return synthesize_comparison(cls, name, allow_eq, raise_on_none=True, inner=[
        f"if val1 {op} val2: return True",
        f"elif val2 {op} val1: return False",
    ])

def synthesize_comparison(cls, name: str, allow_eq: bool, raise_on_none: bool, inner: List[str]) -> ParsedDef:
    body = []
    for field in dataclasses.fields(cls):
        if not field.compare:
            continue

        body.extend([
            f"val1 = self.{field.name}",
            f"val2 = other.{field.name}",
        ])
        body.extend(
            inner if not is_optional(field.type) else [
                # Type refinement for optional fields; we need this to avoid type errors from the interpreter
                "if val1 is not None and val2 is not None:",
                *['  ' + line for line in inner],
                "elif (val1 is None) != (val2 is None):",
                f"  raise TypeError('Cannot compare {cls.__name__} with None')" if raise_on_none else "  return False"
            ]
        )

    body.append(f"return {allow_eq}")
    return compose_fn(cls, name, body, signature=f'(self, other: {cls.__name__}) -> bool')

DATACLASS_MAGIC_METHODS: Dict[str, Callable] = {
    "__init__": synthesize__init__,
    "__repr__": synthesize__repr__,
    "__hash__": synthesize__hash__,
    "__eq__": partial(synthesize_equality, name="__eq__", converse="!="),
    "__ne__": partial(synthesize_equality, name="__ne__", converse="=="),
    "__lt__": partial(synthesize_inequality, name="__lt__", op="<", allow_eq=False),
    "__le__": partial(synthesize_inequality, name="__le__", op="<", allow_eq=True),
    "__gt__": partial(synthesize_inequality, name="__gt__", op=">", allow_eq=False),
    "__ge__": partial(synthesize_inequality, name="__ge__", op=">", allow_eq=True),
}