1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334
|
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import inspect
import os
import warnings
from concurrent.futures import Future, ThreadPoolExecutor
from typing import cast, Optional, Union
from typing_extensions import deprecated
import torch
import torch.distributed as dist
from torch.distributed._state_dict_utils import _offload_state_dict_to_cpu
from torch.distributed.checkpoint._storage_utils import _storage_setup
from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
from torch.distributed.checkpoint.logger import _dcp_method_logger
from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE
from torch.distributed.checkpoint.planner import SavePlan, SavePlanner
from torch.distributed.checkpoint.staging import AsyncStager
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.checkpoint.storage import StorageWriter
from torch.distributed.distributed_c10d import _get_default_group
from .utils import _api_bc_check, _DistWrapper, _profile
__all__ = ["save_state_dict", "save", "async_save"]
@deprecated(
"`save_state_dict` is deprecated and will be removed in future versions."
"Please use `save` instead.",
category=FutureWarning,
)
def save_state_dict(
state_dict: STATE_DICT_TYPE,
storage_writer: StorageWriter,
process_group: Optional[dist.ProcessGroup] = None,
coordinator_rank: int = 0,
no_dist: bool = False,
planner: Optional[SavePlanner] = None,
) -> Metadata:
"""This method is deprecated. Please switch to 'save'."""
storage_writer.reset()
# TODO: test returning `save` here instead.
with _profile():
return _save_state_dict(
state_dict,
storage_writer,
process_group,
coordinator_rank,
no_dist,
planner,
)
@_dcp_method_logger(log_exceptions=True) # type: ignore[arg-type]
@_api_bc_check
def save(
state_dict: STATE_DICT_TYPE,
*,
checkpoint_id: Union[str, os.PathLike, None] = None,
storage_writer: Optional[StorageWriter] = None,
planner: Optional[SavePlanner] = None,
process_group: Optional[dist.ProcessGroup] = None,
) -> Metadata:
"""
Save a distributed model in SPMD style.
This function is different from ``torch.save()`` as it handles
``ShardedTensor`` , and ``DTensor`` by having each rank only save their local shards.
For each ``Stateful`` object (having both a ``state_dict`` and a ``load_state_dict``),
save will call ``state_dict`` before serialization.
.. warning::
There is no guarantees of Backwards Compatibility across PyTorch versions
for saved state_dicts.
.. warning::
If using the `process_group` argument, make sure that only its ranks
call `save_state_dict` and that all data in state_dict belong to it.
.. note::
When saving checkpoint for FSDP's `ShardingStrategy.HYBRID_SHARD`, only one of
the shard_group should be calling `save_state_dict` and the corresponding process
group needs to be passed in.
.. note::
If no process group is available, this function assumes the intention is to save the
state_dict in the local process.
.. note:
Rank 0 is assumed to be the coordinator rank.
Args:
state_dict (Dict[str, Any]): The state_dict to save.
checkpoint_id (Union[str, os.PathLike, None]):
The ID of this checkpoint instance. The meaning of the checkpoint_id
depends on the storage. It can be a path to a folder or to a file.
It can also be a key if the storage is a key-value store.
(Default: ``None``)
storage_writer (Optional[StorageWriter]):
Instance of StorageWriter used to perform writes. If this is not
specified, DCP will automatically infer the writer based on the
checkpoint_id. If checkpoint_id is also None, an exception will
be raised. (Default: ``None``)
planner (Optional[SavePlanner]):
Instance of SavePlanner. If this is not specificed, the default
planner will be used. (Default: ``None``)
process_group (Optional[ProcessGroup]):
ProcessGroup to be used for cross-rank synchronization.
(Default: ``None``)
Returns:
Metadata: Metadata object for the saved checkpoint.
Example:
>>> # xdoctest: +SKIP
>>> my_model = MyModule()
>>> state_dict = {"model": my_model}
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1")
>>> torch.distributed.checkpoint.save(
>>> state_dict=state_dict,
>>> storage_writer=fs_storage_writer,
>>> )
.. note::
save_state_dict uses collectives to coordinate writes across ranks.
For NCCL-based process groups, internal tensor representations of
objects must be moved to the GPU device before communication takes place.
In this case, the device used is given by ``torch.cuda.current_device()``
and it is the user's responsibility to ensure that this is set so that
each rank has an individual GPU, via ``torch.cuda.set_device()``.
"""
torch._C._log_api_usage_once("torch.distributed.checkpoint.save")
no_dist = not (dist.is_available() and dist.is_initialized())
if no_dist:
warnings.warn(
"torch.distributed is unavailable or uninitialized, assuming the intent is to save in a single process."
)
with _profile():
storage_writer = cast(
StorageWriter, _storage_setup(storage_writer, checkpoint_id, reader=False)
)
return _save_state_dict(
state_dict=_stateful_to_state_dict(state_dict),
storage_writer=storage_writer,
process_group=process_group,
no_dist=no_dist,
planner=planner,
)
@_dcp_method_logger(log_exceptions=True)
def async_save(
state_dict: STATE_DICT_TYPE,
*,
checkpoint_id: Union[str, os.PathLike, None] = None,
storage_writer: Optional[StorageWriter] = None,
planner: Optional[SavePlanner] = None,
process_group: Optional[dist.ProcessGroup] = None,
) -> Future:
"""Asynchronous version of ``save``. This code first de-stages the state_dict on to the
staging storage (defaults to CPU memory), and then calls the `save` in a separate thread.
.. warning::
This feature is experimental and subject to change.
Args:
state_dict (Dict[str, Any]): The state_dict to save.
checkpoint_id (Union[str, os.PathLike, None]):
The ID of this checkpoint instance. The meaning of the checkpoint_id
depends on the storage. It can be a path to a folder or to a file.
It can also be a key if the storage is a key-value store.
(Default: ``None``)
storage_writer (Optional[StorageWriter]):
Instance of StorageWriter used to perform 'stage' and 'save'. If
this is not specified, DCP will automatically infer the writer based on the
checkpoint_id. If checkpoint_id is also None, an exception will
be raised. (Default: ``None``)
planner (Optional[SavePlanner]):
Instance of SavePlanner. If this is not specificed, the default
planner will be used. (Default: ``None``)
process_group (Optional[ProcessGroup]):
ProcessGroup to be used for cross-rank synchronization.
(Default: ``None``)
Returns:
Future: A future holding the resultant Metadata object from `save`.
Example:
>>> # xdoctest: +SKIP
>>> my_model = MyModule()
>>> state_dict = {"model": my_model}
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1")
>>> checkpoint_future = torch.distributed.checkpoint.async_save(
>>> state_dict=state_dict,
>>> storage_writer=fs_storage_writer,
>>> )
>>>
>>> # ... do some work ...
>>>
>>> checkpoint_future.result()
"""
torch._C._log_api_usage_once("torch.distributed.checkpoint.async_save")
if dist.is_available() and dist.is_initialized():
pg = process_group or _get_default_group()
assert (
torch.device("cpu") in pg._device_types # type: ignore[attr-defined]
), "A CPU backend must be enabled for async save; try initializing process group with 'cpu:gloo,cuda:nccl'"
storage_writer = cast(
StorageWriter, _storage_setup(storage_writer, checkpoint_id, reader=False)
)
state_dict = _stateful_to_state_dict(state_dict)
if isinstance(storage_writer, AsyncStager):
staged_state_dict = storage_writer.stage(state_dict)
else: # provides bwc for storage_writers not implementing AsyncStager
staged_state_dict = _offload_state_dict_to_cpu(state_dict, type_check=False)
executor = ThreadPoolExecutor(max_workers=1)
f: Future = executor.submit(
save,
staged_state_dict,
checkpoint_id=checkpoint_id,
storage_writer=storage_writer,
planner=planner,
process_group=process_group,
)
f.add_done_callback(lambda f: executor.shutdown(wait=False))
if (
isinstance(storage_writer, AsyncStager)
and storage_writer.should_synchronize_after_execute
):
storage_writer.synchronize_staging()
return f
def _stateful_to_state_dict(state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE:
"""Creates a shallow copy of `state_dict` where `state_dict` is called for each Stateful object."""
stateful_state_dict = {}
for key, elem in state_dict.items():
stateful_state_dict[key] = (
elem.state_dict() if isinstance(elem, Stateful) else elem
)
return stateful_state_dict
def _save_state_dict(
state_dict: STATE_DICT_TYPE,
storage_writer: StorageWriter,
process_group: Optional[dist.ProcessGroup] = None,
coordinator_rank: int = 0,
no_dist: bool = False,
planner: Optional[SavePlanner] = None,
) -> Metadata:
torch._C._log_api_usage_once("torch.distributed.checkpoint.save_state_dict")
distW = _DistWrapper(process_group, not no_dist, coordinator_rank)
if planner is None:
planner = DefaultSavePlanner()
assert planner is not None
global_metadata = None
ckpt_kwargs = {}
if (ckpt_id := getattr(storage_writer, "checkpoint_id", None)) is not None:
ckpt_kwargs["checkpoint_id"] = ckpt_id
ckpt_kwargs["process_group"] = distW.group
@_dcp_method_logger(**ckpt_kwargs)
def local_step():
assert planner is not None
storage_meta = storage_writer.storage_meta()
if "storage_meta" not in inspect.signature(planner.set_up_planner).parameters:
warnings.warn(
"The function definition for SavePlanner.set_up_planner has been updated"
" to include the storage_meta argument. Please update your implementation"
" to include this parameter."
)
planner.set_up_planner(state_dict, distW.is_coordinator) # type: ignore[call-arg, arg-type]
else:
planner.set_up_planner(
state_dict=state_dict,
storage_meta=storage_meta,
is_coordinator=distW.is_coordinator,
)
storage_writer.set_up_storage_writer(distW.is_coordinator)
local_plan = planner.create_local_plan()
local_plan = storage_writer.prepare_local_plan(local_plan)
return local_plan
@_dcp_method_logger(**ckpt_kwargs)
def global_step(all_local_plans):
nonlocal global_metadata
assert planner is not None
all_local_plans, global_metadata = planner.create_global_plan(all_local_plans)
all_local_plans = storage_writer.prepare_global_plan(all_local_plans)
return all_local_plans
central_plan: SavePlan = distW.reduce_scatter("plan", local_step, global_step)
@_dcp_method_logger(**ckpt_kwargs)
def write_data():
assert planner is not None
final_local_plan = planner.finish_plan(central_plan)
all_writes = storage_writer.write_data(final_local_plan, planner)
all_writes.wait()
return all_writes.value()
@_dcp_method_logger(**ckpt_kwargs)
def finish_checkpoint(all_results):
assert global_metadata is not None
storage_writer.finish(metadata=global_metadata, results=all_results)
return global_metadata
return distW.all_reduce("write", write_data, finish_checkpoint)
|