1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301
|
# Owner(s): ["module: unknown"]
from functools import partial
from textwrap import dedent
import torch
from torch.testing import FileCheck
from torch.testing._internal.common_utils import \
(run_tests, IS_SANDCASTLE, clone_input_helper, first_sample, skipIfSlowGradcheckEnv)
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_device_type import instantiate_device_type_tests, ops, OpDTypes
from torch.testing._internal.common_jit import JitCommonTestCase, check_against_reference
from torch.testing._internal.jit_metaprogramming_utils import create_script_fn, create_traced_fn, check_alias_annotation
from torch.testing._internal.jit_utils import disable_autodiff_subgraph_inlining, is_lambda
# TODO: fixme https://github.com/pytorch/pytorch/issues/68972
torch.set_default_dtype(torch.float32)
# variant testing is only done with torch.float and torch.cfloat to avoid
# excessive test times and maximize signal to noise ratio
_variant_ops = partial(ops, dtypes=OpDTypes.supported,
allowed_dtypes=(torch.float, torch.cfloat))
# Tests operators for consistency between JIT and eager, also checks
# correctness of JIT specific alias schemas and intended
# autodifferentiation behavior.
# Inherits from JitCommonTestCase instead of TestCase directly to share
# functionality with original test_jit.py method operator tests
@skipIfSlowGradcheckEnv
class TestJit(JitCommonTestCase):
exact_dtype = True
# Tests that the forward and backward passes of operations produce the
# same values for the cross-product of op variants (function, method, inplace)
# and runtimes (eager, traced, scripted).
# TODO WARNING: inplace x {traced, scripted} not currently tested
@_variant_ops(op_db)
def test_variant_consistency_jit(self, device, dtype, op):
_requires_grad = (dtype in op.supported_backward_dtypes(torch.device(device).type))
include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad, include_conjugated_inputs=include_conjugated_inputs)
# Acquires variants to test
func = op.get_op()
method = op.get_method()
variants = {
# TODO: inplace tests currently fail, fix and add inplace variant
'function': func, 'method': method,
}
# scripting strips the torch.ops prefix from these operators
# incorrectly; don't bother testing this case. Count this
# as "testing"
if isinstance(func, torch._ops.OpOverload):
self.skipTest("variant consistency doesn't work on torch.ops")
# TODO: find better way to standardize on op registration itself..
has_fake_function = op.name in ["resize_", 'resize_as_']
if has_fake_function:
variants = {'method': getattr(torch.Tensor, op.name)}
samples = op.sample_inputs(device, dtype, requires_grad=False)
tested = False
for sample in samples:
# Test traced and scripted consistency
for func_type, variant in variants.items():
if variant is None:
continue
# scripting and check_alias_analysis do not work with lambdas
# lambdas are typically used as a way to simulate methods without
# functional variants, so rely on the other variant for testing
# for now
if is_lambda(variant):
continue
tested = True
try:
self.indiv_variant_test_jit(device, dtype, op, sample, func_type, variant, has_fake_function)
except Exception as e:
variant_error_info = dedent(f"""
Error testing {op.name} {func_type} variant
with dtype: {dtype}
with inputs {sample}:
""")
raise Exception(variant_error_info) from e
assert tested, "JIT Test does not execute any logic"
def indiv_variant_test_jit(self, device, dtype, op, sample, func_type, variant, has_fake_function):
_requires_grad = (dtype in op.supported_backward_dtypes(torch.device(device).type))
support_script = op.supports_scripting
# Create accessor for script function variant
name = op.name + '_' if func_type == 'inplace' else op.name
# run with disable_autodiff_subgraph_inlining(True) to test
# autodiff support. Context manager forces the graph to contain
# DifferentiableGraph nodes if they are present
with disable_autodiff_subgraph_inlining():
# Check scripted forward, grad, and grad grad
if support_script:
script_fn = create_script_fn(self, name, func_type)
def out_fn(output):
# Processes the output for autograd
if sample.output_process_fn_grad is not None:
return sample.output_process_fn_grad(output)
return output
def get_sample():
return clone_input_helper(sample.input) if op.name[-1] == '_' else sample.input
if support_script:
check_against_reference(self,
script_fn,
op.get_op(),
out_fn,
(get_sample(),) + sample.args,
sample.kwargs,
no_grad=not _requires_grad, no_gradgrad=not op.supports_gradgrad)
# Check traced forward, grad, and grad grad
# TODO: fix tracing here
supports_tracing = op.supports_tracing and not has_fake_function
if op.assert_jit_shape_analysis:
self.assertTrue(supports_tracing)
if supports_tracing:
traced_fn = create_traced_fn(self, variant)
check_against_reference(self,
traced_fn,
op.get_op(),
out_fn,
(get_sample(),) + sample.args,
sample.kwargs,
no_grad=not _requires_grad, no_gradgrad=not op.supports_gradgrad)
# Check alias annotation schema for correctness (make
# sure inputs that aren't supposed to be modified aren't)
# Note: only runs in float32 because schema isn't affected by dtype,
# so running it on all dtypes is would be excessive
if dtype == torch.float32:
# TODO: no reason why we cant run this with tracing graph
if support_script and op.name != "rsub":
check_alias_annotation(name, (get_sample(),) + sample.args, sample.kwargs,
func_type=func_type, aten_name=op.aten_name)
# TODO: use script graph as well
checked_shape_analysis = False
if supports_tracing:
out = variant(get_sample(), *sample.args, **sample.kwargs)
# right now, tuple of outputs and tensor output supported
# TODO: list of tensor outputs
tuple_of_tensors = isinstance(out, tuple) and all([isinstance(elem, torch.Tensor) for elem in out])
if isinstance(out, torch.Tensor) or tuple_of_tensors:
if tuple_of_tensors:
sizes = [elem.size() for elem in out]
else:
sizes = out.size()
self.checkShapeAnalysis(sizes, traced_fn.graph, op.assert_jit_shape_analysis)
checked_shape_analysis = True
if op.assert_jit_shape_analysis:
self.assertTrue(checked_shape_analysis)
# Check autodifferentiation of nodes for traced and scripted graphs, only need to check once per sample
if dtype is torch.float32:
# Sandcastle doesn't fuse nodes
if IS_SANDCASTLE:
# fusible nodes are expected to be found in FusionGroups in the DifferentiableGraphs
nonfusible_nodes = op.autodiff_nonfusible_nodes + op.autodiff_fusible_nodes
fusible_nodes = []
else:
nonfusible_nodes = op.autodiff_nonfusible_nodes
fusible_nodes = op.autodiff_fusible_nodes
if supports_tracing:
self.assertAutodiffNode(traced_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes)
if support_script:
self.assertAutodiffNode(script_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes)
# alias testing is only done with torch.float for the same reason
_alias_ops = partial(ops, dtypes=OpDTypes.supported,
allowed_dtypes=(torch.float,))
@_alias_ops((op for op in op_db if op.aliases))
def test_jit_alias_remapping(self, device, dtype, op):
# NOTE: only tests on first sample
samples = op.sample_inputs(device, dtype, requires_grad=True)
sample = first_sample(self, samples)
# [Scripting Data Preparation]
# Prepare data for test scripting
# Below we prepare strings of args/kwargs with and without type annotations.
# These strings are inserted into function template strings which is then torch scripted.
# - args string is ["t0"] corresponding to the "input" tensor required by the op
# - args_kw is the value of args and strings of kwargs used to call the op (without type annotations), for example,
# ["to", "1.0", "(1,)", "True", "tensor(1.0)"] -> def fn(t0): return variant(t0, 1.0, (1,), True, tensor(1.0))
args = ["t0"]
def quote_strs(v):
if isinstance(v, str):
return f"'{v}'"
return str(v)
args_kw = args + \
[f"{v}" for v in sample.args] + \
[f"{k}={quote_strs(v)}" for k, v in sample.kwargs.items()]
# Prepare data for test tracing
sample_args_kwargs = ()
if len(sample.args) > 0:
sample_args_kwargs += (sample.args, )
if len(sample.kwargs) > 0:
sample_args_kwargs += (sample.kwargs, )
original_name = op.aten_name
original_name_inplace = original_name + "_"
expected_dtype = op(sample.input, *sample.args, **sample.kwargs).dtype
for a_op in op.aliases:
inplace = a_op.inplace_variant
method_or_inplace = [a_op.inplace_variant, a_op.method_variant]
variants = (v for v in (a_op.op, a_op.method_variant, a_op.inplace_variant) if v is not None)
# Test scripting:
for variant in variants:
variant_name = variant.__name__
op_name = original_name_inplace if variant is inplace else original_name
if variant in method_or_inplace:
fn_template = '''
def _fn(t0{c}):
return t0.{alias_name}({args_kw})
'''
# remove the first input tensor
script = fn_template.format(
c=", " if len(args_kw[1:]) > 1 else "",
args_kw=", ".join(args_kw[1:]),
alias_name=variant_name,
)
else:
fn_template = '''
def _fn({args}):
return variant({args_kw})
'''
script = fn_template.format(
args=", ".join(args),
args_kw=", ".join(args_kw),
)
# Required to avoid undefined value: tensor error in JIT
# compilation of the function template
script = script.replace("tensor(", "torch.tensor(")
scripted = torch.jit.CompilationUnit(script)._fn
if (variant is inplace and not torch.can_cast(expected_dtype, dtype)):
try:
inp = clone_input_helper(sample.input)
scripted(inp)
except Exception as e:
continue
self.fail("Inplace operation on integer tensor that should be promoted to float didn't fail!")
inp = clone_input_helper(sample.input)
scripted(inp)
inp = clone_input_helper(sample.input)
graph = scripted.graph_for(inp)
FileCheck().check(op.aten_name).check_not(variant_name).run(graph)
# Test tracing:
for variant in variants:
variant_name = variant.__name__
op_name = original_name_inplace if variant is inplace else original_name
def _fn(*sample_args, **sample_kwargs):
return variant(*sample_args, **sample_kwargs)
inp = (clone_input_helper(sample.input),) + sample_args_kwargs
traced = torch.jit.trace(_fn, *inp)
inp = (clone_input_helper(sample.input),) + sample_args_kwargs
traced(*inp)
inp = (clone_input_helper(sample.input),) + sample_args_kwargs
graph = traced.graph_for(*inp)
FileCheck().check(op_name).check_not(variant_name).run(graph)
instantiate_device_type_tests(TestJit, globals())
if __name__ == '__main__':
run_tests()
|