File: distributed.tensor.rst

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (194 lines) | stat: -rw-r--r-- 7,853 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
.. currentmodule:: torch.distributed.tensor

torch.distributed.tensor
===========================

.. note::
  ``torch.distributed.tensor`` is currently in alpha state and under
  development, we are committing backward compatibility for the most APIs listed
  in the doc, but there might be API changes if necessary.


PyTorch DTensor (Distributed Tensor)
---------------------------------------

PyTorch DTensor offers simple and flexible tensor sharding primitives that transparently handles distributed
logic, including sharded storage, operator computation and collective communications across devices/hosts.
``DTensor`` could be used to build different paralleism solutions and support sharded state_dict representation
when working with multi-dimensional sharding.

Please see examples from the PyTorch native parallelism solutions that are built on top of ``DTensor``:

* `Tensor Parallel <https://pytorch.org/docs/main/distributed.tensor.parallel.html>`__
* `FSDP2 <https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md>`__

.. automodule:: torch.distributed.tensor

:class:`DTensor` follows the SPMD (single program, multiple data) programming model to empower users to
write distributed program as if it's a **single-device program with the same convergence property**. It
provides a uniform tensor sharding layout (DTensor Layout) through specifying the :class:`DeviceMesh`
and :class:`Placement`:

- :class:`DeviceMesh` represents the device topology and the communicators of the cluster using
  an n-dimensional array.

- :class:`Placement` describes the sharding layout of the logical tensor on the :class:`DeviceMesh`.
  DTensor supports three types of placements: :class:`Shard`, :class:`Replicate` and :class:`Partial`.


DTensor Class APIs
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. currentmodule:: torch.distributed.tensor

:class:`DTensor` is a ``torch.Tensor`` subclass. This means once a :class:`DTensor` is created, it could be
used in very similar way to ``torch.Tensor``, including running different types of PyTorch operators as if
running them in a single device, allowing proper distributed computation for PyTorch operators.

In addition to existing ``torch.Tensor`` methods, it also offers a set of additional methods to interact with
``torch.Tensor``, ``redistribute`` the DTensor Layout to a new DTensor, get the full tensor content
on all devices, etc.

.. autoclass:: DTensor
    :members: from_local, to_local, full_tensor, redistribute, device_mesh, placements
    :member-order: groupwise


DeviceMesh as the distributed communicator
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. currentmodule:: torch.distributed.device_mesh

:class:`DeviceMesh` was built from DTensor as the abstraction to describe cluster's device topology and represent
multi-dimensional communicators (on top of ``ProcessGroup``). To see the details of how to create/use a DeviceMesh,
please refer to the `DeviceMesh recipe <https://pytorch.org/tutorials/recipes/distributed_device_mesh.html>`__.


DTensor Placement Types
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. automodule:: torch.distributed.tensor.placement_types
.. currentmodule:: torch.distributed.tensor.placement_types

DTensor supports the following types of :class:`Placement` on each :class:`DeviceMesh` dimension:

.. autoclass:: Shard
  :members:
  :undoc-members:

.. autoclass:: Replicate
  :members:
  :undoc-members:

.. autoclass:: Partial
  :members:
  :undoc-members:

.. autoclass:: Placement
  :members:
  :undoc-members:


.. _create_dtensor:

Different ways to create a DTensor
---------------------------------------

.. currentmodule:: torch.distributed.tensor

There're three ways to construct a :class:`DTensor`:
  * :meth:`distribute_tensor` creates a :class:`DTensor` from a logical or "global" ``torch.Tensor`` on
    each rank. This could be used to shard the leaf ``torch.Tensor`` s (i.e. model parameters/buffers
    and inputs).
  * :meth:`DTensor.from_local` creates a :class:`DTensor` from a local ``torch.Tensor`` on each rank, which can
    be used to create :class:`DTensor` from a non-leaf ``torch.Tensor`` s (i.e. intermediate activation
    tensors during forward/backward).
  * DTensor provides dedicated tensor factory functions (e.g. :meth:`empty`, :meth:`ones`, :meth:`randn`, etc.)
    to allow different :class:`DTensor` creations by directly specifying the :class:`DeviceMesh` and
    :class:`Placement`. Compare to :meth:`distribute_tensor`, this could directly materializing the sharded memory
    on device, instead of performing sharding after initializing the logical Tensor memory.

Create DTensor from a logical torch.Tensor
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

The SPMD (single program, multiple data) programming model in ``torch.distributed`` launches multiple processes
(i.e. via ``torchrun``) to execute the same program, this means that the model inside the program would be
initialized on different processes first (i.e. the model might be initialized on CPU, or meta device, or directly
on GPU if enough memory).

``DTensor`` offers a :meth:`distribute_tensor` API that could shard the model weights or Tensors to ``DTensor`` s,
where it would create a DTensor from the "logical" Tensor on each process. This would empower the created
``DTensor`` s to comply with the single device semantic, which is critical for **numerical correctness**.

.. autofunction::  distribute_tensor

Along with :meth:`distribute_tensor`, DTensor also offers a :meth:`distribute_module` API to allow easier
sharding on the :class:`nn.Module` level

.. autofunction::  distribute_module


DTensor Factory Functions
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

DTensor also provides dedicated tensor factory functions to allow creating :class:`DTensor` directly
using torch.Tensor like factory function APIs (i.e. torch.ones, torch.empty, etc), by additionally
specifying the :class:`DeviceMesh` and :class:`Placement` for the :class:`DTensor` created:

.. autofunction:: zeros

.. autofunction:: ones

.. autofunction:: empty

.. autofunction:: full

.. autofunction:: rand

.. autofunction:: randn


Debugging
---------------------------------------

.. automodule:: torch.distributed.tensor.debug
.. currentmodule:: torch.distributed.tensor.debug

Logging
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
When launching the program, you can turn on additional logging using the `TORCH_LOGS` environment variable from
`torch._logging <https://pytorch.org/docs/main/logging.html#module-torch._logging>`__ :

* `TORCH_LOGS=+dtensor` will display `logging.DEBUG` messages and all levels above it.
* `TORCH_LOGS=dtensor` will display `logging.INFO` messages and above.
* `TORCH_LOGS=-dtensor` will display `logging.WARNING` messages and above.

Debugging Tools
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

To debug the program that applied DTensor, and understand more details about what collectives happened under the
hood, DTensor provides a :class:`CommDebugMode`:

.. autoclass:: CommDebugMode
    :members:
    :undoc-members:

To visualize the sharding of a DTensor that have less than 3 dimensions, DTensor provides :meth:`visualize_sharding`:

.. autofunction:: visualize_sharding


Experimental Features
---------------------------------------

``DTensor`` also provides a set of experimental features. These features are either in prototyping stage, or the basic
functionality is done and but looking for user feedbacks. Please submit a issue to PyTorch if you have feedbacks to
these features.

.. automodule:: torch.distributed.tensor.experimental
.. currentmodule:: torch.distributed.tensor.experimental

.. autofunction:: context_parallel
.. autofunction:: local_map
.. autofunction:: register_sharding


.. modules that are missing docs, add the doc later when necessary
.. py:module:: torch.distributed.tensor.device_mesh