File: test_pickle_exception.py

package info (click to toggle)
python-tblib 3.1.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 392 kB
  • sloc: python: 786; makefile: 5
file content (166 lines) | stat: -rw-r--r-- 5,123 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
from traceback import format_exception

try:
    import copyreg
except ImportError:
    # Python 2
    import copy_reg as copyreg

import pickle
import sys

import pytest

import tblib.pickling_support

has_python311 = sys.version_info >= (3, 11)


@pytest.fixture
def clear_dispatch_table():
    bak = copyreg.dispatch_table.copy()
    copyreg.dispatch_table.clear()
    yield None
    copyreg.dispatch_table.clear()
    copyreg.dispatch_table.update(bak)


class CustomError(Exception):
    pass


def strip_locations(tb_text):
    return tb_text.replace('    ~~^~~\n', '').replace('    ^^^^^^^^^^^^^^^^^\n', '')


@pytest.mark.parametrize('protocol', [None, *list(range(1, pickle.HIGHEST_PROTOCOL + 1))])
@pytest.mark.parametrize('how', ['global', 'instance', 'class'])
def test_install(clear_dispatch_table, how, protocol):
    if how == 'global':
        tblib.pickling_support.install()
    elif how == 'class':
        tblib.pickling_support.install(CustomError, ValueError, ZeroDivisionError)

    try:
        try:
            try:
                1 / 0  # noqa: B018
            finally:
                # The ValueError's __context__ will be the ZeroDivisionError
                raise ValueError('blah')
        except Exception as e:
            # Python 3 only syntax
            # raise CustomError("foo") from e
            new_e = CustomError('foo')
            new_e.__cause__ = e
            if has_python311:
                new_e.add_note('note 1')
                new_e.add_note('note 2')
            raise new_e from e
    except Exception as e:
        exc = e
    else:
        raise AssertionError

    expected_format_exception = strip_locations(''.join(format_exception(type(exc), exc, exc.__traceback__)))

    # Populate Exception.__dict__, which is used in some cases
    exc.x = 1
    exc.__cause__.x = 2
    exc.__cause__.__context__.x = 3

    if how == 'instance':
        tblib.pickling_support.install(exc)
    if protocol:
        exc = pickle.loads(pickle.dumps(exc, protocol=protocol))  # noqa: S301

    assert isinstance(exc, CustomError)
    assert exc.args == ('foo',)
    assert exc.x == 1
    assert exc.__traceback__ is not None

    assert isinstance(exc.__cause__, ValueError)
    assert exc.__cause__.__traceback__ is not None
    assert exc.__cause__.x == 2
    assert exc.__cause__.__cause__ is None

    assert isinstance(exc.__cause__.__context__, ZeroDivisionError)
    assert exc.__cause__.__context__.x == 3
    assert exc.__cause__.__context__.__cause__ is None
    assert exc.__cause__.__context__.__context__ is None

    if has_python311:
        assert exc.__notes__ == ['note 1', 'note 2']

    assert expected_format_exception == strip_locations(''.join(format_exception(type(exc), exc, exc.__traceback__)))


@tblib.pickling_support.install
class RegisteredError(Exception):
    pass


def test_install_decorator():
    with pytest.raises(RegisteredError) as ewrap:
        raise RegisteredError('foo')
    exc = ewrap.value
    exc.x = 1
    exc = pickle.loads(pickle.dumps(exc))  # noqa: S301

    assert isinstance(exc, RegisteredError)
    assert exc.args == ('foo',)
    assert exc.x == 1
    assert exc.__traceback__ is not None


@pytest.mark.skipif(not has_python311, reason='no BaseExceptionGroup before Python 3.11')
def test_install_instance_recursively(clear_dispatch_table):
    exc = BaseExceptionGroup('test', [ValueError('foo'), CustomError('bar')])  # noqa: F821
    exc.exceptions[0].__cause__ = ZeroDivisionError('baz')
    exc.exceptions[0].__cause__.__context__ = AttributeError('quux')

    tblib.pickling_support.install(exc)

    installed = {c for c in copyreg.dispatch_table if issubclass(c, BaseException)}
    assert installed == {ExceptionGroup, ValueError, CustomError, ZeroDivisionError, AttributeError}  # noqa: F821


def test_install_typeerror():
    with pytest.raises(TypeError):
        tblib.pickling_support.install('foo')


@pytest.mark.parametrize('protocol', [None, *list(range(1, pickle.HIGHEST_PROTOCOL + 1))])
@pytest.mark.parametrize('how', ['global', 'instance', 'class'])
def test_get_locals(clear_dispatch_table, how, protocol):
    def get_locals(frame):
        if 'my_variable' in frame.f_locals:
            return {'my_variable': int(frame.f_locals['my_variable'])}
        else:
            return {}

    if how == 'global':
        tblib.pickling_support.install(get_locals=get_locals)
    elif how == 'class':
        tblib.pickling_support.install(CustomError, ValueError, ZeroDivisionError, get_locals=get_locals)

    def func(my_arg='2'):
        my_variable = '1'
        raise ValueError(my_variable)

    try:
        func()
    except Exception as e:
        exc = e
    else:
        raise AssertionError

    f_locals = exc.__traceback__.tb_next.tb_frame.f_locals
    assert 'my_variable' in f_locals
    assert f_locals['my_variable'] == '1'

    if how == 'instance':
        tblib.pickling_support.install(exc, get_locals=get_locals)

    exc = pickle.loads(pickle.dumps(exc, protocol=protocol))  # noqa: S301
    assert exc.__traceback__.tb_next.tb_frame.f_locals == {'my_variable': 1}