1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
|
# Owner(s): ["module: fx"]
import unittest
import torch
import torch.fx
from torch.testing._internal.common_utils import TestCase
class MyModuleBase(torch.nn.Module):
def forward(self, x):
matrx = self.get_mul_matrix()
if self.no_relu():
return torch.mm(x, matrx)
else:
return torch.relu(torch.mm(x, matrx))
def get_mul_matrix(self):
return self.param
def no_relu(self):
raise Exception("not implemented")
class MyModuleParamShape(MyModuleBase):
def __init__(self, in_channels):
super().__init__()
self.param = torch.nn.Parameter(torch.randn(in_channels, 3))
def no_relu(self):
return self.param.shape[0] < 10
class MyModuleParamSize(MyModuleBase):
def __init__(self, in_channels):
super().__init__()
self.param = torch.nn.Parameter(torch.randn(in_channels, 3))
def no_relu(self):
return self.param.size()[0] < 10
class MyModuleParamDim(MyModuleBase):
def __init__(self, param):
super().__init__()
self.param = param
def get_mul_matrix(self):
return self.param[0] if (self.param.dim() == 3) else self.param
def no_relu(self):
return self.param.dim() == 3
class MyModuleParamNDim(MyModuleBase):
def __init__(self, param):
super().__init__()
self.param = param
def get_mul_matrix(self):
return self.param[0] if (self.param.ndim == 3) else self.param
def no_relu(self):
return self.param.ndim == 3
class MyModuleParamNumEl(MyModuleBase):
def __init__(self, in_channels):
super().__init__()
self.param = torch.nn.Parameter(torch.randn(in_channels, 3))
def no_relu(self):
return self.param.numel() < 10 * 3
class MyModuleParamNElement(MyModuleBase):
def __init__(self, in_channels):
super().__init__()
self.param = torch.nn.Parameter(torch.randn(in_channels, 3))
def no_relu(self):
return self.param.nelement() < 10 * 3
class TestConstParamShapeInControlFlow(TestCase):
def verify_mm_relu_mods(self, mm_only_mod, relu_mod):
"""
Verify one module only does a mm op while the other
performs both mm and relu ops in cascade
"""
x = torch.randn(10, 5)
torch.testing.assert_allclose(mm_only_mod(x), torch.mm(x, mm_only_mod.get_mul_matrix()))
tracer = torch.fx.Tracer(param_shapes_constant=True)
traced_graph = tracer.trace(mm_only_mod)
# verify the graph module calculates the same result
graph_mod_mm = torch.fx.GraphModule(mm_only_mod, traced_graph)
torch.testing.assert_allclose(graph_mod_mm(x), torch.mm(x, mm_only_mod.get_mul_matrix()))
# Make a new module with different parameter shape to go down the different
# code path
x = torch.randn(10, 15)
torch.testing.assert_allclose(relu_mod(x), torch.relu(torch.mm(x, relu_mod.get_mul_matrix())))
tracer2 = torch.fx.Tracer(param_shapes_constant=True)
traced_graph2 = tracer2.trace(relu_mod)
# verify the graph module calculates the same result
graph_mod_relu = torch.fx.GraphModule(relu_mod, traced_graph2)
torch.testing.assert_allclose(graph_mod_relu(x), torch.relu(torch.mm(x, relu_mod.get_mul_matrix())))
graph1_node_targets = [n.target for n in traced_graph.nodes]
graph2_node_targets = [n.target for n in traced_graph2.nodes]
# the second graph has an exta relu function call node
assert torch.mm in graph1_node_targets and torch.mm in graph2_node_targets
assert torch.relu not in graph1_node_targets and torch.relu in graph2_node_targets
def test_param_shape_const(self):
mymod = MyModuleParamShape(in_channels=5)
mymod2 = MyModuleParamShape(in_channels=15)
self.verify_mm_relu_mods(mymod, mymod2)
def test_param_size_const(self):
mymod = MyModuleParamSize(in_channels=5)
mymod2 = MyModuleParamSize(in_channels=15)
self.verify_mm_relu_mods(mymod, mymod2)
def test_param_dim_const(self):
mymod = MyModuleParamDim(torch.nn.Parameter(torch.randn(2, 5, 3)))
mymod2 = MyModuleParamDim(torch.nn.Parameter(torch.randn(15, 3)))
self.verify_mm_relu_mods(mymod, mymod2)
def test_param_ndim_const(self):
mymod = MyModuleParamNDim(torch.nn.Parameter(torch.randn(2, 5, 3)))
mymod2 = MyModuleParamNDim(torch.nn.Parameter(torch.randn(15, 3)))
self.verify_mm_relu_mods(mymod, mymod2)
def test_param_numel_const(self):
mymod = MyModuleParamNumEl(in_channels=5)
mymod2 = MyModuleParamNumEl(in_channels=15)
self.verify_mm_relu_mods(mymod, mymod2)
def test_param_nelement_const(self):
mymod = MyModuleParamNElement(in_channels=5)
mymod2 = MyModuleParamNElement(in_channels=15)
self.verify_mm_relu_mods(mymod, mymod2)
if __name__ == '__main__':
unittest.main()
|