File: _visualize_sharding.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (178 lines) | stat: -rw-r--r-- 6,496 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
# mypy: allow-untyped-defs
from typing import List, Sequence, Tuple

import numpy as np

from torch._prims_common import ShapeType
from torch.distributed.tensor import DeviceMesh
from torch.distributed.tensor.placement_types import Placement, Shard


__all__ = ["visualize_sharding"]


def _mesh_to_coordinate(mesh, device_type):
    """
    Given a n-dimensional list of device mesh, this function creates a map of
    device and its coordinate
    """
    # Convert the n-dimensional list to a NumPy array
    np_mesh = np.array(mesh.mesh.tolist())

    # Create a dictionary to map each value to its coordinate
    device_to_coordinate_map = {}
    for coord, value in np.ndenumerate(np_mesh):
        # device is unique in device_mesh
        device_to_coordinate_map[f"{device_type}:{str(value)}"] = list(coord)

    return device_to_coordinate_map


def _convert_offset_to_ranges(all_offsets):
    """
    Using tabulate package to create a table is easier when we specify row and col ranges
    This function converts offsets to ranges.
    """
    converted_blocks = []

    for offset in all_offsets:
        shape, offset, value = offset

        # Calculate row_range and column_range
        row_range = (offset[0], offset[0] + shape[0] - 1)
        column_range = (offset[1], offset[1] + shape[1] - 1)

        # Convert value to string to match your desired format
        converted_block = {
            "row_range": row_range,
            "column_range": column_range,
            "value": str(value),
        }
        converted_blocks.append(converted_block)

    return converted_blocks


def _create_table(blocks):
    """
    Creates a tabulate table given row and column ranges with device name
    """
    try:
        from tabulate import tabulate
    except ImportError as e:
        raise ImportError("tabulate package is required to visualize sharding") from e

    # Extract unique row and column ranges
    row_ranges = sorted({block["row_range"] for block in blocks})
    col_ranges = sorted({block["column_range"] for block in blocks})

    # Create a matrix initialized with empty strings
    matrix = [["" for _ in col_ranges] for _ in row_ranges]

    # Fill the matrix with values
    for block in blocks:
        row_index = row_ranges.index(block["row_range"])
        col_index = col_ranges.index(block["column_range"])
        if matrix[row_index][col_index] == "":
            matrix[row_index][col_index] = block["value"]
        else:
            matrix[row_index][col_index] += ", " + block["value"]

    # Prepare headers
    row_headers = [f"Row {r[0]}-{r[1]}" for r in row_ranges]
    col_headers = [f"Col {c[0]}-{c[1]}" for c in col_ranges]

    return tabulate(matrix, headers=col_headers, showindex=row_headers)


def _compute_local_shape_and_global_offset(
    global_shape: ShapeType,
    mesh: DeviceMesh,
    placements: Sequence[Placement],
    my_coordinate: List[int],
) -> Tuple[Tuple[int, ...], Tuple[int, ...]]:
    """
    Same as torch.distributed._tensor._utils.compute_local_shape_and_global_offset but
    with custom my_coordinate input. This is the modified implementation for visualize_sharding.
    """

    if my_coordinate is None:
        # if rank not in the mesh, return empty offset
        return ((), ())
    else:
        local_shape = list(global_shape)
        global_offset = [0] * len(global_shape)

        for idx, placement in enumerate(placements):
            mesh_dim_size = mesh.size(idx)
            if isinstance(placement, Shard):
                shard_dim = placement.dim
                local_offset = [0] * len(global_shape)
                assert shard_dim < len(
                    local_shape
                ), f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}"
                shard_size, shard_offset = placement._local_shard_size_on_dim(
                    local_shape[shard_dim],
                    mesh_dim_size,
                    my_coordinate[idx],
                    return_offset=True,
                )

                local_shape[shard_dim] = shard_size
                local_offset[shard_dim] = shard_offset

                # On a given dimension, if the local_offset[shard_dim] is smaller than global_offset[shard_dim],
                # it means that this dimension has been already sharded in previous placement.
                # Therefore, we cannot simply replace the global_offset[shard_dim] with local_offset[shard_dim].
                # Instead, for the given shard_dim, we need to add local_offset[shard_dim] to existing global_offset[shard_dim].
                if global_offset[shard_dim] <= local_offset[shard_dim]:
                    global_offset[shard_dim] = local_offset[shard_dim]
                else:
                    global_offset[shard_dim] += local_offset[shard_dim]

        return tuple(local_shape), tuple(global_offset)


def visualize_sharding(dtensor, header=""):
    """
    Visualizes sharding in the terminal for :class:`DTensor` that are 1D or 2D.

    .. note:: This requires the ``tabulate`` package. No sharding info will be printed for empty tensors
    """
    if dtensor.numel() == 0:  # we do not print for empty dtensors
        return

    if len(dtensor.shape) >= 3:
        raise RuntimeError(
            "visualize sharding is only implemented for 1D or 2D dtensor"
        )
    placements = dtensor.placements
    device_mesh = dtensor.device_mesh
    device_type = dtensor.device_mesh.device_type

    if device_mesh.get_coordinate() is None:  # current rank is not in the mesh
        return

    # Only display the visualization once for each DTensor, on the rank whose
    # coordinate is 0 on all dimensions. For example, if the mesh is a full mesh,
    # we will only print on rank 0.
    local_rank_zero_on_all_dim = all(
        device_mesh.get_local_rank(mesh_dim=dim) == 0 for dim in range(device_mesh.ndim)
    )
    if not local_rank_zero_on_all_dim:
        return

    device_map = _mesh_to_coordinate(device_mesh, device_type)
    all_offsets = []
    for device in device_map:
        local_shape, global_offset = _compute_local_shape_and_global_offset(
            dtensor.shape, device_mesh, placements, device_map[device]
        )
        all_offsets.append([local_shape, global_offset, device])

    # Convert offsets to blocks with row_ranges for tabulate
    blocks = _convert_offset_to_ranges(all_offsets)

    # Print the table
    print(header)
    print(_create_table(blocks))