1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236
|
# Owner(s): ["oncall: distributed"]
import sys
import torch
import torch.nn as nn
import torch.optim as optim
from torch import distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
FSDPTest,
)
from torch.testing._internal.common_utils import (
TEST_WITH_DEV_DBG_ASAN,
instantiate_parametrized_tests,
parametrize,
run_tests,
)
from torch.utils.checkpoint import checkpoint
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
file=sys.stderr,
)
sys.exit(0)
def get_cur_mem(rank, result, prefix):
"""Collect memory allocated values in a result dict in MB"""
result[prefix] = round(torch.cuda.memory_allocated() / 1024 / 1024)
class Model(nn.Module):
def __init__(self, hidden_dim, with_fsdp=False, with_checkpoint=False):
super().__init__()
if with_fsdp:
self.stem = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3),
FSDP(nn.BatchNorm2d(64)),
nn.ReLU(inplace=True),
)
else:
self.stem = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
)
if with_fsdp:
self.blocks = nn.Sequential(
nn.Conv2d(64, hidden_dim, kernel_size=5, padding=2),
FSDP(nn.BatchNorm2d(hidden_dim)),
nn.ReLU(inplace=True),
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
FSDP(nn.BatchNorm2d(hidden_dim)),
nn.ReLU(inplace=True),
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
FSDP(nn.BatchNorm2d(hidden_dim)),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
nn.Flatten(),
)
else:
self.blocks = nn.Sequential(
nn.Conv2d(64, hidden_dim, kernel_size=5, padding=2),
nn.BatchNorm2d(hidden_dim),
nn.ReLU(inplace=True),
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
nn.BatchNorm2d(hidden_dim),
nn.ReLU(inplace=True),
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
nn.BatchNorm2d(hidden_dim),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
nn.Flatten(),
)
self.head = nn.Linear(hidden_dim, 10)
self.with_checkpoint = with_checkpoint
def forward(self, x):
if self.with_checkpoint:
return self.head(checkpoint(self.blocks, self.stem(x)))
else:
return self.head(self.blocks(self.stem(x)))
def create_model(with_fsdp, with_checkpoint, model_hidden_dim):
torch.manual_seed(0)
model = Model(model_hidden_dim, with_fsdp, with_checkpoint)
if with_fsdp:
model.stem = FSDP(model.stem)
model.blocks = FSDP(model.blocks)
model.head = FSDP(model.head)
return model
class TestFSDPMemory(FSDPTest):
@property
def world_size(self):
return 2
def _dist_train(self, with_checkpoint, expected, model_hidden_dim, iterations):
gpu_id = self.rank
world_size = self.world_size
batch = torch.randn(size=(2, 3, 224, 224)).cuda()
model = create_model(
with_fsdp=True,
with_checkpoint=with_checkpoint,
model_hidden_dim=model_hidden_dim,
)
model = model.cuda()
model = FSDP(model)
# We enable momentum so that after the first iteration, the optimizer state is added
# to the total memory used.
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
results = {} # results of memory stats
for iteration in range(iterations):
get_cur_mem(gpu_id, results, f"iter {iteration}: start")
out = model(batch)
get_cur_mem(gpu_id, results, f"iter {iteration}: after fwd")
out = sum(o.sum() for o in out[0])
fake_loss = criterion(out, torch.tensor(0.0).cuda())
get_cur_mem(gpu_id, results, f"iter {iteration}: after loss")
fake_loss.backward()
get_cur_mem(gpu_id, results, f"iter {iteration}: after bwd")
optimizer.step()
get_cur_mem(gpu_id, results, f"iter {iteration}: after step")
# It is important to use `set_to_none` below, not optimizer.zero_grad() to reclaim memory.
model.zero_grad(set_to_none=True)
get_cur_mem(gpu_id, results, f"iter {iteration}: done")
def cmp(results, expected):
ret = ""
self.assertEqual(results.keys(), expected.keys())
for k, v in results.items():
exp = expected[k]
if abs(exp - v) > 1: # allow 1MB rounding differences
ret += f"{k}: got {v}, expected {exp}\n"
return ret
output = cmp(results, expected)
self.assertEqual(output, "")
@skip_if_lt_x_gpu(2)
@parametrize("ckpt", ["no_ckpt", "ckpt"])
def test_fsdp_memory(self, ckpt):
# hidden_dim 128: model size ~4MB
model_hidden_dim = 128
model = create_model(
with_fsdp=False, with_checkpoint=False, model_hidden_dim=model_hidden_dim
).cuda()
model_size_mb = round(torch.cuda.memory_allocated() / 1024 / 1024)
del model
sharded_model_size_mb = int(model_size_mb / self.world_size)
# We have observed that sometimes after 3rd iteration, 4th one can fail (not on this
# test but on much bigger scale tests). We run 4 iterations here just in case it happens.
iterations = 4
expected = {}
for iteration in range(iterations):
if iteration == 0:
# sharded model size + 1MB temp memory
expected[f"iter {iteration}: start"] = sharded_model_size_mb + 1
# it is hard to calculate this memory size, get it from printed memory usage
if ckpt == "ckpt":
expected[f"iter {iteration}: after fwd"] = 51
expected[f"iter {iteration}: after loss"] = 51
else:
expected[f"iter {iteration}: after fwd"] = 340
expected[f"iter {iteration}: after loss"] = 340
# sharded model size + sharded grad size + 1M temp memory
expected[f"iter {iteration}: after bwd"] = 2 * sharded_model_size_mb + 1
else:
# after optimizer step in the first iteraiton, memory usage increased by
# sharded_model_size_mb becasue of increased optimizer states memory usage
expected[f"iter {iteration}: start"] = 2 * sharded_model_size_mb + 1
if ckpt == "ckpt":
expected[f"iter {iteration}: after fwd"] = (
51 + sharded_model_size_mb
)
expected[f"iter {iteration}: after loss"] = (
51 + sharded_model_size_mb
)
else:
expected[f"iter {iteration}: after fwd"] = (
340 + sharded_model_size_mb
)
expected[f"iter {iteration}: after loss"] = (
340 + sharded_model_size_mb
)
expected[f"iter {iteration}: after bwd"] = 3 * sharded_model_size_mb + 1
# sharded model size + sharded grad size + optimizer states + 1M temp memory
expected[f"iter {iteration}: after step"] = 3 * sharded_model_size_mb + 1
# grad memory is claimed after setting grad = None
# sharded model size + optimizer states + 1M temp memory
expected[f"iter {iteration}: done"] = 2 * sharded_model_size_mb + 1
# Get the fsdp and checkpoint flags.
with_ckpt = ckpt == "ckpt"
self._dist_train(
with_ckpt,
expected,
model_hidden_dim,
iterations,
)
instantiate_parametrized_tests(TestFSDPMemory)
if __name__ == "__main__":
run_tests()
|