1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
|
# mypy: allow-untyped-defs
from dataclasses import dataclass
from typing import Optional
import torch
@dataclass(frozen=True)
class MixedPrecisionPolicy:
"""
This configures FSDP's mixed precision. Unlike autocast, this applies mixed
precision at the module level, not op level, which means low-precision
activations are saved for backward and high-to-low-precision casts are
incurred only at module boundaries.
FSDP works well with module-level mixed precision since it keeps the
high-precision sharded parameters in memory anyway. In other words, FSDP
does not require any extra memory to keep a high-precision copy of the
parameters for the optimizer step.
Attributes:
param_dtype (Optional[torch.dtype]): This specifies the dtype for
the unsharded parameter and hence the dtype for forward/backward
computation and the parameter all-gather. If this is ``None``, then
the unsharded parameter uses the original dtype. The optimizer step
uses the sharded parameter in the original dtype. (Default:
``None``)
reduce_dtype (Optional[torch.dtype]): This specifies the dtype for
gradient reduction (i.e. reduce-scatter or all-reduce). If this is
``None`` but ``param_dtype`` is not ``None``, then the reduction
uses the compute dtype. This can be used to run gradient reduction
in full precision while using low precision for compute. If also
gradient reduction is disabled via :meth:`set_requires_gradient_sync`,
then FSDP will accumulate gradients using ``reduce_dtype``.
(Default: ``None``)
output_dtype (Optional[torch.dtype]): This specifies the dtype for
casting floating-point forward outputs. This can be used to
help implement cases where different modules have different mixed
precision policies. (Default: ``None``)
cast_forward_inputs (bool): This specifies whether FSDP should cast the
forward's floating-point input tensors to ``param_dtype`` or not.
"""
param_dtype: Optional[torch.dtype] = None
reduce_dtype: Optional[torch.dtype] = None
output_dtype: Optional[torch.dtype] = None
cast_forward_inputs: bool = True
def __post_init__(self):
# Clamp `reduce_dtype` to `None` if no casting is required: since
# gradients are computed in `param_dtype`, if `reduce_dtype` matches,
# then we do not need extra casting
if self.param_dtype == self.reduce_dtype:
# Bypass the frozen dataclass checks
object.__setattr__(self, "reduce_dtype", None)
@dataclass
class OffloadPolicy:
"""
This base class represents the policy of no offloading and is only used as
the default value for the ``offload_policy`` arg.
"""
@dataclass
class CPUOffloadPolicy(OffloadPolicy):
"""
This offload policy offloads parameters, gradients, and optimizer states to
CPU. Sharded parameters are copied host-to-device before all-gather. The
all-gathered parameters are freed according to ``reshard_after_forward``.
Sharded gradients are copied device-to-host in backward, and the optimizer
step runs on CPU with CPU optimizer states.
Attributes:
pin_memory (bool): Whether to pin sharded parameter and gradient
memory. Pinning memory allows both more efficient H2D/D2H copies
and for the copies to overlap with compute. However, the pinned
memory cannot be used by other processes. Set this to ``False`` if
you have insufficient CPU memory. (Default: ``True``)
"""
pin_memory: bool = True
|