1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
|
# mypy: ignore-errors
import torch
import torch.utils._pytree as pytree
from torch.testing._utils import wrapper_set_seed
from functorch.compile import compiled_function, min_cut_rematerialization_partition, nop
from .make_fx import randomize
import re
class assert_raises_regex:
def __init__(self, exception_cls, regex):
self.exception_cls = exception_cls
self.regex = regex
def __enter__(self):
pass
def __exit__(self, exc_type, exc_val, traceback):
if exc_type == self.exception_cls:
msg = str(exc_val)
if not re.search(self.regex, msg):
raise AssertionError(
f"Expected exception to match regex. regex: {self.regex}, exception: {msg}")
return True # Squashes the exception
if exc_type is not None:
raise AssertionError(
f"Expected {self.exception_cls} to be raised, instead got exception {exc_type}")
raise AssertionError("Expected exception to be raised but none was")
def aot_autograd_check(
func,
args,
kwargs,
dynamic,
assert_raises_regex_fn=assert_raises_regex,
assert_equals_fn=torch.testing._comparison.assert_close,
check_gradients=True,
try_check_data_specialization=False,
skip_correctness_check=False):
"""Compares func(*args, **kwargs) in eager-mode to under AOTAutograd.
Compares outputs and (if check_gradients=True) gradients produced by
AOTAutograd against eager-mode PyTorch.
We assume that func(*args, **kwargs) succeeds in eager-mode PyTorch.
"""
flat_args, args_spec = pytree.tree_flatten((args, kwargs))
args = [arg for arg in flat_args if isinstance(arg, torch.Tensor)]
# We construct a new function that only accepts Tensors as inputs
def func_no_tensors(args):
reconstructed_flat_args = []
args = iter(args)
for v in flat_args:
if isinstance(v, torch.Tensor):
reconstructed_flat_args.append(next(args))
else:
reconstructed_flat_args.append(v)
c_args, c_kwargs = pytree.tree_unflatten(reconstructed_flat_args, args_spec)
return func(*c_args, **c_kwargs)
compiled_f = compiled_function(
func_no_tensors, nop, nop, dynamic=dynamic, partition_fn=min_cut_rematerialization_partition)
out = wrapper_set_seed(func_no_tensors, args)
if check_gradients == "auto":
any_tensor_requires_grad = pytree.tree_any_only(torch.Tensor, lambda x: x.requires_grad, args)
any_output_requires_grad = pytree.tree_any_only(torch.Tensor, lambda x: x.requires_grad, out)
check_gradients = any_tensor_requires_grad and any_output_requires_grad
if not check_gradients:
compiled_out = wrapper_set_seed(compiled_f, args)
if not skip_correctness_check:
assert_equals_fn(compiled_out, out, msg=outputs_msg)
return
_test_aot_autograd_forwards_backwards_helper(
func_no_tensors, compiled_f, args, assert_raises_regex_fn, assert_equals_fn,
try_check_data_specialization, skip_correctness_check)
outputs_msg = (
"Outputs of the operator are different in eager-mode PyTorch vs "
"AOTAutograd. This means the operator will have incorrect output "
"underneath torch.compile. This could be because the operator's "
"implementation not traceable or that there is a bug in AOTAutograd."
)
def _test_aot_autograd_forwards_backwards_helper(
f, compiled_f, args, assert_raises_regex_fn, assert_equals_fn,
try_check_data_specialization, skip_correctness_check=False):
# Verify grads are equal between compiled and non-compiled versions of f.
def call_forwards_backwards(f, args):
flat_args = pytree.arg_tree_leaves(*args)
diff_args = [arg for arg in flat_args if isinstance(arg, torch.Tensor) and
arg.requires_grad]
out = wrapper_set_seed(f, args)
flat_out = pytree.tree_leaves(out)
sm = 0
for i in flat_out:
if isinstance(i, torch.Tensor):
# We need to call .abs() because it is possible that the output of the
# operator is a complex Tensor and autograd will yell at autograd.grad
# on a complex Tensor unless we manually provide the grad_output flag.
sm += i.sum().abs()
assert isinstance(sm, torch.Tensor)
return out, torch.autograd.grad(sm, diff_args, allow_unused=True)
def check(args, ignore_failure=False):
try:
orig_out, orig_grad = call_forwards_backwards(f, args)
except Exception:
if ignore_failure:
return
raise
# See https://github.com/pytorch/pytorch/pull/98960#issuecomment-1505962215
tensor_args = [x for x in pytree.tree_flatten(args)[0] if isinstance(x, torch.Tensor)]
any_non_leaves = any(x.grad_fn is not None for x in tensor_args)
if all(x is None for x in orig_grad) and any_non_leaves:
with assert_raises_regex_fn(RuntimeError, 'does not require grad and does not have a grad_fn'):
call_forwards_backwards(compiled_f, args)
return
msg = (
"Gradients of the operator are different in eager-mode PyTorch vs "
"AOTAutograd. This means the operator will have incorrect gradients "
"underneath torch.compile. This could be because the operator's "
"backward is incorrectly registered or not traceable or that there "
"is a bug in AOTAutograd."
)
compiled_out, compiled_grad = call_forwards_backwards(compiled_f, args)
if not skip_correctness_check:
assert_equals_fn(compiled_out, orig_out, msg=outputs_msg)
assert_equals_fn(compiled_grad, orig_grad, msg=msg)
check(args, ignore_failure=False)
# Randomize the data and run the traced graph with it, to catch bugs
# where we may have baked in Tensor data into the trace.
# This is not guaranteed to succeed, because `f` might have preconditions
# on the values of the inputs, so we just ignore if this test fails.
if try_check_data_specialization:
args = randomize(args)
check(args, ignore_failure=True)
|