File: ddp_comm_hook.md

package info (click to toggle)
accelerate 1.12.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 4,900 kB
  • sloc: python: 40,061; sh: 90; makefile: 79
file content (337 lines) | stat: -rw-r--r-- 10,630 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
<!--
Copyright 2022 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contains specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

# DDP Communication Hooks

Distributed Data Parallel (DDP) communication hooks provide a generic interface to control how gradients are communicated across workers by overriding the vanilla allreduce in `DistributedDataParallel`. A few built-in communication hooks are provided, and users can easily apply any of these hooks to optimize communication.


- **FP16 Compression Hook**: Compresses gradients by casting them to half-precision floating-point format (`torch.float16`), reducing communication overhead.
- **BF16 Compression Hook**: Similar to FP16, but uses the Brain Floating Point format (`torch.bfloat16`), which can be more efficient on certain hardware.
- **PowerSGD Hook**: An advanced gradient compression algorithm that provides high compression rates and can accelerate bandwidth-bound distributed training.

In this tutorial, you will see how to quickly set up DDP communication hooks and perform training with the utilities provided in Accelerate, which can be as simple as adding just one new line of code! This demonstrates how to use DDP communication hooks to optimize gradient communication in distributed training with the Accelerate library.

## FP16 Compression Hook

<hfoptions id="fp16">
<hfoption id="PyTorch">

```python
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.algorithms.ddp_comm_hooks import default_hooks
from accelerate.test_utils.testing import get_backend

device_type, _, _ = get_backend()
device_id = getattr(torch, device_type, torch.cuda).current_device()

class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(10, 10)

    def forward(self, x):
        return self.layer(x)

model = MyModel()
model = DDP(model, device_ids=[device_id])
model.register_comm_hook(state=None, hook=default_hooks.fp16_compress_hook)

# Training loop
for data, targets in data_loader:
    outputs = model(data)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
```

</hfoption>
<hfoption id="Accelerate">

```python
from accelerate import Accelerator, DDPCommunicationHookType, DistributedDataParallelKwargs
import torch

class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(10, 10)

    def forward(self, x):
        return self.layer(x)

# DDP Communication Hook setup
ddp_kwargs = DistributedDataParallelKwargs(comm_hook=DDPCommunicationHookType.FP16)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])

model = MyModel()
optimizer = torch.optim.Adam(model.parameters())
data_loader = DataLoader(dataset, batch_size=16)

model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader)

# Training loop
for data, targets in data_loader:
    outputs = model(data)
    loss = criterion(outputs, targets)
    accelerator.backward(loss)
    optimizer.step()
    optimizer.zero_grad()
```

</hfoption>
</hfoptions>

### BF16 Compression Hook

<Tip warning={true}>

BF16 Compression Hook API is experimental, and it requires NCCL version later than 2.9.6.

</Tip>

<hfoptions id="bf16">
<hfoption id="PyTorch">

```python
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.algorithms.ddp_comm_hooks import default_hooks
from accelerate.test_utils.testing import get_backend

device_type, _, _ = get_backend()
device_id = getattr(torch, device_type, torch.cuda).current_device()

class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(10, 10)

    def forward(self, x):
        return self.layer(x)

model = MyModel()
model = DDP(model, device_ids=[device_id])
model.register_comm_hook(state=None, hook=default_hooks.bf16_compress_hook)

# Training loop
for data, targets in data_loader:
    outputs = model(data)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
```

</hfoption>
<hfoption id="Accelerate">

```python
from accelerate import Accelerator, DDPCommunicationHookType, DistributedDataParallelKwargs
import torch

class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(10, 10)

    def forward(self, x):
        return self.layer(x)

# DDP Communication Hook setup
ddp_kwargs = DistributedDataParallelKwargs(comm_hook=DDPCommunicationHookType.BF16)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])

model = MyModel()
optimizer = torch.optim.Adam(model.parameters())
data_loader = DataLoader(dataset, batch_size=16)

model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader)

# Training loop
for data, targets in data_loader:
    outputs = model(data)
    loss = criterion(outputs, targets)
    accelerator.backward(loss)
    optimizer.step()
    optimizer.zero_grad()
```

</hfoption>
</hfoptions>

### PowerSGD Hook

<Tip warning={true}>

PowerSGD typically requires extra memory of the same size as the model’s gradients to enable error feedback, which can compensate for biased compressed communication and improve accuracy.

</Tip>

<hfoptions id="powerSGD">
<hfoption id="PyTorch">

```python
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.algorithms.ddp_comm_hooks import powerSGD_hook
from accelerate.test_utils.testing import get_backend

device_type, _, _ = get_backend()
device_id = getattr(torch, device_type, torch.cuda).current_device()

class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(10, 10)

    def forward(self, x):
        return self.layer(x)

model = MyModel()
model = DDP(model, device_ids=[device_id])
state = powerSGD_hook.PowerSGDState(process_group=None)
model.register_comm_hook(state=state, hook=powerSGD_hook.powerSGD_hook)

# Training loop
for data, targets in data_loader:
    outputs = model(data)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
```

</hfoption>
<hfoption id="Accelerate">

```python
from accelerate import Accelerator, DDPCommunicationHookType, DistributedDataParallelKwargs
import torch

class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(10, 10)

    def forward(self, x):
        return self.layer(x)

# DDP Communication Hook setup
ddp_kwargs = DistributedDataParallelKwargs(comm_hook=DDPCommunicationHookType.POWER_SGD)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])

model = MyModel()
optimizer = torch.optim.Adam(model.parameters())
data_loader = DataLoader(dataset, batch_size=16)

model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader)

# Training loop
for data, targets in data_loader:
    outputs = model(data)
    loss = criterion(outputs, targets)
    accelerator.backward(loss)
    optimizer.step()
    optimizer.zero_grad()
```

</hfoption>
</hfoptions>

## DDP Communication Hooks utilities

There are two additional utilities for supporting optional functionalities with the communication hooks.

### comm_wrapper

`comm_wrapper` is an option to wrap a communication hook with additional functionality. For example, it can be used to combine FP16 compression with other communication strategies. Currently supported wrappers are `no`, `fp16`, and `bf16`.

```python
from accelerate import Accelerator, DDPCommunicationHookType, DistributedDataParallelKwargs
import torch

class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(10, 10)

    def forward(self, x):
        return self.layer(x)

# DDP Communication Hook setup
ddp_kwargs = DistributedDataParallelKwargs(
    comm_hook=DDPCommunicationHookType.POWER_SGD,
    comm_wrapper=DDPCommunicationHookType.FP16
)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])

model = MyModel()
optimizer = torch.optim.Adam(model.parameters())
data_loader = DataLoader(dataset, batch_size=16)

model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader)

# Training loop
for data, targets in data_loader:
    outputs = model(data)
    loss = criterion(outputs, targets)
    accelerator.backward(loss)
    optimizer.step()
    optimizer.zero_grad()
```

### comm_state_option

`comm_state_option` allows you to pass additional state information required by certain communication hooks. This is particularly useful for stateful hooks like `PowerSGD`, which require maintaining hyperparameters and internal states across training steps. Below is an example showcasing the use of `comm_state_option` with the `PowerSGD` hook.

```python
from accelerate import Accelerator, DDPCommunicationHookType, DistributedDataParallelKwargs
import torch

class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(10, 10)

    def forward(self, x):
        return self.layer(x)

# DDP Communication Hook setup
ddp_kwargs = DistributedDataParallelKwargs(
    comm_hook=DDPCommunicationHookType.POWER_SGD,
    comm_state_option={"matrix_approximation_rank": 2}
)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])

model = MyModel()
optimizer = torch.optim.Adam(model.parameters())
data_loader = DataLoader(dataset, batch_size=16)

model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader)

# Training loop
for data, targets in data_loader:
    outputs = model(data)
    loss = criterion(outputs, targets)
    accelerator.backward(loss)
    optimizer.step()
    optimizer.zero_grad()
```

For more advanced usage and additional hooks, refer to the [PyTorch DDP Communication Hooks documentation](https://pytorch.org/docs/stable/ddp_comm_hooks.html).