1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
|
# mypy: allow-untyped-defs
import enum
from typing import Any, Callable, overload
import torch
from torch.distributed.algorithms.join import Joinable, JoinHook
from torch.optim import Optimizer
class _ZeROJoinHook(JoinHook):
zero: Any = ...
def __init__(self, zero: Any) -> None: ...
def main_hook(self) -> None: ...
class _DDPBucketAssignment:
bucket_index: int
parameters: list[torch.Tensor]
offset: int
device: torch.device
tensor: torch.Tensor | None
class _OverlapStatus(enum.IntEnum):
UNINITIALIZED: int = ...
DDP_HAS_REBUILT_BUCKETS: int = ...
INITIALIZED: int = ...
class _OverlapInfo:
status: Any = ...
params_per_bucket: Any = ...
params_per_rank: Any = ...
offsets: Any = ...
broadcast_handles: Any = ...
bucket_index_to_future: Any = ...
bucket_index_to_bucket: Any = ...
bucket_indices_seen: Any = ...
assigned_ranks_per_bucket: list[set[int]] = ...
total_size: int = ...
shard_buckets: bool = ...
def __init__(self) -> None: ...
def wait_for_broadcasts(self) -> None: ...
def clear_per_iter_info(self) -> None: ...
class ZeroRedundancyOptimizer(Optimizer, Joinable):
functional_optim_map: Any = ...
initialized: bool = ...
process_group: Any = ...
world_size: int = ...
rank: int = ...
global_rank: int = ...
parameters_as_bucket_view: bool = ...
optim: Any = ...
_device_to_device_index: dict[torch.device, int] = ...
_overlap_with_ddp: bool = ...
_overlap_info: _OverlapInfo = ...
_buckets: list[list[torch.Tensor]] = ...
_bucket_assignments_per_rank: list[dict[int, _DDPBucketAssignment]] = ...
def __init__(
self,
params: Any,
optimizer_class: type[Optimizer],
process_group: Any | None = ...,
parameters_as_bucket_view: bool = ...,
overlap_with_ddp: bool = ...,
**defaults: Any,
) -> None: ...
def add_param_group(self, param_group: dict[str, Any]) -> None: ...
def consolidate_state_dict(self, to: int = ...) -> None: ...
@overload
def step(self, closure: None = ..., **kwargs: Any) -> None: ...
@overload
def step(self, closure: Callable[[], float], **kwargs: Any) -> float: ...
def load_state_dict(self, state_dict: dict[str, Any]) -> None: ...
def state_dict(self) -> dict[str, Any]: ...
def _local_step(
self,
gradients: list[torch.Tensor | None] | None = None,
closure: Callable[[], float] | None = None,
**kwargs: Any,
) -> float | None: ...
def _get_assigned_rank(self, bucket_index: int) -> int: ...
def _init_zero_for_overlap(self) -> None: ...
def join_hook(self, **kwargs): ...
@property
def join_device(self) -> torch.device: ...
def join_process_group(self) -> Any: ...
|