1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
|
from abc import ABC, abstractmethod
from typing import Any, List, Optional, Tuple
import torch
import torch.distributed as dist
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
from torch.distributed._shard.sharded_tensor.shard import Shard
from torch.distributed.fsdp._shard_utils import (
_all_gather_dtensor,
_create_chunk_dtensor,
_create_chunk_sharded_tensor,
)
from torch.distributed.tensor import DeviceMesh, DTensor
class FSDPExtensions(ABC):
"""
This enables some customizable hooks to enable composability with tensor
parallelism. To activate these hooks, use :func:`_set_fsdp_extensions` to
set a custom :class:`FSDPExtensions` that implements the hooks.
"""
@abstractmethod
def pre_flatten_transform(
self,
tensor: torch.Tensor,
) -> Tuple[torch.Tensor, Optional[Any]]:
"""E.g. converting ``DistributedTensor`` to local tensor."""
...
@abstractmethod
def post_unflatten_transform(
self,
tensor: torch.Tensor,
param_extension: Any,
) -> torch.Tensor:
"""E.g. converting local tensor to ``DistributedTensor``."""
...
@abstractmethod
def chunk_tensor(
self,
tensor: torch.Tensor,
rank: int,
world_size: int,
num_devices_per_node: int,
pg: dist.ProcessGroup,
device: Optional[torch.device] = None,
) -> torch.Tensor:
"""Shards a tensor to chunks and returns the local chunk."""
...
@abstractmethod
def chunk_dtensor(
self,
tensor: torch.Tensor,
rank: int,
device_mesh: DeviceMesh,
) -> torch.Tensor:
"""Shards a tensor/DTensor to DTensor and returns the local DTensor."""
...
@abstractmethod
def pre_load_state_dict_transform(
self,
tensor: torch.Tensor,
) -> Tuple[torch.Tensor, List[Shard]]:
"""
This is to be called before loading a *sharded* model state dict and
should return the tensor and list of shards from which to load data.
"""
...
@abstractmethod
def all_gather_dtensor(
self,
tensor: DTensor,
parent_mesh: Optional[DeviceMesh],
) -> torch.Tensor:
"""
This is to be called before loading a *sharded* DTensor state dict.
This gathers tensor in FSDP dimension and returns local tensor of
TP DTensor.
"""
...
_extensions: Optional[FSDPExtensions] = None
def _set_fsdp_extensions(flattener: FSDPExtensions) -> None:
global _extensions
_extensions = flattener
def _ext_pre_flatten_transform(
tensor: torch.Tensor,
fsdp_extension: Optional[FSDPExtensions] = None,
) -> Tuple[torch.Tensor, Optional[Any]]:
if fsdp_extension is not None:
new_tensor, param_extension = fsdp_extension.pre_flatten_transform(tensor)
if param_extension is not None:
return new_tensor, param_extension
return tensor, None
def _ext_post_unflatten_transform(
tensor: torch.Tensor,
param_extension: Any,
fsdp_extension: Optional[FSDPExtensions] = None,
) -> torch.Tensor:
if fsdp_extension is not None and param_extension is not None:
return fsdp_extension.post_unflatten_transform(tensor, param_extension)
return tensor
def _ext_chunk_tensor(
tensor: torch.Tensor,
rank: int,
world_size: int,
num_devices_per_node: int,
pg: dist.ProcessGroup,
fsdp_extension: Optional[FSDPExtensions] = None,
) -> torch.Tensor:
chunk_tensor_fn = (
fsdp_extension.chunk_tensor
if fsdp_extension is not None
else _create_chunk_sharded_tensor
)
return chunk_tensor_fn(
tensor,
rank,
world_size,
num_devices_per_node,
pg,
)
def _ext_chunk_dtensor(
tensor: torch.Tensor,
rank: int,
device_mesh: DeviceMesh,
fsdp_extension: Optional[FSDPExtensions] = None,
) -> torch.Tensor:
chunk_dtensor_fn = (
fsdp_extension.chunk_dtensor
if fsdp_extension is not None
else _create_chunk_dtensor
)
return chunk_dtensor_fn(
tensor,
rank,
device_mesh,
)
def _ext_pre_load_state_dict_transform(
tensor: torch.Tensor,
fsdp_extension: Optional[FSDPExtensions] = None,
) -> Tuple[torch.Tensor, List[Shard]]:
if fsdp_extension is not None:
return fsdp_extension.pre_load_state_dict_transform(tensor)
assert type(tensor) is ShardedTensor
shards = tensor.local_shards()
return (tensor, shards)
def _ext_all_gather_dtensor(
tensor: DTensor,
parent_mesh: Optional[DeviceMesh],
fsdp_extension: Optional[FSDPExtensions] = None,
) -> torch.Tensor:
all_gather_dtensor_fn = (
fsdp_extension.all_gather_dtensor
if fsdp_extension is not None
else _all_gather_dtensor
)
return all_gather_dtensor_fn(tensor, parent_mesh)
|