1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226
|
# Owner(s): ["module: fx"]
from collections import defaultdict
from typing import Dict, List, Tuple
import torch
from torch.fx.passes.split_utils import split_by_tags
from torch.testing._internal.common_utils import TestCase
class TestFXSplit(TestCase):
def test_split_preserve_node_meta(self):
class TestModule(torch.nn.Module):
def forward(self, x, y):
x = x + x
y = y * y
return x - y
gm = torch.fx.symbolic_trace(TestModule())
for node in gm.graph.nodes:
node.meta["name"] = node.name
if node.name == "add":
node.tag = "a"
elif node.name == "mul":
node.tag = "b"
elif node.name == "sub":
node.tag = "c"
split_gm = split_by_tags(gm, ["a", "b", "c"])
for m in split_gm.children():
for n in m.graph.nodes:
if n.op != "output":
self.assertIn("name", n.meta)
self.assertEqual(n.meta["name"], n.name)
# Validate that metadata is copied correctly for graph placeholder nodes
for node in split_gm.graph.nodes:
if node.op == "placeholder":
self.assertIn("name", node.meta)
self.assertEqual(node.meta["name"], node.name)
class TestSplitByTags(TestCase):
class TestModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear1 = torch.nn.Linear(2, 3)
self.linear2 = torch.nn.Linear(4, 5)
self.linear3 = torch.nn.Linear(6, 7)
self.linear4 = torch.nn.Linear(8, 6)
def forward(
self,
x1: torch.Tensor,
x2: torch.Tensor,
x3: torch.Tensor,
) -> torch.Tensor:
v1 = self.linear1(x1)
v2 = self.linear2(x2)
v3 = self.linear3(x3)
v4 = torch.cat([v1, v2, v3])
return self.linear4(v4)
@staticmethod
def trace_and_tag(
module: torch.nn.Module, tags: List[str]
) -> Tuple[torch.fx.GraphModule, Dict[str, List[str]]]:
"""
Test simple gm consists of nodes with tag (only show call_module nodes here):
linear1 - tag: "red"
linear2 - tag: "blue"
linear3, linear4 - tag: "green"
At the beginning we have:
gm:
linear1
linear2
linear3
linear4
split_gm = split_by_tags(gm, tags)
Then we have:
split_gm:
red:
linear1
blue:
linear2
green:
linear3
linear4
"""
tag_node = defaultdict(list)
gm: torch.fx.GraphModule = torch.fx.symbolic_trace(module)
# Add tag to all nodes and build dictionary record tag to call_module nodes
for node in gm.graph.nodes:
if "linear1" in node.name:
node.tag = tags[0]
tag_node[tags[0]].append(node.name)
elif "linear2" in node.name:
node.tag = tags[1]
tag_node[tags[1]].append(node.name)
else:
node.tag = tags[2]
if node.op == "call_module":
tag_node[tags[2]].append(node.name)
return gm, tag_node
def test_split_by_tags(self) -> None:
tags = ["red", "blue", "green"]
module = TestSplitByTags.TestModule()
gm, tag_node = TestSplitByTags.trace_and_tag(module, tags)
split_gm, orig_to_split_fqn_mapping = split_by_tags(
gm, tags, return_fqn_mapping=True
)
# Ensure split_gm has (and only has) ordered submodules named
# red_0, blue_1, green_2
for idx, (name, _) in enumerate(split_gm.named_children()):
if idx < len(tags):
self.assertTrue(
name == tags[idx],
f"split_gm has an incorrect submodule named {name}",
)
# Ensure each submodule has expected (ordered) call_module node(s).
# For example, a submodule named split_gm.red_0 has (and only has) linear1;
# split_gm.green_2 has (and only has) linear3 and linear4 with order
sub_graph_idx = 0
for sub_name, sub_graph_module in split_gm.named_children():
node_idx = 0
for node in sub_graph_module.graph.nodes:
if node.op != "call_module":
continue
self.assertTrue(
node.name == tag_node[f"{sub_name}"][node_idx],
# pyre-fixme[61]: `name` is undefined, or not always defined.
f"{sub_name} has incorrectly include {node.name}",
)
node_idx += 1
sub_graph_idx += 1
self.assertEqual(
orig_to_split_fqn_mapping,
{
"linear1": "red.linear1",
"linear2": "blue.linear2",
"linear3": "green.linear3",
"linear4": "green.linear4",
},
f"{orig_to_split_fqn_mapping=}",
)
class TestSplitOutputType(TestCase):
class TestModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
self.relu = torch.nn.ReLU()
def forward(self, x):
conv = self.conv(x)
conv = conv * 0.5
relu = self.relu(conv)
return relu
@staticmethod
def trace_and_tag(
module: torch.nn.Module, inputs: torch.Tensor, tags: List[str]
) -> Tuple[torch.fx.GraphModule, Dict[str, List[str]]]:
"""
Test simple gm consists of nodes with tag (only show call_module nodes here):
conv - tag: "red"
mul - tag: "blue"
relu - tag: "green"
At the beginning we have:
gm:
conv
mul
relu
split_gm = split_by_tags(gm, tags)
Then we have:
split_gm:
red:
conv
blue:
mul
green:
relu
"""
tag_node = defaultdict(list)
gm: torch.fx.GraphModule = torch.export.export(module, (inputs,)).module()
# Add tag to all nodes and build dictionary record tag to call_module nodes
for node in gm.graph.nodes:
if "conv" in node.name:
node.tag = tags[0]
tag_node[tags[0]].append(node.name)
elif "mul" in node.name:
node.tag = tags[1]
tag_node[tags[1]].append(node.name)
else:
node.tag = tags[2]
if node.op == "call_module":
tag_node[tags[2]].append(node.name)
return gm, tag_node
def test_split_by_tags(self) -> None:
tags = ["red", "blue", "green"]
module = TestSplitOutputType.TestModule()
inputs = torch.randn((1, 3, 224, 224))
gm, tag_node = TestSplitOutputType.trace_and_tag(module, inputs, tags)
split_gm, orig_to_split_fqn_mapping = split_by_tags(
gm, tags, return_fqn_mapping=True
)
gm_output = module(inputs)
split_gm_output = split_gm(inputs)
self.assertTrue(type(gm_output) == type(split_gm_output))
self.assertTrue(torch.equal(gm_output, split_gm_output))
|