File: test_fx_node_hook.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (91 lines) | stat: -rw-r--r-- 3,205 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# Owner(s): ["module: fx"]
import torch
from torch.fx import symbolic_trace
from torch.testing._internal.common_utils import TestCase


class TestFXNodeHook(TestCase):
    def test_hooks_for_node_update(self):
        global create_node_hook1_called
        global create_node_hook2_called
        global erase_node_hook1_called
        global erase_node_hook2_called
        global replace_node_hook1_called
        global replace_node_hook2_called
        create_node_hook1_called = False
        create_node_hook2_called = False
        erase_node_hook1_called = False
        erase_node_hook2_called = False
        replace_node_hook1_called = False
        replace_node_hook2_called = False

        def fn(a, b, c):
            x = torch.nn.functional.linear(a, b)
            x = x + c
            return x.cos()

        def create_node_hook1(node):
            global create_node_hook1_called
            create_node_hook1_called = True

        def create_node_hook2(node):
            global create_node_hook2_called
            create_node_hook2_called = True

        def erase_node_hook1(node):
            global erase_node_hook1_called
            erase_node_hook1_called = True

        def erase_node_hook2(node):
            global erase_node_hook2_called
            erase_node_hook2_called = True

        def replace_node_hook1(old, new, user):
            global replace_node_hook1_called
            self.assertEqual(old.name, "a")
            self.assertEqual(new, "a_1")
            self.assertEqual(user.name, "linear")
            replace_node_hook1_called = True

        def replace_node_hook2(old, new, user):
            global replace_node_hook2_called
            replace_node_hook2_called = True

        gm = symbolic_trace(fn)
        gm._register_create_node_hook(create_node_hook1)
        gm._register_create_node_hook(create_node_hook2)
        gm._register_erase_node_hook(erase_node_hook1)
        gm._register_erase_node_hook(erase_node_hook2)
        gm._register_replace_node_hook(replace_node_hook1)
        gm._register_replace_node_hook(replace_node_hook2)

        graph = gm.graph
        node_a = None
        for node in graph.find_nodes(op="placeholder"):
            node_a = node
            break
        assert node_a is not None
        # This will create a new node
        node_a_copy = graph.node_copy(node_a)
        node_a.replace_all_uses_with(node_a_copy)
        graph.erase_node(node_a)

        assert (
            create_node_hook1_called
            and create_node_hook2_called
            and erase_node_hook1_called
            and erase_node_hook2_called
            and replace_node_hook1_called
            and replace_node_hook2_called
        )

        gm._unregister_create_node_hook(create_node_hook1)
        gm._unregister_create_node_hook(create_node_hook2)
        gm._unregister_erase_node_hook(erase_node_hook1)
        gm._unregister_erase_node_hook(erase_node_hook2)
        gm._unregister_replace_node_hook(replace_node_hook1)
        gm._unregister_replace_node_hook(replace_node_hook2)

        assert gm._create_node_hooks == []
        assert gm._erase_node_hooks == []
        assert gm._replace_hooks == []