1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
|
# mypy: allow-untyped-defs
import logging
import time
from collections import defaultdict
from contextlib import contextmanager
from enum import Enum
from typing import Dict, Iterator, List, Set, Tuple
import torch
import torch.distributed as dist
import torch.distributed.fsdp._flat_param as flat_param_file
from torch.distributed.fsdp._common_utils import (
_apply_to_modules,
_get_module_fsdp_state,
clean_tensor_name,
)
logger = logging.getLogger(__name__)
class SimpleProfiler:
class Type(str, Enum):
ALL = "all"
ALLGATHER = "all_gather"
ALLGATHER_OBJ = "all_gather_object"
RESHARDING = "resharding"
H2D = "H2D"
D2H = "D2H"
results: Dict[str, float] = defaultdict(float)
profiling: Set[str] = set()
@classmethod
def reset(cls) -> None:
cls.results.clear()
cls.profiling.clear()
@classmethod
@contextmanager
def profile(cls, profile_type: str) -> Iterator[None]:
assert profile_type not in cls.profiling, (
f"{profile_type} is already being profiled. "
"SimpleProfiler does not support profiling multiple instances at "
"the same time. "
)
cls.profiling.add(profile_type)
begin = time.monotonic()
try:
yield
finally:
end = time.monotonic()
cls.results[profile_type] += end - begin
cls.profiling.remove(profile_type)
@classmethod
def dump_and_reset(cls, msg: str) -> None:
# This cannot be combined with DETAIL distributed log
# as the profiling will be very incorrect.
if dist.get_rank() == 0 and dist.get_debug_level() == dist.DebugLevel.INFO:
logger.info("%s %s", msg, cls.results)
cls.reset()
def _get_sharded_module_tree_with_module_name_to_fqns(
model: torch.nn.Module,
) -> Tuple[str, Dict[str, List[str]]]:
"""
It is used for composable fully_shard() code path, it returns
1. sharded module tree info: each line reprents a submodule name that contats the
submodule's FQN and its submodule class name, if the submodule is sharded by `fully_shard`,
the submodule name will add a postfix with ' FULLY SHARDED'. Each increased tree
level adds 4 spaces before the printed name. A printed sharded module tree info for a toy model
is like this:
[CompositeModel] FULLY SHARDED
l1[Linear]
u1[UnitModule] FULLY SHARDED
u1.l1[Linear]
u1.seq[Sequential]
u1.seq.0[ReLU]
u1.seq.1[Linear]
u1.seq.2[ReLU]
u1.l2[Linear]
u2[UnitModule] FULLY SHARDED
u2.l1[Linear]
u2.seq[Sequential]
u2.seq.0[ReLU]
u2.seq.1[Linear]
u2.seq.2[ReLU]
u2.l2[Linear]
l2[Linear]
2. a dict mapping from the concated module FQN and class name to a list of its managed
original parameters' FQNs. An example of the dict for the above toy sharded model is like this:
{'[CompositeModel]': ['l1.weight', 'l1.bias', 'l2.weight', 'l2.bias'],
'u1[UnitModule]': ['u1.l1.weight', 'u1.l1.bias', 'u1.seq.1.weight', 'u1.seq.1.bias', 'u1.l2.weight', 'u1.l2.bias'],
'u2[UnitModule]': ['u2.l1.weight', 'u2.l1.bias', 'u2.seq.1.weight', 'u2.seq.1.bias', 'u2.l2.weight', 'u2.l2.bias']
}
All FQNs are prefixed starting from ``model``.
Args:
model (torch.nn.Module): Root module (which may or may not be passed to
composable `fully_shard()`).
"""
def module_fn(
module, prefix, tree_level, sharded_tree_info, sharded_module_name_to_fqns
):
num_spaces = tree_level * 4
trimed_prefix = (
prefix[:-1] if (len(prefix) > 0 and prefix[-1] == ".") else prefix
)
prefixed_module_name = trimed_prefix + "[" + module.__class__.__name__ + "]"
printed_prefixed_module_name = " " * num_spaces + prefixed_module_name
state = _get_module_fsdp_state(module)
if state is None:
sharded_tree_info[0] += printed_prefixed_module_name + "\n"
return
handle = state._fully_sharded_module_to_handle.get(module, None)
if handle:
sharded_tree_info[0] += (
printed_prefixed_module_name + " FULLY SHARDED" + "\n"
)
else:
sharded_tree_info[0] += printed_prefixed_module_name + "\n"
if handle:
param = handle.flat_param
assert isinstance(param, flat_param_file.FlatParameter)
global_fqns = [
clean_tensor_name(prefix + name) for name in param._fqns
] # prefixed from the top level `model` (i.e. including `prefix`)
if prefixed_module_name in sharded_module_name_to_fqns:
sharded_module_name_to_fqns[prefixed_module_name].extend(global_fqns)
else:
sharded_module_name_to_fqns[prefixed_module_name] = global_fqns
def return_fn(sharded_tree_info, sharded_module_name_to_fqns):
return sharded_tree_info[0], sharded_module_name_to_fqns
# Use List to mutate its value in place while running the recursive functions
sharded_tree_info: List[str] = [
"",
]
sharded_module_name_to_fqns: Dict[str, List[str]] = {}
return _apply_to_modules(
model,
module_fn,
return_fn,
[key for key, _ in model.named_parameters()],
sharded_tree_info,
sharded_module_name_to_fqns,
)
|