1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
|
import functools
import torch
import torch.distributed as dist
class DefaultState(object):
r"""
Stores state needed to perform the default communication algorithm
within a communication hook.
Args:
process_group (ProcessGroup): The process group to be used.
"""
__slots__ = [
"process_group",
"world_size",
"gradient_predivide_factor",
"gradient_postdivide_factor"
]
def __init__(
self,
process_group: dist.ProcessGroup
):
if process_group is None:
raise ValueError(f"Expected to pass in an explicit ProcessGroup to {self}.")
self.process_group = process_group
self.world_size = dist.get_world_size(process_group)
# Setting two factors `self.gradient_predivide_factor`
# and `self.gradient_postdivide_factor` to avoid underflow and overflow
self.gradient_predivide_factor = self._get_gradient_predivide_factor(
self.world_size
)
self.gradient_postdivide_factor = self.world_size / self.gradient_predivide_factor
def _get_gradient_predivide_factor(self, world_size: int) -> float:
factor: int = 1
while world_size % factor == 0 and world_size / factor > factor:
factor *= 2
return float(factor)
class LowPrecisionState(DefaultState):
r"""
Stores state needed to perform gradient communication in a lower precision
within a communication hook. Communication hook will cast gradients back
to the original parameter precision specified by ``parameter_type`` (default: torch.float32).
Builds on top of the :class:`DefaultState`.
Args:
parameter_type (torch.dtype): The precision of model's parameters.
Required for a hook to cast gradients back to a parameter's precision.
"""
__slots__ = [
"parameter_type",
]
def __init__(
self,
process_group,
parameter_type=torch.float32,
):
super().__init__(process_group)
self.parameter_type = parameter_type
def _decompress(state: LowPrecisionState, grad: torch.Tensor):
"""
Casts gradients back to full parameter precision so that
further computation happens in full precision.
"""
orig_grad_data = grad.data
grad.data = grad.data.to(state.parameter_type)
# Don't let this memory get reused until after the transfer.
orig_grad_data.record_stream(torch.cuda.current_stream()) # type: ignore[arg-type]
def allreduce_hook(state: DefaultState, grad: torch.Tensor):
r"""
This FSDP communication hook implements ``all_reduce`` algorithm
and a necessary pre- and post-division of gradients.
Args:
state (DefaultState): State information, configures pre- and post-division factors.
grad (torch.Tensor): A gradient for the local batch that needs to be communicated across ranks.
"""
# Average grad by pre-division factor. Together pre- and post-division factors
# lead to an overall averaging by world_size, required for consistency with PyTorch DDP.
# This is a two-step process to avoid potential underflow and overflow.
if state.gradient_predivide_factor > 1:
grad.div_(state.gradient_predivide_factor)
dist.all_reduce(grad, group=state.process_group)
# Average grad by post-division factor.
if state.gradient_postdivide_factor > 1:
grad.div_(state.gradient_postdivide_factor)
def reduce_scatter_hook(state: DefaultState, grad: torch.Tensor, output: torch.Tensor):
r"""
This FSDP communication hook implements ``reduce_scatter`` algorithm for
sharded FSDP strategies and a necessary pre- and post-division of gradients.
Args:
state (DefaultState): State information, configures pre- and post-division factors.
grad (torch.Tensor): An unsharded gradient for the local batch that needs to be
communicated across ranks.
output (torch.Tensor): Stores a single shard of the gradient after ``reduce_scatter``.
"""
# Average grad by pre-division factor.
if state.gradient_predivide_factor > 1:
grad.div_(state.gradient_predivide_factor)
dist._reduce_scatter_base(
output, grad, group=state.process_group
)
# Average grad's shard by post-division factor.
if state.gradient_postdivide_factor > 1:
output.div_(state.gradient_postdivide_factor)
def _low_precision_hook(prec: torch.dtype, state: LowPrecisionState, grad: torch.Tensor, output: torch.Tensor):
grad.data = grad.data.to(prec)
if output is not None:
output.data = output.data.to(prec)
reduce_scatter_hook(state, grad, output)
_decompress(state, output)
else:
allreduce_hook(state, grad)
_decompress(state, grad)
def fp16_compress_hook(state: LowPrecisionState, grad: torch.Tensor, output: torch.Tensor = None):
r"""
This FSDP communication hook implements a simple gradient compression
approach that casts ``grad`` to half-precision floating-point format (``torch.float16``).
It also averages gradients by ``world_size`` in two steps: first it pre-divides gradients by a
``state.gradient_predivide_factor``, and after a communication step (``all_reduce`` or ``reduce_scatter``)
gradients are averaged by a ``state.gradient_postdivide_factor``.
Once post-division is done, compressed gradients are casted back to parameters' precision.
Args:
state (LowPrecisionState): State information, configures pre- and post-division factors, parameters' precision.
grad (torch.Tensor): A gradient for the local batch that needs to be communicated across ranks in a lower precision.
output (torch.Tensor): Stores a single shard of the gradient after ``reduce_scatter``.
"""
fp16_hook = functools.partial(_low_precision_hook, torch.float16)
return fp16_hook(state, grad, output)
def bf16_compress_hook(state: LowPrecisionState, grad: torch.Tensor, output: torch.Tensor = None):
r"""
This FSDP communication hook implements a simple gradient compression
approach that casts ``grad`` to half-precision floating-point format (``torch.float16``).
It also averages gradients by ``world_size`` in two steps: first it pre-divides gradients by a
``state.gradient_predivide_factor``, and after a communication step (``all_reduce`` or ``reduce_scatter``)
gradients are averaged by a ``state.gradient_postdivide_factor``.
Once post-division is done, compressed gradients are casted back to parameters' precision.
Args:
state (LowPrecisionState): State information, configures pre- and post-division factors, parameters' precision.
grad (torch.Tensor): A gradient for the local batch that needs to be communicated across ranks in a lower precision.
output (torch.Tensor): Stores a single shard of the gradient after ``reduce_scatter``.
"""
bf16_hook = functools.partial(_low_precision_hook, torch.bfloat16)
return bf16_hook(state, grad, output)
|