File: hierarchical_model_averager.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (164 lines) | stat: -rw-r--r-- 9,665 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# Copyright 2022 Cruise LLC
import logging
import warnings
from collections import OrderedDict
from typing import Union, Iterable, Dict

import torch
import torch.distributed as dist
import torch.distributed.algorithms.model_averaging.averagers as averagers
import torch.distributed.algorithms.model_averaging.utils as utils

logger = logging.getLogger(__name__)


class HierarchicalModelAverager(averagers.ModelAverager):
    r"""
    Runs hierarchical model averaging (`hierarchical SGD <https://arxiv.org/pdf/2010.12998.pdf>`_).
    Process groups of different sizes are organized in a hierarhicy, and they average parameters
    by using different periods concurrently after the warm-up stage.
    This is an extension of :class:`~torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager`
    that supports `post-local SGD <https://arxiv.org/abs/1808.07217>`_, which essentially only supports
    a two-level hierarchy: the intra-machine level and the global level, where the intra-machine
    level is usually embedded in :meth:`~torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook`.
    Similarly, the process groups within this class do not have such an intra-machine process
    subgroup, which should be embedded by the post-local SGD communication hook instead.

    Args:
        period_group_size_dict: An ordered dict mapping keys of model averaging period to
                                process group size, used for initializing process groups of
                                different sizes in a hierarchy to average parameters concurrently.
                                Particularly, at each iteration, there will be at most a single
                                process group that runs averaging -- the period of such group should
                                have the largest period which the current step can be divided by.
                                For example, if the dict has three keys: 2, 4, and 8,
                                then this means totally three process groups will be created to
                                average parameters every 2, 4, and 8 iterations, respectively.
                                At the 4th iteration, only the second process group will run
                                averaging, because the first process group should be a
                                subset of the second process group, and no need to execute the first
                                process group redundantly.
                                On the other hand, the third process group can only be triggered
                                every 8 iterations, so it will not be triggered at the 4th iteration.
        warmup_steps (int): The number of warm-up steps. During this stage, model averaging is skipped.
        process_group (ProcessGroup, optional): The overall process group containing all the processes that runs model averaging.
                                                If ``None``, the default process group, which is created
                                                by :func:`torch.distributed.init_process_group`, will be used.
                                                (default: ``None``)

    Example::
        >>> # xdoctest: +SKIP('undefined rank')
        >>> from collections import OrderedDict
        >>> import torch
        >>> import torch.distributed as dist
        >>> from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import (
        >>>     PostLocalSGDState,
        >>>     post_localSGD_hook,
        >>> )
        >>> import torch.distributed.algorithms.model_averaging.hierarchical_model_averager as hierarchicalSGD
        >>> import torch.nn as nn
        >>>
        >>> dist.init_process_group("nccl", rank=rank, world_size=16)
        >>> torch.cuda.set_device(rank)
        >>> module = nn.Linear(1, 1, bias=False).to(rank)
        >>> model = nn.parallel.DistributedDataParallel(
        >>>    module, device_ids=[rank], output_device=rank
        >>> )
        >>> # Register a post-localSGD communication hook.
        >>> # Assume that each machine has 4 GPUs, then each intra-machine subgroup has a size of 4.
        >>> subgroup, _ = dist.new_subgroups()
        >>> state = PostLocalSGDState(process_group=None, subgroup=subgroup, start_localSGD_iter=100)
        >>> model.register_comm_hook(state, post_localSGD_hook)
        >>>
        >>> # Average parameters among each group of 8 processes every 4 iterations, and among all
        >>> # the 16 processes every 16 iterations.
        >>> averager = hierarchicalSGD.HierarchicalModelAverager(
        >>>     period_group_size_dict=OrderedDict([(4, 8), (16, 16)]), warmup_steps=100)
        >>> # Note that ``warmup_steps`` must be the same as ``start_localSGD_iter`` used in ``PostLocalSGDState``.
        >>> # In the first 100 steps, run global gradient averaging like normal DDP at every step.
        >>> # After 100 steps, run model averaging at two levels.
        >>> for step in range(0, 200):
        >>>    optimizer.zero_grad()
        >>>    loss = loss_fn(output, labels)
        >>>    loss.backward()
        >>>    optimizer.step()
        >>>    # Average parameters after ``optimizer.step()``.
        >>>    # Thus, the inter-node communication only occurs periodically after ``warmup_steps``.
        >>>    averager.average_parameters(model.parameters())

    .. warning ::
        The last group size in the dict must be the size of the provided ``process_group``,
        which indicates model averaging at the highest level of the hierarchy.
        If ``process_group`` is not provided, then the last group size should be equal to the world size.

    .. warning ::
        `HierarchicalModelAverager` is experimental and subject to change.
    """

    def __init__(self, period_group_size_dict=None, warmup_steps=0, process_group=None):
        super().__init__(process_group)
        if not period_group_size_dict:
            raise ValueError("Arg ``period_group_size_dict`` must not be empty.")
        self._periods = list(period_group_size_dict.keys())
        if self._periods[0] <= 0:
            raise ValueError("The minimum period in arg ``period_group_size_dict`` must be a positive value.")
        elif self._periods[-1] == 1:
            warnings.warn(
                "When the maximum period in arg ``period_group_size_dict`` is 1, "
                "no need to use model averaging because the communication cost "
                "of all-reducing parameters will be no less than the cost of all-reducing gradients "
                "by DistributedDataParallel in the backward pass. Therefore, only "
                "DistributedDataParallel should be used for this case."
            )
        overall_group_size = dist.get_world_size(group=self.process_group)
        if list(period_group_size_dict.values())[-1] != overall_group_size:
            raise ValueError(
                f"The last value in arg ``period_process_group_dict`` {list(period_group_size_dict.values())[-1]} "
                f"must be equal to the size of arg ``process_group`` {overall_group_size}."
            )

        self.period_process_group_dict = OrderedDict()
        logger.info("Model averaging hierarchy:")
        for period, group_size in period_group_size_dict.items():
            logger.info(
                f"\tEach group that has {group_size} processes average parameters every {period} iterations, "
                "if no higher-level averaging.")
            if group_size != overall_group_size:
                self.period_process_group_dict[period], _ = dist.new_subgroups(
                    group_size=group_size, group=self.process_group)
            else:
                self.period_process_group_dict[period] = self.process_group

        if warmup_steps < 0:
            raise ValueError("Arg ``warmup_steps`` must be a non-negative number.")
        self.warmup_steps = warmup_steps

    def _find_process_group(self):
        """
        Returns a process group as the value of an ``period_process_group_dict`` entry,
        if ``step`` can be divided by a period in the keys of ``period_process_group_dict``.
        If ``step`` can be divided by multiple periods in the keys of ``period_process_group_dict``,
        then the returned process group is the one corresponding to the largest period,
        since this process group will be used for averaging parameters at this ``step``.
        Returns ``None`` if not found.
        """
        for period in reversed(self._periods):
            if self.step % period == 0:
                return self.period_process_group_dict[period]
        return None

    def average_parameters(self, params: Union[Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]]):
        """
        Averages parameters or parameter groups of an optimizer if ``step`` is no less than ``warmup_steps``
        and it can be divided by a period in the keys of ``period_process_group_dict``,
        where ``step`` is increased by 1 at each iteration in the training loop.
        If ``step`` can be divided by multiple periods in the keys of ``period_process_group_dict``,
        only the largest period is used, and the corresponding process group is used for averaging parameters.
        Args:
            params: The parameters of a model or parameter groups of an optimizer.
        """
        if self.step >= self.warmup_steps:
            group = self._find_process_group()
            if group is not None:
                utils.average_parameters_or_parameter_groups(params, group)
        self.step += 1