1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291
|
import copy
from typing import cast, Dict, List, OrderedDict, Tuple, TypedDict
import numpy as np
import torch
from torch.distributed._tools.mem_tracker import (
_MemRefType,
_ModMemStats,
_ModState,
MemTracker,
)
from torch.distributed._tools.runtime_estimator import RuntimeEstimator
from torch.distributed._tools.sac_estimator import SACEstimator, SACTradeOffStats
class ModOrder(TypedDict):
fw_pre_order: List[str]
bw_pre_order: List[str]
fw_post_order: List[str]
bw_post_order: List[str]
class ModRuntime(TypedDict):
fw: float
bw: float
class ModStats(TypedDict):
fqn: str
# per-module params
param_per_module: int
# per-module grads
grad_per_module: int
# total accumulated gradients up to and including this module
grad_total: int
# per module fw activation size (excluding input and output)
act_fw_per_module: int
# per module bw activation size during peak_bw
act_bw_per_module: int
# per module activation grad size during peak_bw
act_grad_per_module: int
# total activation size up to but excluding the current module
# includes input of the current module (i.e., output of previous module)
act_total: int
# Inputs to the module
input_per_module: int
# Outputs of the module
output_per_module: int
# Total fw run-time of the module
fw_runtime_per_module: float
# Total bw run-time of the module
bw_runtime_per_module: float
# Is this module a leaf module
is_leaf: bool
# Total ac run-time of the module
sac_runtime: float
# Total ac_memory for the module
sac_memory: int
# Number of piecewise-linear functions used for approximating ac tradeoff curve
n_segments: int
# Slopes of the of piecewise-linear functions
slopes: List[float]
# Intercepts of the of piecewise-linear functions
intercepts: List[float]
# X breakpoints of the of piecewise-linear functions
breakpoints: List[float]
# Original trade-off curves
tradeoff_curve: OrderedDict[float, float]
class ModuleInfo(TypedDict):
mod_order: ModOrder
mod_stats: List[ModStats]
def aggregate_stats(
model: torch.nn.Module,
mem_tracker: MemTracker,
runtime_estimator: RuntimeEstimator,
sac_estimator: SACEstimator,
dev: torch.device,
) -> ModuleInfo:
"""
Collect modulewise stats for a given model, including memory, runtime, and AC tradeoff stats.
Args:
model: nn.Module object
runtime_estimator: RuntimeEstimator object with runtime stats
mem_tracker: MemTracker object with memory stats
sac_estimator: SACEstimator object with AC tradeoff stats
dev: device the model was run on (used to extract memory stats from MemTracker)
Returns:
ModuleInfo: A dictionary with module order and module stats.
"""
# Memory stats
mod_mem_stats: Dict[torch.nn.Module, _ModMemStats] = dict(
copy.deepcopy(mem_tracker.memory_tracking)
)
# Runtime stats
mod_runtime_stats: Dict[str, ModRuntime] = {
fqn: {"fw": v["fw"], "bw": v["bw"]}
for fqn, v in runtime_estimator.mod_runtimes.items()
}
# Module order
mod_order: ModOrder = {
"fw_pre_order": list(runtime_estimator.mod_fw_pre_order),
"bw_pre_order": list(runtime_estimator.mod_bw_pre_order),
"fw_post_order": list(runtime_estimator.mod_fw_post_order),
"bw_post_order": list(runtime_estimator.mod_bw_post_order),
}
# Selective Activation Checkpointing stats
sac_estimator.pwlf_sac_tradeoff_curve()
mod_sac_tradeoff_stats: Dict[str, SACTradeOffStats] = copy.deepcopy(
sac_estimator.sac_mod_tradeoff_stats
)
module_info: ModuleInfo = {
"mod_order": mod_order,
"mod_stats": [],
}
for mod in model.modules():
if mod_mem_stat := mod_mem_stats.get(mod, None):
if tradeoff_stats := mod_sac_tradeoff_stats.get(mod_mem_stat.mod_fqn, None):
sac_runtime = tradeoff_stats.sac_runtime
sac_memory = tradeoff_stats.sac_memory
n_segments = tradeoff_stats.n_segments
slopes = tradeoff_stats.slopes
intercepts = tradeoff_stats.intercepts
breakpoints = tradeoff_stats.fit_breaks
tradeoff_curve = tradeoff_stats.tradeoff_curve
is_leaf = False
else:
sac_runtime = sac_memory = n_segments = 0
slopes = intercepts = breakpoints = []
tradeoff_curve: OrderedDict[float, float] = OrderedDict() # type: ignore[no-redef]
is_leaf = True
mod_stat: ModStats = {
"fqn": mod_mem_stat.mod_fqn,
"param_per_module": mod_mem_stat.parameter_mem,
"grad_per_module": mod_mem_stat.parameter_mem,
"grad_total": mod_mem_stat.snapshots[_ModState.PRE_BW][-1][dev][
_MemRefType.GRAD
],
"act_fw_per_module": max(
0,
mod_mem_stat.snapshots[_ModState.POST_FW][-1][dev][_MemRefType.ACT]
- mod_mem_stat.snapshots[_ModState.PRE_FW][-1][dev][_MemRefType.ACT]
- mod_mem_stat.output_mem,
),
"act_bw_per_module": max(
0,
mod_mem_stat.snapshots[_ModState.PEAK_BW][-1][dev][_MemRefType.ACT],
),
"act_grad_per_module": (
mod_mem_stat.snapshots[_ModState.PEAK_BW][-1][dev][_MemRefType.TEMP]
- mod_mem_stat.snapshots[_ModState.PRE_BW][-1][dev][
_MemRefType.TEMP
]
),
"act_total": mod_mem_stat.snapshots[_ModState.POST_FW][-1][dev][
_MemRefType.ACT
],
"input_per_module": mod_mem_stat.input_mem,
"output_per_module": mod_mem_stat.output_mem,
"fw_runtime_per_module": mod_runtime_stats[mod_mem_stat.mod_fqn]["fw"],
"bw_runtime_per_module": mod_runtime_stats[mod_mem_stat.mod_fqn]["bw"],
"is_leaf": is_leaf,
"sac_runtime": sac_runtime,
"sac_memory": sac_memory,
"n_segments": n_segments,
"slopes": slopes,
"intercepts": intercepts,
"breakpoints": breakpoints,
"tradeoff_curve": tradeoff_curve,
}
module_info["mod_stats"].append(mod_stat)
return module_info
class Node(ModStats):
index: int # index according to forward pre-order
pos_fw_post_order: int # index according to forward post-order
class Graph:
def __init__(self, n: int) -> None:
self.nodes: List[Node] = []
self.name2node: Dict[str, Node] = {}
self.ad_matrix = np.zeros((n, n))
self.fw_post_order: List[str] = []
def add_node(self, node: Node) -> None:
self.nodes.append(node)
self.name2node[node["fqn"]] = node
def parse_module_info(module_info: ModuleInfo) -> Graph:
"""
Parse module info and create a graph (tree) of modules. The graph will be
used by MILP solver to find optimal SAC and/or FSDP configurations.
"""
mod_stats = module_info["mod_stats"]
fw_pre_order = module_info["mod_order"]["fw_pre_order"]
# assertion and number of nodes
assert len(mod_stats) == len(fw_pre_order)
n_nodes = len(mod_stats)
# create graph
g = Graph(n_nodes)
g.fw_post_order = module_info["mod_order"]["fw_post_order"]
# sort the modules by pre-order and add them to the graph
module_info["mod_stats"] = sorted(
mod_stats, key=lambda x: fw_pre_order.index(x["fqn"])
)
for i, one_mod_stats in enumerate(mod_stats):
node: Node = cast(Node, one_mod_stats)
node["index"] = i
node["pos_fw_post_order"] = g.fw_post_order.index(node["fqn"])
g.add_node(node)
# set up ancestor-descendant matrix
for i in range(n_nodes):
for j in range(i, n_nodes):
if is_self_or_submodule(g.nodes[j]["fqn"], g.nodes[i]["fqn"]):
g.ad_matrix[i][j] = 1
else:
break
return g
def is_self_or_submodule(name_descendant: str, name_ancestor: str) -> bool:
"""
check if name_descendant is a submodule of name_ancestor, or if they are the same
"""
return name_descendant == name_ancestor or name_ancestor + "." in name_descendant
def is_submodule(name_descendant: str, name_ancestor: str) -> bool:
"""
if name_descendant is a submodule of name_ancestor, but not the same
"""
return name_ancestor + "." in name_descendant
def display_bytes(b: int, unit: str = "MiB") -> str:
"""
return a string that represent the number of bytes in a desired unit
"""
if unit == "KiB":
return f"{b/2**10:.2f} KiB"
if unit == "MiB":
return f"{b/2**20:.2f} MiB"
if unit == "GiB":
return f"{b/2**30:.2f} GiB"
return f"{b:.2f} bytes"
def get_peak_memory_runtime_baseline(graph: Graph) -> Tuple[int, float]:
"""
Get the baseline peak memory and runtime.
Baseline here means there is no FSDP or AC.
Memory includes the parameters, gradients, activations, and activation gradients.
Memory does not include e.g., optimizer states, embedding tables, etc.
Returns:
int: peak memory in bytes
float: compute time in ms
"""
P_1 = graph.nodes[0]["param_per_module"]
num_nodes = len(graph.nodes)
peak_mem = 0
for i in range(num_nodes):
TG_i = graph.nodes[i]["grad_total"]
AG_i = graph.nodes[i]["act_grad_per_module"]
TA_i = graph.nodes[i]["act_total"]
peak_mem = max(peak_mem, P_1 + TG_i + AG_i + TA_i)
compute_time = (
graph.nodes[0]["fw_runtime_per_module"]
+ graph.nodes[0]["bw_runtime_per_module"]
)
return (peak_mem, compute_time)
|