File: xla.py

package info (click to toggle)
pytorch-ignite 0.5.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 11,712 kB
  • sloc: python: 46,874; sh: 376; makefile: 27
file content (177 lines) | stat: -rw-r--r-- 6,174 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
from typing import Any, Callable, cast, List, Mapping, Optional, Tuple

import torch

from ignite.distributed.comp_models.base import ComputationModel

try:
    import torch_xla
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.xla_multiprocessing as xmp

    has_xla_support = True
except ImportError:
    has_xla_support = False


if has_xla_support:
    XLA_TPU = "xla-tpu"

    class _XlaDistModel(ComputationModel):
        """Private class for PyTorch XLA basic distributed computation model.
        It handles single/multi-device computation model.

        Supported XLA devices:

        - CPU
        - TPU

        """

        name = "xla-dist"

        available_backends = (XLA_TPU,)

        @staticmethod
        def create_from_context() -> Optional["_XlaDistModel"]:
            return _XlaDistModel()

        @staticmethod
        def create_from_backend(backend: str = XLA_TPU, **kwargs: Any) -> "_XlaDistModel":
            if backend not in _XlaDistModel.available_backends:
                raise ValueError(f"Backend should be one of '{_XlaDistModel.available_backends}'")

            return _XlaDistModel(backend=backend, **kwargs)

        def __init__(self, backend: Optional[str] = None, **kwargs: Any):
            """This is a private method. Please, use `create_from_backend` or `create_from_context`"""
            super(_XlaDistModel, self).__init__()
            if backend is not None:
                self._create_from_backend(backend, **kwargs)
            else:
                self._init_from_context()

        def _create_from_backend(self, backend: str, **kwargs: Any) -> None:
            xm.rendezvous("init")

            self._backend: str = backend
            self._setup_attrs()

        def _init_from_context(self) -> None:
            self._backend = XLA_TPU
            self._setup_attrs()

        def _compute_nproc_per_node(self) -> int:
            tensor = torch.tensor([self.get_local_rank() + 1.0], dtype=torch.float).to(self.device())
            xm.all_reduce("max", [tensor])
            return int(tensor.item())

        def get_local_rank(self) -> int:
            return xm.get_local_ordinal()

        def get_rank(self) -> int:
            return xm.get_ordinal()

        def get_world_size(self) -> int:
            return xm.xrt_world_size()

        def get_nproc_per_node(self) -> int:
            return cast(int, self._nproc_per_node)

        def get_nnodes(self) -> int:
            return cast(int, self._nnodes)

        def get_node_rank(self) -> int:
            return cast(int, self._node)

        def device(self) -> torch.device:
            dev = torch_xla._XLAC._xla_get_default_device()
            return torch.device(dev)

        def backend(self) -> str:
            return self._backend

        def finalize(self) -> None:
            pass

        @staticmethod
        def _dist_worker_task_fn(
            local_rank: int, backend: str, fn: Callable, args: Tuple, kwargs_dict: Mapping
        ) -> None:
            from ignite.distributed.utils import _set_model, finalize

            model = _XlaDistModel.create_from_backend(backend)
            _set_model(model)
            fn(local_rank, *args, **kwargs_dict)
            finalize()

        @staticmethod
        def spawn(
            fn: Callable,
            args: Tuple,
            kwargs_dict: Optional[Mapping] = None,
            nproc_per_node: int = 1,
            nnodes: int = 1,
            node_rank: int = 0,
            backend: str = XLA_TPU,
            **kwargs: Any,
        ) -> None:
            if "start_method" not in kwargs:
                kwargs["start_method"] = "fork"

            xmp.spawn(
                _XlaDistModel._dist_worker_task_fn,
                args=(backend, fn, args, kwargs_dict),
                nprocs=nproc_per_node,
                **kwargs,
            )

        _collective_op_dtype = torch.float32
        _reduce_op_map = {
            "SUM": "sum",
            "PRODUCT": "mul",
            "MIN": "min",
            "MAX": "max",
            "AND": "and",
            "OR": "or",
        }

        def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM", group: Optional[Any] = None) -> torch.Tensor:
            if group is not None and not self._check_group_type(group):
                raise ValueError("Argument group should be list of int")
            op = self._reduce_op_map[op]
            xm.all_reduce(op, [tensor], groups=group)
            return tensor

        def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor:
            # from https://github.com/jysohn23/xla/blob/model-parallel-colab/Gather_Scatter_Broadcast_PyTorch_XLA.ipynb

            if group is not None and (not isinstance(group, list) or not all(isinstance(item, int) for item in group)):
                raise ValueError("Argument group should be list of int")

            group_size = self.get_world_size()
            output = torch.zeros((group_size,) + tensor.shape, dtype=tensor.dtype, device=tensor.device)
            output[self.get_rank() % group_size] = tensor
            xm.all_reduce("sum", [output], groups=group)
            return output.reshape(-1, *output.shape[2:])

        def _do_all_gather_object(self, tensor: Any, group: Optional[Any] = None) -> List[Any]:
            raise NotImplementedError("all_gather on object is not implemented for xla")

        def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any:
            return [ranks]

        def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
            # from https://github.com/jysohn23/xla/blob/model-parallel-colab/Gather_Scatter_Broadcast_PyTorch_XLA.ipynb
            if src != self.get_rank():
                tensor.fill_(0.0)
            xm.all_reduce("sum", [tensor])
            return tensor

        def barrier(self) -> None:
            xm.rendezvous("barrier")

        def _check_group_type(self, group: Optional[Any]) -> bool:
            if isinstance(group, list) and all(isinstance(item, int) for item in group):
                return True
            return False