1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
|
# Owner(s): ["oncall: export"]
# flake8: noqa
import copy
import io
import unittest
import torch
import torch._dynamo as torchdynamo
import torch.utils._pytree as pytree
from torch._dynamo.test_case import TestCase
from torch.export import export, load, save
from torch.export._trace import _export
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
ops,
)
from torch.testing._internal.common_utils import (
IS_WINDOWS,
run_tests,
TestCase as TorchTestCase,
)
from torch.testing._internal.hop_db import (
hop_db,
hop_that_doesnt_have_opinfo_test_allowlist,
)
hop_tests = []
for op_info in hop_db:
op_info_hop_name = op_info.name
if op_info_hop_name in hop_that_doesnt_have_opinfo_test_allowlist:
continue
hop_tests.append(op_info)
class TestHOPGeneric(TestCase):
def test_all_hops_have_op_info(self):
from torch._ops import _higher_order_ops
hops_that_have_op_info = set([k.name for k in hop_db])
all_hops = _higher_order_ops.keys()
missing_ops = []
for op in all_hops:
if (
op not in hops_that_have_op_info
and op not in hop_that_doesnt_have_opinfo_test_allowlist
):
missing_ops.append(op)
self.assertTrue(len(missing_ops) == 0, f"Missing op info for {missing_ops}")
@unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case")
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
class TestHOP(TestCase):
def _compare(self, eager_model, export, args, kwargs):
eager_args = copy.deepcopy(args)
eager_kwargs = copy.deepcopy(kwargs)
export_args = copy.deepcopy(args)
export_kwargs = copy.deepcopy(kwargs)
flat_orig_outputs = pytree.tree_leaves(eager_model(*eager_args, **eager_kwargs))
flat_loaded_outputs = pytree.tree_leaves(
export.module()(*export_args, **export_kwargs)
)
for orig, loaded in zip(flat_orig_outputs, flat_loaded_outputs):
self.assertEqual(type(orig), type(loaded))
self.assertEqual(orig, loaded)
@ops(hop_tests, allowed_dtypes=(torch.float,))
def test_aot_export(self, device, dtype, op):
class Foo(torch.nn.Module):
def forward(self, *args):
return op.op(*args)
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=True)
for inp in sample_inputs_itr:
model = Foo()
input = inp.input if isinstance(inp.input, tuple) else (inp.input,)
args = (*input, *inp.args)
kwargs = inp.kwargs
ep = export(model, args, kwargs)
self._compare(model, ep, args, kwargs)
@ops(hop_tests, allowed_dtypes=(torch.float,))
def test_pre_dispatch_export(self, device, dtype, op):
class Foo(torch.nn.Module):
def forward(self, *args):
return op.op(*args)
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=True)
for inp in sample_inputs_itr:
model = Foo()
input = inp.input if isinstance(inp.input, tuple) else (inp.input,)
args = (*input, *inp.args)
kwargs = inp.kwargs
ep = _export(model, args, kwargs, pre_dispatch=True)
self._compare(model, ep, args, kwargs)
@ops(hop_tests, allowed_dtypes=(torch.float,))
def test_retrace_export(self, device, dtype, op):
class Foo(torch.nn.Module):
def forward(self, *args):
return op.op(*args)
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=True)
for inp in sample_inputs_itr:
model = Foo()
input = inp.input if isinstance(inp.input, tuple) else (inp.input,)
args = (*input, *inp.args)
kwargs = inp.kwargs
ep = _export(model, args, kwargs, pre_dispatch=True)
ep = ep.run_decompositions()
self._compare(model, ep, args, kwargs)
@ops(hop_tests, allowed_dtypes=(torch.float,))
def test_serialize_export(self, device, dtype, op):
class Foo(torch.nn.Module):
def forward(self, *args):
return op.op(*args)
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=True)
for inp in sample_inputs_itr:
model = Foo()
input = inp.input if isinstance(inp.input, tuple) else (inp.input,)
args = (*input, *inp.args)
kwargs = inp.kwargs
ep = _export(model, args, kwargs, pre_dispatch=True)
ep = ep.run_decompositions()
buffer = io.BytesIO()
save(ep, buffer)
buffer.seek(0)
ep = load(buffer)
if "while_loop" in str(op):
# while_loop's arguments are cast into list after deserailize
# but while_loop expects it to still be tuple
with self.assertRaisesRegex(
RuntimeError, "carried_inputs must be a tuple"
):
self._compare(model, ep, args, kwargs)
else:
self._compare(model, ep, args, kwargs)
instantiate_device_type_tests(TestHOP, globals())
if __name__ == "__main__":
run_tests()
|