File: test_concurrency.py

package info (click to toggle)
celery 5.5.3-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 8,008 kB
  • sloc: python: 64,346; sh: 795; makefile: 378
file content (188 lines) | stat: -rw-r--r-- 5,687 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import importlib
import os
import sys
from itertools import count
from unittest.mock import Mock, patch

import pytest

from celery import concurrency
from celery.concurrency.base import BasePool, apply_target
from celery.exceptions import WorkerShutdown, WorkerTerminate


class test_BasePool:

    def test_apply_target(self):

        scratch = {}
        counter = count(0)

        def gen_callback(name, retval=None):

            def callback(*args):
                scratch[name] = (next(counter), args)
                return retval

            return callback

        apply_target(gen_callback('target', 42),
                     args=(8, 16),
                     callback=gen_callback('callback'),
                     accept_callback=gen_callback('accept_callback'))

        assert scratch['target'] == (1, (8, 16))
        assert scratch['callback'] == (2, (42,))
        pa1 = scratch['accept_callback']
        assert pa1[0] == 0
        assert pa1[1][0] == os.getpid()
        assert pa1[1][1]

        # No accept callback
        scratch.clear()
        apply_target(gen_callback('target', 42),
                     args=(8, 16),
                     callback=gen_callback('callback'),
                     accept_callback=None)
        assert scratch == {
            'target': (3, (8, 16)),
            'callback': (4, (42,)),
        }

    def test_apply_target__propagate(self):
        target = Mock(name='target')
        target.side_effect = KeyError()
        with pytest.raises(KeyError):
            apply_target(target, propagate=(KeyError,))

    def test_apply_target__raises(self):
        target = Mock(name='target')
        target.side_effect = KeyError()
        with pytest.raises(KeyError):
            apply_target(target)

    def test_apply_target__raises_WorkerShutdown(self):
        target = Mock(name='target')
        target.side_effect = WorkerShutdown()
        with pytest.raises(WorkerShutdown):
            apply_target(target)

    def test_apply_target__raises_WorkerTerminate(self):
        target = Mock(name='target')
        target.side_effect = WorkerTerminate()
        with pytest.raises(WorkerTerminate):
            apply_target(target)

    def test_apply_target__raises_BaseException(self):
        target = Mock(name='target')
        callback = Mock(name='callback')
        target.side_effect = BaseException()
        apply_target(target, callback=callback)
        callback.assert_called()

    @patch('celery.concurrency.base.reraise')
    def test_apply_target__raises_BaseException_raises_else(self, reraise):
        target = Mock(name='target')
        callback = Mock(name='callback')
        reraise.side_effect = KeyError()
        target.side_effect = BaseException()
        with pytest.raises(KeyError):
            apply_target(target, callback=callback)
        callback.assert_not_called()

    def test_does_not_debug(self):
        x = BasePool(10)
        x._does_debug = False
        x.apply_async(object)

    def test_num_processes(self):
        assert BasePool(7).num_processes == 7

    def test_interface_on_start(self):
        BasePool(10).on_start()

    def test_interface_on_stop(self):
        BasePool(10).on_stop()

    def test_interface_on_apply(self):
        BasePool(10).on_apply()

    def test_interface_info(self):
        assert BasePool(10).info == {
            'implementation': 'celery.concurrency.base:BasePool',
            'max-concurrency': 10,
        }

    def test_interface_flush(self):
        assert BasePool(10).flush() is None

    def test_active(self):
        p = BasePool(10)
        assert not p.active
        p._state = p.RUN
        assert p.active

    def test_restart(self):
        p = BasePool(10)
        with pytest.raises(NotImplementedError):
            p.restart()

    def test_interface_on_terminate(self):
        p = BasePool(10)
        p.on_terminate()

    def test_interface_terminate_job(self):
        with pytest.raises(NotImplementedError):
            BasePool(10).terminate_job(101)

    def test_interface_did_start_ok(self):
        assert BasePool(10).did_start_ok()

    def test_interface_register_with_event_loop(self):
        assert BasePool(10).register_with_event_loop(Mock()) is None

    def test_interface_on_soft_timeout(self):
        assert BasePool(10).on_soft_timeout(Mock()) is None

    def test_interface_on_hard_timeout(self):
        assert BasePool(10).on_hard_timeout(Mock()) is None

    def test_interface_close(self):
        p = BasePool(10)
        p.on_close = Mock()
        p.close()
        assert p._state == p.CLOSE
        p.on_close.assert_called_with()

    def test_interface_no_close(self):
        assert BasePool(10).on_close() is None


class test_get_available_pool_names:

    def test_no_concurrent_futures__returns_no_threads_pool_name(self):
        expected_pool_names = (
            'prefork',
            'eventlet',
            'gevent',
            'solo',
            'processes',
            'custom',
        )
        with patch.dict(sys.modules, {'concurrent.futures': None}):
            importlib.reload(concurrency)
            assert concurrency.get_available_pool_names() == expected_pool_names

    def test_concurrent_futures__returns_threads_pool_name(self):
        expected_pool_names = (
            'prefork',
            'eventlet',
            'gevent',
            'solo',
            'processes',
            'threads',
            'custom',
        )
        with patch.dict(sys.modules, {'concurrent.futures': Mock()}):
            importlib.reload(concurrency)
            assert concurrency.get_available_pool_names() == expected_pool_names