1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
|
# Owner(s): ["oncall: fx"]
import itertools
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.graph_module import GraphModule
from torch.fx.passes.dialect.common.cse_pass import CSEPass
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TestCase,
)
def FactoryFunctionCall(x, device):
y = torch.full(x.shape, 3, device=device)
z = torch.add(y, x)
return z
def TorchTensorCall(x):
y = torch.tensor(3)
return x + y
def TakeList(x):
z = torch.cat([x, x])
return z
def ReturnList(x):
a = torch.arange(10).reshape(5, 2)
z = torch.split(a, [1, 4])
return z
def Mutation(x):
y = x + 2
y.add_(1)
return x + y
def MutationInput(x):
x.add_(1)
y = x + 2
return x + y
def MutationFactory(x, device):
y = torch.full(x.shape, 3, device=device)
y.add_(1)
return x + y
def MutationTorchTensorCall(x):
y = torch.tensor(3)
y.add_(1)
return x + y
def MutationMetadata(x):
x.resize_(2)
return x
Passes = [CSEPass]
Test_Cases = [
TakeList,
ReturnList,
Mutation,
MutationInput,
MutationMetadata,
MutationTorchTensorCall,
]
Factory_Test_Cases = [FactoryFunctionCall, MutationFactory]
Devices = ["cpu"]
if torch.cuda.is_available():
Devices.append("cuda")
def name_fn(common_pass, f, device):
"""Names parameterized test cases."""
return f"{type(common_pass()).__name__}_{f.__name__}_{device}"
@instantiate_parametrized_tests
class TestCommonPass(TestCase):
@parametrize(
"common_pass,f,device", itertools.product(Passes, Test_Cases, Devices), name_fn
)
def test_correctness(self, common_pass, f, device):
inp = torch.randn(10, device=device)
traced_m = make_fx(f)(inp)
P = common_pass()
res = P(traced_m)
modified_m = res.graph_module
assert isinstance(modified_m, GraphModule)
inp_copy = inp.clone()
expected = f(inp)
result = modified_m(inp_copy)
self.assertEqual(result, expected)
@parametrize(
"common_pass,f,device",
itertools.product(Passes, Factory_Test_Cases, Devices),
name_fn,
)
def test_correctness_factory(self, common_pass, f, device):
inp = torch.randn(10, device=device)
traced_m = make_fx(f)(inp, device)
P = common_pass()
res = P(traced_m)
modified_m = res.graph_module
assert isinstance(modified_m, GraphModule)
inp_copy = inp.clone()
expected = f(inp, device)
result = modified_m(inp_copy, device)
self.assertEqual(result, expected)
if __name__ == "__main__":
run_tests()
|