1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
|
# mypy: allow-untyped-defs
from typing import List, Sequence, Tuple
import numpy as np
from torch._prims_common import ShapeType
from torch.distributed.tensor import DeviceMesh
from torch.distributed.tensor.placement_types import Placement, Shard
__all__ = ["visualize_sharding"]
def _mesh_to_coordinate(mesh, device_type):
"""
Given a n-dimensional list of device mesh, this function creates a map of
device and its coordinate
"""
# Convert the n-dimensional list to a NumPy array
np_mesh = np.array(mesh.mesh.tolist())
# Create a dictionary to map each value to its coordinate
device_to_coordinate_map = {}
for coord, value in np.ndenumerate(np_mesh):
# device is unique in device_mesh
device_to_coordinate_map[f"{device_type}:{str(value)}"] = list(coord)
return device_to_coordinate_map
def _convert_offset_to_ranges(all_offsets):
"""
Using tabulate package to create a table is easier when we specify row and col ranges
This function converts offsets to ranges.
"""
converted_blocks = []
for offset in all_offsets:
shape, offset, value = offset
# Calculate row_range and column_range
row_range = (offset[0], offset[0] + shape[0] - 1)
column_range = (offset[1], offset[1] + shape[1] - 1)
# Convert value to string to match your desired format
converted_block = {
"row_range": row_range,
"column_range": column_range,
"value": str(value),
}
converted_blocks.append(converted_block)
return converted_blocks
def _create_table(blocks):
"""
Creates a tabulate table given row and column ranges with device name
"""
try:
from tabulate import tabulate
except ImportError as e:
raise ImportError("tabulate package is required to visualize sharding") from e
# Extract unique row and column ranges
row_ranges = sorted({block["row_range"] for block in blocks})
col_ranges = sorted({block["column_range"] for block in blocks})
# Create a matrix initialized with empty strings
matrix = [["" for _ in col_ranges] for _ in row_ranges]
# Fill the matrix with values
for block in blocks:
row_index = row_ranges.index(block["row_range"])
col_index = col_ranges.index(block["column_range"])
if matrix[row_index][col_index] == "":
matrix[row_index][col_index] = block["value"]
else:
matrix[row_index][col_index] += ", " + block["value"]
# Prepare headers
row_headers = [f"Row {r[0]}-{r[1]}" for r in row_ranges]
col_headers = [f"Col {c[0]}-{c[1]}" for c in col_ranges]
return tabulate(matrix, headers=col_headers, showindex=row_headers)
def _compute_local_shape_and_global_offset(
global_shape: ShapeType,
mesh: DeviceMesh,
placements: Sequence[Placement],
my_coordinate: List[int],
) -> Tuple[Tuple[int, ...], Tuple[int, ...]]:
"""
Same as torch.distributed._tensor._utils.compute_local_shape_and_global_offset but
with custom my_coordinate input. This is the modified implementation for visualize_sharding.
"""
if my_coordinate is None:
# if rank not in the mesh, return empty offset
return ((), ())
else:
local_shape = list(global_shape)
global_offset = [0] * len(global_shape)
for idx, placement in enumerate(placements):
mesh_dim_size = mesh.size(idx)
if isinstance(placement, Shard):
shard_dim = placement.dim
local_offset = [0] * len(global_shape)
assert shard_dim < len(
local_shape
), f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}"
shard_size, shard_offset = placement._local_shard_size_on_dim(
local_shape[shard_dim],
mesh_dim_size,
my_coordinate[idx],
return_offset=True,
)
local_shape[shard_dim] = shard_size
local_offset[shard_dim] = shard_offset
# On a given dimension, if the local_offset[shard_dim] is smaller than global_offset[shard_dim],
# it means that this dimension has been already sharded in previous placement.
# Therefore, we cannot simply replace the global_offset[shard_dim] with local_offset[shard_dim].
# Instead, for the given shard_dim, we need to add local_offset[shard_dim] to existing global_offset[shard_dim].
if global_offset[shard_dim] <= local_offset[shard_dim]:
global_offset[shard_dim] = local_offset[shard_dim]
else:
global_offset[shard_dim] += local_offset[shard_dim]
return tuple(local_shape), tuple(global_offset)
def visualize_sharding(dtensor, header=""):
"""
Visualizes sharding in the terminal for :class:`DTensor` that are 1D or 2D.
.. note:: This requires the ``tabulate`` package. No sharding info will be printed for empty tensors
"""
if dtensor.numel() == 0: # we do not print for empty dtensors
return
if len(dtensor.shape) >= 3:
raise RuntimeError(
"visualize sharding is only implemented for 1D or 2D dtensor"
)
placements = dtensor.placements
device_mesh = dtensor.device_mesh
device_type = dtensor.device_mesh.device_type
if device_mesh.get_coordinate() is None: # current rank is not in the mesh
return
# Only display the visualization once for each DTensor, on the rank whose
# coordinate is 0 on all dimensions. For example, if the mesh is a full mesh,
# we will only print on rank 0.
local_rank_zero_on_all_dim = all(
device_mesh.get_local_rank(mesh_dim=dim) == 0 for dim in range(device_mesh.ndim)
)
if not local_rank_zero_on_all_dim:
return
device_map = _mesh_to_coordinate(device_mesh, device_type)
all_offsets = []
for device in device_map:
local_shape, global_offset = _compute_local_shape_and_global_offset(
dtensor.shape, device_mesh, placements, device_map[device]
)
all_offsets.append([local_shape, global_offset, device])
# Convert offsets to blocks with row_ranges for tabulate
blocks = _convert_offset_to_ranges(all_offsets)
# Print the table
print(header)
print(_create_table(blocks))
|