File: partial_tensor.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (314 lines) | stat: -rw-r--r-- 13,104 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
import functools
from typing import Callable, Dict, TYPE_CHECKING

import torch
import torch.distributed as dist
import torch.distributed._shard.sharding_spec as shard_spec
from torch.distributed import distributed_c10d
from torch.distributed.nn.functional import (
    reduce_scatter,
)
from torch.distributed._shard.common_op_utils import _register_default_op
from torch.distributed._shard.op_registry_utils import _decorator_func
from torch.utils._pytree import tree_map

if TYPE_CHECKING:
    # Only include ShardedTensor when do type checking, exclude it
    # from run-time to resolve circular dependency.
    from torch.distributed._shard.sharded_tensor import ShardedTensor

# Custom PartialTensor ops
_PARTIAL_TENSOR_OPS: Dict[Callable, Callable] = {}

def _custom_partial_tensor_op(func):
    """
    Decorate for custom partial tensor op
    Args:
        func(Callable): Torch function for which we want to provide a PartialTensor
            implementation (ex: torch.nn.functional.linear)
    """
    return functools.partial(
        _decorator_func,
        op=func,
        op_table=_PARTIAL_TENSOR_OPS
    )

class _PartialTensor(torch.Tensor):
    """
    PartialTensor is an abstraction to represent Tensors that need
    aggregation across multiple devices and multiple processes.

    PartialTensor is initialized in an SPMD like fashion where each rank
    initializes the PartialTensor. The PartialTensor object on each rank
    then only stores the local partial shard, process group and the
    aggregation way to get a full tensor.

    PartialTensor doesn't provide any Tensor like operations but is a
    wrapper providing the Tensor representing the local partial shard.

    We assume the size of each local tensor to be exactly the same.

    Users can apply custom distributed sharded computations on top of
    this primitive.

    Args:
        local_partial_shard (Tensor): Partial result stored across ranks.
        process_group (ProcessGroup): The process group to aggregate on.
        reduce_op (distributed_c10d.ReduceOp): Way to aggregate the partial result.
            Default: ``distributed_c10d.ReduceOp.SUM``

    Examples:
        >>> # All tensors below are of torch.int64 type.
        >>> # We have 2 process groups, 2 ranks.
        >>> # xdoctest: +SKIP
        >>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank
        >>> tensor = torch.cat([tensor, tensor + 2])
        >>> tensor
        tensor([1, 2, 3, 4]) # Rank 0
        tensor([3, 4, 5, 6]) # Rank 1
        >>> partial_tensor = _PartialTensor(tensor, distributed_c10d.ReduceOp.MAX)
        >>> sharding_dim = 0
        >>> collect_spec = shard_spec.ChunkShardingSpec(
                dim=sharding_dim,
                placements=[
                    "rank:0/cuda:0",
                    "rank:1/cuda:1",
                ],
            )
        >>> complete_tensor = partial_tensor.reshard(collect_spec)
        >>> complete_tensor
        ShardedTensor(
            ShardedTensorMetadata(
                shards_metadata=[
                    ShardMetadata(shard_offsets=[0], shard_sizes=[2], placement=rank:0/cuda:0),
                    ShardMetadata(shard_offsets=[2], shard_sizes=[2], placement=rank:1/cuda:1)],
                size=torch.Size([4])
        )
        >>> complete_tensor.local_tensor()
        tensor([3, 4]) # Rank 0
        tensor([5, 6]) # Rank 1

        >>> # All tensors below are of torch.cfloat type.
        >>> # We have 2 process groups, 2 ranks.
        >>> tensor = torch.tensor([1, 2]) + 2 * rank
        >>> tensor = torch.cat([tensor, tensor + 2])
        >>> tensor
        tensor([1, 2, 3, 4]) # Rank 0
        tensor([3, 4, 5, 6]) # Rank 1
        >>> partial_tensor = _PartialTensor(tensor)
        >>> complete_tensor = partial_tensor.reshard(collect_spec)
        >>> complete_tensor
        ShardedTensor(
            ShardedTensorMetadata(
                shards_metadata=[
                    ShardMetadata(shard_offsets=[0], shard_sizes=[2], placement=rank:0/cuda:0),
                    ShardMetadata(shard_offsets=[2], shard_sizes=[2], placement=rank:1/cuda:1)],
                size=torch.Size([4])
        )
        >>> complete_tensor.local_tensor()
        tensor([4, 6]) # Rank 0
        tensor([8, 10]) # Rank 1
    """

    _process_group: distributed_c10d.ProcessGroup
    _local_shard: torch.Tensor
    _reduce_op: distributed_c10d.ReduceOp

    __slots__ = ["_process_group", "_local_shard", "_reduce_op"]

    def __new__(cls, local_shard, process_group=None, reduce_op=distributed_c10d.ReduceOp.SUM):
        r = torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
            cls,
            local_shard.size(),
            dtype=local_shard.dtype,
            layout=local_shard.layout,
            pin_memory=local_shard.is_pinned(),
            requires_grad=local_shard.requires_grad)      # type: ignore[arg-type]
        r._process_group = (     # type: ignore[attr-defined]
            process_group
            if process_group is not None
            else distributed_c10d._get_default_group()
        )
        r._reduce_op = reduce_op
        r._local_shard = local_shard
        return r

    def __post_init__(self):
        if not isinstance(self._reduce_op, distributed_c10d.ReduceOp):
            raise ValueError(
                "reduce_op needs to be a member of distributed_c10d.ReduceOp."
            )

    def reshard(self, resharding_spec: shard_spec.ShardingSpec) -> "ShardedTensor":
        """
        The reshard happens in two steps logically:

        1. Aggregate all the shards of the partial tensor.
        2. Shard this tensor according to the provided spec.

        In reality, for the sake of performance, we consolidate all partial tensors
        across multiple ranks and covert to a sharded tensor in one step.

        Args:
            resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`):
                The specification describing how we reshard the aggregated local result.

        Returns:
            A :class:`ShardedTensor` filled with local aggregated result.
        """
        from torch.distributed._shard.sharded_tensor.api import ShardedTensor

        if not isinstance(resharding_spec, shard_spec.ChunkShardingSpec):
            raise NotImplementedError("Only ChunkShardingSpec supported for reshard.")
        if self._local_shard.is_complex():
            raise NotImplementedError("Only real partial tensor supported for reshard.")
        sharding_dim = int(resharding_spec.dim)  # type: ignore[attr-defined]
        chunk_mode_res = self._local_shard.size(sharding_dim) % self._process_group.size()
        local_shard = self._local_shard
        # Add padding when the size is not divisible by the world size.
        if chunk_mode_res != 0:
            padding = [0] * (local_shard.dim() * 2)
            padding[-1] = self._process_group.size() - chunk_mode_res
            local_shard = torch.nn.functional.pad(
                local_shard,
                tuple(padding),
                "constant",
                0,
            )
        current_rank = dist.get_rank(self._process_group)  # type: ignore[attr-defined]
        rank_idx = None
        rearrange_local_shards = False
        indices = [0] * self._process_group.size()
        for idx, placement in enumerate(resharding_spec.placements):  # type: ignore[attr-defined]
            if placement.rank() == current_rank:  # type: ignore[index, union-attr]
                rank_idx = idx  # type: ignore[attr-defined]
            if placement.rank() != idx:  # type: ignore[index, union-attr]
                rearrange_local_shards = True
            indices[placement.rank()] = idx  # type: ignore[index, union-attr]

        local_shards = local_shard.chunk(self._process_group.size(), dim=sharding_dim)
        if rearrange_local_shards:
            # Need to re-arrange original shard_dim of output_tensor_list.
            local_shards = [local_shards[idx] for idx in indices]  # type: ignore[call-overload]
        local_result = reduce_scatter(
            torch.empty_like(local_shards[0]),
            list(local_shards),
            op=self._reduce_op,
            group=self._process_group,
        )

        sharded_tensor_size = self._local_shard.size()
        # Remove padding when the size is not divisible by the world size.
        if chunk_mode_res != 0:
            uneven_local_shards = self._local_shard.chunk(
                self._process_group.size(), dim=sharding_dim
            )
            expected_size = uneven_local_shards[rank_idx].size()  # type: ignore[index]
            if local_result.size() != expected_size:
                local_result = local_result.narrow(
                    sharding_dim,
                    0,
                    expected_size[sharding_dim],
                )
        return ShardedTensor._init_from_local_tensor(
            local_result,
            resharding_spec,
            sharded_tensor_size,
            process_group=self._process_group,
        )

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        # Find process_group
        process_group = None

        def find_process_group(e):
            nonlocal process_group
            if process_group is None and isinstance(e, _PartialTensor):
                process_group = e._process_group

        tree_map(find_process_group, args)
        tree_map(find_process_group, kwargs)

        if func in _PARTIAL_TENSOR_OPS:
            return _PARTIAL_TENSOR_OPS[func](types, args, kwargs, process_group)

        # Need to disable all dispatch to print args and kwargs appropriately.
        guard = torch._C._DisableTorchDispatch()  # type: ignore[attr-defined]
        try:
            with torch._C.DisableTorchFunction():
                raise RuntimeError(
                    f"torch function '{func.__name__}', with args: {args} and "
                    f"kwargs: {kwargs} not supported for PartialTensor!")
        finally:
            del guard

    @classmethod
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        raise RuntimeError(
            f"A {cls.__name__} object is being used from c++ "
            f"while calling {func.__module__}.{func.__name__} "
            "but the there is no custom __torch_dispatch__ implementation for it."
        )

    def __repr__(self):
        return f"PartialTensor({super(_PartialTensor, self).__repr__()})"

def _transpose_impl(types, args=(), kwargs=None, process_group=None):
    partial_tensor = args[0]
    input = partial_tensor._local_shard
    dim0 = args[1]
    dim1 = args[2]
    return _PartialTensor(
        torch.transpose(input, dim0, dim1),
        process_group,
        partial_tensor._reduce_op
    )

@_custom_partial_tensor_op(torch.Tensor.transpose)
def partial_transpose(types, args=(), kwargs=None, process_group=None):
    return _transpose_impl(types, args, kwargs, process_group)

@_custom_partial_tensor_op(torch.transpose)
def partial_torch_transpose(types, args=(), kwargs=None, process_group=None):
    return _transpose_impl(types, args, kwargs, process_group)

@_custom_partial_tensor_op(torch.cat)
def partial_cat(types, args=(), kwargs=None, process_group=None):
    input_list = args[0]
    if len(input_list) == 0:
        raise RuntimeError('Empty list of tensors to torch.cat!')

    local_shards = []
    for idx, input in enumerate(input_list):
        if not isinstance(input, _PartialTensor):
            raise RuntimeError('All inputs need to be an instance of _PartialTensor')
        if idx == 0:
            reduce_op = input._reduce_op
        elif reduce_op != input._reduce_op:
            raise RuntimeError(
                'All _PartialTensor reduce_ops need to be the same, found: '
                '{reduce_op} and {input._reduce_op}'
            )

        local_shards.append(input._local_shard)

    if kwargs is None:
        dim = 0
    else:
        if 'out' in kwargs:
            raise RuntimeError('"out" kwarg is not supported!')
        dim = kwargs['dim'] if 'dim' in kwargs else 0

    return _PartialTensor(torch.cat(local_shards, dim), process_group, input._reduce_op)

# Tensor properties access
_register_default_op(torch.Tensor.requires_grad.__get__, _custom_partial_tensor_op)  # type: ignore[attr-defined]
_register_default_op(torch.Tensor.shape.__get__, _custom_partial_tensor_op)  # type: ignore[attr-defined]
_register_default_op(torch.Tensor.dtype.__get__, _custom_partial_tensor_op)  # type: ignore[attr-defined]
_register_default_op(torch.Tensor.layout.__get__, _custom_partial_tensor_op)  # type: ignore[attr-defined]
_register_default_op(torch.Tensor.size, _custom_partial_tensor_op)
_register_default_op(torch.Tensor.dim, _custom_partial_tensor_op)
_register_default_op(torch.Tensor.ndim.__get__, _custom_partial_tensor_op)  # type: ignore[attr-defined]
_register_default_op(torch.Tensor.is_contiguous, _custom_partial_tensor_op)
_register_default_op(torch.Tensor.contiguous, _custom_partial_tensor_op)