1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
|
# Owner(s): ["module: fx"]
import os
import tempfile
import torch
from torch.fx import subgraph_rewriter, symbolic_trace
from torch.fx.passes.graph_transform_observer import GraphTransformObserver
from torch.testing._internal.common_utils import TestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_fx.py TESTNAME\n\n"
"instead."
)
class TestGraphTransformObserver(TestCase):
def test_graph_transform_observer(self):
class M(torch.nn.Module):
def forward(self, x):
val = torch.neg(x)
return torch.add(val, val)
def pattern(x):
return torch.neg(x)
def replacement(x):
return torch.relu(x)
traced = symbolic_trace(M())
log_url = tempfile.mkdtemp()
with GraphTransformObserver(
traced, "replace_neg_with_relu", log_url=log_url
) as ob:
subgraph_rewriter.replace_pattern(traced, pattern, replacement)
self.assertTrue("relu" in ob.created_nodes)
self.assertTrue("neg" in ob.erased_nodes)
current_pass_count = GraphTransformObserver.get_current_pass_count()
self.assertTrue(
os.path.isfile(
os.path.join(
log_url,
f"pass_{current_pass_count}_replace_neg_with_relu_input_graph.dot",
)
)
)
self.assertTrue(
os.path.isfile(
os.path.join(
log_url,
f"pass_{current_pass_count}_replace_neg_with_relu_output_graph.dot",
)
)
)
|