File: __init__.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (44 lines) | stat: -rw-r--r-- 968 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
"""
NOTICE: DTensor has moved to torch.distributed.tensor

This file is a shim to redirect to the new location, and
we keep the old import path starts with `_tensor` for
backward compatibility. We will remove this folder once
we resolve all the BC issues.
"""
import sys
from importlib import import_module


submodules = [
    # TODO: _shards_wrapper/_utils here mainly for checkpoint BC, remove them
    "_shards_wrapper",
    "_utils",
    "experimental",
    "device_mesh",
]

# Redirect imports
for submodule in submodules:
    full_module_name = f"torch.distributed.tensor.{submodule}"
    sys.modules[f"torch.distributed._tensor.{submodule}"] = import_module(
        full_module_name
    )

from torch.distributed.tensor import (  # noqa: F401
    DeviceMesh,
    distribute_module,
    distribute_tensor,
    DTensor,
    empty,
    full,
    init_device_mesh,
    ones,
    Partial,
    Placement,
    rand,
    randn,
    Replicate,
    Shard,
    zeros,
)