1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
|
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
from model_registry import MLPModule, ModelWithParamAlias
import torch
from torch.distributed.pipelining import pipe_split, pipeline
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TestCase,
)
d_hid = 512
microbatch_size = 16
torch.manual_seed(0)
# Basic example
class ExampleCode(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
self.lin1 = torch.nn.Linear(d_hid, d_hid)
self.lin2 = torch.nn.Linear(d_hid, d_hid)
def forward(self, x, y):
x = torch.mm(x, self.mm_param1) # mutli-use param
skip_connection = x
x = x + y
x = torch.relu(x)
pipe_split()
x = torch.mm(x, self.mm_param1) # mutli-use param
x = self.lin1(x)
pipe_split()
x = torch.relu(x)
x = x + skip_connection
x = torch.mm(x, self.mm_param2)
pipe_split()
x = self.lin2(x)
x = torch.relu(x)
return x
class MultiMLP(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.mlp0 = MLPModule(d_hid)
self.mlp1 = MLPModule(d_hid)
self.mlp2 = MLPModule(d_hid)
self.mlp3 = MLPModule(d_hid)
def forward(self, x, y):
x = self.mlp0(x)
pipe_split()
x = self.mlp1(x)
pipe_split()
x = self.mlp2(x)
pipe_split()
x = self.mlp3(x)
return x - y
EXPECTED_N_STAGES = {
ExampleCode: 4,
MultiMLP: 4,
ModelWithParamAlias: 2,
}
# Currently, we don't enforce full set equality on the FQNs between the original
# and pipelined models, because in the multi-use param case, PP will deduplicate
# the FQNs from the state_dict.
# TODO
CHECK_FQN_SET_EQUALITY = False
class PipeTests(TestCase):
@parametrize("ModelClass", [ExampleCode, MultiMLP, ModelWithParamAlias])
def test_model_split(self, ModelClass):
mod = ModelClass()
x = torch.randn(microbatch_size, d_hid)
y = torch.randn(microbatch_size, d_hid)
pipe = pipeline(
mod,
mb_args=(x, y),
)
assert (
pipe.num_stages == EXPECTED_N_STAGES[ModelClass]
), f"nstages = {pipe.num_stages}, expect {EXPECTED_N_STAGES[ModelClass]}"
ref_out = mod(x, y)
out = pipe(x, y)[0]
torch.testing.assert_close(out, ref_out)
print(f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref_out)}")
# Check qualname
# state_dict.keys include both parameters and persistent buffers
old_names = set(mod.state_dict().keys())
new_names = set()
for idx in range(pipe.num_stages):
stage_mod = pipe.get_stage_module(idx)
stage_fqns = set(stage_mod.state_dict().keys())
assert stage_fqns.issubset(old_names)
new_names.update(stage_fqns)
if CHECK_FQN_SET_EQUALITY:
assert (
old_names == new_names
), f"""
old names {old_names}
new names {new_names}
"""
print("Qualname check passed")
instantiate_parametrized_tests(PipeTests)
if __name__ == "__main__":
run_tests()
|