1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264
|
# Owner(s): ["module: onnx"]
import pytorch_test_common
import torch
import torch._dynamo
import torch.fx
from torch.onnx._internal.fx.passes import _utils as pass_utils
from torch.testing._internal import common_utils
class TestFxPasses(common_utils.TestCase):
def test_set_node_name_correctly_renames_when_new_name_collides_recursively(self):
def func(x, y, z):
return x + y + z
x = torch.randn(3)
y = torch.randn(3)
z = torch.randn(3)
gm, _ = torch._dynamo.export(func)(x, y, z)
torch._dynamo.reset()
# Purposely name the nodes in a way that will cause a recursive collision later.
# See :func:`set_node_name` for name collision renaming logic.
base_name = "tensor"
nodes = list(gm.graph.nodes)
for i, node in enumerate(nodes[1:]):
if i == 0:
node.name = base_name
else:
node.name = f"{base_name}.{i}"
# Run `set_node_name` and verify that the names are correct.
name_to_node = {node.name: node for node in gm.graph.nodes}
pass_utils.set_node_name(nodes[0], base_name, name_to_node)
assert nodes[0].name == base_name, f"Expected {base_name}, got {nodes[0].name}"
assert len({node.name for node in nodes}) == len(
nodes
), f"Expected all names to be unique, got {nodes}"
def test_set_node_name_succeeds_when_no_name_collisions(self):
def func(x, y, z):
return x + y + z
x = torch.randn(3)
y = torch.randn(3)
z = torch.randn(3)
gm, _ = torch._dynamo.export(func)(x, y, z)
torch._dynamo.reset()
# Run `set_node_name` and verify that the names are correct.
new_name = "some_tensor"
nodes = list(gm.graph.nodes)
name_to_node = {node.name: node for node in nodes}
pass_utils.set_node_name(nodes[1], new_name, name_to_node)
assert nodes[1].name == new_name, f"Expected {new_name}, got {nodes[0].name}"
assert len({node.name for node in nodes}) == len(
nodes
), f"Expected all names to be unique, got {nodes}"
def test_onnx_dynamo_export_raises_when_model_contains_unsupported_fx_nodes(self):
@torch.library.custom_op(
"mylibrary::foo_op", device_types="cpu", mutates_args=()
)
def foo_op(x: torch.Tensor) -> torch.Tensor:
return x + 1
@torch.library.custom_op(
"mylibrary::bar_op", device_types="cpu", mutates_args=()
)
def bar_op(x: torch.Tensor) -> torch.Tensor:
return x + 2
@foo_op.register_fake
def _(x):
return torch.empty_like(x)
@bar_op.register_fake
def _(x):
return torch.empty_like(x)
def func(x, y, z):
return foo_op(x) + bar_op(y) + z
x = torch.randn(3)
y = torch.randn(3)
z = torch.randn(3)
with self.assertRaises(torch.onnx.OnnxExporterError) as ctx:
torch.onnx.dynamo_export(func, x, y, z)
inner_exception = ctx.exception.__cause__
self.assertRegex(
str(inner_exception),
r"Unsupported FX nodes.*mylibrary\.foo_op.*mylibrary\.bar_op",
)
torch._dynamo.reset()
@common_utils.instantiate_parametrized_tests
class TestModularizePass(common_utils.TestCase):
@pytorch_test_common.xfail(
error_message="'torch_nn_modules_activation_GELU_used_gelu_1' not found",
reason="optimizer",
)
@common_utils.parametrize(
"is_exported_program",
[
common_utils.subtest(
True,
name="exported_program",
),
common_utils.subtest(
False,
name="nn_module",
),
],
)
def test_modularize_pass_succeeds_when_submodule_output_is_unused(
self, is_exported_program
):
# This is an ill-formed model, but exporter must not crash.
# It is illegal for submodule to have zero output. For modularization pass it can happen
# when the submodule output is unused, so no inner node is connected to any outer
# nodes.
# However, this also means the entire submodule should be erased by DCE. Hence
# it should never occur.
#
# Minified repro from Background_Matting. https://github.com/pytorch/benchmark/issues/1768
class TestModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.unused_relu = torch.nn.ReLU()
self.used_gelu = torch.nn.GELU()
def forward(self, x, y):
result = self.used_gelu(x + y)
unused_relu_result = self.unused_relu(x)
return result
if is_exported_program:
model = torch.export.export(
TestModule(), args=(torch.randn(3), torch.randn(3))
)
else:
model = TestModule()
onnx_program = torch.onnx.dynamo_export(model, torch.randn(3), torch.randn(3))
model_proto = onnx_program.model_proto
function_proto_names = [function.name for function in model_proto.functions]
self.assertIn(
"torch_nn_modules_activation_GELU_used_gelu_1", function_proto_names
)
self.assertFalse(any("ReLU" in name for name in function_proto_names))
@pytorch_test_common.xfail(
error_message="'torch_nn_modules_activation_ReLU_relu_1' not found",
reason="optimizer",
)
@common_utils.parametrize(
"is_exported_program",
[
common_utils.subtest(
True,
name="exported_program",
),
common_utils.subtest(
False,
name="nn_module",
),
],
)
def test_modularize_pass_succeeds_when_a_submodule_is_called_multiple_times(
self, is_exported_program
):
class TestModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.relu = torch.nn.ReLU()
def forward(self, x, y):
out = x + y
out = self.relu(out)
out = out + x
out = self.relu(out)
return out
if is_exported_program:
model = torch.export.export(
TestModule(), args=(torch.randn(3), torch.randn(3))
)
else:
model = TestModule()
onnx_program = torch.onnx.dynamo_export(model, torch.randn(3), torch.randn(3))
model_proto = onnx_program.model_proto
function_proto_names = [function.name for function in model_proto.functions]
self.assertIn("torch_nn_modules_activation_ReLU_relu_1", function_proto_names)
self.assertIn("torch_nn_modules_activation_ReLU_relu_2", function_proto_names)
@pytorch_test_common.xfail(
error_message="'torch_nn_modules_activation_ReLU_inner_module_relu_1' not found",
reason="optimizer",
)
@common_utils.parametrize(
"is_exported_program",
[
common_utils.subtest(
True,
name="exported_program",
),
common_utils.subtest(
False,
name="nn_module",
),
],
)
def test_modularize_pass_succeeds_when_a_submodule_is_called_from_multiple_layers(
self, is_exported_program
):
# Minified repro from basic_gnn_edgecnn.
class InnerModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.relu = torch.nn.ReLU()
def forward(self, x):
return self.relu(x)
class TestModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.inner_module = InnerModule()
def forward(self, x, y):
out = x + y
out = self.inner_module(out)
out = out + x
out = self.inner_module.relu(out)
return out
if is_exported_program:
model = torch.export.export(
TestModule(), args=(torch.randn(3), torch.randn(3))
)
else:
model = TestModule()
onnx_program = torch.onnx.dynamo_export(model, torch.randn(3), torch.randn(3))
model_proto = onnx_program.model_proto
function_proto_names = [function.name for function in model_proto.functions]
self.assertIn(
"torch_nn_modules_activation_ReLU_inner_module_relu_1", function_proto_names
)
self.assertIn(
"torch_nn_modules_activation_ReLU_inner_module_relu_2", function_proto_names
)
# local module qualified name is unstable in test environment depending on different test
# invocation methods.
self.assertTrue(
any("InnerModule_inner_module_1" in name for name in function_proto_names)
)
if __name__ == "__main__":
common_utils.run_tests()
|