1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
|
# Owner(s): ["oncall: fx"]
import torch
from torch.testing._internal.common_utils import (
TestCase, parametrize, instantiate_parametrized_tests, run_tests)
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.passes.dialect.common.cse_pass import CSEPass
from torch.fx.graph_module import GraphModule
import itertools
def FactoryFunctionCall(x, device):
y = torch.full(x.shape, 3, device=device)
z = torch.add(y, x)
return z
def TorchTensorCall(x):
y = torch.tensor(3)
return x + y
def TakeList(x):
z = torch.cat([x, x])
return z
def ReturnList(x):
a = torch.arange(10).reshape(5, 2)
z = torch.split(a, [1, 4])
return z
def Mutation(x):
y = x + 2
y.add_(1)
return x + y
def MutationInput(x):
x.add_(1)
y = x + 2
return x + y
def MutationFactory(x, device):
y = torch.full(x.shape, 3, device=device)
y.add_(1)
return x + y
def MutationTorchTensorCall(x):
y = torch.tensor(3)
y.add_(1)
return x + y
def MutationMetadata(x):
x.resize_(2)
return x
Passes = [CSEPass]
Test_Cases = [TakeList,
ReturnList,
Mutation,
MutationInput,
MutationMetadata,
MutationTorchTensorCall]
Factory_Test_Cases = [FactoryFunctionCall, MutationFactory]
Devices = ["cpu"]
if torch.cuda.is_available():
Devices.append("cuda")
@instantiate_parametrized_tests
class TestCommonPass(TestCase):
@parametrize("common_pass,f,device", itertools.product(Passes, Test_Cases, Devices))
def test_correctness(self, common_pass, f, device):
inp = torch.randn(10, device=device)
traced_m = make_fx(f)(inp)
P = common_pass()
res = P(traced_m)
modified_m = res.graph_module
assert isinstance(modified_m, GraphModule)
inp_copy = inp.clone()
expected = f(inp)
result = modified_m(inp_copy)
self.assertEqual(result, expected)
@parametrize("common_pass,f,device", itertools.product(Passes, Factory_Test_Cases, Devices))
def test_correctness_factory(self, common_pass, f, device):
inp = torch.randn(10, device=device)
traced_m = make_fx(f)(inp, device)
P = common_pass()
res = P(traced_m)
modified_m = res.graph_module
assert isinstance(modified_m, GraphModule)
inp_copy = inp.clone()
expected = f(inp, device)
result = modified_m(inp_copy, device)
self.assertEqual(result, expected)
if __name__ == '__main__':
run_tests()
|