1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
|
# Functions for synthesizing magic methods for JIT-compiled dataclasses
import os
from functools import partial
from torch._jit_internal import is_optional, FAKE_FILENAME_PREFIX
from torch._sources import ParsedDef, SourceContext
from typing import Callable, Dict, List
import ast
import dataclasses
import inspect
import sys
def _get_fake_filename(cls, method_name):
return os.path.join(FAKE_FILENAME_PREFIX, cls.__name__, method_name)
def compose_fn(cls, name: str, body_lines: List[str], signature: str) -> ParsedDef:
body = '\n'.join(f' {b}' for b in body_lines)
decl = f'def {name}{signature}:\n{body}'
# Parse the function declaration
try:
py_ast = ast.parse(decl)
except SyntaxError:
# This should only happen if there's some unforeseeable change
# in the dataclasses module that makes our synthesized code fail
raise RuntimeError(
f"TorchScript failed to synthesize dataclass method '{name}' for class '{cls.__name__}'. "
"Please file a bug report at <https://github.com/pytorch/pytorch/issues>"
)
fake_filename = _get_fake_filename(cls, name)
# Parse the function
return ParsedDef(
py_ast,
ctx=SourceContext(
source=decl,
filename=fake_filename,
file_lineno=0,
leading_whitespace_len=0
),
source=decl,
filename=fake_filename,
file_lineno=0
)
def synthesize__init__(cls) -> ParsedDef:
# Supporting default factories in the way that people expect would sort of require us to
# allow compiling lambda functions, which is not currently supported.
if any(field.default_factory is not dataclasses.MISSING for field in dataclasses.fields(cls)):
raise NotImplementedError("Default factory initializers are not supported in TorchScript dataclasses")
# Simply read off the generated __init__ signature from CPython's implementation. It'll be
# almost correct except for InitVar annotations, which we need to handle specially.
signature = inspect.signature(cls.__init__)
# Handle InitVars if needed (only works on Python 3.8+, when a `type` attribute was added to InitVar);
# see CPython commit here https://github.com/python/cpython/commit/01ee12ba35a333e8a6a25c4153c4a21838e9585c
init_vars: List[str] = []
if sys.version_info >= (3, 8):
params = []
for name, param in signature.parameters.items():
ann = param.annotation
if isinstance(ann, dataclasses.InitVar):
# The TorchScript interpreter can't handle InitVar annotations, so we unwrap the underlying type here
init_vars.append(name)
params.append(param.replace(annotation=ann.type)) # type: ignore[attr-defined]
else:
params.append(param)
signature = signature.replace(parameters=params)
body = [
# Assign all attributes to self
f'self.{field.name} = {field.name}'
for field in dataclasses.fields(cls)
if field.init and field.name not in init_vars
]
# Call user's impl of __post_init__ if it exists
if hasattr(cls, '__post_init__'):
body.append('self.__post_init__(' + ', '.join(init_vars) + ')')
return compose_fn(cls, '__init__', body or ['pass'], signature=str(signature))
# This is a placeholder at the moment since the TorchScript interpreter doesn't call __repr__
def synthesize__repr__(cls) -> ParsedDef:
return compose_fn(
cls, '__repr__',
[f"return '{cls.__name__}(" + ", ".join([
f"{field.name}=self.{field.name}"
for field in dataclasses.fields(cls) if field.repr
]) + ")'"],
signature='(self) -> str'
)
def synthesize__hash__(cls) -> ParsedDef:
return compose_fn(
cls, '__hash__',
[
# This is just a placeholder to prevent compilation from failing; this won't even get called at
# all right now because the TorchScript interpreter doesn't call custom __hash__ implementations
"raise NotImplementedError('__hash__ is not supported for dataclasses in TorchScript')"
],
signature='(self) -> int'
)
# Implementation for __eq__ and __ne__
def synthesize_equality(cls, name: str, converse: str) -> ParsedDef:
return synthesize_comparison(cls, name, allow_eq=True, raise_on_none=False, inner=[
f"if val1 {converse} val2: return False"
])
def synthesize_inequality(cls, name: str, op: str, allow_eq: bool) -> ParsedDef:
return synthesize_comparison(cls, name, allow_eq, raise_on_none=True, inner=[
f"if val1 {op} val2: return True",
f"elif val2 {op} val1: return False",
])
def synthesize_comparison(cls, name: str, allow_eq: bool, raise_on_none: bool, inner: List[str]) -> ParsedDef:
body = []
for field in dataclasses.fields(cls):
if not field.compare:
continue
body.extend([
f"val1 = self.{field.name}",
f"val2 = other.{field.name}",
])
body.extend(
inner if not is_optional(field.type) else [
# Type refinement for optional fields; we need this to avoid type errors from the interpreter
"if val1 is not None and val2 is not None:",
*[' ' + line for line in inner],
"elif (val1 is None) != (val2 is None):",
f" raise TypeError('Cannot compare {cls.__name__} with None')" if raise_on_none else " return False"
]
)
body.append(f"return {allow_eq}")
return compose_fn(cls, name, body, signature=f'(self, other: {cls.__name__}) -> bool')
DATACLASS_MAGIC_METHODS: Dict[str, Callable] = {
"__init__": synthesize__init__,
"__repr__": synthesize__repr__,
"__hash__": synthesize__hash__,
"__eq__": partial(synthesize_equality, name="__eq__", converse="!="),
"__ne__": partial(synthesize_equality, name="__ne__", converse="=="),
"__lt__": partial(synthesize_inequality, name="__lt__", op="<", allow_eq=False),
"__le__": partial(synthesize_inequality, name="__le__", op="<", allow_eq=True),
"__gt__": partial(synthesize_inequality, name="__gt__", op=">", allow_eq=False),
"__ge__": partial(synthesize_inequality, name="__ge__", op=">", allow_eq=True),
}
|