1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230
|
import functools
import torch
import torch.distributed as dist
import torch.nn as nn
class PythonDDP(nn.Module):
"""
Python only implementation for DistributedDataParallel module.
Unlike the production DistributedDataParallel which relies on many C++ core
utils to manage gradient distribution and reduction. This class implement
all functions in pure Python such as param bucketing, gradient
synchronization and reduction. The only C++ dependency is the common utils:
``dist.all_reduce``
The idea: parallelize gradient calculation and reduction, the same algo as
https://pytorch.org/docs/stable/notes/ddp.html, main steps:
1. Distribute params into list of buckets.
2. Register per-param hook to be invoked when grad is ready during backward
3. In the hook, copy grad to corresponding bucket. If bucket is full, kick
off an async all_reduce operation to calculate average grad.
4. After backward wait for all async ops to be done. Copy reduced grads back
to original places.
Two modes are supported, asynchronous reduction (async_reduction=True) and
synchronous reduction (async_reduction=False) which shares the same algo as
LegacyDistributedDataParallel.
Same as DistributedDataParallel to use this class , a process group needs to
be initiated.
Example::
>>> torch.distributed.init_process_group(
>>> backend='gloo', world_size=N, init_method='...'
>>> )
>>> pg = dist.distributed_c10d._get_default_group()
>>> async_reduction = True
>>> module = ToyModel()
>>> ddp_model = PythonDDP(module, pg, async_reduction)
>>> loss_fn = nn.MSELoss()
>>> optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
>>> outputs = ddp_model(torch.randn(20, 10).to(rank))
>>> labels = torch.randn(20, 10).to(rank)
>>> loss_fn(outputs, labels).backward()
>>>
>>> # Reduce param grads
>>> ddp_model.all_reduce_grads()
>>> optimizer.step()
>>>
"""
class Bucket:
"""Bucket is a container for list of params. """
def __init__(self, max_buffer_size):
self.param_to_offset = {}
self.buffer = None
self.ready_param_grad_count = 0
self.total_elements = 0
self._MAX_BUFFER_SIZE = max_buffer_size
def __str__(self):
return "Bucket: num_params={}, total_elements={}, ready_param_grad_count={}".format(
len(self.param_to_offset),
self.total_elements,
self.ready_param_grad_count)
def is_full(self):
"""
Returns whether grad for all the params in current bucket are ready
and copied to self.buffer.
"""
assert self.ready_param_grad_count >= 0
assert self.ready_param_grad_count <= len(self.param_to_offset)
return len(self.param_to_offset) == self.ready_param_grad_count
def empty(self):
self.ready_param_grad_count = 0
def try_hold_param(self, param):
"""
Checks whether current bucket has enough buffer to hold the incoming
param. If there is enough space, distribute param into current
bucket and Returns true. Otherwise, returns False.
"""
if self.total_elements + param.numel() <= self._MAX_BUFFER_SIZE :
self.param_to_offset[param] = self.total_elements
self.total_elements += param.numel()
return True
else:
return False
def __init__(self, module, process_group, async_reduction=True, buffer_size=2 ** 22):
super(PythonDDP, self).__init__()
self.module = module
self.process_group = process_group
self.world_size = dist.get_world_size(group=self.process_group)
self.async_reduction = async_reduction
# Holds all_reduce handles, used when async_reduction is True
self.async_handles = set()
# Ensure buffer_size is large enough to hold largest param.
max_numel = max(p.numel() for p in module.parameters())
assert buffer_size > max_numel, "buffer_size: {} should be larger than largest param: {}".format(buffer_size, max_numel)
# Build buckets for params
self.param_to_bucket, self.buckets = self._build_buckets_for_params(buffer_size)
# Register per-parameter hook to be invoked when grad is ready.
for p in self.module.parameters():
assert p.requires_grad
p.register_hook(functools.partial(self._on_param_grad_ready, p))
def _build_buckets_for_params(self, max_buffer_size):
"""
Distributes params into list of buckets. Maintains param -> bucket
mapping. Returns tuple of (param_to_buckets, buckets).
"""
print("_build_buckets_for_params called")
params_to_buckets = {}
buckets = set()
cur_bucket = self.Bucket(max_buffer_size)
total_param = 0
for param in self.module.parameters():
total_param += 1
assert param.requires_grad, "param.requires_grad must be True"
if cur_bucket.try_hold_param(param):
params_to_buckets[param] = cur_bucket
buckets.add(cur_bucket)
else:
new_bucket = self.Bucket(max_buffer_size)
assert new_bucket.try_hold_param(param), "param must be holded in a empty bucket"
params_to_buckets[param] = new_bucket
buckets.add(new_bucket)
cur_bucket = new_bucket
first_param = next(self.module.parameters())
for bucket in buckets:
bucket.buffer = first_param.new(bucket.total_elements)
assert bucket.buffer is not None, 'bucket.buffer should not be None'
print("len(param_to_bucket)={}, len(buckets)={}".format(
len(params_to_buckets), len(buckets)))
# Sanity check to ensure all params are distributed correctly into buckets
total_params_in_buckets = 0
for bucket in buckets:
total_params_in_buckets += len(bucket.param_to_offset)
assert total_param == total_params_in_buckets
return params_to_buckets, buckets
# Callback when param.grad is ready. Note during callback, param.grad won't
# be ready yet, we MUST use the given ''grad'' which would be passed upon
# callback.
def _on_param_grad_ready(self, param, grad):
"""
Callback when grad for param is ready. Copy grad to its corresponding
bucket. When the bucket is full, kickoff an async all_reduce if
async_reduction is set, and adds the resultant handle to
self.async_handles.
.. warning::
Note param.grad isn't set yet. Use the passed grad instead.
"""
# Validate bucket and offset are set.
bucket = self.param_to_bucket.get(param)
assert bucket is not None, "Failed to find bucket for param"
offset = bucket.param_to_offset.get(param)
assert offset is not None, "offset must be set for param"
assert bucket.buffer is not None, "buffer must be allocated"
# Copy grad to bucket, note param.grad isn't ready yet.
sz = param.numel()
assert grad is not None
assert param.requires_grad
assert param.numel() == grad.numel()
bucket.buffer[offset : offset + sz].copy_(grad.detach().view(-1))
bucket.ready_param_grad_count += 1
# Kickoff grad reduction async when bucket is full. This ensures grad
# reduction and other grad calculation runs in parallel.
if self.async_reduction and bucket.is_full():
bucket.buffer.div_(self.world_size)
handle = dist.all_reduce(
bucket.buffer, dist.ReduceOp.SUM, self.process_group, True)
self.async_handles.add(handle)
def forward(self, *inputs, **kwargs):
return self.module(*inputs, **kwargs)
def all_reduce_grads(self):
"""
Reduces all gradients across worker and updates param gradients. The
client should call this func post backward.
If async_reduction is True, waits for all async handles (of all_reduce),
otherwise, kicks off synchrous all_reduce for all buckets.
Once all all buckets are reduced, copy the reduced grads back to their
original parameters. After that, reset all buckets in prep for the next
iteration.
"""
if self.async_reduction:
for handle in self.async_handles:
handle.wait()
self.async_handles.clear()
else:
for bucket in self.buckets:
assert bucket.is_full()
bucket.buffer.div_(self.world_size)
dist.all_reduce(bucket.buffer, dist.ReduceOp.SUM, self.process_group)
# Copy reduced-grad back into original place
for bucket in self.buckets:
assert bucket.is_full()
for cur_p, cur_offset in bucket.param_to_offset.items():
sz = cur_p.numel()
if cur_p.grad is not None:
with torch.no_grad():
cur_p.grad.copy_(bucket.buffer[cur_offset : cur_offset + sz].view_as(cur_p))
else:
cur_p.grad = bucket.buffer[cur_offset : cur_offset + sz].view_as(cur_p).clone()
# Empty bucket for next epoch
for bucket in self.buckets:
bucket.empty()
|