File: gradient_accumulation.md

package info (click to toggle)
accelerate 1.12.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 4,900 kB
  • sloc: python: 40,061; sh: 90; makefile: 79
file content (470 lines) | stat: -rw-r--r-- 21,141 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
<!--Copyright 2022 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

# Performing gradient accumulation with Accelerate

Gradient accumulation is a technique where you can train on bigger batch sizes than 
your machine would normally be able to fit into memory. This is done by accumulating gradients over
several batches, and only stepping the optimizer after a certain number of batches have been performed.

While technically standard gradient accumulation code would work fine in a distributed setup, it is not the most efficient
method for doing so and you may experience considerable slowdowns!

In this tutorial you will see how to quickly setup gradient accumulation and perform it with the utilities provided in Accelerate,
which can total to adding just one new line of code!

This example will use a very simplistic PyTorch training loop that performs gradient accumulation every two batches:

```python
device = "cuda"
model.to(device)

gradient_accumulation_steps = 2

for index, batch in enumerate(training_dataloader):
    inputs, targets = batch
    inputs = inputs.to(device)
    targets = targets.to(device)
    outputs = model(inputs)
    loss = loss_function(outputs, targets)
    loss = loss / gradient_accumulation_steps
    loss.backward()
    if (index + 1) % gradient_accumulation_steps == 0:
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
```

## Converting it to Accelerate

First the code shown earlier will be converted to utilize Accelerate without the special gradient accumulation helper:

```diff
+ from accelerate import Accelerator
+ accelerator = Accelerator()

+ model, optimizer, training_dataloader, scheduler = accelerator.prepare(
+     model, optimizer, training_dataloader, scheduler
+ )

  for index, batch in enumerate(training_dataloader):
      inputs, targets = batch
-     inputs = inputs.to(device)
-     targets = targets.to(device)
      outputs = model(inputs)
      loss = loss_function(outputs, targets)
      loss = loss / gradient_accumulation_steps
+     accelerator.backward(loss)
      if (index+1) % gradient_accumulation_steps == 0:
          optimizer.step()
          scheduler.step()
          optimizer.zero_grad()
```

<Tip warning={true}>

  In its current state, this code is not going to perform gradient accumulation efficiently due to a process called gradient synchronization. Read more about that in the [Concepts tutorial](../concept_guides/gradient_synchronization)!

</Tip>

## Letting Accelerate handle gradient accumulation

All that is left now is to let Accelerate handle the gradient accumulation for us. To do so you should pass in a `gradient_accumulation_steps` parameter to [`Accelerator`], dictating the number 
of steps to perform before each call to `step()` and how to automatically adjust the loss during the call to [`~Accelerator.backward`]:

```diff
  from accelerate import Accelerator
- accelerator = Accelerator()
+ accelerator = Accelerator(gradient_accumulation_steps=2)
```

Alternatively, you can pass in a `gradient_accumulation_plugin` parameter to the [`Accelerator`] object's `__init__`, which will allow you to further customize the gradient accumulation behavior. 
Read more about that in the [GradientAccumulationPlugin](../package_reference/accelerator#accelerate.utils.GradientAccumulationPlugin) docs.

From here you can use the [`~Accelerator.accumulate`] context manager from inside your training loop to automatically perform the gradient accumulation for you!
You just wrap it around the entire training part of our code: 

```diff
- for index, batch in enumerate(training_dataloader):
+ for batch in training_dataloader:
+     with accelerator.accumulate(model):
          inputs, targets = batch
          outputs = model(inputs)
```

You can remove all the special checks for the step number and the loss adjustment:

```diff
- loss = loss / gradient_accumulation_steps
  accelerator.backward(loss)
- if (index+1) % gradient_accumulation_steps == 0:
  optimizer.step()
  scheduler.step()
  optimizer.zero_grad()
```

As you can see the [`Accelerator`] is able to keep track of the batch number you are on and it will automatically know whether to step through the prepared optimizer and how to adjust the loss. 

<Tip>

Typically with gradient accumulation, you would need to adjust the number of steps to reflect the change in total batches you are 
training on. Accelerate automagically does this for you by default. Behind the scenes we instantiate a [`GradientAccumulationPlugin`] configured to do this.

</Tip>

<Tip warning={true}>

The [`state.GradientState`] is sync'd with the active dataloader being iterated upon. As such it assumes naively that when we have reached the end of the dataloader everything will sync and a step will be performed. To disable this, set `sync_with_dataloader` to be `False` in the [`GradientAccumulationPlugin`]:

```{python}
from accelerate import Accelerator
from accelerate.utils import GradientAccumulationPlugin

plugin = GradientAccumulationPlugin(sync_with_dataloader=False)
accelerator = Accelerator(..., gradient_accumulation_plugin=plugin)
```

</Tip>

## The finished code

Below is the finished implementation for performing gradient accumulation with Accelerate

```python
from accelerate import Accelerator
accelerator = Accelerator(gradient_accumulation_steps=2)
model, optimizer, training_dataloader, scheduler = accelerator.prepare(
    model, optimizer, training_dataloader, scheduler
)
for batch in training_dataloader:
    with accelerator.accumulate(model):
        inputs, targets = batch
        outputs = model(inputs)
        loss = loss_function(outputs, targets)
        accelerator.backward(loss)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
```

<Tip warning={true}>

It's important that **only one forward/backward** should be done inside the context manager `with accelerator.accumulate(model)`.

</Tip>


To learn more about what magic this wraps around, read the [Gradient Synchronization concept guide](../concept_guides/gradient_synchronization)


## Self-contained example

Here is a self-contained example that you can run to see gradient accumulation in action with Accelerate:

```python
import torch
import copy
from accelerate import Accelerator
from accelerate.utils import set_seed
from torch.utils.data import TensorDataset, DataLoader

# seed
set_seed(0)

# define toy inputs and labels
x = torch.tensor([1., 2., 3., 4., 5., 6., 7., 8.])
y = torch.tensor([2., 4., 6., 8., 10., 12., 14., 16.])
gradient_accumulation_steps = 4
per_device_batch_size = len(x) // gradient_accumulation_steps

# define dataset and dataloader
dataset = TensorDataset(x, y)
dataloader = DataLoader(dataset, batch_size=per_device_batch_size)

# define model, optimizer and loss function
class SimpleLinearModel(torch.nn.Module):
    def __init__(self):
        super(SimpleLinearModel, self).__init__()
        self.weight = torch.nn.Parameter(torch.zeros((1, 1)))

    def forward(self, inputs):
        return inputs @ self.weight

model = SimpleLinearModel()
model_clone = copy.deepcopy(model)
criterion = torch.nn.MSELoss()
model_optimizer = torch.optim.SGD(model.parameters(), lr=0.02)
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps)
model, model_optimizer, dataloader = accelerator.prepare(model, model_optimizer, dataloader)
model_clone_optimizer = torch.optim.SGD(model_clone.parameters(), lr=0.02)
print(f"initial model weight is {model.weight.mean().item():.5f}")
print(f"initial model weight is {model_clone.weight.mean().item():.5f}")
for i, (inputs, labels) in enumerate(dataloader):
    with accelerator.accumulate(model):
        inputs = inputs.view(-1, 1)
        print(i, inputs.flatten())
        labels = labels.view(-1, 1)
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        accelerator.backward(loss)
        model_optimizer.step()
        model_optimizer.zero_grad()
loss = criterion(x.view(-1, 1) @ model_clone.weight, y.view(-1, 1))
model_clone_optimizer.zero_grad()
loss.backward()
model_clone_optimizer.step()
print(f"w/ accumulation, the final model weight is {model.weight.mean().item():.5f}")
print(f"w/o accumulation, the final model weight is {model_clone.weight.mean().item():.5f}")
```
```
initial model weight is 0.00000
initial model weight is 0.00000
0 tensor([1., 2.])
1 tensor([3., 4.])
2 tensor([5., 6.])
3 tensor([7., 8.])
w/ accumulation, the final model weight is 2.04000
w/o accumulation, the final model weight is 2.04000
```

## Gradient accumulation on training samples of variable size

As was pointed out in this [blog-post](https://huggingface.co/blog/gradient_accumulation), which points out a common error that occurs when performing gradient accumulation on training samples of variable size:

>  [...] for gradient accumulation across token-level tasks like causal LM training, the correct loss should be computed by the **total loss across all batches in a gradient accumulation step** divided by the **total number of all non padding tokens in those batches**. This is not the same as the average of the per-batch loss values. 

In other words, some adjustments must be made on losses that operate on a token-level basis.

### Skeleton code

```python
from accelerate import Accelerator
import math
import contextlib

gradient_accumulation_steps = 2
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps)
model, optimizer, training_dataloader, scheduler = accelerator.prepare(
    model, optimizer, training_dataloader, scheduler
)

training_iterator = iter(training_dataloader)
num_samples_in_epoch = len(training_dataloader)
remainder = num_samples_in_epoch % gradient_accumulation_steps
remainder = remainder if remainder != 0 else gradient_accumulation_steps
total_updates = math.ceil(num_samples_in_epoch / gradient_accumulation_steps)
        

total_batched_samples = 0
for update_step in range(total_updates):
        # In order to correctly the total number of non-padded tokens on which we'll compute the cross-entropy loss
        # we need to pre-load the full local batch - i.e the next per_device_batch_size * accumulation_steps samples
        batch_samples = []
        num_batches_in_step = gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
        for _ in range(num_batches_in_step):
            batch_samples += [next(training_iterator)]
            
        # get local num items in batch 
        num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples])
        # to compute it correctly in a multi-device DDP training, we need to gather the total number of items in the full batch.
        num_items_in_batch = accelerator.gather(num_items_in_batch).sum().item()
            
        for i, batch in enumerate(batch_samples):
            # if we perform gradient accumulation in a multi-devices set-up, we want to avoid unnecessary communications when accumulating
            # cf: https://muellerzr.github.io/blog/gradient_accumulation.html
            if (i < len(batch_samples) - 1 and accelerator.num_processes > 1):
                ctx = model.no_sync
            else:
                ctx = contextlib.nullcontext
            
            total_batched_samples += 1

            with ctx():
                inputs, targets = batch
                outputs = model(inputs)
                loss = loss_function(outputs, targets) # the loss function should sum over samples rather than averaging
                
                # We multiply by num_processes because the DDP calculates the average gradient across all devices whereas dividing by num_items_in_batch already takes into account all devices
                # Same reason for gradient_accumulation_steps, but this times it's Accelerate that calculate the average gradient across the accumulated steps
                loss = (loss * gradient_accumulation_steps * accelerator.num_processes) / num_items_in_batch
                
                accelerator.backward(loss)

        # Sync gradients and perform optimization steps once every gradient_accumulation_steps
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
```

### Self-contained causal LM example

```py
import torch
import copy
from accelerate import Accelerator
from accelerate.utils import set_seed
from accelerate.logging import  get_logger
from torch.utils.data import Dataset, DataLoader
import math
import contexlib

# seed
set_seed(0)
logger = get_logger(__name__)

class MyDataset(Dataset):
    def __init__(self, num_samples):
        super().__init__()
        self.len = num_samples

    def __getitem__(self, index):
        input_ids = torch.arange(1, index+2, dtype=torch.float32)
        labels = torch.remainder(input_ids, 2)
        return {"input_ids": input_ids, "labels": labels}

    def __len__(self):
        return self.len
    
def collate_fn(features):
    input_ids = torch.nn.utils.rnn.pad_sequence([f["input_ids"] for f in features], batch_first=True, padding_value=-100)
    labels = torch.nn.utils.rnn.pad_sequence([f["labels"] for f in features], batch_first=True, padding_value=-100)
    return {"input_ids": input_ids[..., None], "labels": labels[..., None]}

# define toy inputs and labels
gradient_accumulation_steps = 2
per_device_batch_size = 4

# define accelerator
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps)

# define dataset and dataloader
# for this toy example, we'll compute gradient descent over one single global batch
dataset = MyDataset(per_device_batch_size*gradient_accumulation_steps*accelerator.num_processes)
dataloader = DataLoader(dataset, batch_size=per_device_batch_size, collate_fn=collate_fn)

# define model, model_optimizer and loss function
model = torch.nn.Linear(1, 2, bias=False)
model_clone = copy.deepcopy(model)
criterion = torch.nn.CrossEntropyLoss(reduction="sum") # must sum over samples rather than averaging
model_optimizer = torch.optim.SGD(model.parameters(), lr=0.08)


logger.warning(f"initial model weight is {model.weight.detach().cpu().squeeze()}")
logger.warning(f"initial model clone weight is {model_clone.weight.detach().cpu().squeeze()}")

# prepare artifacts - accelerator handles device placement and dataloader splitting
model, model_optimizer = accelerator.prepare(model, model_optimizer)
dataloader = accelerator.prepare_data_loader(dataloader, device_placement=True)
training_iterator = iter(dataloader)

num_samples_in_epoch = len(dataloader)
remainder = num_samples_in_epoch % gradient_accumulation_steps
remainder = remainder if remainder != 0 else gradient_accumulation_steps
total_gradient_updates = math.ceil(num_samples_in_epoch / gradient_accumulation_steps)

total_batched_samples = 0
for update_step in range(total_gradient_updates):
        # In order to correctly the total number of non-padded tokens on which we'll compute the cross-entropy loss
        # we need to pre-load the full local batch - i.e the next per_device_batch_size * accumulation_steps samples
        batch_samples = []
        num_batches_in_step = gradient_accumulation_steps if update_step != (total_gradient_updates - 1) else remainder
        for _ in range(num_batches_in_step):
            batch_samples += [next(training_iterator)]
            
        # get local num items in batch 
        local_num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples])
        logger.warning(f"Step {update_step} - Device {accelerator.process_index} - num items in the local batch {local_num_items_in_batch}", main_process_only=False)

        # to compute it correctly in a multi-device DDP training, we need to gather the total number of items in the full batch.
        num_items_in_batch = accelerator.gather(local_num_items_in_batch).sum().item()
        logger.warning(f"Total num items {num_items_in_batch}")

        for i, batch in enumerate(batch_samples):
            inputs, labels = batch["input_ids"], batch["labels"]
            total_batched_samples += 1
            # if we perform gradient accumulation in a multi-devices set-up, we want to avoid unnecessary communications when accumulating
            # cf: https://muellerzr.github.io/blog/gradient_accumulation.html
            if (i < len(batch_samples) - 1 and accelerator.num_processes > 1):
                ctx = model.no_sync
            else:
                ctx = contextlib.nullcontext
            with ctx():

                outputs = model(inputs)
                loss = criterion(outputs.view(-1, 2), labels.view(-1).to(torch.int64))
                
                # We multiply by num_processes because the DDP calculates the average gradient across all devices whereas dividing by num_items_in_batch already takes into account all devices
                # Same reason for gradient_accumulation_steps, but this times it's Accelerate that calculate the average gradient across the accumulated steps 
                loss = (loss * gradient_accumulation_steps * accelerator.num_processes) / num_items_in_batch
                accelerator.backward(loss)
        model_optimizer.step()
        model_optimizer.zero_grad()
                

logger.warning(f"Device {accelerator.process_index} - w/ accumulation, the final model weight is {accelerator.unwrap_model(model).weight.detach().cpu().squeeze()}", main_process_only=False)

# We know do the same operation but on a single device and without gradient accumulation

if accelerator.is_main_process:
    # prepare one single entire batch
    dataloader = DataLoader(dataset, batch_size=len(dataset), collate_fn=collate_fn)
    full_batch_without_accum = next(iter(dataloader))
    total_inputs, total_labels = full_batch_without_accum["input_ids"], full_batch_without_accum["labels"]
    model_clone_optimizer = torch.optim.SGD(model_clone.parameters(), lr=0.08)
    
    # train the cloned model
    loss = torch.nn.CrossEntropyLoss(reduction="mean")(model_clone(total_inputs).view(-1, 2), total_labels.view(-1).to(torch.int64))
    model_clone_optimizer.zero_grad()
    loss.backward()
    model_clone_optimizer.step()
    
    # We should have the same final weights.
    logger.warning(f"w/o accumulation, the final model weight is {model_clone.weight.detach().cpu().squeeze()}")

```

Results on a single device - gradient accumulation steps set to 1 and batch_size set to 8:
```
initial model weight is tensor([-0.0075,  0.5364])
initial model clone weight is tensor([-0.0075,  0.5364])
Step 0 - Device 0 - num items in the local batch 36
Total num items 36
Device 0 - w/ accumulation, the final model weight is tensor([0.0953, 0.4337])
w/o accumulation, the final model weight is tensor([0.0953, 0.4337])
```

Results on a two devices set-up - gradient accumulation steps set to 2 and batch_size set to 4.
```
initial model weight is tensor([-0.0075,  0.5364])
initial model clone weight is tensor([-0.0075,  0.5364])
Step 0 - Device 0 - num items in the local batch 52
Step 0 - Device 1 - num items in the local batch 84
Total num items 136
Device 1 - w/ accumulation, the final model weight is tensor([0.2117, 0.3172])
Device 0 - w/ accumulation, the final model weight is tensor([0.2117, 0.3172])
w/o accumulation, the final model weight is tensor([0.2117, 0.3172])
```

### To go further:

Please find a complete example script on a real world training run in the examples folder at the path [`accelerate/examples/by_feature/gradient_accumulation_for_autoregressive_models.py`](https://github.com/huggingface/accelerate/blob/main/examples/by_feature/gradient_accumulation_for_autoregressive_models.py).

Running it on several training configurations with constant global batch size equal to 32 gives the following graph:

<div style="text-align: center">
<img src="https://huggingface.co/datasets/hf-audio/gradient_accumulation_example/resolve/main/training_losses.png">
</div>

Note that the training losses are exactly the same up to training step 20. The small deviation after this training step occurs at the very end of the first epoch, because, by [default](https://huggingface.co/docs/accelerate/en/package_reference/torch_wrappers#accelerate.data_loader.prepare_data_loader.even_batches), the dataloader duplicates the samples at the beginning of the dataset when the total batch size doesn't exactly divide the dataset.