File: utils.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (211 lines) | stat: -rw-r--r-- 9,070 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import collections.abc
import copy
from typing import Optional, List, Sequence

import torch
from torch.distributed import distributed_c10d
from torch.distributed import rpc
from torch.distributed._shard.sharding_spec._internals import (
    check_tensor,
    validate_non_overlapping_shards_metadata,
)

from torch.distributed._shard.metadata import ShardMetadata
from .metadata import TensorProperties, ShardedTensorMetadata
from .shard import Shard

def _parse_and_validate_remote_device(pg, remote_device):
    if remote_device is None:
        raise ValueError("remote device is None")

    worker_name = remote_device.worker_name()
    rank = remote_device.rank()
    device = remote_device.device()

    # Validate rank, skip validation if rank is not part of process group.
    if not distributed_c10d._rank_not_in_group(pg):
        if rank is not None and (rank < 0 or rank >= distributed_c10d.get_world_size(pg)):
            raise ValueError(f'Invalid rank: {rank}')

    if worker_name is not None:
        if not rpc._is_current_rpc_agent_set():
            raise RuntimeError(f'RPC framework needs to be initialized for using worker names: {worker_name}')

        workers = rpc._get_current_rpc_agent().get_worker_infos()
        for worker in workers:
            if worker.name == worker_name:
                return worker.id, device

        raise ValueError(f'Invalid worker name: {worker_name}')

    return rank, device

def _validate_output_tensor_for_gather(
    my_rank: int,
    dst_rank: int,
    size: torch.Size,
    dst_tensor: Optional[torch.Tensor],
) -> None:
    if dst_rank == my_rank:
        if dst_tensor is None:
            raise ValueError(
                f"Argument ``dst_tensor`` must be specified on destination rank {dst_rank}"
            )
        if tuple(size) != (dst_tensor.size()):
            raise ValueError(
                f"Argument ``dst_tensor`` have size {tuple(dst_tensor.size())},"
                f"but should be {tuple(size)}"
            )
    elif dst_tensor:
        raise ValueError(
            "Argument ``dst_tensor`` must NOT be specified "
            "on non-destination ranks."
        )

def _flatten_tensor_size(size) -> torch.Size:
    """
    Checks if tensor size is valid, then flatten/return a torch.Size object.
    """
    if len(size) == 1 and isinstance(size[0], collections.abc.Sequence):
        dims = list(*size)
    else:
        dims = list(size)

    for dim in dims:
        if not isinstance(dim, int):
            raise TypeError(f'size has to be a sequence of ints, found: {dims}')

    return torch.Size(dims)

def _raise_if_mismatch(expected, actual, prop_name, ranks, is_local=True):
    if is_local:
        assert isinstance(ranks, int)
        if expected != actual:
            raise ValueError(f"Local shards' tensor {prop_name} property need to be the same on rank:{ranks}! "
                             f"Found one local shard tensor {prop_name}={expected}, "
                             f"the other local shard tensor {prop_name}={actual}.")
    else:
        # compare failure check across ranks, ranks list should have two rank
        assert len(ranks) == 2
        if expected != actual:
            raise ValueError(f"ShardedTensor {prop_name} property does not match from different ranks! "
                             f"Found {prop_name}={expected} on rank:{ranks[0]}, "
                             f"and {prop_name}={actual} on rank:{ranks[1]}.")


def build_metadata_from_local_shards(
    local_shards: List[Shard],
    global_size: torch.Size,
    current_rank: int,
    pg: distributed_c10d.ProcessGroup
) -> ShardedTensorMetadata:

    assert len(local_shards) > 0, "must have local shards!"
    local_shard_metadatas: List[ShardMetadata] = []

    first_shard_dtype = local_shards[0].tensor.dtype
    first_shard_layout = local_shards[0].tensor.layout
    first_shard_requires_grad = local_shards[0].tensor.requires_grad
    first_shard_is_pinned = local_shards[0].tensor.is_pinned()

    # 1). Validate local tensors and associated metadatas
    for i, local_shard in enumerate(local_shards):
        local_shard_tensor = local_shard.tensor
        local_shard_meta = local_shard.metadata
        local_shard_metadatas.append(local_shard_meta)
        rank, local_device = _parse_and_validate_remote_device(pg, local_shard_meta.placement)

        if local_shard_tensor.layout != torch.strided or local_shard_tensor.layout != first_shard_layout:
            raise ValueError(
                f'Only torch.strided layout is currently supported, but found '
                f'{local_shard_tensor.layout} on rank:{current_rank}!'
            )

        if not local_shard_tensor.is_contiguous():
            raise ValueError('Only torch.contiguous_format memory_format is currently supported!')

        if rank != current_rank:
            raise ValueError(
                f"Local shard metadata's rank does not match with the rank in its process group! "
                f'Found current rank in the process group: {current_rank}, '
                f"local ShardMetadata placement's rank: {rank}"
            )
        if local_shard_tensor.device != local_device:
            raise ValueError(
                f"Local shard tensor device does not match with local Shard's placement! "
                f"Found local shard tensor device: {local_shard_tensor.device}, "
                f"local shard metadata placement device: {local_device}"
            )

        _raise_if_mismatch(local_shard_meta.shard_sizes, list(local_shard_tensor.size()), "size", current_rank)
        _raise_if_mismatch(local_shard_tensor.is_pinned(), first_shard_is_pinned, "pin_memory", current_rank)
        _raise_if_mismatch(local_shard_tensor.dtype, first_shard_dtype, "dtype", current_rank)
        _raise_if_mismatch(local_shard_tensor.requires_grad, first_shard_requires_grad, "requires_grad", current_rank)

    # 2). Build a "local" ShardedTensorMetadata with all local shards on this rank, then
    #    do all_gather to collect local_sharded_tensor_metadata from all ranks
    local_tensor_properties = TensorProperties(
        dtype=first_shard_dtype,
        layout=first_shard_layout,
        requires_grad=first_shard_requires_grad,
        memory_format=torch.contiguous_format,
        pin_memory=first_shard_is_pinned
    )

    local_sharded_tensor_metadata = ShardedTensorMetadata(
        shards_metadata=local_shard_metadatas,
        size=global_size,
        tensor_properties=local_tensor_properties)

    return local_sharded_tensor_metadata


def build_global_metadata(gathered_metadatas: Sequence[Optional[ShardedTensorMetadata]]):
    global_sharded_tensor_metadata = None
    global_metadata_rank = 0

    for rank, rank_metadata in enumerate(gathered_metadatas):
        if rank_metadata is None:
            continue

        if global_sharded_tensor_metadata is None:
            global_sharded_tensor_metadata = copy.deepcopy(rank_metadata)
            global_metadata_rank = rank
        else:
            _raise_if_mismatch(global_sharded_tensor_metadata.size,
                               rank_metadata.size,
                               "global_size",
                               [global_metadata_rank, rank],
                               is_local=False)

            # don't need to check layout and memory format as we already checked in local shards validation stage
            _raise_if_mismatch(global_sharded_tensor_metadata.tensor_properties.dtype,
                               rank_metadata.tensor_properties.dtype,
                               "dtype",
                               [global_metadata_rank, rank],
                               is_local=False)

            _raise_if_mismatch(global_sharded_tensor_metadata.tensor_properties.requires_grad,
                               rank_metadata.tensor_properties.requires_grad,
                               "requires_grad",
                               [global_metadata_rank, rank],
                               is_local=False)

            _raise_if_mismatch(global_sharded_tensor_metadata.tensor_properties.pin_memory,
                               rank_metadata.tensor_properties.pin_memory,
                               "pin_memory",
                               [global_metadata_rank, rank],
                               is_local=False)
            # pass all validations, extend shards metadata
            global_sharded_tensor_metadata.shards_metadata.extend(rank_metadata.shards_metadata)

    if global_sharded_tensor_metadata is not None:
        # check if shards_metadata have overlap shards
        validate_non_overlapping_shards_metadata(global_sharded_tensor_metadata.shards_metadata)

        # check if the shards_metadata is compatible with global size of the sharded tensor.
        check_tensor(global_sharded_tensor_metadata.shards_metadata, global_sharded_tensor_metadata.size)
    else:
        raise ValueError("ShardedTensor have no local shards on all ranks!")

    return global_sharded_tensor_metadata