File: _utils.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (67 lines) | stat: -rw-r--r-- 2,320 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# mypy: allow-untyped-defs
import warnings
from typing import Tuple, Union

from torch.distributed.device_mesh import _mesh_resources
from torch.distributed.tensor import DeviceMesh
from torch.distributed.tensor.placement_types import Placement


try:
    from torch._dynamo.external_utils import is_compiling as is_torchdynamo_compiling
except Exception:

    def is_torchdynamo_compiling():  # type: ignore[misc]
        return False


LayoutsType = Union[Placement, Tuple[Placement, ...]]


def _deprecate_warnings(func_name: str, extra_msg: str) -> None:
    """
    Inject common validation logics for `_prepare_input` funcs via this decorator.

    Include verifying that input needs to be either a :class:`Tensor` or :class:`DTensor`
    and only 1D :class:`DeviceMesh` is passed in.
    """
    # TODO: Will follow up with dynamo POC to make warnings.warn working with dynamo.
    if not is_torchdynamo_compiling():
        warnings.warn(
            f"{func_name} is deprecated and will be removed soon. {extra_msg}",
            FutureWarning,
            stacklevel=3,
        )


def _validate_tp_mesh_dim(
    device_mesh: DeviceMesh,
) -> None:
    """
    Check whether TP mesh dimension is valid or not.

    Args:
        device_mesh (:class:`DeviceMesh`):
            The `device_mesh` where we perform
            Tensor Parallelism on.

    Return:
        `True` if the mesh dimension
        is valid, `False` otherwise.
    """
    if device_mesh.ndim > 1:
        raise ValueError(
            f"Tensor Parallel only accepts a 1D DeviceMesh, but found {device_mesh.ndim}D!"
            'If you have a 2-D or N-D device_mesh, consider passing in device_mesh["tp"]'
        )

    root_mesh = _mesh_resources.get_root_mesh(device_mesh)
    # if a root mesh is not the same as device_mesh,
    # meaning the device_mesh is sliced out from the root mesh.
    if root_mesh and root_mesh != device_mesh:
        tp_mesh_dim_in_root = _mesh_resources.get_root_mesh_dim(device_mesh)
        if tp_mesh_dim_in_root != root_mesh.ndim - 1:
            raise RuntimeError(
                f"Found TP device_mesh on the {tp_mesh_dim_in_root} dimension of its parent mesh.",
                "Currently we only support intranode TP and TP needs to be the innermost dimension on its parent mesh.",
            )