1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364
|
# Owner(s): ["oncall: distributed"]
import copy
import os
from typing import TYPE_CHECKING
import torch
import torch.distributed.checkpoint as dcp
import torch.nn as nn
from torch.distributed._tensor import DTensor
from torch.distributed.checkpoint import FileSystemReader
from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed.pipelining import PipelineStage
from torch.distributed.pipelining.schedules import (
PipelineScheduleSingle,
Schedule1F1B,
ScheduleGPipe,
ScheduleInterleaved1F1B,
ScheduleInterleavedZeroBubble,
ScheduleLoopedBFS,
)
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
requires_nccl,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
skip_but_pass_in_sandcastle_if,
)
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
if TYPE_CHECKING:
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
# MLP Layer
class MLPModule(torch.nn.Module):
def __init__(self, d_hid: int):
super().__init__()
self.net1 = torch.nn.Linear(d_hid, d_hid)
self.relu = torch.nn.ReLU()
self.net2 = torch.nn.Linear(d_hid, d_hid)
def forward(self, x):
x = self.net1(x)
x = self.relu(x)
x = self.net2(x)
return x
class ComposabilityTest(MultiProcessTestCase):
@classmethod
def backend_str(cls) -> str:
# Testing with NCCL backend
return "nccl"
def setUp(self):
super().setUp()
self._spawn_processes()
def tearDown(self):
super().tearDown()
try:
os.remove(self.file_name)
except OSError:
pass
@property
def world_size(self):
return 4
@property
def device(self):
return self.rank
@requires_nccl()
@skip_if_lt_x_gpu(4)
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "Test requires 4+ GPUs")
@parametrize("dp_type", ["DDP", "FSDP"])
@parametrize(
"ScheduleClass",
[
ScheduleGPipe,
Schedule1F1B,
ScheduleInterleaved1F1B,
ScheduleLoopedBFS,
ScheduleInterleavedZeroBubble,
],
)
@parametrize("use_new_runtime", [False, True])
def test_manual_with_data_parallel(self, dp_type, ScheduleClass, use_new_runtime):
device = torch.device("cuda", self.device)
torch.cuda.set_device(self.device)
store = torch.distributed.FileStore(self.file_name, self.world_size)
torch.distributed.init_process_group(
backend="nccl",
store=store,
rank=self.rank,
world_size=self.world_size,
# TODO (kwen2501): disabled eager init below as this test is failing
# with bug fix #139013. Temporarily use lazy init to cover the
# composability aspect of this test.
# device_id=device,
)
device_mesh = init_device_mesh(
"cuda", mesh_shape=(2, 2), mesh_dim_names=("dp", "pp")
)
pp_group = device_mesh["pp"].get_group()
dp_mesh = device_mesh["dp"]
# create "entire model"
total_layers = 8
dim = 10
full_model = nn.ModuleList([MLPModule(dim) for _ in range(total_layers)])
ref_model = nn.Sequential(*copy.deepcopy(full_model))
ref_model.to(self.device)
# Prepare inputs
num_microbatches = 8
inputs = [
torch.rand((num_microbatches, dim), device=self.device)
for _ in range(dp_mesh.size())
]
input = inputs[dp_mesh.get_local_rank()]
input_mb = [[input[i].reshape((1, dim))] for i in range(num_microbatches)]
# dummy loss needed just to force backwards to run in schedule step
def loss_fn(y, target):
return y.sum()
# Get stage module i from the entire model
def get_stage_module(stage_idx, num_stages):
# divide the model (8 layers) by the number of stages
layers_per_stage = total_layers // num_stages
assert layers_per_stage * num_stages == total_layers
# return offset so validation code can match partial layer back to orig model
offset = stage_idx * layers_per_stage
partial_model = nn.Sequential(
*full_model[offset : (stage_idx + 1) * layers_per_stage]
)
partial_model.to(self.device)
return partial_model, offset
# Apply DP to stage module
def apply_dp(partial_model, dp_type):
if dp_type == "FSDP":
# apply FSDP
mp_policy = MixedPrecisionPolicy(
# TODO(whc) need to fix PP + FSDP-mixed-precision
# tracer for PP assumes f32 and is caught off guard when runtime FSDP interacts using bf16 inputs
# param_dtype=torch.bfloat16, reduce_dtype=torch.float32
param_dtype=torch.float32,
reduce_dtype=torch.float32,
)
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
for layer in partial_model.children():
fully_shard(
layer,
**fsdp_config,
reshard_after_forward=False,
)
dp_model = fully_shard(partial_model, **fsdp_config)
elif dp_type == "DDP":
dp_model = DDP(partial_model, process_group=dp_mesh.get_group())
else:
raise RuntimeError(f"unsupported dp type {dp_type}")
return dp_model
# Create pipeline stage
def build_stage(stage_idx, num_stages):
partial_model, offset = get_stage_module(stage_idx, num_stages)
dp_model = apply_dp(partial_model, dp_type)
stage = PipelineStage(
dp_model,
stage_idx,
num_stages,
self.device,
group=pp_group,
)
return stage, offset
# Attach to a schedule
if issubclass(ScheduleClass, PipelineScheduleSingle):
if use_new_runtime:
# Can't test PipelineScheduleSingle classes using new runtime
# return should still clean up this test instance correctly
torch.distributed.destroy_process_group()
return
pipeline_stage, offset = build_stage(pp_group.rank(), pp_group.size())
partial_models = [pipeline_stage.submod]
offsets = [offset]
pipeline_schedule = ScheduleClass(
pipeline_stage,
n_microbatches=num_microbatches,
loss_fn=loss_fn,
)
else:
n_virtual = 2
num_stages = pp_group.size() * n_virtual
stages = []
offsets = []
for i in range(n_virtual):
stage, offset = build_stage(pp_group.rank() + n_virtual * i, num_stages)
stages.append(stage)
offsets.append(offset)
partial_models = [pipeline_stage.submod for pipeline_stage in stages]
pipeline_schedule = ScheduleClass(
stages,
n_microbatches=num_microbatches,
loss_fn=loss_fn,
)
# Run
# TODO(whc) should we make it a hard error if you pass arguments into the step API on nonzero ranks?
# why are we passing inputs/targets on every rank?
if pp_group.rank() == 0:
pipeline_schedule._step_microbatches(arg_mbs=input_mb, target_mbs=input_mb)
else:
pipeline_schedule._step_microbatches(
arg_mbs=[[] for _ in input_mb], target_mbs=input_mb
)
# Ref model runs on 2 different inputs, accumulating grads across them.
# this ensures that we detect if the FSDP reduce becomes a no-op.
# (in fsdp case, we use one of these inputs on each DP rank)
(ref_model(inputs[0]).sum()).backward()
(ref_model(inputs[1]).sum()).backward()
# simulate the built-in averaging done by FSDP
for p in ref_model.parameters():
p.grad /= dp_mesh.size()
# Validate that whichever weights we have locally match that part of our local/full ref model
# (we force FSDP's grads to be all-gathered (.full_tensor) to make it simpler)
ref_parameters = dict(ref_model.named_parameters())
if dp_type == "FSDP":
for partial_model, offset in zip(partial_models, offsets):
for name, p in partial_model.named_parameters():
parts = name.split(".")
parts[0] = str(int(parts[0]) + offset)
name = ".".join(parts)
ref_p = ref_parameters[name]
self.assertTrue(isinstance(p.grad, DTensor))
torch.testing.assert_close(
ref_p.grad, p.grad.full_tensor(), rtol=1e-5, atol=5e-5
)
elif dp_type == "DDP":
for partial_model, offset in zip(partial_models, offsets):
for name, p in partial_model.named_parameters():
parts = name.split(".")[1:] # remove the "module." prefix
parts[0] = str(int(parts[0]) + offset)
name = ".".join(parts)
ref_p = ref_parameters[name]
torch.testing.assert_close(ref_p.grad, p.grad, rtol=1e-5, atol=5e-5)
torch.distributed.destroy_process_group()
@requires_nccl()
@skip_if_lt_x_gpu(4)
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "Test requires 4+ GPUs")
def test_pp_and_dcp(self):
"""
Test that pipeline parallelism and distributed checkpointing can be used together and
with saved correct FQNs
"""
class AppState(Stateful):
def __init__(self, model, optimizer):
self.model = model
self.optimizer = optimizer
def state_dict(self):
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
model_state_dict, optimizer_state_dict = get_state_dict(
self.model, self.optimizer
)
return {"model": model_state_dict, "optim": optimizer_state_dict}
def load_state_dict(self, state_dict):
# sets our state dicts on the model and optimizer, now that we've loaded
set_state_dict(
self.model,
self.optimizer,
model_state_dict=state_dict["model"],
optim_state_dict=state_dict["optim"],
)
class PPModelChunk(nn.Module):
def __init__(self, layers: nn.ModuleDict, start_index: int, end_index: int):
super().__init__()
# Filter layers based on start_index and end_index
self.layers = nn.ModuleDict(
{str(i): layers[str(i)] for i in range(start_index, end_index)}
)
def forward(self, x):
for layer in self.layers.values():
x = layer(x)
return x
device = torch.device("cuda", self.device)
torch.cuda.set_device(self.device)
store = torch.distributed.FileStore(self.file_name, self.world_size)
torch.distributed.init_process_group(
backend="nccl",
store=store,
rank=self.rank,
world_size=self.world_size,
device_id=device,
)
# create "entire model"
total_layers = 8
dim = 10
full_model = nn.ModuleDict(
{f"{i}": MLPModule(dim) for i in range(total_layers)}
)
# Calculate start and end indices based on rank
start_index = self.rank * 2
end_index = start_index + 2
pp_model = PPModelChunk(full_model, start_index, end_index)
pp_model.to(self.device)
opt = torch.optim.Adam(pp_model.parameters(), lr=0.1)
# perform work in a temp dir that is cleaned up after the test
@with_temp_dir
def _dcp_test(self):
state_dict = {"app": AppState(pp_model, opt)}
dcp.save(state_dict, checkpoint_id=self.temp_dir)
# temp checkpoint
sd: STATE_DICT_TYPE = {}
_load_state_dict(
sd,
storage_reader=FileSystemReader(self.temp_dir),
planner=_EmptyStateDictLoadPlanner(),
)
# Check parameter names in sd and compare with pp_model
pp_model_param_names = set(pp_model.state_dict().keys())
sd_param_names = set(sd["app"]["model"].keys())
# Verify each parameter name in pp_model is contained in sd
for param_name in pp_model_param_names:
self.assertIn(
param_name,
sd_param_names,
f"Parameter name '{param_name}' not found in state_dict.",
)
_dcp_test(self)
instantiate_parametrized_tests(ComposabilityTest)
if __name__ == "__main__":
run_tests()
|