1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
|
torch.distributed.fsdp.fully_shard
==================================
PyTorch FSDP2 (``fully_shard``)
-------------------------------
PyTorch FSDP2 provides a fully sharded data parallelism (FSDP) implementation
targeting performant eager-mode while using per-parameter sharding for improved
usability.
- If you are new to FSDP, we recommend that you start with FSDP2 due to improved
usability.
- If you are currently using FSDP1, consider evaluating the following
differences to see if you should switch to FSDP2:
Compared to PyTorch FSDP1 (``FullyShardedDataParallel``):
- FSDP2 uses ``DTensor``-based dim-0 per-parameter sharding for a simpler
sharding representation compared to FSDP1's flat-parameter sharding, while
preserving similar throughput performance. More specifically, FSDP2 chunks
each parameter on dim-0 across the data parallel workers (using
``torch.chunk(dim=0)``), whereas FSDP1 flattens, concatenates, and chunks a
group of tensors together, making reasoning about what data is present on
each worker and resharding to different parallelisms complex. Per-parameter
sharding provides a more intuitive user experience, relaxes constraints
around frozen parameters, and allows for communication-free (sharded) state
dicts, which otherwise require all-gathers in FSDP1.
- FSDP2 implements a different memory management approach to handle the
multi-stream usages that avoids ``torch.Tensor.record_stream``. This ensures
deterministic and expected memory usage and does not require blocking the CPU
like in FSDP1's ``limit_all_gathers=True``.
- FSDP2 exposes APIs for manual control over prefetching and collective
scheduling, allowing power users more customization. See the methods on
``FSDPModule`` below for details.
- FSDP2 simplifies some of the API surface: e.g. FSDP2 does not directly
support full state dicts. Instead, users can reshard the sharded state dicts
containing ``DTensor`` s to full state dicts themselves using ``DTensor``
APIs like ``DTensor.full_tensor()`` or by using higher-level APIs like
`PyTorch Distributed Checkpoint <https://pytorch.org/docs/stable/distributed.checkpoint.html>`_ 's
distributed state dict APIs. Also, some other args have been removed; see
`here <https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md>`_ for
details.
If you are onboarding FSDP for the first time or if any of the above appeals to
your use case, we recommend that you consider using FSDP2.
See `this RFC <https://github.com/pytorch/pytorch/issues/114299>`_ for details
on system design and implementation.
.. note::
``torch.distributed.fsdp.fully_shard`` is currently in prototype state and
under development. The core API will likely not change, but we may make some
API changes if necessary.
.. currentmodule:: torch.distributed.fsdp
The frontend API is ``fully_shard`` that can be called on a ``module``:
.. autofunction:: fully_shard
Calling ``fully_shard(module)`` dynamically constructs a new class that
subclasses ``type(module)`` and an FSDP class ``FSDPModule``. For example, if
we call ``fully_shard(linear)`` on a module ``linear: nn.Linear``, then FSDP
constructs a new class ``FSDPLinear`` and changes ``linear`` 's type to this.
Otherwise, ``fully_shard`` does not change the module structure and parameter
fully-qualified names. The class ``FSDPModule`` allows providing some
FSDP-specific methods on the module.
.. autoclass:: FSDPModule
:members:
:member-order: bysource
.. autoclass:: UnshardHandle
:members:
.. autofunction:: register_fsdp_forward_method
.. autoclass:: MixedPrecisionPolicy
:members:
.. autoclass:: OffloadPolicy
:members:
.. autoclass:: CPUOffloadPolicy
:members:
|