1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
|
from ._flat_param import FlatParameter as FlatParameter
from ._fully_shard import (
CPUOffloadPolicy,
FSDPModule,
fully_shard,
MixedPrecisionPolicy,
OffloadPolicy,
register_fsdp_forward_method,
UnshardHandle,
)
from .fully_sharded_data_parallel import (
BackwardPrefetch,
CPUOffload,
FullOptimStateDictConfig,
FullStateDictConfig,
FullyShardedDataParallel,
LocalOptimStateDictConfig,
LocalStateDictConfig,
MixedPrecision,
OptimStateDictConfig,
OptimStateKeyType,
ShardedOptimStateDictConfig,
ShardedStateDictConfig,
ShardingStrategy,
StateDictConfig,
StateDictSettings,
StateDictType,
)
__all__ = [
# FSDP1
"BackwardPrefetch",
"CPUOffload",
"FullOptimStateDictConfig",
"FullStateDictConfig",
"FullyShardedDataParallel",
"LocalOptimStateDictConfig",
"LocalStateDictConfig",
"MixedPrecision",
"OptimStateDictConfig",
"OptimStateKeyType",
"ShardedOptimStateDictConfig",
"ShardedStateDictConfig",
"ShardingStrategy",
"StateDictConfig",
"StateDictSettings",
"StateDictType",
# FSDP2
"CPUOffloadPolicy",
"FSDPModule",
"fully_shard",
"MixedPrecisionPolicy",
"OffloadPolicy",
"register_fsdp_forward_method",
"UnshardHandle",
]
# Set namespace for exposed private names
CPUOffloadPolicy.__module__ = "torch.distributed.fsdp"
FSDPModule.__module__ = "torch.distributed.fsdp"
fully_shard.__module__ = "torch.distributed.fsdp"
MixedPrecisionPolicy.__module__ = "torch.distributed.fsdp"
OffloadPolicy.__module__ = "torch.distributed.fsdp"
register_fsdp_forward_method.__module__ = "torch.distributed.fsdp"
UnshardHandle.__module__ = "torch.distributed.fsdp"
|