1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253
|
# Owner(s): ["oncall: distributed"]
import sys
import time
from statistics import mean
from unittest.mock import patch
import torch
import torch.nn as nn
from torch import distributed as dist
from torch.cuda import Event
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
FSDPTest,
)
from torch.testing._internal.common_utils import (
TEST_WITH_DEV_DBG_ASAN,
get_cycles_per_ms,
run_tests,
)
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
file=sys.stderr,
)
sys.exit(0)
class Layer(nn.Module):
def __init__(self, compute_cycles, has_params: bool):
super().__init__()
self.sleep_cycles = compute_cycles
self.optional_param = None
if has_params:
self.optional_param = nn.Parameter(torch.rand(1))
def forward(self, x):
# Get 2 events.
self.e1 = Event(enable_timing=True)
self.e2 = Event(enable_timing=True)
# Record the fake forward compute time.
self.e1.record()
if self.sleep_cycles > 0:
torch.cuda._sleep(self.sleep_cycles)
if self.optional_param is not None:
x = x + self.optional_param # force the param to be part of the graph
self.e2.record()
return x
def get_time(self):
# return the recorded duration.
return self.e1.elapsed_time(self.e2)
def _create_model(compute_cycles, has_params: bool):
model = FSDP(
nn.Sequential(
FSDP(Layer(compute_cycles, has_params)),
FSDP(Layer(compute_cycles, has_params)),
FSDP(Layer(compute_cycles, has_params)),
FSDP(Layer(compute_cycles, has_params)),
)
).cuda()
return model
class Min10:
def __init__(self):
self.data = []
def add(self, new_data):
if len(self.data) < 10:
self.data.append(new_data)
else:
self.data = sorted(self.data)
if new_data < self.data[-1]:
self.data[-1] = new_data
def avg(self):
return mean(self.data)
class TestForwardOverlapWorldSizeOne(FSDPTest):
@property
def world_size(self):
return 1
def _dist_train(self):
rank = self.rank
world_size = self.world_size
# Save the original torch.distributed._all_gather_base function since we will
# patch it to include an artificial delay.
orig_all_gather = torch.distributed._all_gather_base
def run(compute_cycles, all_gather_cycles):
has_params = all_gather_cycles > 0
model = _create_model(compute_cycles, has_params)
# Get the input and sets the input's requires_grad to True because
# we have a fake compute in the forward pass.
batch = torch.rand(1).cuda()
batch.requires_grad = True
# Run one dummy iteration to trigger the execution order validation
# all-gathers
out = model(batch)
out.backward()
model.zero_grad(set_to_none=True)
# We run 20 iterations but only collect timing data from the minimal 10
# data points because nondeterministic system events can disturb the timing.
cpu_iter = Min10()
cpu_wait = Min10()
gpu_compute = Min10()
gpu_total = Min10()
for _ in range(20):
# Get two events for measuring the overall time.
e1 = Event(enable_timing=True)
e2 = Event(enable_timing=True)
cpu_start = time.process_time()
all_gather_called = False
def _delayed_all_gather(*args, **kwargs):
nonlocal all_gather_called
all_gather_called = True
torch.cuda._sleep(all_gather_cycles)
assert orig_all_gather
return orig_all_gather(*args, **kwargs)
# forward pass
#
# Even though both e1 & e2 are on the compute stream, since
# compute depends on all_gather, e2-e1 includes all_gather time.
e1.record()
with patch("torch.distributed._all_gather_base", _delayed_all_gather):
out = model(batch)
if has_params and world_size > 1:
self.assertTrue(all_gather_called)
else:
self.assertFalse(all_gather_called)
e2.record()
# backward pass
out.backward()
model.zero_grad(set_to_none=True)
cpu_iter_time = time.process_time() - cpu_start
# wait for gpu
out.item()
cpu_wait_for_gpu_time = time.process_time() - cpu_start - cpu_iter_time
# get sum of the compute time
times = []
for mod in model.modules():
if not isinstance(mod, Layer):
continue
times.append(mod.get_time())
# get gpu compute + all_gather time
overall_gpu_time = e1.elapsed_time(e2)
cpu_iter.add(cpu_iter_time)
cpu_wait.add(cpu_wait_for_gpu_time)
gpu_compute.add(sum(times))
gpu_total.add(overall_gpu_time)
del model
return {
"cpu_iter": cpu_iter.avg(),
"cpu_wait": cpu_wait.avg(),
"gpu_compute": gpu_compute.avg(),
"gpu_total": gpu_total.avg(),
}
sleep_cycles = int(100 * get_cycles_per_ms())
e1 = run(0, 0) # no compute, no all-gather
e2 = run(0, sleep_cycles) # no compute, only all-gather
e3 = run(sleep_cycles, 0) # only compute, no all-gather
e4 = run(sleep_cycles, sleep_cycles) # both compute and all-gather
debug_string = f"\nrank{rank}:\n e1: {e1}\n e2: {e2}\n e3: {e3}\n e4: {e4}"
print(debug_string)
# Check the cpu/gpu timing. CPU should run ahead of GPU. Therefore, cpu-gpu
# wait should be long, except when there is no real work on GPU.
#
# If the assertions fail below, we likely have a cpu-gpu wait in the forward/backward pass.
# e4["cpu_iter"] may not be short as cpu may take some time to queue both compute and all-gather.
short = [
e1["cpu_iter"],
e2["cpu_iter"],
e3["cpu_iter"],
e1["cpu_wait"],
]
long = [e3["cpu_wait"], e4["cpu_wait"]]
if world_size == 1:
short.append(e2["cpu_wait"]) # all gather should not be happening.
else:
long.append(
e2["cpu_wait"]
) # all gather should happen and prolong the cpu-gpu wait.
for s in short:
for l in long:
# 10X longer is a safe margin, since the GPU work timing is around 100X more
# of that of the CPU.
self.assertTrue(s * 10 < l)
# Check the GPU timing.
short = [e1["gpu_compute"], e1["gpu_total"], e2["gpu_compute"]]
long = [e3["gpu_compute"], e3["gpu_total"], e4["gpu_compute"], e4["gpu_total"]]
if world_size == 1:
short.append(e2["gpu_total"]) # all gather should not be happening.
else:
long.append(
e2["gpu_total"]
) # all gather should happen and prolong the cpu-gpu wait.
for s in short:
for l in long:
# 10X longer is a safe margin, since the time is around 100X longer
# when there is work on GPU vs. no work.
self.assertTrue(s * 10 < l)
# Check the GPU overlapping when there is all-gather.
if world_size > 1:
compute_only = e3["gpu_compute"]
all_gather_only = e2["gpu_total"]
both = e4["gpu_total"]
self.assertTrue(compute_only + all_gather_only > 1.1 * both)
@skip_if_lt_x_gpu(2)
def test_forward_overlap(self):
self._dist_train()
class TestForwardOverlapWorldSizeTwo(TestForwardOverlapWorldSizeOne):
@property
def world_size(self):
return 2
if __name__ == "__main__":
run_tests()
|