1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
|
# Owner(s): ["module: fx"]
import torch
from torch.fx import symbolic_trace
from torch.testing._internal.common_utils import TestCase
class TestFXNodeHook(TestCase):
def test_hooks_for_node_update(self):
global create_node_hook1_called
global create_node_hook2_called
global erase_node_hook1_called
global erase_node_hook2_called
global replace_node_hook1_called
global replace_node_hook2_called
create_node_hook1_called = False
create_node_hook2_called = False
erase_node_hook1_called = False
erase_node_hook2_called = False
replace_node_hook1_called = False
replace_node_hook2_called = False
def fn(a, b, c):
x = torch.nn.functional.linear(a, b)
x = x + c
return x.cos()
def create_node_hook1(node):
global create_node_hook1_called
create_node_hook1_called = True
def create_node_hook2(node):
global create_node_hook2_called
create_node_hook2_called = True
def erase_node_hook1(node):
global erase_node_hook1_called
erase_node_hook1_called = True
def erase_node_hook2(node):
global erase_node_hook2_called
erase_node_hook2_called = True
def replace_node_hook1(old, new, user):
global replace_node_hook1_called
self.assertEqual(old.name, "a")
self.assertEqual(new, "a_1")
self.assertEqual(user.name, "linear")
replace_node_hook1_called = True
def replace_node_hook2(old, new, user):
global replace_node_hook2_called
replace_node_hook2_called = True
gm = symbolic_trace(fn)
gm._register_create_node_hook(create_node_hook1)
gm._register_create_node_hook(create_node_hook2)
gm._register_erase_node_hook(erase_node_hook1)
gm._register_erase_node_hook(erase_node_hook2)
gm._register_replace_node_hook(replace_node_hook1)
gm._register_replace_node_hook(replace_node_hook2)
graph = gm.graph
node_a = None
for node in graph.find_nodes(op="placeholder"):
node_a = node
break
assert node_a is not None
# This will create a new node
node_a_copy = graph.node_copy(node_a)
node_a.replace_all_uses_with(node_a_copy)
graph.erase_node(node_a)
assert (
create_node_hook1_called
and create_node_hook2_called
and erase_node_hook1_called
and erase_node_hook2_called
and replace_node_hook1_called
and replace_node_hook2_called
)
gm._unregister_create_node_hook(create_node_hook1)
gm._unregister_create_node_hook(create_node_hook2)
gm._unregister_erase_node_hook(erase_node_hook1)
gm._unregister_erase_node_hook(erase_node_hook2)
gm._unregister_replace_node_hook(replace_node_hook1)
gm._unregister_replace_node_hook(replace_node_hook2)
assert gm._create_node_hooks == []
assert gm._erase_node_hooks == []
assert gm._replace_hooks == []
|