File: ddp_comm_hooks.rst

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (212 lines) | stat: -rw-r--r-- 7,707 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
DDP Communication Hooks
=======================

DDP communication hook is a generic interface to control how to communicate
gradients across workers by overriding the vanilla allreduce in
`DistributedDataParallel <https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel.>`_.
A few built-in communication hooks are provided,
and users can easily apply any of these hooks to optimize communication.
Besides, the hook interface can also support user-defined communication
strategies for more advanced use cases.

How to Use a Communication Hook?
--------------------------------

To use a communication hook, the user just needs to let the DDP model register
the hook before the training loop as below.

:func:`torch.nn.parallel.DistributedDataParallel.register_comm_hook`

What Does a Communication Hook Operate On?
------------------------------------------

A communication hook provides a flexible way to allreduce gradients.
Therefore, it mainly operates on the gradients on each replica before allreduce,
which are bucketized to increase the overlap between communication and computation.
Particularly, :class:`torch.distributed.GradBucket` represents a bucket of gradient tensors to be allreduced.

.. autoclass:: torch.distributed.GradBucket

.. autofunction:: torch.distributed.GradBucket.index
.. autofunction:: torch.distributed.GradBucket.buffer
.. autofunction:: torch.distributed.GradBucket.gradients
.. autofunction:: torch.distributed.GradBucket.is_last
.. autofunction:: torch.distributed.GradBucket.set_buffer
.. autofunction:: torch.distributed.GradBucket.parameters

Default Communication Hooks
---------------------------

Default communication hooks are simple **stateless** hooks, so the input state
in ``register_comm_hook`` is either a process group or ``None``.
The input ``bucket`` is a :class:`torch.distributed.GradBucket` object.

.. currentmodule:: torch.distributed.algorithms.ddp_comm_hooks.default_hooks
.. autofunction:: allreduce_hook
.. autofunction:: fp16_compress_hook
.. autofunction:: bf16_compress_hook

Additionally, a communication hook wrapper is provided to support :meth:`~fp16_compress_hook` or :meth:`~bf16_compress_hook` as a wrapper,
which can be combined with other communication hooks.

.. autofunction:: fp16_compress_wrapper
.. autofunction:: bf16_compress_wrapper

PowerSGD Communication Hook
---------------------------

PowerSGD (`Vogels et al., NeurIPS 2019 <https://arxiv.org/abs/1905.13727>`_)
is a gradient compression algorithm, which can provide very high compression
rates and accelerate bandwidth-bound distributed training.
This algorithm needs to maintain both some hyperparameters and the internal
state. Therefore, PowerSGD communication hook is a **stateful** hook,
and the user needs to provide a state object defined as below.

PowerSGD State
^^^^^^^^^^^^^^^^

.. currentmodule:: torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook
.. autoclass:: PowerSGDState

PowerSGD Hooks
^^^^^^^^^^^^^^^^

.. warning ::
    PowerSGD typically requires extra memory of the same size as the model's
    gradients to enable error feedback, which can compensate for biased
    compressed communication and improve accuracy.

.. warning ::
    PowerSGD hooks may conflict with `Apex automatic mixed precision package <https://github.com/NVIDIA/apex>`_.
    Please use PyTorch `native automatic mixed precision package <https://pytorch.org/docs/stable/amp.html>`_
    instead.

.. autofunction:: powerSGD_hook
.. autofunction:: batched_powerSGD_hook

Debugging Communication Hooks
-----------------------------

As the name implies, debugging communication hooks are **only** used for debugging and performance optimization purpose.

.. currentmodule:: torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks

.. warning ::
    Debugging communication hooks do not necessarily output the correct results.

.. autofunction:: noop_hook

Checkpointing of Communication Hooks
------------------------------------

.. currentmodule:: torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook

A stateful communication hook can be saved as a part of model checkpointing to enable trainer restarts.
To make a hook serializable, ``__setstate__`` and ``__getstate__`` should be defined.

.. warning ::
    ``__getstate__`` should exclude non-serializable attributes from a returned dictionary.

.. warning ::
    ``__setstate__`` should properly initialize non-serializable attributes, excluded from a provided ``state``.

:class:`PowerSGDState` has ``__setstate__`` and ``__getstate__`` implemented and can be used as a reference.

.. class:: PowerSGDState
    :noindex:

    .. automethod:: PowerSGDState.__getstate__
    .. automethod:: PowerSGDState.__setstate__

Here is a simple, end-to-end example of saving and reloading PowerSGD state and hook.

::

    import os
    import sys
    import tempfile
    import torch
    import torch.distributed as dist
    import torch.nn as nn
    import torch.optim as optim

    from torch.distributed.algorithms.ddp_comm_hooks import powerSGD_hook as powerSGD

    class SimpleModel(nn.Module):
        def __init__(self):
            super(SimpleModel, self).__init__()
            self.fc1 = nn.Linear(24,24)
            self.relu = nn.ReLU()
            self.fc2 = nn.Linear(24,12)

        def forward(self, x):
            return self.fc2(self.relu(self.fc1(x)))

    def setup(rank, world_size):
        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '12355'

        # initialize the process group
        dist.init_process_group("nccl", rank=rank, world_size=world_size)

    def cleanup():
        dist.destroy_process_group()

    def run_demo(demo_fn, world_size):
        mp.spawn(
            demo_fn,
            args=(world_size,),
            nprocs=world_size,
            join=True)

    def demo_serialization(rank, world_size):
        setup(rank, world_size)

        CHECKPOINT = tempfile.gettempdir() + "/checkpoint.pt"

        model = SimpleModel().to(rank)
        ddp_model = DistributedDataParallel(model, device_ids=[rank])

        powersgd_hook = powerSGD.powerSGD_hook
        powersgd_state = powerSGD.PowerSGDState(process_group=None)

        optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
        ddp_model.register_comm_hook(powersgd_state, powersgd_hook)

        state = {
            'state_dict': ddp_model.state_dict(),
            'comm_hook': hook,
            'comm_hook_state': hook_state}

        if rank == 0:
            torch.save(state, CHECKPOINT)

        dist.barrier()
        map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
        checkpoint = torch.load(CHECKPOINT, map_location=map_location)

        ddp_model.load_state_dict(checkpoint['state_dict'])
        powersgd_hook = checkpoint['comm_hook']
        powersgd_state = checkpoint['comm_hook_state']

        ddp_model.register_comm_hook(powersgd_state, powersgd_hook)

        if rank == 0:
            os.remove(CHECKPOINT)

        cleanup()

    if __name__ == "__main__":
        n_gpus = torch.cuda.device_count()
        assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
        world_size = n_gpus
        run_demo(demo_serialization, world_size)

Acknowledgements
----------------

Many thanks to PowerSGD paper author **Thijs Vogels** for the code review on
PowerSGD communication hook, as well as the
`comparison experiments <https://observablehq.com/@tvogels/powersgd-benchmark>`_,
which show that the performance of PowerSGD communication hook is on par with
the implementation in the original `paper <https://arxiv.org/abs/1905.13727>`_.