1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
|
from traceback import format_exception
try:
import copyreg
except ImportError:
# Python 2
import copy_reg as copyreg
import pickle
import sys
import pytest
import tblib.pickling_support
has_python311 = sys.version_info >= (3, 11)
@pytest.fixture
def clear_dispatch_table():
bak = copyreg.dispatch_table.copy()
copyreg.dispatch_table.clear()
yield None
copyreg.dispatch_table.clear()
copyreg.dispatch_table.update(bak)
class CustomError(Exception):
pass
def strip_locations(tb_text):
return tb_text.replace(' ~~^~~\n', '').replace(' ^^^^^^^^^^^^^^^^^\n', '')
@pytest.mark.parametrize('protocol', [None, *list(range(1, pickle.HIGHEST_PROTOCOL + 1))])
@pytest.mark.parametrize('how', ['global', 'instance', 'class'])
def test_install(clear_dispatch_table, how, protocol):
if how == 'global':
tblib.pickling_support.install()
elif how == 'class':
tblib.pickling_support.install(CustomError, ValueError, ZeroDivisionError)
try:
try:
try:
1 / 0 # noqa: B018
finally:
# The ValueError's __context__ will be the ZeroDivisionError
raise ValueError('blah')
except Exception as e:
# Python 3 only syntax
# raise CustomError("foo") from e
new_e = CustomError('foo')
new_e.__cause__ = e
if has_python311:
new_e.add_note('note 1')
new_e.add_note('note 2')
raise new_e from e
except Exception as e:
exc = e
else:
raise AssertionError
expected_format_exception = strip_locations(''.join(format_exception(type(exc), exc, exc.__traceback__)))
# Populate Exception.__dict__, which is used in some cases
exc.x = 1
exc.__cause__.x = 2
exc.__cause__.__context__.x = 3
if how == 'instance':
tblib.pickling_support.install(exc)
if protocol:
exc = pickle.loads(pickle.dumps(exc, protocol=protocol)) # noqa: S301
assert isinstance(exc, CustomError)
assert exc.args == ('foo',)
assert exc.x == 1
assert exc.__traceback__ is not None
assert isinstance(exc.__cause__, ValueError)
assert exc.__cause__.__traceback__ is not None
assert exc.__cause__.x == 2
assert exc.__cause__.__cause__ is None
assert isinstance(exc.__cause__.__context__, ZeroDivisionError)
assert exc.__cause__.__context__.x == 3
assert exc.__cause__.__context__.__cause__ is None
assert exc.__cause__.__context__.__context__ is None
if has_python311:
assert exc.__notes__ == ['note 1', 'note 2']
assert expected_format_exception == strip_locations(''.join(format_exception(type(exc), exc, exc.__traceback__)))
@tblib.pickling_support.install
class RegisteredError(Exception):
pass
def test_install_decorator():
with pytest.raises(RegisteredError) as ewrap:
raise RegisteredError('foo')
exc = ewrap.value
exc.x = 1
exc = pickle.loads(pickle.dumps(exc)) # noqa: S301
assert isinstance(exc, RegisteredError)
assert exc.args == ('foo',)
assert exc.x == 1
assert exc.__traceback__ is not None
@pytest.mark.skipif(not has_python311, reason='no BaseExceptionGroup before Python 3.11')
def test_install_instance_recursively(clear_dispatch_table):
exc = BaseExceptionGroup('test', [ValueError('foo'), CustomError('bar')]) # noqa: F821
exc.exceptions[0].__cause__ = ZeroDivisionError('baz')
exc.exceptions[0].__cause__.__context__ = AttributeError('quux')
tblib.pickling_support.install(exc)
installed = {c for c in copyreg.dispatch_table if issubclass(c, BaseException)}
assert installed == {ExceptionGroup, ValueError, CustomError, ZeroDivisionError, AttributeError} # noqa: F821
def test_install_typeerror():
with pytest.raises(TypeError):
tblib.pickling_support.install('foo')
@pytest.mark.parametrize('protocol', [None, *list(range(1, pickle.HIGHEST_PROTOCOL + 1))])
@pytest.mark.parametrize('how', ['global', 'instance', 'class'])
def test_get_locals(clear_dispatch_table, how, protocol):
def get_locals(frame):
if 'my_variable' in frame.f_locals:
return {'my_variable': int(frame.f_locals['my_variable'])}
else:
return {}
if how == 'global':
tblib.pickling_support.install(get_locals=get_locals)
elif how == 'class':
tblib.pickling_support.install(CustomError, ValueError, ZeroDivisionError, get_locals=get_locals)
def func(my_arg='2'):
my_variable = '1'
raise ValueError(my_variable)
try:
func()
except Exception as e:
exc = e
else:
raise AssertionError
f_locals = exc.__traceback__.tb_next.tb_frame.f_locals
assert 'my_variable' in f_locals
assert f_locals['my_variable'] == '1'
if how == 'instance':
tblib.pickling_support.install(exc, get_locals=get_locals)
exc = pickle.loads(pickle.dumps(exc, protocol=protocol)) # noqa: S301
assert exc.__traceback__.tb_next.tb_frame.f_locals == {'my_variable': 1}
|