File: test_async_process_executor.py

package info (click to toggle)
pytorch 2.9.1%2Bdfsg-1~exp2
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 180,096 kB
  • sloc: python: 1,473,255; cpp: 942,030; ansic: 79,796; asm: 7,754; javascript: 2,502; java: 1,962; sh: 1,809; makefile: 628; xml: 8
file content (157 lines) | stat: -rw-r--r-- 5,272 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# Owner(s): ["oncall: distributed checkpointing"]

import sys
from unittest.mock import patch

import torch
from torch import distributed as dist
from torch.distributed.checkpoint._async_process_executor import (
    _ProcessBasedAsyncCheckpointExecutor,
)
from torch.distributed.checkpoint.storage import StorageWriter
from torch.distributed.elastic.utils.distributed import get_free_port
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
from torch.testing._internal.distributed._tensor.common_dtensor import (
    DTensorTestBase,
    with_comms,
)


if TEST_WITH_DEV_DBG_ASAN:
    print(
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
        file=sys.stderr,
    )
    sys.exit(0)


class TestStorageWriter(StorageWriter):
    """Unified test storage writer with configurable behaviors."""

    def __init__(
        self,
        behavior="success",
    ):
        """
        Create a test storage writer with specified behavior.

        Args:
            behavior: "success", "fail_once"
        """
        self.behavior = behavior
        self.call_count = 0

    def _should_fail(self):
        """Determine if this call should fail based on the configured behavior."""
        self.call_count += 1

        if self.behavior == "success":
            return False
        elif self.behavior == "fail_once":
            return self.call_count == 1

        return False

    # Implement all required StorageWriter methods directly
    def reset(self, checkpoint_id=None):
        """Reset for new checkpoint."""

    def set_up_storage_writer(self, is_coordinator):
        """Set up storage writer."""

    def prepare_local_plan(self, plan):
        """Prepare local plan."""
        return plan

    def prepare_global_plan(self, plans):
        """Prepare global plan."""
        return plans

    def write_data(self, plan, planner):
        """Write data with policy-based failure behavior."""
        from torch.futures import Future

        # Check if we should fail based on policy
        if self._should_fail():
            raise RuntimeError(
                f"TestStorageWriter: {self.behavior} policy triggered failure on call {self.call_count}"
            )

        # Return a Future that completes to simple WriteResult-like objects
        future = Future()
        result = [{"success": True, "bytes_written": 100}]
        future.set_result(result)
        return future

    def finish(self, metadata, results):
        """Finish checkpoint."""
        return None

    def storage_meta(self):
        """Return storage metadata."""
        return None

    @classmethod
    def validate_checkpoint_id(cls, checkpoint_id):
        """Validate checkpoint ID."""
        return True


class TestAsyncProcessExecutor(DTensorTestBase):
    """Test suite for async checkpoint process executor error handling using public APIs."""

    @with_comms
    def test_checkpoint_save_failure_continues_serving(self) -> None:
        """Test that checkpoint save failure doesn't exit process, continues serving."""

        test_state_dict = {
            "model": {"weight": torch.randn(4, 4), "bias": torch.randn(4)},
            "optimizer": {"param_groups": [{"lr": 0.01}]},
            "epoch": 5,
        }

        # 1. Simulate a failure in creating PG in background process.
        with patch(
            "torch.distributed.checkpoint._async_process_executor.get_free_port",
            return_value=-1,
        ):
            with self.assertRaises(ValueError) as _:
                proc_executor = _ProcessBasedAsyncCheckpointExecutor()
                fut = proc_executor.execute_save(
                    staging_future_or_state_dict=test_state_dict,
                )
                fut.result()

        # 2. Attempt save with failing storage writer
        with patch(
            "torch.distributed.checkpoint._async_process_executor.get_free_port",
            return_value=get_free_port(),
        ) as mock_get_free_port:
            proc_executor = _ProcessBasedAsyncCheckpointExecutor()
            fut = proc_executor.execute_save(
                staging_future_or_state_dict=test_state_dict,
                storage_writer=TestStorageWriter(behavior="fail_once"),
            )
            self.assertIn("fail_once policy triggered failure", str(fut.exception()))
            # Verify new process was created for this attempt
            if dist.get_rank() == 0:
                mock_get_free_port.assert_called_once()

        # 3. Second save attempt with successful storage writer - process should still be alive
        with patch(
            "torch.distributed.checkpoint._async_process_executor.get_free_port",
        ) as mock_get_free_port:
            proc_executor = _ProcessBasedAsyncCheckpointExecutor()
            fut = proc_executor.execute_save(
                staging_future_or_state_dict=test_state_dict,
                storage_writer=TestStorageWriter(behavior="success"),
            )
            result = fut.result()
            # Verify process is still alive
            mock_get_free_port.assert_not_called()
            # Verify successful save
            self.assertIsNotNone(result)


if __name__ == "__main__":
    run_tests()