1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
|
# Owner(s): ["oncall: distributed checkpointing"]
import sys
from unittest.mock import patch
import torch
from torch import distributed as dist
from torch.distributed.checkpoint._async_process_executor import (
_ProcessBasedAsyncCheckpointExecutor,
)
from torch.distributed.checkpoint.storage import StorageWriter
from torch.distributed.elastic.utils.distributed import get_free_port
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
file=sys.stderr,
)
sys.exit(0)
class TestStorageWriter(StorageWriter):
"""Unified test storage writer with configurable behaviors."""
def __init__(
self,
behavior="success",
):
"""
Create a test storage writer with specified behavior.
Args:
behavior: "success", "fail_once"
"""
self.behavior = behavior
self.call_count = 0
def _should_fail(self):
"""Determine if this call should fail based on the configured behavior."""
self.call_count += 1
if self.behavior == "success":
return False
elif self.behavior == "fail_once":
return self.call_count == 1
return False
# Implement all required StorageWriter methods directly
def reset(self, checkpoint_id=None):
"""Reset for new checkpoint."""
def set_up_storage_writer(self, is_coordinator):
"""Set up storage writer."""
def prepare_local_plan(self, plan):
"""Prepare local plan."""
return plan
def prepare_global_plan(self, plans):
"""Prepare global plan."""
return plans
def write_data(self, plan, planner):
"""Write data with policy-based failure behavior."""
from torch.futures import Future
# Check if we should fail based on policy
if self._should_fail():
raise RuntimeError(
f"TestStorageWriter: {self.behavior} policy triggered failure on call {self.call_count}"
)
# Return a Future that completes to simple WriteResult-like objects
future = Future()
result = [{"success": True, "bytes_written": 100}]
future.set_result(result)
return future
def finish(self, metadata, results):
"""Finish checkpoint."""
return None
def storage_meta(self):
"""Return storage metadata."""
return None
@classmethod
def validate_checkpoint_id(cls, checkpoint_id):
"""Validate checkpoint ID."""
return True
class TestAsyncProcessExecutor(DTensorTestBase):
"""Test suite for async checkpoint process executor error handling using public APIs."""
@with_comms
def test_checkpoint_save_failure_continues_serving(self) -> None:
"""Test that checkpoint save failure doesn't exit process, continues serving."""
test_state_dict = {
"model": {"weight": torch.randn(4, 4), "bias": torch.randn(4)},
"optimizer": {"param_groups": [{"lr": 0.01}]},
"epoch": 5,
}
# 1. Simulate a failure in creating PG in background process.
with patch(
"torch.distributed.checkpoint._async_process_executor.get_free_port",
return_value=-1,
):
with self.assertRaises(ValueError) as _:
proc_executor = _ProcessBasedAsyncCheckpointExecutor()
fut = proc_executor.execute_save(
staging_future_or_state_dict=test_state_dict,
)
fut.result()
# 2. Attempt save with failing storage writer
with patch(
"torch.distributed.checkpoint._async_process_executor.get_free_port",
return_value=get_free_port(),
) as mock_get_free_port:
proc_executor = _ProcessBasedAsyncCheckpointExecutor()
fut = proc_executor.execute_save(
staging_future_or_state_dict=test_state_dict,
storage_writer=TestStorageWriter(behavior="fail_once"),
)
self.assertIn("fail_once policy triggered failure", str(fut.exception()))
# Verify new process was created for this attempt
if dist.get_rank() == 0:
mock_get_free_port.assert_called_once()
# 3. Second save attempt with successful storage writer - process should still be alive
with patch(
"torch.distributed.checkpoint._async_process_executor.get_free_port",
) as mock_get_free_port:
proc_executor = _ProcessBasedAsyncCheckpointExecutor()
fut = proc_executor.execute_save(
staging_future_or_state_dict=test_state_dict,
storage_writer=TestStorageWriter(behavior="success"),
)
result = fut.result()
# Verify process is still alive
mock_get_free_port.assert_not_called()
# Verify successful save
self.assertIsNotNone(result)
if __name__ == "__main__":
run_tests()
|