File: test_backends.py

package info (click to toggle)
celery 5.6.2-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 8,376 kB
  • sloc: python: 67,264; sh: 795; makefile: 378
file content (136 lines) | stat: -rw-r--r-- 4,510 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import threading
from contextlib import contextmanager
from unittest.mock import patch

import pytest

import celery.contrib.testing.worker as contrib_embed_worker
from celery.app import backends
from celery.backends.cache import CacheBackend
from celery.exceptions import ImproperlyConfigured
from celery.utils.nodenames import anon_nodename


class CachedBackendWithTreadTrucking(CacheBackend):
    test_instance_count = 0
    test_call_stats = {}

    def _track_attribute_access(self, method_name):
        cls = type(self)

        instance_no = getattr(self, '_instance_no', None)
        if instance_no is None:
            instance_no = self._instance_no = cls.test_instance_count
            cls.test_instance_count += 1
            cls.test_call_stats[instance_no] = []

        cls.test_call_stats[instance_no].append({
            'thread_id': threading.get_ident(),
            'method_name': method_name
        })

    def __getattribute__(self, name):
        if name == '_instance_no' or name == '_track_attribute_access':
            return super().__getattribute__(name)

        if name.startswith('__') and name != '__init__':
            return super().__getattribute__(name)

        self._track_attribute_access(name)
        return super().__getattribute__(name)


@contextmanager
def embed_worker(app,
                 concurrency=1,
                 pool='threading', **kwargs):
    """
    Helper embedded worker for testing.

    It's based on a :func:`celery.contrib.testing.worker.start_worker`,
    but doesn't modify logging settings and additionally shutdown
    worker pool.
    """
    # prepare application for worker
    app.finalize()
    app.set_current()

    worker = contrib_embed_worker.TestWorkController(
        app=app,
        concurrency=concurrency,
        hostname=anon_nodename(),
        pool=pool,
        # not allowed to override TestWorkController.on_consumer_ready
        ready_callback=None,
        without_heartbeat=kwargs.pop("without_heartbeat", True),
        without_mingle=True,
        without_gossip=True,
        **kwargs
    )

    t = threading.Thread(target=worker.start, daemon=True)
    t.start()
    worker.ensure_started()

    yield worker

    worker.stop()
    t.join(10.0)
    if t.is_alive():
        raise RuntimeError(
            "Worker thread failed to exit within the allocated timeout. "
            "Consider raising `shutdown_timeout` if your tasks take longer "
            "to execute."
        )


class test_backends:

    @pytest.mark.parametrize('url,expect_cls', [
        ('cache+memory://', CacheBackend),
    ])
    def test_get_backend_aliases(self, url, expect_cls, app):
        backend, url = backends.by_url(url, app.loader)
        assert isinstance(backend(app=app, url=url), expect_cls)

    def test_unknown_backend(self, app):
        with pytest.raises(ImportError):
            backends.by_name('fasodaopjeqijwqe', app.loader)

    def test_backend_by_url(self, app, url='redis://localhost/1'):
        from celery.backends.redis import RedisBackend
        backend, url_ = backends.by_url(url, app.loader)
        assert backend is RedisBackend
        assert url_ == url

    def test_sym_raises_ValuError(self, app):
        with patch('celery.app.backends.symbol_by_name') as sbn:
            sbn.side_effect = ValueError()
            with pytest.raises(ImproperlyConfigured):
                backends.by_name('xxx.xxx:foo', app.loader)

    def test_backend_can_not_be_module(self, app):
        with pytest.raises(ImproperlyConfigured):
            backends.by_name(pytest, app.loader)

    @pytest.mark.celery(
        result_backend=f'{CachedBackendWithTreadTrucking.__module__}.'
        f'{CachedBackendWithTreadTrucking.__qualname__}'
        f'+memory://')
    def test_backend_thread_safety(self):
        @self.app.task
        def dummy_add_task(x, y):
            return x + y

        with embed_worker(app=self.app, pool='threads'):
            result = dummy_add_task.delay(6, 9)
            assert result.get(timeout=10) == 15

        call_stats = CachedBackendWithTreadTrucking.test_call_stats
        # check that backend instance is used without same thread
        for backend_call_stats in call_stats.values():
            thread_ids = set()
            for call_stat in backend_call_stats:
                thread_ids.add(call_stat['thread_id'])
            assert len(thread_ids) <= 1, \
                "The same celery backend instance is used by multiple threads"