1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
|
# mypy: ignore-errors
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._utils import wrapper_set_seed
import torch.utils._pytree as pytree
def make_fx_check(
func,
args,
kwargs,
tracing_mode,
assert_close=torch.testing.assert_close,
randomize_data=False,
):
f, *new_args = handle_sizes_for_dynamic_shapes(func, args, kwargs)
def run(f, *args, **kwargs):
return wrapper_set_seed(f, *args, **kwargs)
traced_f = make_fx(f, tracing_mode=tracing_mode)(*new_args)
msg = (
"op(*args, **kwargs) and make_fx(op)(*args, **kwargs) produced different "
"values. This could mean that your abstract impls (meta/FakeTensor impls) "
"are incorrect, that your operator is not completely traceable (e.g., "
"it relies on some global state), or that there is a bug in make_fx. "
"Note that if you passed a python function (and not an operator) to "
"make_fx_check, it is still possible that the python function will still "
"work with torch.compile because it handles capturing pieces of "
"your python code to compile."
)
# Randomize the data and run the traced graph with it, to catch bugs
# where we may have baked in Tensor data into the trace.
# This is not guaranteed to succeed, because `f` might have preconditions
# on the values of the inputs, so we just ignore if we used
# random data and it fails.
if randomize_data:
new_args = randomize(new_args)
try:
expected = run(f, *new_args)
except Exception:
if randomize_data:
return
raise
result = run(traced_f, *new_args)
assert_close(result, expected, msg=msg)
# Arguably we should make make_fx promote torch.Size() objects to symbolic shapes.
# Absent that, here is our strategy:
#
# If any argument is a torch.Size(), maybe get dynamic shapes for it by:
# - Create a temporary Tensor whose size is the torch.Size() we want. Note that
# we use an expanded Tensor as we cannot pass "meta" Tensors to make_fx.
# - Pass it to make_fx such that it is is converted to a proxy Tensor
# - Unpack the size in the wrapper to get a torch.Size with dynamic shapes (in
# symbolic mode, a no-op otherwise)
def handle_sizes_for_dynamic_shapes(func, args, kwargs):
def f(args, kwargs, extra_args, extra_kwargs):
if extra_args:
for i, t in extra_args:
args[i] = t.size()
if extra_kwargs:
for k, t in extra_kwargs.items():
kwargs[k] = t.size()
return func(*args, **kwargs)
extra_args = []
extra_kwargs = {}
for i, arg in enumerate(args):
if isinstance(arg, torch.Size):
extra_args.append((i, torch.empty(arg, device="cpu")))
for key, value in kwargs.items():
if isinstance(value, torch.Size):
extra_kwargs[key] = torch.empty(value, device="cpu")
return f, args, kwargs, extra_args, extra_kwargs
def randomize(args):
def transform(x):
if not x.dtype.is_floating_point:
return x
return x.detach().clone().uniform_(0, 1).requires_grad_(x.requires_grad)
return pytree.tree_map_only(torch.Tensor, transform, args)
|