1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
|
# Owner(s): ["module: dynamo"]
import operator
import torch
import torch._dynamo
import torch._dynamo.config as config
import torch._dynamo.test_case
from torch._dynamo.testing import same
from torch.fx._lazy_graph_module import _force_skip_lazy_graph_module
class Seq(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.layers = torch.nn.Sequential(
torch.nn.Linear(10, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 10),
torch.nn.Sigmoid(),
)
def forward(self, x):
return self.layers(x)
class Conv_Bn_Relu(torch.nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super().__init__()
self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.bn = torch.nn.BatchNorm2d(out_channels, eps=0.001)
self.relu = torch.nn.ReLU()
def forward(self, x):
return self.relu(self.bn(self.conv(x)))
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
def transform(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
for node in gm.graph.nodes:
# Checks if we're calling a function (i.e:
# operator.add)
if node.op == "call_function":
# The target attribute is the function
# that call_function calls.
if node.target == operator.mul:
node.target = operator.add
gm.graph.lint() # Does some checks to make sure the
# Graph is well-formed.
gm.recompile()
return gm
@config.patch("verify_correctness", True)
class TestVerifyCorrectness(torch._dynamo.test_case.TestCase):
def test_example_inputs(self):
def fn(a, bc, d):
b, c = bc
return a / d - b / c
def compiler_fn(graph, example_inputs):
nonlocal r1
r1 = graph(*example_inputs)[0]
return graph.forward
a = torch.empty(2).fill_(1)
b = torch.empty(2).fill_(2)
c = torch.empty(2).fill_(3)
d = 4
r1 = None
r2 = fn(a, (b, c), d)
opt_fn = torch._dynamo.optimize_assert(compiler_fn)(fn)
r3 = opt_fn(a, (b, c), d)
self.assertIsNotNone(r1)
self.assertEqual(r1.shape, r2.shape)
self.assertEqual(r1.shape, r3.shape)
self.assertEqual(r1.device, r2.device)
self.assertEqual(r1.device, r3.device)
@_force_skip_lazy_graph_module()
def test_torchscript(self):
s = Seq()
i = torch.randn(10)
r1 = s(i)
opt_s = torch.compile(s, backend="ts")
r2 = opt_s(i)
self.assertTrue(same(r1, r2))
def test_incorrect_verify_true(self):
"""
If a bad optimization return a graph that
is not functionally equal to the original graph;
When config.verify_correctness=True, it will
check the correctness of outputs and raise an error
"""
i1 = torch.randn(10)
i2 = torch.randn(10)
def incorrect_compile_fn(gm, example_inputs):
return transform(gm).forward
toy_example(i1, i2)
try:
opt_toy_example = torch.compile(toy_example, backend=incorrect_compile_fn)
opt_toy_example(i1, i2)
except RuntimeError:
pass
else:
self.fail("expected failure")
@config.patch("verify_correctness", False)
def test_incorrect_verify_false(self):
"""
The bad optimization return a graph that
is not functionally equal to the original graph;
When config.verify_correctness=False, wrong outputs
will return
"""
i1 = torch.randn(10)
i2 = torch.randn(10)
def incorrect_compile_fn(gm, example_inputs):
return transform(gm).forward
r1 = toy_example(i1, i2)
opt_toy_example = torch.compile(toy_example, backend=incorrect_compile_fn)
r2 = opt_toy_example(i1, i2)
self.assertTrue(not same(r1, r2))
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()
|