File: test_ccallback.py

package info (click to toggle)
python-scipy 1.1.0-7
  • links: PTS, VCS
  • area: main
  • in suites: buster
  • size: 93,828 kB
  • sloc: python: 156,854; ansic: 82,925; fortran: 80,777; cpp: 7,505; makefile: 427; sh: 294
file content (199 lines) | stat: -rw-r--r-- 6,061 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
from __future__ import division, print_function, absolute_import

from numpy.testing import assert_equal, assert_
from pytest import raises as assert_raises

import time
import pytest
import ctypes
import threading
from scipy._lib import _ccallback_c as _test_ccallback_cython
from scipy._lib import _test_ccallback
from scipy._lib._ccallback import LowLevelCallable

try:
    import cffi
    HAVE_CFFI = True
except ImportError:
    HAVE_CFFI = False


ERROR_VALUE = 2.0


def callback_python(a, user_data=None):
    if a == ERROR_VALUE:
        raise ValueError("bad value")

    if user_data is None:
        return a + 1
    else:
        return a + user_data

def _get_cffi_func(base, signature):
    if not HAVE_CFFI:
        pytest.skip("cffi not installed")

    # Get function address
    voidp = ctypes.cast(base, ctypes.c_void_p)
    address = voidp.value

    # Create corresponding cffi handle
    ffi = cffi.FFI()
    func = ffi.cast(signature, address)
    return func


def _get_ctypes_data():
    value = ctypes.c_double(2.0)
    return ctypes.cast(ctypes.pointer(value), ctypes.c_voidp)


def _get_cffi_data():
    if not HAVE_CFFI:
        pytest.skip("cffi not installed")
    ffi = cffi.FFI()
    return ffi.new('double *', 2.0)


CALLERS = {
    'simple': _test_ccallback.test_call_simple,
    'nodata': _test_ccallback.test_call_nodata,
    'nonlocal': _test_ccallback.test_call_nonlocal,
    'cython': _test_ccallback_cython.test_call_cython,
}

# These functions have signatures known to the callers
FUNCS = {
    'python': lambda: callback_python,
    'capsule': lambda: _test_ccallback.test_get_plus1_capsule(),
    'cython': lambda: LowLevelCallable.from_cython(_test_ccallback_cython, "plus1_cython"),
    'ctypes': lambda: _test_ccallback_cython.plus1_ctypes,
    'cffi': lambda: _get_cffi_func(_test_ccallback_cython.plus1_ctypes,
                                   'double (*)(double, int *, void *)'),
    'capsule_b': lambda: _test_ccallback.test_get_plus1b_capsule(),
    'cython_b': lambda: LowLevelCallable.from_cython(_test_ccallback_cython, "plus1b_cython"),
    'ctypes_b': lambda: _test_ccallback_cython.plus1b_ctypes,
    'cffi_b': lambda: _get_cffi_func(_test_ccallback_cython.plus1b_ctypes,
                                     'double (*)(double, double, int *, void *)'),
}

# These functions have signatures the callers don't know
BAD_FUNCS = {
    'capsule_bc': lambda: _test_ccallback.test_get_plus1bc_capsule(),
    'cython_bc': lambda: LowLevelCallable.from_cython(_test_ccallback_cython, "plus1bc_cython"),
    'ctypes_bc': lambda: _test_ccallback_cython.plus1bc_ctypes,
    'cffi_bc': lambda: _get_cffi_func(_test_ccallback_cython.plus1bc_ctypes,
                                      'double (*)(double, double, double, int *, void *)'),
}

USER_DATAS = {
    'ctypes': _get_ctypes_data,
    'cffi': _get_cffi_data,
    'capsule': _test_ccallback.test_get_data_capsule,
}


def test_callbacks():
    def check(caller, func, user_data):
        caller = CALLERS[caller]
        func = FUNCS[func]()
        user_data = USER_DATAS[user_data]()

        if func is callback_python:
            func2 = lambda x: func(x, 2.0)
        else:
            func2 = LowLevelCallable(func, user_data)
            func = LowLevelCallable(func)

        # Test basic call
        assert_equal(caller(func, 1.0), 2.0)

        # Test 'bad' value resulting to an error
        assert_raises(ValueError, caller, func, ERROR_VALUE)

        # Test passing in user_data
        assert_equal(caller(func2, 1.0), 3.0)

    for caller in sorted(CALLERS.keys()):
        for func in sorted(FUNCS.keys()):
            for user_data in sorted(USER_DATAS.keys()):
                check(caller, func, user_data)


def test_bad_callbacks():
    def check(caller, func, user_data):
        caller = CALLERS[caller]
        user_data = USER_DATAS[user_data]()
        func = BAD_FUNCS[func]()

        if func is callback_python:
            func2 = lambda x: func(x, 2.0)
        else:
            func2 = LowLevelCallable(func, user_data)
            func = LowLevelCallable(func)

        # Test that basic call fails
        assert_raises(ValueError, caller, LowLevelCallable(func), 1.0)

        # Test that passing in user_data also fails
        assert_raises(ValueError, caller, func2, 1.0)

        # Test error message
        llfunc = LowLevelCallable(func)
        try:
            caller(llfunc, 1.0)
        except ValueError as err:
            msg = str(err)
            assert_(llfunc.signature in msg, msg)
            assert_('double (double, double, int *, void *)' in msg, msg)

    for caller in sorted(CALLERS.keys()):
        for func in sorted(BAD_FUNCS.keys()):
            for user_data in sorted(USER_DATAS.keys()):
                check(caller, func, user_data)


def test_signature_override():
    caller = _test_ccallback.test_call_simple
    func = _test_ccallback.test_get_plus1_capsule()

    llcallable = LowLevelCallable(func, signature="bad signature")
    assert_equal(llcallable.signature, "bad signature")
    assert_raises(ValueError, caller, llcallable, 3)

    llcallable = LowLevelCallable(func, signature="double (double, int *, void *)")
    assert_equal(llcallable.signature, "double (double, int *, void *)")
    assert_equal(caller(llcallable, 3), 4)


def test_threadsafety():
    def callback(a, caller):
        if a <= 0:
            return 1
        else:
            res = caller(lambda x: callback(x, caller), a - 1)
            return 2*res

    def check(caller):
        caller = CALLERS[caller]

        results = []

        count = 10

        def run():
            time.sleep(0.01)
            r = caller(lambda x: callback(x, caller), count)
            results.append(r)

        threads = [threading.Thread(target=run) for j in range(20)]
        for thread in threads:
            thread.start()
        for thread in threads:
            thread.join()

        assert_equal(results, [2.0**count]*len(threads))

    for caller in CALLERS.keys():
        check(caller)