File: distributed.fsdp.fully_shard.rst

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (85 lines) | stat: -rw-r--r-- 3,733 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
torch.distributed.fsdp.fully_shard
==================================

PyTorch FSDP2 (``fully_shard``)
-------------------------------

PyTorch FSDP2 provides a fully sharded data parallelism (FSDP) implementation
targeting performant eager-mode while using per-parameter sharding for improved
usability.

- If you are new to FSDP, we recommend that you start with FSDP2 due to improved
  usability.
- If you are currently using FSDP1, consider evaluating the following
  differences to see if you should switch to FSDP2:

Compared to PyTorch FSDP1 (``FullyShardedDataParallel``):

- FSDP2 uses ``DTensor``-based dim-0 per-parameter sharding for a simpler
  sharding representation compared to FSDP1's flat-parameter sharding, while
  preserving similar throughput performance. More specifically, FSDP2 chunks
  each parameter on dim-0 across the data parallel workers (using
  ``torch.chunk(dim=0)``), whereas FSDP1 flattens, concatenates, and chunks a
  group of tensors together, making reasoning about what data is present on
  each worker and resharding to different parallelisms complex. Per-parameter
  sharding provides a more intuitive user experience, relaxes constraints
  around frozen parameters, and allows for communication-free (sharded) state
  dicts, which otherwise require all-gathers in FSDP1.
- FSDP2 implements a different memory management approach to handle the
  multi-stream usages that avoids ``torch.Tensor.record_stream``. This ensures
  deterministic and expected memory usage and does not require blocking the CPU
  like in FSDP1's ``limit_all_gathers=True``.
- FSDP2 exposes APIs for manual control over prefetching and collective
  scheduling, allowing power users more customization. See the methods on
  ``FSDPModule`` below for details.
- FSDP2 simplifies some of the API surface: e.g. FSDP2 does not directly
  support full state dicts. Instead, users can reshard the sharded state dicts
  containing ``DTensor`` s to full state dicts themselves using ``DTensor``
  APIs like ``DTensor.full_tensor()`` or by using higher-level APIs like
  `PyTorch Distributed Checkpoint <https://pytorch.org/docs/stable/distributed.checkpoint.html>`_ 's
  distributed state dict APIs. Also, some other args have been removed; see
  `here <https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md>`_ for
  details.

If you are onboarding FSDP for the first time or if any of the above appeals to
your use case, we recommend that you consider using FSDP2.

See `this RFC <https://github.com/pytorch/pytorch/issues/114299>`_ for details
on system design and implementation.

.. note::
  ``torch.distributed.fsdp.fully_shard`` is currently in prototype state and
  under development. The core API will likely not change, but we may make some
  API changes if necessary.

.. currentmodule:: torch.distributed.fsdp

The frontend API is ``fully_shard`` that can be called on a ``module``:

.. autofunction:: fully_shard

Calling ``fully_shard(module)`` dynamically constructs a new class that
subclasses ``type(module)`` and an FSDP class ``FSDPModule``. For example, if
we call ``fully_shard(linear)`` on a module ``linear: nn.Linear``, then FSDP
constructs a new class ``FSDPLinear`` and changes ``linear`` 's type to this.
Otherwise, ``fully_shard`` does not change the module structure and parameter
fully-qualified names. The class ``FSDPModule`` allows providing some
FSDP-specific methods on the module.

.. autoclass:: FSDPModule
    :members:
    :member-order: bysource

.. autoclass:: UnshardHandle
    :members:

.. autofunction:: register_fsdp_forward_method

.. autoclass:: MixedPrecisionPolicy
    :members:

.. autoclass:: OffloadPolicy
    :members:

.. autoclass:: CPUOffloadPolicy
    :members: