File: horovod.py

package info (click to toggle)
pytorch-ignite 0.5.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 11,712 kB
  • sloc: python: 46,874; sh: 376; makefile: 27
file content (210 lines) | stat: -rw-r--r-- 8,143 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import warnings
from typing import Any, Callable, cast, List, Mapping, Optional, Tuple

import torch

from ignite.distributed.comp_models.base import ComputationModel

try:
    import horovod.torch as hvd

    try:
        # old API
        from horovod.run.runner import run as hvd_mp_spawn
    except ImportError:
        # new API: https://github.com/horovod/horovod/pull/2099
        from horovod import run as hvd_mp_spawn

    has_hvd_support = True
except ImportError:
    has_hvd_support = False


if has_hvd_support:
    HOROVOD = "horovod"

    class _HorovodDistModel(ComputationModel):
        """Private class for `Horovod <https://horovod.readthedocs.io/en/stable/>`_ distributed computation model."""

        name = "horovod-dist"

        available_backends = (HOROVOD,)

        @staticmethod
        def _get_hvd_rank() -> int:
            try:
                rank = hvd.rank()
            except ValueError as e:
                rank = -1
            return rank

        @staticmethod
        def create_from_context() -> Optional["_HorovodDistModel"]:
            rank = _HorovodDistModel._get_hvd_rank()
            # hvd must be initialized
            if not rank > -1:
                return None
            return _HorovodDistModel()

        @staticmethod
        def create_from_backend(backend: str = HOROVOD, **kwargs: Any) -> "_HorovodDistModel":
            if backend not in _HorovodDistModel.available_backends:
                raise ValueError(f"Backend should be one of '{_HorovodDistModel.available_backends}'")

            rank = _HorovodDistModel._get_hvd_rank()
            # hvd must be not initialized
            if rank > -1:
                raise RuntimeError("Can not re-initialize Horovod if it is already initialized")
            return _HorovodDistModel(backend, **kwargs)

        def __init__(self, backend: Optional[str] = None, **kwargs: Any) -> None:
            """This is a private method. Please, use `create_from_backend` or `create_from_context`"""
            super(_HorovodDistModel, self).__init__()
            if backend is not None:
                self._create_from_backend(backend, **kwargs)
            else:
                self._init_from_context()

        def _create_from_backend(self, backend: str, **kwargs: Any) -> None:
            self._backend: str = backend
            comm = kwargs.get("comm", None)
            hvd.init(comm=comm)
            self._setup_attrs()
            if torch.cuda.is_available():
                torch.cuda.set_device(self.get_local_rank())

        def _init_from_context(self) -> None:
            self._backend = HOROVOD
            self._setup_attrs()

        def _compute_nproc_per_node(self) -> int:
            return hvd.local_size()

        def get_local_rank(self) -> int:
            return hvd.local_rank()

        def get_rank(self) -> int:
            return hvd.rank()

        def get_world_size(self) -> int:
            return hvd.size()

        def get_nproc_per_node(self) -> int:
            return cast(int, self._nproc_per_node)

        def get_nnodes(self) -> int:
            return cast(int, self._nnodes)

        def get_node_rank(self) -> int:
            return cast(int, self._node)

        def device(self) -> torch.device:
            if torch.cuda.is_available():
                index = torch.cuda.current_device()
                if index < self.get_local_rank():
                    warnings.warn(
                        "Current device index is less than current local rank. "
                        "Please, make sure to call torch.cuda.set_device(local_rank)."
                    )
                return torch.device(f"cuda:{index}")
            return torch.device("cpu")

        def backend(self) -> str:
            return self._backend

        def finalize(self) -> None:
            hvd.shutdown()

        @staticmethod
        def _dist_worker_task_fn(backend: str, fn: Callable, args: Tuple, kwargs_dict: Mapping) -> None:
            from ignite.distributed.utils import _set_model, finalize

            model = _HorovodDistModel.create_from_backend(backend)
            _set_model(model)
            fn(model.get_local_rank(), *args, **kwargs_dict)
            finalize()

        @staticmethod
        def spawn(
            fn: Callable,
            args: Tuple,
            kwargs_dict: Optional[Mapping] = None,
            nproc_per_node: int = 1,
            hosts: Optional[str] = None,
            backend: str = HOROVOD,
            **kwargs: Any,
        ) -> None:
            c1 = "nnodes" in kwargs and kwargs["nnodes"] > 1
            c2 = "node_rank" in kwargs and kwargs["node_rank"] > 0
            if c1 or c2:
                raise RuntimeError(
                    "For multi-node configuration, please set 'hosts' argument instead according to horovod.run API."
                )
            if "nnodes" in kwargs:
                # Remove 'nnodes=1' as it is an unexpected keyword argument for horovod.run
                del kwargs["nnodes"]
            if "node_rank" in kwargs:
                # Remove 'node_rank=0' as it is an unexpected keyword argument for horovod.run
                del kwargs["node_rank"]

            hvd_mp_spawn(
                _HorovodDistModel._dist_worker_task_fn,
                args=(HOROVOD, fn, args, kwargs_dict),
                num_proc=nproc_per_node,
                hosts=hosts,
                **kwargs,
            )

        _reduce_op_map = {
            "SUM": hvd.mpi_ops.Sum,
            "AVERAGE": hvd.mpi_ops.Average,
            "ADASUM": hvd.mpi_ops.Adasum,
        }

        _manual_reduce_op_map = {"MIN": torch.min, "MAX": torch.max, "PRODUCT": torch.prod}

        def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM", group: Optional[Any] = None) -> torch.Tensor:
            if group is not None:
                raise NotImplementedError("all_reduce with group for horovod is not implemented")
            if op in self._manual_reduce_op_map:
                op_fn = self._manual_reduce_op_map[op]
                return self._do_manual_all_reduce(tensor, op_fn)
            if op not in self._reduce_op_map:
                raise ValueError(f"Unsupported reduction operation: '{op}'")
            op = self._reduce_op_map[op]
            return hvd.allreduce(tensor, op=op)

        def _do_manual_all_reduce(self, tensor: torch.Tensor, op: Any) -> torch.Tensor:
            # We have to unsqueeze otherwise tensors will be gathered into a single tensor
            # without splitting (e.g. [1, 1, 1, 3, 3, 3] instead of [[1, 1, 1], [3, 3, 3]])
            # and reduction op wont work as expected
            res = self._do_all_gather(tensor.unsqueeze(0))
            reduced_res = op(res, dim=0)
            if isinstance(reduced_res, torch.Tensor):
                return reduced_res
            # output can also torch min/max_return_type: (min/max_vals, indices)
            return reduced_res[0]

        def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor:
            if group is not None:
                raise NotImplementedError("all_gather with group for horovod is not implemented")
            if tensor.ndimension() == 0:
                tensor = tensor.unsqueeze(0)
            return hvd.allgather(tensor)

        def _do_all_gather_object(self, tensor: Any, group: Optional[Any] = None) -> List[Any]:
            if group is not None:
                raise NotImplementedError("all_gather with group for horovod is not implemented")

            return hvd.allgather_object(tensor)

        def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any:
            return hvd.ProcessSet(ranks)

        def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
            return hvd.broadcast(tensor, root_rank=src)

        def barrier(self) -> None:
            # https://github.com/horovod/horovod/issues/159#issuecomment-424834603
            # hvd.allreduce(torch.tensor(0, device=self.device()), name="barrier")
            hvd.allreduce(torch.tensor(0, device="cpu"), name="barrier")