1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
|
# Owner(s): ["oncall: distributed"]
import torch
from torch.distributed.fsdp._trace_utils import _ExecOrderTracer
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_utils import run_tests, TestCase
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.weight1 = torch.nn.Parameter(torch.randn(6, 6))
self.weight2 = torch.nn.Parameter(torch.randn(6, 6))
self.weight_unused = torch.nn.Parameter(torch.randn(2, 2))
self.layer0 = torch.nn.Linear(6, 6)
self.layer1 = torch.nn.Linear(6, 6, bias=False)
self.layer2 = torch.nn.Sequential(
torch.nn.Linear(6, 3, bias=False),
torch.nn.ReLU(),
torch.nn.Linear(3, 6, bias=False),
)
self.relu = torch.nn.ReLU()
def forward(self, x: torch.Tensor, run_all_layers: bool) -> torch.Tensor:
z = self.relu(self.layer0(x))
z = self.relu(self.layer2(z))
z = z @ self.weight1
if run_all_layers:
z = self.relu(self.layer1(z))
z = z @ self.weight2
# Use `layer0` twice to check the handling of multiplicity in the
# saved data structures
z = self.relu(self.layer0(x))
return z
class TestSymbolicTracing(TestCase):
def test_symbolic_tracing_outputs(self):
"""
Tests running ``tracer.trace()`` inside ``patch_tracer()`` by checking
the saved data structures.
"""
model = Model()
tracer = torch.fx.Tracer()
orig_call_module = tracer.call_module
orig_create_proxy = tracer.create_proxy
exec_order_tracer = _ExecOrderTracer()
with exec_order_tracer.patch_tracer(tracer=tracer, root_module=model):
concrete_args = {"run_all_layers": True}
tracer.trace(model, concrete_args)
# Check that the tracer methods are unchanged after exiting the context
self.assertEqual(orig_call_module, tracer.call_module)
self.assertEqual(orig_create_proxy, tracer.create_proxy)
# Check `module_forward_order`
correct_module_forward_order = [
model,
model.layer0,
model.relu,
model.layer2,
model.layer2[0],
model.layer2[1],
model.layer2[2],
model.relu,
model.layer1,
model.relu,
model.layer0,
model.relu,
]
exec_info = exec_order_tracer.exec_info
self.assertEqual(exec_info.module_forward_order, correct_module_forward_order)
# Check `module_to_param_usage_infos`
self.assertEqual(
exec_info.module_to_param_usage_infos[model],
[
(model.layer0, list(model.layer0.named_parameters())),
(model.layer2, list(model.layer2.named_parameters())),
(model, [("weight1", model.weight1)]),
(model.layer1, list(model.layer1.named_parameters())),
(model, [("weight2", model.weight2)]),
(model.layer0, list(model.layer0.named_parameters())),
],
)
self.assertEqual(
exec_info.module_to_param_usage_infos[model.layer0],
[(model.layer0, list(model.layer0.named_parameters()))],
)
self.assertEqual(
exec_info.module_to_param_usage_infos[model.layer1],
[(model.layer1, list(model.layer1.named_parameters()))],
)
self.assertEqual(
exec_info.module_to_param_usage_infos[model.layer2],
[
(model.layer2[0], list(model.layer2[0].named_parameters())),
(model.layer2[2], list(model.layer2[2].named_parameters())),
],
)
self.assertEqual(exec_info.module_to_param_usage_infos[model.relu], [])
# Check `param_forward_order`
correct_param_order = [
model.layer0.weight,
model.layer0.bias,
model.layer2[0].weight,
model.layer2[2].weight,
model.weight1,
model.layer1.weight,
model.weight2,
]
self.assertEqual(exec_info.param_forward_order, correct_param_order)
# Check `visited_params`
self.assertEqual(
len(exec_info.visited_params), len(exec_info.param_forward_order)
)
self.assertEqual(exec_info.visited_params, set(exec_info.param_forward_order))
devices = ("cuda", "hpu")
instantiate_device_type_tests(TestSymbolicTracing, globals(), only_for=devices)
if __name__ == "__main__":
run_tests()
|