File: local_timer_example.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (124 lines) | stat: -rw-r--r-- 4,161 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#!/usr/bin/env python3
# Owner(s): ["oncall: r2p"]

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
import multiprocessing as mp
import signal
import time

import torch.distributed.elastic.timer as timer
import torch.multiprocessing as torch_mp
from torch.testing._internal.common_utils import (
    IS_MACOS,
    IS_WINDOWS,
    run_tests,
    skip_but_pass_in_sandcastle_if,
    TEST_WITH_DEV_DBG_ASAN,
    TestCase,
)


logging.basicConfig(
    level=logging.INFO, format="[%(levelname)s] %(asctime)s %(module)s: %(message)s"
)


def _happy_function(rank, mp_queue):
    timer.configure(timer.LocalTimerClient(mp_queue))
    with timer.expires(after=1):
        time.sleep(0.5)


def _stuck_function(rank, mp_queue):
    timer.configure(timer.LocalTimerClient(mp_queue))
    with timer.expires(after=1):
        time.sleep(5)


# timer is not supported on macos or windows
if not (IS_WINDOWS or IS_MACOS):

    class LocalTimerExample(TestCase):
        """
        Demonstrates how to use LocalTimerServer and LocalTimerClient
        to enforce expiration of code-blocks.

        Since torch multiprocessing's ``start_process`` method currently
        does not take the multiprocessing context as parameter argument
        there is no way to create the mp.Queue in the correct
        context BEFORE spawning child processes. Once the ``start_process``
        API is changed in torch, then re-enable ``test_torch_mp_example``
        unittest. As of now this will SIGSEGV.
        """

        @skip_but_pass_in_sandcastle_if(
            TEST_WITH_DEV_DBG_ASAN, "test is asan incompatible"
        )
        def test_torch_mp_example(self):
            # in practice set the max_interval to a larger value (e.g. 60 seconds)
            mp_queue = mp.get_context("spawn").Queue()
            server = timer.LocalTimerServer(mp_queue, max_interval=0.01)
            server.start()

            world_size = 8

            # all processes should complete successfully
            # since start_process does NOT take context as parameter argument yet
            # this method WILL FAIL (hence the test is disabled)
            torch_mp.spawn(
                fn=_happy_function, args=(mp_queue,), nprocs=world_size, join=True
            )

            with self.assertRaises(Exception):
                # torch.multiprocessing.spawn kills all sub-procs
                # if one of them gets killed
                torch_mp.spawn(
                    fn=_stuck_function, args=(mp_queue,), nprocs=world_size, join=True
                )

            server.stop()

        @skip_but_pass_in_sandcastle_if(
            TEST_WITH_DEV_DBG_ASAN, "test is asan incompatible"
        )
        def test_example_start_method_spawn(self):
            self._run_example_with(start_method="spawn")

        # @skip_but_pass_in_sandcastle_if(TEST_WITH_DEV_DBG_ASAN, "test is asan incompatible")
        # def test_example_start_method_forkserver(self):
        #     self._run_example_with(start_method="forkserver")

        def _run_example_with(self, start_method):
            spawn_ctx = mp.get_context(start_method)
            mp_queue = spawn_ctx.Queue()
            server = timer.LocalTimerServer(mp_queue, max_interval=0.01)
            server.start()

            world_size = 8
            processes = []
            for i in range(0, world_size):
                if i % 2 == 0:
                    p = spawn_ctx.Process(target=_stuck_function, args=(i, mp_queue))
                else:
                    p = spawn_ctx.Process(target=_happy_function, args=(i, mp_queue))
                p.start()
                processes.append(p)

            for i in range(0, world_size):
                p = processes[i]
                p.join()
                if i % 2 == 0:
                    self.assertEqual(-signal.SIGKILL, p.exitcode)
                else:
                    self.assertEqual(0, p.exitcode)

            server.stop()


if __name__ == "__main__":
    run_tests()