1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
|
# mypy: allow-untyped-defs
from typing import Any, Dict, List, Optional, overload, Sequence, Tuple, TypeVar, Union
from typing_extensions import deprecated
import torch
from torch.nn.parallel._functions import Gather, Scatter
__all__ = ["scatter", "scatter_kwargs", "gather"]
@deprecated(
"`is_namedtuple` is deprecated, please use the python checks instead",
category=FutureWarning,
)
def is_namedtuple(obj: Any) -> bool:
# Check if type was created from collections.namedtuple or a typing.NamedTuple.
return _is_namedtuple(obj)
def _is_namedtuple(obj: Any) -> bool:
# Check if type was created from collections.namedtuple or a typing.NamedTuple.
return (
isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields")
)
T = TypeVar("T", dict, list, tuple)
# For some reason, 'scatter' returns a tuple when given a single Tensor input but a list otherwise.
@overload
def scatter(
inputs: torch.Tensor,
target_gpus: Sequence[Union[int, torch.device]],
dim: int = ...,
) -> Tuple[torch.Tensor, ...]:
...
@overload
def scatter(
inputs: T,
target_gpus: Sequence[Union[int, torch.device]],
dim: int = ...,
) -> List[T]:
...
def scatter(inputs, target_gpus, dim=0):
r"""Slice tensors into approximately equal chunks and distributes them across given GPUs.
Duplicates references to objects that are not tensors.
"""
def scatter_map(obj):
if isinstance(obj, torch.Tensor):
return Scatter.apply(target_gpus, None, dim, obj)
if _is_namedtuple(obj):
return [type(obj)(*args) for args in zip(*map(scatter_map, obj))]
if isinstance(obj, tuple) and len(obj) > 0:
return list(zip(*map(scatter_map, obj)))
if isinstance(obj, list) and len(obj) > 0:
return [list(i) for i in zip(*map(scatter_map, obj))]
if isinstance(obj, dict) and len(obj) > 0:
return [type(obj)(i) for i in zip(*map(scatter_map, obj.items()))]
return [obj for _ in target_gpus]
# After scatter_map is called, a scatter_map cell will exist. This cell
# has a reference to the actual function scatter_map, which has references
# to a closure that has a reference to the scatter_map cell (because the
# fn is recursive). To avoid this reference cycle, we set the function to
# None, clearing the cell
try:
res = scatter_map(inputs)
finally:
scatter_map = None # type: ignore[assignment]
return res
def scatter_kwargs(
inputs: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]],
target_gpus: Sequence[Union[int, torch.device]],
dim: int = 0,
) -> Tuple[Tuple[Any, ...], Tuple[Dict[str, Any], ...]]:
r"""Scatter with support for kwargs dictionary."""
scattered_inputs = scatter(inputs, target_gpus, dim) if inputs else []
scattered_kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
if len(scattered_inputs) < len(scattered_kwargs):
scattered_inputs.extend(
() for _ in range(len(scattered_kwargs) - len(scattered_inputs))
)
elif len(scattered_kwargs) < len(inputs):
scattered_kwargs.extend(
{} for _ in range(len(scattered_inputs) - len(scattered_kwargs))
)
return tuple(scattered_inputs), tuple(scattered_kwargs)
def gather(outputs: Any, target_device: Union[int, torch.device], dim: int = 0) -> Any:
r"""Gather tensors from different GPUs on a specified device.
This function is useful for gathering the results of a distributed computation.
It takes a sequence of objects, one for each GPU, and returns a single object
on the specified device.
Args:
outputs (Any): A sequence of objects (potentially tensors) to gather.
target_device (Union[int, torch.device]): The device to gather the tensors to.
Use 'cpu' for CPU to avoid a deprecation warning.
dim (int, optional): The dimension along which to gather. Default: 0.
Returns:
Any: A gathered object (potentially tensor) on the specified device.
"""
def gather_map(outputs):
out = outputs[0]
if isinstance(out, torch.Tensor):
return Gather.apply(target_device, dim, *outputs)
if out is None:
return None
if isinstance(out, dict):
if not all(len(out) == len(d) for d in outputs):
raise ValueError("All dicts must have the same number of keys")
return type(out)((k, gather_map([d[k] for d in outputs])) for k in out)
if _is_namedtuple(out):
return type(out)._make(map(gather_map, zip(*outputs)))
return type(out)(map(gather_map, zip(*outputs)))
# Recursive function calls like this create reference cycles.
# Setting the function to None clears the refcycle.
try:
res = gather_map(outputs)
finally:
gather_map = None # type: ignore[assignment]
return res
|