1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
|
# Owner(s): ["oncall: r2p"]
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import unittest.mock as mock
from torch.distributed.elastic.timer import TimerServer
from torch.distributed.elastic.timer.api import RequestQueue, TimerRequest
class MockRequestQueue(RequestQueue):
def size(self):
return 2
def get(self, size, timeout):
return [TimerRequest(1, "test_1", 0), TimerRequest(2, "test_2", 0)]
class MockTimerServer(TimerServer):
"""
Mock implementation of TimerServer for testing purposes.
This mock has the following behavior:
1. reaping worker 1 throws
2. reaping worker 2 succeeds
3. reaping worker 3 fails (caught exception)
For each workers 1 - 3 returns 2 expired timers
"""
def __init__(self, request_queue, max_interval):
super().__init__(request_queue, max_interval)
def register_timers(self, timer_requests):
pass
def clear_timers(self, worker_ids):
pass
def get_expired_timers(self, deadline):
return {
i: [TimerRequest(i, f"test_{i}_0", 0), TimerRequest(i, f"test_{i}_1", 0)]
for i in range(1, 4)
}
def _reap_worker(self, worker_id):
if worker_id == 1:
raise RuntimeError("test error")
elif worker_id == 2:
return True
elif worker_id == 3:
return False
class TimerApiTest(unittest.TestCase):
@mock.patch.object(MockTimerServer, "register_timers")
@mock.patch.object(MockTimerServer, "clear_timers")
def test_run_watchdog(self, mock_clear_timers, mock_register_timers):
"""
tests that when a ``_reap_worker()`` method throws an exception
for a particular worker_id, the timers for successfully reaped workers
are cleared properly
"""
max_interval = 1
request_queue = mock.Mock(wraps=MockRequestQueue())
timer_server = MockTimerServer(request_queue, max_interval)
timer_server._run_watchdog()
request_queue.size.assert_called_once()
request_queue.get.assert_called_with(request_queue.size(), max_interval)
mock_register_timers.assert_called_with(request_queue.get(2, 1))
mock_clear_timers.assert_called_with({1, 2})
|