1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
|
import torch
from torch.utils._pytree import tree_flatten, tree_map
from torch.fx.operator_schemas import normalize_function
from torch.testing._internal.jit_utils import clone_inputs
from torch.utils._python_dispatch import TorchDispatchMode
from itertools import combinations
from collections import namedtuple
from copy import deepcopy
# Named Tuples used within SchemaCheckMode
Mutation = namedtuple('Mutation', ['op_name', 'arg_name'])
Aliasing = namedtuple('Aliasing', ['op_name', 'arg_name', 'output_number'])
# Simplified naming for C++ classes
SchemaArgument = torch._C._SchemaArgument
SchemaArgType = torch._C._SchemaArgType
SchemaInfo = torch._C._SchemaInfo
# This TorchDispatchMode Subclass is used to verify op schemas
# This TorchDispatchMode Scubclass currently:
# - Records the called ops
# - Checks for mutations on all inputs
# - Checks for aliasing on all inputs
class SchemaCheckMode(TorchDispatchMode):
def __init__(self):
# Information recorded for testing purposes. For example:
# - incorrect schemas
# - overly conservative schemas
self.ops = []
self.mutated = []
self.aliasing = []
def reset_cache(self):
self.ops.clear()
self.mutated.clear()
self.aliasing.clear()
def display_ops(self):
print(*self.ops, sep=",")
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
def has_mutated(before, after, md):
are_tensors = type(before) == torch.Tensor and type(after) == torch.Tensor
if are_tensors and before.layout != torch.sparse_csr and after.layout != torch.sparse_csr:
return not (
before.size() == after.size() and
torch.allclose(before, after, equal_nan=True) and
md[0] == after.stride() and
md[1] == after.storage()._cdata
)
return False
def has_aliased(lhs, rhs):
try:
return torch._C._overlaps(lhs, rhs)
except Exception as exception:
if str(exception).startswith("Cannot inspect value of type "):
return False
else:
raise exception
def standardize_name(name):
return name if name != "self" else "input"
def unwrap(e):
if isinstance(e, torch.Tensor) and not type(e) == torch.Tensor:
try:
return e.elem
except AttributeError as t:
return e
return e
def parse_metadata(e):
if isinstance(e, torch.Tensor):
if not type(e) == torch.Tensor:
try:
current = e.elem
return (deepcopy(current.stride()), current.storage()._cdata)
except AttributeError as t:
return None
# Sparse CSR tensors do not have strides or storage
elif (e.layout != torch.sparse_csr):
return (deepcopy(e.stride()), e.storage()._cdata)
return None
self.ops.append(func._schema.name)
# Clone and process arguments and outputs
pre_arguments = normalize_function(
func,
args,
kwargs,
normalize_to_only_use_kwargs=True
).kwargs
c_p_args = dict(zip(pre_arguments.keys(), clone_inputs(pre_arguments.values())))
cloned_arguments = {name : tree_map(unwrap, c_p_args.get(name)) for name in c_p_args}
cloned_metadata = {name : tree_map(parse_metadata, tree_flatten(pre_arguments.get(name))[0]) for name in pre_arguments}
out = func(*args, **kwargs)
arguments = {name : tree_map(unwrap, pre_arguments.get(name)) for name in pre_arguments}
tuple_out = out if isinstance(out, tuple) else (out, )
tuple_out = tree_map(unwrap, tuple_out)
schema_info = SchemaInfo(func._schema)
schema_info.add_argument_values(pre_arguments)
# Process arguments with outputs
for i in range(len(func._schema.arguments)):
arg = func._schema.arguments[i]
name = standardize_name(arg.name)
if arguments.get(name) is not None:
before = cloned_arguments.get(name)
md = cloned_metadata.get(name)
after = arguments.get(name)
for j in range(len(tuple_out)):
# aten::_unsafe_view is intended to have incorrect aliasing notation (hence unsafe)
if has_aliased(tuple_out[j], after) and func._schema.name != 'aten::_unsafe_view':
if not schema_info.may_contain_alias(
SchemaArgument(SchemaArgType.output, j),
SchemaArgument(SchemaArgType.input, i)):
raise RuntimeError(f'Argument {name} is not defined to alias output but was aliasing')
else:
self.aliasing.append(Aliasing(func._schema.name, name, f"output_{j}"))
if any(has_mutated(a, b, c) for a, b, c in zip(tree_flatten(before)[0], tree_flatten(after)[0], md)):
if not schema_info.is_mutable(SchemaArgument(SchemaArgType.input, i)):
raise RuntimeError(f"Argument {name} is not defined as mutable but was mutated")
else:
self.mutated.append(Mutation(func._schema.name, name))
# Aliasing between outputs
for i, j in combinations(range(len(func._schema.returns)), 2):
if has_aliased(tuple_out[i], tuple_out[j]):
if not schema_info.may_contain_alias(
SchemaArgument(SchemaArgType.output, i),
SchemaArgument(SchemaArgType.output, j)):
raise RuntimeError(f'Outputs {i} and {j} alias unexpectedly')
return out
|