File: error_handler_test.py

package info (click to toggle)
python-certbot 4.0.0-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 2,688 kB
  • sloc: python: 21,764; makefile: 182; sh: 108
file content (155 lines) | stat: -rw-r--r-- 5,202 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""Tests for certbot._internal.error_handler."""
import contextlib
import signal
import sys
from typing import Callable
from typing import Dict
from typing import Union
import unittest
from unittest import mock

import pytest

from certbot.compat import os


def get_signals(signums):
    """Get the handlers for an iterable of signums."""
    return {s: signal.getsignal(s) for s in signums}


def set_signals(sig_handler_dict):
    """Set the signal (keys) with the handler (values) from the input dict."""
    for s, h in sig_handler_dict.items():
        signal.signal(s, h)


@contextlib.contextmanager
def signal_receiver(signums):
    """Context manager to catch signals"""
    signals = []
    prev_handlers: Dict[int, Union[int, None, Callable]] = get_signals(signums)
    set_signals({s: lambda s, _: signals.append(s) for s in signums})
    yield signals
    set_signals(prev_handlers)


def send_signal(signum):
    """Send the given signal"""
    os.kill(os.getpid(), signum)


class ErrorHandlerTest(unittest.TestCase):
    """Tests for certbot._internal.error_handler.ErrorHandler."""

    def setUp(self):
        from certbot._internal import error_handler

        self.init_func = mock.MagicMock()
        self.init_args = {42,}
        self.init_kwargs = {'foo': 'bar'}
        self.handler = error_handler.ErrorHandler(self.init_func,
                                                  *self.init_args,
                                                  **self.init_kwargs)

        # pylint: disable=protected-access
        self.signals = error_handler._SIGNALS

    def test_context_manager(self):
        exception_raised = False
        try:
            with self.handler:
                raise ValueError
        except ValueError:
            exception_raised = True

        assert exception_raised
        self.init_func.assert_called_once_with(*self.init_args,
                                               **self.init_kwargs)

    def test_context_manager_with_signal(self):
        if not self.signals:
            self.skipTest(reason='Signals cannot be handled on Windows.')
        init_signals = get_signals(self.signals)
        with signal_receiver(self.signals) as signals_received:
            with self.handler:
                should_be_42 = 42
                send_signal(self.signals[0])
                should_be_42 *= 10

        # check execution stopped when the signal was sent
        assert 42 == should_be_42
        # assert signals were caught
        assert [self.signals[0]] == signals_received
        # assert the error handling function was just called once
        self.init_func.assert_called_once_with(*self.init_args,
                                               **self.init_kwargs)
        for signum in self.signals:
            assert init_signals[signum] == signal.getsignal(signum)

    def test_bad_recovery(self):
        bad_func = mock.MagicMock(side_effect=[ValueError])
        self.handler.register(bad_func)
        try:
            with self.handler:
                raise ValueError
        except ValueError:
            pass
        self.init_func.assert_called_once_with(*self.init_args,
                                               **self.init_kwargs)
        bad_func.assert_called_once_with()

    def test_bad_recovery_with_signal(self):
        if not self.signals:
            self.skipTest(reason='Signals cannot be handled on Windows.')
        sig1 = self.signals[0]
        sig2 = self.signals[-1]
        bad_func = mock.MagicMock(side_effect=lambda: send_signal(sig1))
        self.handler.register(bad_func)
        with signal_receiver(self.signals) as signals_received:
            with self.handler:
                send_signal(sig2)
        assert [sig2, sig1] == signals_received
        self.init_func.assert_called_once_with(*self.init_args,
                                               **self.init_kwargs)
        bad_func.assert_called_once_with()

    def test_sysexit_ignored(self):
        try:
            with self.handler:
                sys.exit(0)
        except SystemExit:
            pass
        assert self.init_func.called is False

    def test_regular_exit(self):
        func = mock.MagicMock()
        self.handler.register(func)
        with self.handler:
            pass
        self.init_func.assert_not_called()
        func.assert_not_called()


class ExitHandlerTest(ErrorHandlerTest):
    """Tests for certbot._internal.error_handler.ExitHandler."""

    def setUp(self):
        from certbot._internal import error_handler
        super().setUp()
        self.handler = error_handler.ExitHandler(self.init_func,
                                                 *self.init_args,
                                                 **self.init_kwargs)

    def test_regular_exit(self):
        func = mock.MagicMock()
        self.handler.register(func)
        with self.handler:
            pass
        self.init_func.assert_called_once_with(*self.init_args,
                                               **self.init_kwargs)
        func.assert_called_once_with()


if __name__ == "__main__":
    sys.exit(pytest.main(sys.argv[1:] + [__file__]))  # pragma: no cover