1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375
|
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
from typing import Any, Dict
import torch
from torch.distributed._tensor import DeviceMesh
from torch.distributed._tensor.api import distribute_tensor, DTensor
from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
RowwiseParallel,
)
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
MLPModule,
MLPStacked,
ModelArgs,
NUM_DEVICES,
skip_unless_torch_gpu,
Transformer,
with_comms,
)
c10d_functional = torch.ops.c10d_functional
class TestCommModeFeatures(DTensorTestBase):
# checks if parameter / sharding info is the same as ground truth
def check_same_set_of_keys(self, dict1, dict2):
"""
Used to ensure the comm_mode parameter/sharding dictionaries contain the same information produced by the
ground truth
"""
dict1_keys = []
dict2_keys = []
for key in dict1:
for nested_key in dict1[key]:
dict1_keys.append((key, nested_key))
for key in dict2:
for nested_key in dict2[key]:
dict2_keys.append((key, nested_key))
self.assertEqual(len(dict1_keys), len(dict2_keys))
for i in range(len(dict1_keys)):
self.assertEqual(dict1_keys[i], dict2_keys[i])
# generates the ground truth parameter and sharding info
def ground_truth(self, model):
"""
Used to generate the ground-truth parameter and sharding info for a given distributed model to
verify comm_mode correctness
"""
module_parameters_dict: Dict[str, Any] = {}
module_sharding_dict: Dict[str, Any] = {}
for name, parameters in model.named_parameters():
# splits name into module name to create FQN and parameter name
module_name = model.__class__.__name__ + "." + name.rsplit(".", 1)[0]
parameter_name = name.rsplit(".", 1)[1]
if module_name not in module_parameters_dict:
module_parameters_dict[module_name] = {}
module_parameters_dict[module_name][parameter_name] = parameters.data
if isinstance(parameters.data, DTensor):
key_name = module_name + "." + parameter_name
module_sharding_dict[key_name] = parameters.data.placements
return module_parameters_dict, module_sharding_dict
@with_comms
def test_MLP_distributed_sharding_display(self):
"""
tests parameters and sharding on a module level
"""
device_mesh = DeviceMesh(
self.device_type,
torch.arange(0, NUM_DEVICES),
)
inp_size = [8, 10]
torch.manual_seed(0)
inp = torch.rand(*inp_size, device=self.device_type)
model = MLPModule(self.device_type)
parallelize_plan = {
"net1": ColwiseParallel(),
"net2": RowwiseParallel(),
}
model = parallelize_module(model, device_mesh, parallelize_plan)
comm_mode = CommDebugMode()
with comm_mode:
output_tp = model(inp)
output_tp.sum().backward()
module_parameters_dict, module_sharding_dict = self.ground_truth(model)
# checks if parameter / sharding info is the same as ground truth
self.check_same_set_of_keys(
module_parameters_dict, comm_mode.get_parameter_info()
)
self.check_same_set_of_keys(module_sharding_dict, comm_mode.get_sharding_info())
@with_comms
def test_MLPStacked_distributed_sharding_display(self):
"""
tests model with nested modules and makes sure comm_mode correctly resets parameter and sharding information
"""
device_mesh = DeviceMesh(
self.device_type,
torch.arange(0, NUM_DEVICES),
)
inp_size = [8, 10]
torch.manual_seed(0)
inp = torch.rand(*inp_size, device=self.device_type)
model = MLPModule(self.device_type)
parallelize_plan = {
"net1": ColwiseParallel(),
"net2": RowwiseParallel(),
}
model = parallelize_module(model, device_mesh, parallelize_plan)
comm_mode = CommDebugMode()
with comm_mode:
output_tp = model(inp)
output_tp.sum().backward()
model2 = MLPStacked(self.device_type)
parallelize_plan = {
"MLPStacked.layers.0.net1": ColwiseParallel(),
"MLPStacked.layers.0.net2": RowwiseParallel(),
"MLPStacked.layers.1.net1": ColwiseParallel(),
"MLPStacked.layers.1.net2": RowwiseParallel(),
}
model2 = parallelize_module(model2, device_mesh, parallelize_plan)
with comm_mode:
# ensures that comm_mode is resetting properly
self.assertEqual(comm_mode.get_parameter_info(), {})
self.assertEqual(comm_mode.get_sharding_info(), {})
output_tp = model2(inp)
module_parameters_dict, module_sharding_dict = self.ground_truth(model2)
self.check_same_set_of_keys(
module_parameters_dict, comm_mode.get_parameter_info()
)
self.check_same_set_of_keys(module_sharding_dict, comm_mode.get_sharding_info())
self.assertEqual(len(comm_mode.get_sharding_info()), 8)
@with_comms
def test_MLP_module_tracing(self):
"""
tests module-level tracing for MLP module
"""
device_mesh = DeviceMesh(
self.device_type,
torch.arange(0, NUM_DEVICES),
)
inp_size = [8, 10]
torch.manual_seed(0)
inp = torch.rand(*inp_size, device=self.device_type)
model = MLPModule(self.device_type)
parallelize_plan = {
"net1": ColwiseParallel(),
"net2": RowwiseParallel(),
}
model = parallelize_module(model, device_mesh, parallelize_plan)
comm_mode = CommDebugMode()
with comm_mode:
output_tp = model(inp)
output_tp.sum().backward()
# checks to see if all sub-modules make it into the module_depth_dictionary
self.assertEqual(len(comm_mode.advanced_module_tracker.module_helper_dict), 5)
# checks to see if all collectives were correctly traced at the module-level
self.assertEqual(
comm_mode.comm_module_counts["Global"]["forward"][
c10d_functional.all_reduce
],
1,
)
self.assertEqual(
comm_mode.comm_module_counts["MLPModule"]["forward"][
c10d_functional.all_reduce
],
1,
)
self.assertEqual(
comm_mode.comm_module_counts["MLPModule.net2"]["forward"][
c10d_functional.all_reduce
],
1,
)
@skip_unless_torch_gpu
@with_comms
def test_transformer_module_tracing(self, is_seq_parallel=False):
"""
tests module-level tracing for more complicated transformer module and
ensures that comm_module depth and tracing dictionaries correctly reset
"""
device_mesh = DeviceMesh(
self.device_type,
torch.arange(0, NUM_DEVICES),
)
inp_size = [8, 10]
torch.manual_seed(0)
inp = torch.rand(*inp_size, device=self.device_type)
model = MLPModule(self.device_type)
parallelize_plan = {
"net1": ColwiseParallel(),
"net2": RowwiseParallel(),
}
model = parallelize_module(model, device_mesh, parallelize_plan)
comm_mode = CommDebugMode()
with comm_mode:
self.assertEqual(
len(comm_mode.advanced_module_tracker.module_helper_dict), 1
)
self.assertEqual(
comm_mode.comm_module_counts,
{"Global": {"forward": {}, "backward": {}}},
)
output_tp = model(inp)
model_args = ModelArgs(dropout_p=0.0)
model2 = Transformer(model_args).to(device=self.device_type)
model2 = Transformer.parallelize(model2, device_mesh, is_seq_parallel)
inp_size = [8, 8]
inp = torch.randint(model_args.vocab_size, inp_size, device=self.device_type)
inp = distribute_tensor(inp, device_mesh=device_mesh)
comm_mode = CommDebugMode()
with comm_mode:
output = model2(inp)
# checks to see if all collectives were correctly traced at the module-level
self.assertEqual(
comm_mode.comm_module_counts["Global"]["forward"][
c10d_functional.all_reduce
],
6,
)
self.assertEqual(
comm_mode.comm_module_counts["Global"]["forward"][
c10d_functional.all_gather_into_tensor
],
1,
)
self.assertEqual(
comm_mode.comm_module_counts["Transformer"]["forward"][
c10d_functional.all_reduce
],
6,
)
self.assertEqual(
comm_mode.comm_module_counts["Transformer"]["forward"][
c10d_functional.all_gather_into_tensor
],
1,
)
self.assertEqual(
comm_mode.comm_module_counts["Transformer.tok_embeddings"]["forward"][
c10d_functional.all_reduce
],
1,
)
self.assertEqual(
comm_mode.comm_module_counts["Transformer.pos_embeddings"]["forward"][
c10d_functional.all_reduce
],
1,
)
self.assertEqual(
comm_mode.comm_module_counts["Transformer.layers.0"]["forward"][
c10d_functional.all_reduce
],
2,
)
self.assertEqual(
comm_mode.comm_module_counts["Transformer.layers.0.attention"]["forward"][
c10d_functional.all_reduce
],
1,
)
self.assertEqual(
comm_mode.comm_module_counts["Transformer.layers.0.attention.wo"][
"forward"
][c10d_functional.all_reduce],
1,
)
self.assertEqual(
comm_mode.comm_module_counts["Transformer.layers.0.feed_forward"][
"forward"
][c10d_functional.all_reduce],
1,
)
self.assertEqual(
comm_mode.comm_module_counts["Transformer.layers.0.feed_forward.w2"][
"forward"
][c10d_functional.all_reduce],
1,
)
self.assertEqual(
comm_mode.comm_module_counts["Transformer.layers.1"]["forward"][
c10d_functional.all_reduce
],
2,
)
self.assertEqual(
comm_mode.comm_module_counts["Transformer.layers.1.attention"]["forward"][
c10d_functional.all_reduce
],
1,
)
self.assertEqual(
comm_mode.comm_module_counts["Transformer.layers.1.attention.wo"][
"forward"
][c10d_functional.all_reduce],
1,
)
self.assertEqual(
comm_mode.comm_module_counts["Transformer.layers.1.feed_forward"][
"forward"
][c10d_functional.all_reduce],
1,
)
self.assertEqual(
comm_mode.comm_module_counts["Transformer.layers.1.feed_forward.w2"][
"forward"
][c10d_functional.all_reduce],
1,
)
self.assertEqual(
comm_mode.comm_module_counts["Transformer.output"]["forward"][
c10d_functional.all_gather_into_tensor
],
1,
)
if __name__ == "__main__":
run_tests()
|