1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
|
# Owner(s): ["oncall: export"]
import unittest
import torch
from functorch.experimental import control_flow
from torch import Tensor
from torch._dynamo.eval_frame import is_dynamo_supported
from torch._export.verifier import SpecViolationError, Verifier
from torch.export import export_for_training
from torch.export.exported_program import InputKind, InputSpec, TensorArgument
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase
@unittest.skipIf(not is_dynamo_supported(), "dynamo isn't supported")
class TestVerifier(TestCase):
def test_verifier_basic(self) -> None:
class Foo(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y
f = Foo()
ep = export_for_training(f, (torch.randn(100), torch.randn(100)))
verifier = Verifier()
verifier.check(ep)
def test_verifier_call_module(self) -> None:
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(10, 10)
def forward(self, x: Tensor) -> Tensor:
return self.linear(x)
gm = torch.fx.symbolic_trace(M())
verifier = Verifier()
with self.assertRaises(SpecViolationError):
verifier._check_graph_module(gm)
def test_verifier_no_functional(self) -> None:
class Foo(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y
f = Foo()
ep = export_for_training(
f, (torch.randn(100), torch.randn(100))
).run_decompositions({})
for node in ep.graph.nodes:
if node.target == torch.ops.aten.add.Tensor:
node.target = torch.ops.aten.add_.Tensor
verifier = Verifier()
with self.assertRaises(SpecViolationError):
verifier.check(ep)
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
def test_verifier_higher_order(self) -> None:
class Foo(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
def true_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y
def false_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x - y
return control_flow.cond(x.sum() > 2, true_fn, false_fn, [x, y])
f = Foo()
ep = export_for_training(f, (torch.randn(3, 3), torch.randn(3, 3)))
verifier = Verifier()
verifier.check(ep)
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
def test_verifier_nested_invalid_module(self) -> None:
class Foo(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
def true_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y
def false_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x - y
return control_flow.cond(x.sum() > 2, true_fn, false_fn, [x, y])
f = Foo()
ep = export_for_training(
f, (torch.randn(3, 3), torch.randn(3, 3))
).run_decompositions({})
for node in ep.graph_module.true_graph_0.graph.nodes:
if node.target == torch.ops.aten.add.Tensor:
node.target = torch.ops.aten.add_.Tensor
verifier = Verifier()
with self.assertRaises(SpecViolationError):
verifier.check(ep)
def test_ep_verifier_basic(self) -> None:
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(10, 10)
def forward(self, x: Tensor) -> Tensor:
return self.linear(x)
ep = export_for_training(M(), (torch.randn(10, 10),))
ep.validate()
def test_ep_verifier_invalid_param(self) -> None:
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.register_parameter(
name="a", param=torch.nn.Parameter(torch.randn(100))
)
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y + self.a
ep = export_for_training(M(), (torch.randn(100), torch.randn(100)))
# Parameter doesn't exist in the state dict
ep.graph_signature.input_specs[0] = InputSpec(
kind=InputKind.PARAMETER, arg=TensorArgument(name="p_a"), target="bad_param"
)
with self.assertRaisesRegex(SpecViolationError, "not in the state dict"):
ep.validate()
# Add non-torch.nn.Parameter parameter to the state dict
ep.state_dict["bad_param"] = torch.randn(100)
with self.assertRaisesRegex(
SpecViolationError, "not an instance of torch.nn.Parameter"
):
ep.validate()
def test_ep_verifier_invalid_buffer(self) -> None:
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.a = torch.tensor(3.0)
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y + self.a
ep = export_for_training(M(), (torch.randn(100), torch.randn(100)))
# Buffer doesn't exist in the state dict
ep.graph_signature.input_specs[0] = InputSpec(
kind=InputKind.BUFFER,
arg=TensorArgument(name="c_a"),
target="bad_buffer",
persistent=True,
)
with self.assertRaisesRegex(SpecViolationError, "not in the state dict"):
ep.validate()
def test_ep_verifier_buffer_mutate(self) -> None:
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.my_parameter = torch.nn.Parameter(torch.tensor(2.0))
self.my_buffer1 = torch.nn.Buffer(torch.tensor(3.0))
self.my_buffer2 = torch.nn.Buffer(torch.tensor(4.0))
def forward(self, x1, x2):
# Use the parameter, buffers, and both inputs in the forward method
output = (
x1 + self.my_parameter
) * self.my_buffer1 + x2 * self.my_buffer2
# Mutate one of the buffers (e.g., increment it by 1)
self.my_buffer2.add_(1.0)
return output
ep = export_for_training(M(), (torch.tensor(5.0), torch.tensor(6.0)))
ep.validate()
def test_ep_verifier_invalid_output(self) -> None:
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.my_parameter = torch.nn.Parameter(torch.tensor(2.0))
self.my_buffer1 = torch.nn.Buffer(torch.tensor(3.0))
self.my_buffer2 = torch.nn.Buffer(torch.tensor(4.0))
def forward(self, x1, x2):
# Use the parameter, buffers, and both inputs in the forward method
output = (
x1 + self.my_parameter
) * self.my_buffer1 + x2 * self.my_buffer2
# Mutate one of the buffers (e.g., increment it by 1)
self.my_buffer2.add_(1.0)
return output
ep = export_for_training(M(), (torch.tensor(5.0), torch.tensor(6.0)))
output_node = list(ep.graph.nodes)[-1]
output_node.args = (
(
output_node.args[0][0],
next(iter(ep.graph.nodes)),
),
)
with self.assertRaisesRegex(SpecViolationError, "Number of output nodes"):
ep.validate()
if __name__ == "__main__":
run_tests()
|