1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
|
# Owner(s): ["oncall: export"]
import copy
import unittest
import torch
from functorch.experimental import control_flow
from torch._dynamo.eval_frame import is_dynamo_supported
from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
from torch.export import export
from torch.fx.passes.infra.pass_base import PassResult
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
class TestPassInfra(TestCase):
def test_export_pass_base(self) -> None:
class Foo(torch.nn.Module):
def forward(self, x):
y = torch.cat([x, x])
return torch.ops.aten.tensor_split.sections(y, 2)
f = Foo()
class NullPass(_ExportPassBaseDeprecatedDoNotUse):
pass
ep = export(f, (torch.ones(3, 2),))
old_nodes = ep.graph.nodes
ep = ep._transform_do_not_use(NullPass())
new_nodes = ep.graph.nodes
for node in new_nodes:
if node.op != "call_function":
continue
self.assertTrue(hasattr(node, "stack_trace"))
self.assertIsNotNone(node.stack_trace)
self.assertEqual(len(new_nodes), len(old_nodes))
for new_node, old_node in zip(new_nodes, old_nodes):
self.assertEqual(new_node.op, old_node.op)
self.assertEqual(new_node.target, old_node.target)
@unittest.skipIf(IS_WINDOWS, "Windows not supported")
def test_cond(self) -> None:
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, pred, x, y):
def true_fn(x, y):
b = x.item()
torch._check(b >= 2)
torch._check(b <= 5)
return x - y
def false_fn(x, y):
c = y.item()
torch._check(c >= 2)
torch._check(c <= 5)
return x + y
ret = control_flow.cond(pred, true_fn, false_fn, [x, y])
return ret
x = torch.tensor([2])
y = torch.tensor([5])
mod = M()
_ = export(mod, (torch.tensor(True), x, y))._transform_do_not_use(
_ExportPassBaseDeprecatedDoNotUse()
)
def test_node_name_stability(self) -> None:
# Tests that graph nodes stay the same for nodes that are not touched
# during transformation
class CustomModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
# Define a parameter
self.my_parameter = torch.nn.Parameter(torch.tensor(2.0))
# Define two buffers
self.my_buffer1 = torch.nn.Buffer(torch.tensor(3.0))
self.my_buffer2 = torch.nn.Buffer(torch.tensor(4.0))
def forward(self, x1, x2):
# Use the parameter, buffers, and both inputs in the forward method
output = (
x1 + self.my_parameter
) * self.my_buffer1 + x2 * self.my_buffer2
# Mutate one of the buffers (e.g., increment it by 1)
self.my_buffer2.add_(1.0)
return output
inps = (torch.rand(1), torch.rand(1))
m = CustomModule()
ep_before = export(m, inps)
# No op transformation that doesn't perform any meaningful changes to node
ep_after = ep_before._transform_do_not_use(_ExportPassBaseDeprecatedDoNotUse())
for before_node, after_node in zip(ep_before.graph.nodes, ep_after.graph.nodes):
self.assertEqual(before_node.name, after_node.name)
def test_graph_signature_updated_after_transformation(self) -> None:
# Checks that pass infra correctly updates graph signature
# after transformations.
class CustomModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.my_parameter = torch.nn.Parameter(torch.tensor(2.0))
self.my_buffer1 = torch.nn.Buffer(torch.tensor(3.0))
self.my_buffer2 = torch.nn.Buffer(torch.tensor(4.0))
def forward(self, x1, x2):
# Use the parameter, buffers, and both inputs in the forward method
output = (
x1 + self.my_parameter
) * self.my_buffer1 + x2 * self.my_buffer2
return output
my_module = CustomModule()
# Test the custom module with two input tensors
input_tensor1 = torch.tensor(5.0)
input_tensor2 = torch.tensor(6.0)
ep_before = torch.export.export(my_module, (input_tensor1, input_tensor2))
from torch.fx.passes.infra.pass_base import PassResult
def modify_input_output_pass(gm):
for node in gm.graph.nodes:
if node.op == "call_function":
node.name = node.name + "_modified"
gm.recompile()
return PassResult(gm, True)
ep_after = ep_before._transform_do_not_use(modify_input_output_pass)
new_signature = ep_after.graph_signature
for node_name in new_signature.user_outputs:
self.assertTrue("_modified" in node_name)
old_signature = ep_before.graph_signature
self.assertNotEqual(new_signature.user_outputs, old_signature.user_outputs)
def test_replace_hook_basic(self) -> None:
class CustomModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.my_parameter = torch.nn.Parameter(torch.tensor(2.0))
self.my_buffer1 = torch.nn.Buffer(torch.tensor(3.0))
self.my_buffer2 = torch.nn.Buffer(torch.tensor(4.0))
def forward(self, x1, x2):
# Use the parameter, buffers, and both inputs in the forward method
output = (
x1 + self.my_parameter
) * self.my_buffer1 + x2 * self.my_buffer2
return output
my_module = CustomModule()
inputs = (torch.tensor(6.0), torch.tensor(7.0))
ep_before = export(my_module, inputs)
def replace_pass(gm):
for node in gm.graph.nodes:
if node.op == "call_function":
node.name = node.name + "_modified"
gm.recompile()
return PassResult(gm, True)
gm = copy.deepcopy(ep_before.graph_module)
sig = copy.deepcopy(ep_before.graph_signature)
with gm._set_replace_hook(sig.get_replace_hook()):
replace_pass(gm)
for node_name in sig.user_outputs:
self.assertTrue("_modified" in node_name)
old_signature = ep_before.graph_signature
self.assertNotEqual(sig.user_outputs, old_signature.user_outputs)
if __name__ == "__main__":
run_tests()
|