1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
|
# Owner(s): ["oncall: distributed"]
import sys
from typing import List
from unittest.mock import patch
import torch
import torch.nn as nn
from torch import distributed as dist
from torch.distributed.fsdp import BackwardPrefetch, FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._common_utils import _get_handle_fqns_from_root
from torch.distributed.fsdp._flat_param import HandleTrainingState
from torch.distributed.fsdp._runtime_utils import (
_get_handle_to_prefetch,
_get_training_state,
)
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest, get_devtype
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
device_type = torch.device(get_devtype())
NUM_ITERS = 2
DECODER_PARAM_FQNS = [
"decoder.layers.{index}.self_attn.in_proj_weight",
"decoder.layers.{index}.self_attn.in_proj_bias",
"decoder.layers.{index}.self_attn.out_proj.weight",
"decoder.layers.{index}.self_attn.out_proj.bias",
"decoder.layers.{index}.multihead_attn.in_proj_weight",
"decoder.layers.{index}.multihead_attn.in_proj_bias",
"decoder.layers.{index}.multihead_attn.out_proj.weight",
"decoder.layers.{index}.multihead_attn.out_proj.bias",
"decoder.layers.{index}.linear1.weight",
"decoder.layers.{index}.linear1.bias",
"decoder.layers.{index}.linear2.weight",
"decoder.layers.{index}.linear2.bias",
"decoder.layers.{index}.norm1.weight",
"decoder.layers.{index}.norm1.bias",
"decoder.layers.{index}.norm2.weight",
"decoder.layers.{index}.norm2.bias",
"decoder.layers.{index}.norm3.weight",
"decoder.layers.{index}.norm3.bias",
]
ENCODER_PARAM_FQNS = [
"encoder.layers.{index}.self_attn.in_proj_weight",
"encoder.layers.{index}.self_attn.in_proj_bias",
"encoder.layers.{index}.self_attn.out_proj.weight",
"encoder.layers.{index}.self_attn.out_proj.bias",
"encoder.layers.{index}.linear1.weight",
"encoder.layers.{index}.linear1.bias",
"encoder.layers.{index}.linear2.weight",
"encoder.layers.{index}.linear2.bias",
"encoder.layers.{index}.norm1.weight",
"encoder.layers.{index}.norm1.bias",
"encoder.layers.{index}.norm2.weight",
"encoder.layers.{index}.norm2.bias",
]
TOTAL_NUM_PREFETCH_FOR_PRE = 12
TOTAL_NUM_PREFETCH_FOR_POST = 11
ENCODER_BEGIN_INDEX_FOR_PRE = 6
ENCODER_BEGIN_INDEX_FOR_POST = 5
ENCODER_PREFETCH_NUM = 5
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
file=sys.stderr,
)
sys.exit(0)
class TestBackwardPrefetch(FSDPTest):
@property
def world_size(self):
return 2
def _dist_train(self, backward_prefetch=BackwardPrefetch.BACKWARD_PRE):
rank = self.rank
orig_get_handle_to_prefetch = _get_handle_to_prefetch
torch.manual_seed(0)
policy = ModuleWrapPolicy(
{nn.TransformerEncoderLayer, nn.TransformerDecoderLayer}
)
model = FSDP(
nn.Transformer(d_model=1024, nhead=8, device=device_type),
device_id=device_type.type,
auto_wrap_policy=policy,
use_orig_params=True,
backward_prefetch=backward_prefetch,
)
optim = torch.optim.SGD(model.parameters(), lr=1e-2)
# prepare input
torch.manual_seed(rank + 1)
src = torch.randn((10, 1, 1024), device=device_type)
tgt = torch.randn((20, 1, 1024), device=device_type)
# monkey patch
all_handle_fqns: List[List[str]] = []
def patched_get_handle_to_prefetch(*args, **kwargs):
handle = orig_get_handle_to_prefetch(*args, **kwargs)
self.assertEqual(
len(args), 2, "expect _get_handle_to_prefetch(state, current_handle)"
)
state = args[0]
current_handle = args[1]
training_state = _get_training_state(current_handle)
if (
training_state == HandleTrainingState.BACKWARD_PRE
and state.backward_prefetch == BackwardPrefetch.BACKWARD_PRE
) or (
training_state == HandleTrainingState.BACKWARD_POST
and state.backward_prefetch == BackwardPrefetch.BACKWARD_POST
):
nonlocal all_handle_fqns
# FQNs prefixed from the root module
# state._exec_order_data.param_to_fqn
fqns = _get_handle_fqns_from_root(state, handle)
all_handle_fqns.append(fqns)
return handle
# flat params from prefetch handle should match
# DECODER_PARAM_FQNS and ENCODER_PARAM_FQNS
with patch(
"torch.distributed.fsdp._runtime_utils._get_handle_to_prefetch",
patched_get_handle_to_prefetch,
):
for _ in range(NUM_ITERS):
optim.zero_grad()
loss = model(src, tgt).sum()
loss.backward()
optim.step()
if backward_prefetch is None:
self.assertEqual(len(all_handle_fqns), 0)
continue
elif backward_prefetch == BackwardPrefetch.BACKWARD_PRE:
# state._exec_order_data.handles_post_forward_order
# equals forward order
# encoder 0...5 -> decoder 0...5 -> root
# pre-backward hook order
# root -> decoder 5...0 -> encoder 5...0
# prefetch order
# decoder 5...0 -> encoder 5...0 -> None
# None: when current_handle=encoder 0,
# _get_handle_to_prefetch returns None
# +1 is for the above None
encoder_begin_index = ENCODER_BEGIN_INDEX_FOR_PRE
self.assertEqual(
len(all_handle_fqns), TOTAL_NUM_PREFETCH_FOR_PRE + 1
)
elif backward_prefetch == BackwardPrefetch.BACKWARD_POST:
# state._exec_order_data.handles_post_forward_order
# equals forward order (same as BACKWARD_PRE)
# encoder 0...5 -> decoder 0...5 -> root
# post-backward hook (AccumulateGrad) order
# decoder 5, 4...0 -> encoder 5...0 -> root
# prefetch order
# decoder 4...0 -> encoder 5...0 -> None -> None
# 1st None: when current_handle=encoder 0,
# _get_handle_to_prefetch returns None
# 2nd None: when current_handle=root,
# get decoder 5 inside _get_handle_to_prefetch
# but not needed since decoder 5 is computed already
# +2 is for the above Nones
encoder_begin_index = ENCODER_BEGIN_INDEX_FOR_POST
self.assertEqual(
len(all_handle_fqns), TOTAL_NUM_PREFETCH_FOR_POST + 2
)
# ith_prefetch: 0, 1st, 2nd, 3rd, 4th ... ith prefetch
for ith_prefetch, fqns in enumerate(all_handle_fqns):
if ith_prefetch >= 0 and ith_prefetch < encoder_begin_index:
layer_index = encoder_begin_index - 1 - ith_prefetch
self.assertEqual(
fqns,
[x.format(index=layer_index) for x in DECODER_PARAM_FQNS],
)
elif (
ith_prefetch >= encoder_begin_index
and ith_prefetch <= encoder_begin_index + ENCODER_PREFETCH_NUM
):
layer_index = (
encoder_begin_index + ENCODER_PREFETCH_NUM - ith_prefetch
)
self.assertEqual(
fqns,
[x.format(index=layer_index) for x in ENCODER_PARAM_FQNS],
)
else:
self.assertTrue(fqns is None)
all_handle_fqns = []
@skip_if_lt_x_gpu(2)
def test_backward_prefetch(self):
# subtest reuse process group to shorten test time
self.run_subtests(
{
"backward_prefetch": [
None,
BackwardPrefetch.BACKWARD_PRE,
BackwardPrefetch.BACKWARD_POST,
],
},
self._test_backward_prefetch,
)
def _test_backward_prefetch(self, backward_prefetch: BackwardPrefetch):
self._dist_train(backward_prefetch)
if __name__ == "__main__":
run_tests()
|