File: conftest.py

package info (click to toggle)
celery 5.5.3-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 8,008 kB
  • sloc: python: 64,346; sh: 795; makefile: 378
file content (787 lines) | stat: -rw-r--r-- 23,215 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
import builtins
import inspect
import io
import logging
import os
import platform
import sys
import threading
import types
import warnings
from contextlib import contextmanager
from functools import wraps
from importlib import import_module, reload
from unittest.mock import MagicMock, Mock, patch

import pytest
from kombu import Queue

from celery.backends.cache import CacheBackend, DummyClient
# we have to import the pytest plugin fixtures here,
# in case user did not do the `python setup.py develop` yet,
# that installs the pytest plugin into the setuptools registry.
from celery.contrib.pytest import celery_app, celery_enable_logging, celery_parameters, depends_on_current_app
from celery.contrib.testing.app import TestApp, Trap
from celery.contrib.testing.mocks import TaskMessage, TaskMessage1, task_message_from_sig

# Tricks flake8 into silencing redefining fixtures warnings.
__all__ = (
    'celery_app', 'celery_enable_logging', 'depends_on_current_app',
    'celery_parameters'
)


PYPY3 = getattr(sys, 'pypy_version_info', None) and sys.version_info[0] > 3

CASE_LOG_REDIRECT_EFFECT = 'Test {0} didn\'t disable LoggingProxy for {1}'
CASE_LOG_LEVEL_EFFECT = 'Test {0} modified the level of the root logger'
CASE_LOG_HANDLER_EFFECT = 'Test {0} modified handlers for the root logger'

_SIO_write = io.StringIO.write
_SIO_init = io.StringIO.__init__

SENTINEL = object()


def noop(*args, **kwargs):
    pass


class WhateverIO(io.StringIO):

    def __init__(self, v=None, *a, **kw):
        _SIO_init(self, v.decode() if isinstance(v, bytes) else v, *a, **kw)

    def write(self, data):
        _SIO_write(self, data.decode() if isinstance(data, bytes) else data)


@pytest.fixture(scope='session')
def celery_config():
    return {
        'broker_url': 'memory://',
        'broker_transport_options': {
            'polling_interval': 0.1
        },
        'result_backend': 'cache+memory://',
        'task_default_queue': 'testcelery',
        'task_default_exchange': 'testcelery',
        'task_default_routing_key': 'testcelery',
        'task_queues': (
            Queue('testcelery', routing_key='testcelery'),
        ),
        'accept_content': ('json', 'pickle'),

        # Mongo results tests (only executed if installed and running)
        'mongodb_backend_settings': {
            'host': os.environ.get('MONGO_HOST') or 'localhost',
            'port': os.environ.get('MONGO_PORT') or 27017,
            'database': os.environ.get('MONGO_DB') or 'celery_unittests',
            'taskmeta_collection': (
                os.environ.get('MONGO_TASKMETA_COLLECTION') or
                'taskmeta_collection'
            ),
            'user': os.environ.get('MONGO_USER'),
            'password': os.environ.get('MONGO_PASSWORD'),
        }
    }


@pytest.fixture(scope='session')
def use_celery_app_trap():
    return True


@pytest.fixture(autouse=True)
def reset_cache_backend_state(celery_app):
    """Fixture that resets the internal state of the cache result backend."""
    yield
    backend = celery_app.__dict__.get('backend')
    if backend is not None:
        if isinstance(backend, CacheBackend):
            if isinstance(backend.client, DummyClient):
                backend.client.cache.clear()
            backend._cache.clear()


@contextmanager
def assert_signal_called(signal, **expected):
    """Context that verifies signal is called before exiting."""
    handler = Mock()

    def on_call(**kwargs):
        return handler(**kwargs)

    signal.connect(on_call)
    try:
        yield handler
    finally:
        signal.disconnect(on_call)
    handler.assert_called_with(signal=signal, **expected)


@pytest.fixture
def app(celery_app):
    yield celery_app


@pytest.fixture(autouse=True, scope='session')
def AAA_disable_multiprocessing():
    # pytest-cov breaks if a multiprocessing.Process is started,
    # so disable them completely to make sure it doesn't happen.
    stuff = [
        'multiprocessing.Process',
        'billiard.Process',
        'billiard.context.Process',
        'billiard.process.Process',
        'billiard.process.BaseProcess',
        'multiprocessing.Process',
    ]
    ctxs = [patch(s) for s in stuff]
    [ctx.__enter__() for ctx in ctxs]

    yield

    [ctx.__exit__(*sys.exc_info()) for ctx in ctxs]


def alive_threads():
    return [
        thread
        for thread in threading.enumerate()
        if not thread.name.startswith("pytest_timeout ") and thread.is_alive()
    ]


@pytest.fixture(autouse=True)
def task_join_will_not_block():
    from celery import _state, result
    prev_res_join_block = result.task_join_will_block
    _state.orig_task_join_will_block = _state.task_join_will_block
    prev_state_join_block = _state.task_join_will_block
    result.task_join_will_block = \
        _state.task_join_will_block = lambda: False
    _state._set_task_join_will_block(False)

    yield

    result.task_join_will_block = prev_res_join_block
    _state.task_join_will_block = prev_state_join_block
    _state._set_task_join_will_block(False)


@pytest.fixture(scope='session', autouse=True)
def record_threads_at_startup(request):
    try:
        request.session._threads_at_startup
    except AttributeError:
        request.session._threads_at_startup = alive_threads()


@pytest.fixture(autouse=True)
def threads_not_lingering(request):
    yield
    assert request.session._threads_at_startup == alive_threads()


@pytest.fixture(autouse=True)
def AAA_reset_CELERY_LOADER_env():
    yield
    assert not os.environ.get('CELERY_LOADER')


@pytest.fixture(autouse=True)
def test_cases_shortcuts(request, app, patching, celery_config):
    if request.instance:
        @app.task
        def add(x, y):
            return x + y

        # IMPORTANT: We set an .app attribute for every test case class.
        request.instance.app = app
        request.instance.Celery = TestApp
        request.instance.assert_signal_called = assert_signal_called
        request.instance.task_message_from_sig = task_message_from_sig
        request.instance.TaskMessage = TaskMessage
        request.instance.TaskMessage1 = TaskMessage1
        request.instance.CELERY_TEST_CONFIG = celery_config
        request.instance.add = add
        request.instance.patching = patching
    yield
    if request.instance:
        request.instance.app = None


@pytest.fixture(autouse=True)
def sanity_no_shutdown_flags_set():
    yield

    # Make sure no test left the shutdown flags enabled.
    from celery.worker import state as worker_state

    # check for EX_OK
    assert worker_state.should_stop is not False
    assert worker_state.should_terminate is not False
    # check for other true values
    assert not worker_state.should_stop
    assert not worker_state.should_terminate


@pytest.fixture(autouse=True)
def sanity_stdouts(request):
    yield

    from celery.utils.log import LoggingProxy
    assert sys.stdout
    assert sys.stderr
    assert sys.__stdout__
    assert sys.__stderr__
    this = request.node.name
    if isinstance(sys.stdout, (LoggingProxy, Mock)) or \
            isinstance(sys.__stdout__, (LoggingProxy, Mock)):
        raise RuntimeError(CASE_LOG_REDIRECT_EFFECT.format(this, 'stdout'))
    if isinstance(sys.stderr, (LoggingProxy, Mock)) or \
            isinstance(sys.__stderr__, (LoggingProxy, Mock)):
        raise RuntimeError(CASE_LOG_REDIRECT_EFFECT.format(this, 'stderr'))


@pytest.fixture(autouse=True)
def sanity_logging_side_effects(request):
    from _pytest.logging import LogCaptureHandler
    root = logging.getLogger()
    rootlevel = root.level
    roothandlers = [
        x for x in root.handlers if not isinstance(x, LogCaptureHandler)]

    yield

    this = request.node.name
    root_now = logging.getLogger()
    if root_now.level != rootlevel:
        raise RuntimeError(CASE_LOG_LEVEL_EFFECT.format(this))
    newhandlers = [x for x in root_now.handlers if not isinstance(
        x, LogCaptureHandler)]
    if newhandlers != roothandlers:
        raise RuntimeError(CASE_LOG_HANDLER_EFFECT.format(this))


def setup_session(scope='session'):
    using_coverage = (
        os.environ.get('COVER_ALL_MODULES') or '--with-coverage' in sys.argv
    )
    os.environ.update(
        # warn if config module not found
        C_WNOCONF='yes',
        KOMBU_DISABLE_LIMIT_PROTECTION='yes',
    )

    if using_coverage and not PYPY3:
        from warnings import catch_warnings
        with catch_warnings(record=True):
            import_all_modules()
        warnings.resetwarnings()
    from celery._state import set_default_app
    set_default_app(Trap())


def teardown():
    # Don't want SUBDEBUG log messages at finalization.
    try:
        from multiprocessing.util import get_logger
    except ImportError:
        pass
    else:
        get_logger().setLevel(logging.WARNING)

    # Make sure test database is removed.
    import os
    if os.path.exists('test.db'):
        try:
            os.remove('test.db')
        except OSError:
            pass

    # Make sure there are no remaining threads at shutdown.
    import threading
    remaining_threads = [thread for thread in threading.enumerate()
                         if thread.getName() != 'MainThread']
    if remaining_threads:
        sys.stderr.write(
            '\n\n**WARNING**: Remaining threads at teardown: %r...\n' % (
                remaining_threads))


def find_distribution_modules(name=__name__, file=__file__):
    current_dist_depth = len(name.split('.')) - 1
    current_dist = os.path.join(os.path.dirname(file),
                                *([os.pardir] * current_dist_depth))
    abs = os.path.abspath(current_dist)
    dist_name = os.path.basename(abs)

    for dirpath, dirnames, filenames in os.walk(abs):
        package = (dist_name + dirpath[len(abs):]).replace('/', '.')
        if '__init__.py' in filenames:
            yield package
            for filename in filenames:
                if filename.endswith('.py') and filename != '__init__.py':
                    yield '.'.join([package, filename])[:-3]


def import_all_modules(name=__name__, file=__file__,
                       skip=('celery.decorators',
                             'celery.task')):
    for module in find_distribution_modules(name, file):
        if not module.startswith(skip):
            try:
                import_module(module)
            except ImportError:
                pass
            except OSError as exc:
                warnings.warn(UserWarning(
                    'Ignored error importing module {}: {!r}'.format(
                        module, exc,
                    )))


@pytest.fixture
def sleepdeprived(request):
    """Mock sleep method in patched module to do nothing.

    Example:
        >>> import time
        >>> @pytest.mark.sleepdeprived_patched_module(time)
        >>> def test_foo(self, sleepdeprived):
        >>>     pass
    """
    module = request.node.get_closest_marker(
        "sleepdeprived_patched_module").args[0]
    old_sleep, module.sleep = module.sleep, noop
    try:
        yield
    finally:
        module.sleep = old_sleep


# Taken from
# http://bitbucket.org/runeh/snippets/src/tip/missing_modules.py
@pytest.fixture
def mask_modules(request):
    """Ban some modules from being importable inside the context
    For example::
        >>> @pytest.mark.masked_modules('gevent.monkey')
        >>> def test_foo(self, mask_modules):
        ...     try:
        ...         import sys
        ...     except ImportError:
        ...         print('sys not found')
        sys not found
    """
    realimport = builtins.__import__
    modnames = request.node.get_closest_marker("masked_modules").args

    def myimp(name, *args, **kwargs):
        if name in modnames:
            raise ImportError('No module named %s' % name)
        else:
            return realimport(name, *args, **kwargs)

    builtins.__import__ = myimp
    try:
        yield
    finally:
        builtins.__import__ = realimport


@pytest.fixture
def environ(request):
    """Mock environment variable value.
    Example::
        >>> @pytest.mark.patched_environ('DJANGO_SETTINGS_MODULE', 'proj.settings')
        >>> def test_other_settings(self, environ):
        ...    ...
    """
    env_name, env_value = request.node.get_closest_marker("patched_environ").args
    prev_val = os.environ.get(env_name, SENTINEL)
    os.environ[env_name] = env_value
    try:
        yield
    finally:
        if prev_val is SENTINEL:
            os.environ.pop(env_name, None)
        else:
            os.environ[env_name] = prev_val


def replace_module_value(module, name, value=None):
    """Mock module value, given a module, attribute name and value.

    Example::

        >>> replace_module_value(module, 'CONSTANT', 3.03)
    """
    has_prev = hasattr(module, name)
    prev = getattr(module, name, None)
    if value:
        setattr(module, name, value)
    else:
        try:
            delattr(module, name)
        except AttributeError:
            pass
    try:
        yield
    finally:
        if prev is not None:
            setattr(module, name, prev)
        if not has_prev:
            try:
                delattr(module, name)
            except AttributeError:
                pass


@contextmanager
def platform_pyimp(value=None):
    """Mock :data:`platform.python_implementation`
    Example::
        >>> with platform_pyimp('PyPy'):
        ...     ...
    """
    yield from replace_module_value(platform, 'python_implementation', value)


@contextmanager
def sys_platform(value=None):
    """Mock :data:`sys.platform`

    Example::
        >>> mock.sys_platform('darwin'):
        ...     ...
    """
    prev, sys.platform = sys.platform, value
    try:
        yield
    finally:
        sys.platform = prev


@contextmanager
def pypy_version(value=None):
    """Mock :data:`sys.pypy_version_info`

    Example::
        >>> with pypy_version((3, 6, 1)):
        ...     ...
    """
    yield from replace_module_value(sys, 'pypy_version_info', value)


def _restore_logging():
    outs = sys.stdout, sys.stderr, sys.__stdout__, sys.__stderr__
    root = logging.getLogger()
    level = root.level
    handlers = root.handlers

    try:
        yield
    finally:
        sys.stdout, sys.stderr, sys.__stdout__, sys.__stderr__ = outs
        root.level = level
        root.handlers[:] = handlers


@contextmanager
def restore_logging_context_manager():
    """Restore root logger handlers after test returns.
    Example::
        >>> with restore_logging_context_manager():
        ...     setup_logging()
    """
    yield from _restore_logging()


@pytest.fixture
def restore_logging(request):
    """Restore root logger handlers after test returns.
    Example::
        >>> def test_foo(self, restore_logging):
        ...     setup_logging()
    """
    yield from _restore_logging()


@pytest.fixture
def module(request):
    """Mock one or modules such that every attribute is a :class:`Mock`."""
    yield from _module(*request.node.get_closest_marker("patched_module").args)


@contextmanager
def module_context_manager(*names):
    """Mock one or modules such that every attribute is a :class:`Mock`."""
    yield from _module(*names)


def _module(*names):
    prev = {}

    class MockModule(types.ModuleType):

        def __getattr__(self, attr):
            setattr(self, attr, Mock())
            return types.ModuleType.__getattribute__(self, attr)

    mods = []
    for name in names:
        try:
            prev[name] = sys.modules[name]
        except KeyError:
            pass
        mod = sys.modules[name] = MockModule(name)
        mods.append(mod)
    try:
        yield mods
    finally:
        for name in names:
            try:
                sys.modules[name] = prev[name]
            except KeyError:
                try:
                    del (sys.modules[name])
                except KeyError:
                    pass


class _patching:

    def __init__(self, monkeypatch, request):
        self.monkeypatch = monkeypatch
        self.request = request

    def __getattr__(self, name):
        return getattr(self.monkeypatch, name)

    def __call__(self, path, value=SENTINEL, name=None,
                 new=MagicMock, **kwargs):
        value = self._value_or_mock(value, new, name, path, **kwargs)
        self.monkeypatch.setattr(path, value)
        return value

    def object(self, target, attribute, *args, **kwargs):
        return _wrap_context(
            patch.object(target, attribute, *args, **kwargs),
            self.request)

    def _value_or_mock(self, value, new, name, path, **kwargs):
        if value is SENTINEL:
            value = new(name=name or path.rpartition('.')[2])
        for k, v in kwargs.items():
            setattr(value, k, v)
        return value

    def setattr(self, target, name=SENTINEL, value=SENTINEL, **kwargs):
        # alias to __call__ with the interface of pytest.monkeypatch.setattr
        if value is SENTINEL:
            value, name = name, None
        return self(target, value, name=name)

    def setitem(self, dic, name, value=SENTINEL, new=MagicMock, **kwargs):
        # same as pytest.monkeypatch.setattr but default value is MagicMock
        value = self._value_or_mock(value, new, name, dic, **kwargs)
        self.monkeypatch.setitem(dic, name, value)
        return value

    def modules(self, *mods):
        modules = []
        for mod in mods:
            mod = mod.split('.')
            modules.extend(reversed([
                '.'.join(mod[:-i] if i else mod) for i in range(len(mod))
            ]))
        modules = sorted(set(modules))
        return _wrap_context(module_context_manager(*modules), self.request)


def _wrap_context(context, request):
    ret = context.__enter__()

    def fin():
        context.__exit__(*sys.exc_info())
    request.addfinalizer(fin)
    return ret


@pytest.fixture()
def patching(monkeypatch, request):
    """Monkeypath.setattr shortcut.
    Example:
        .. code-block:: python
        >>> def test_foo(patching):
        >>>     # execv value here will be mock.MagicMock by default.
        >>>     execv = patching('os.execv')
        >>>     patching('sys.platform', 'darwin')  # set concrete value
        >>>     patching.setenv('DJANGO_SETTINGS_MODULE', 'x.settings')
        >>>     # val will be of type mock.MagicMock by default
        >>>     val = patching.setitem('path.to.dict', 'KEY')
    """
    return _patching(monkeypatch, request)


@contextmanager
def stdouts():
    """Override `sys.stdout` and `sys.stderr` with `StringIO`
    instances.
        >>> with conftest.stdouts() as (stdout, stderr):
        ...     something()
        ...     self.assertIn('foo', stdout.getvalue())
    """
    prev_out, prev_err = sys.stdout, sys.stderr
    prev_rout, prev_rerr = sys.__stdout__, sys.__stderr__
    mystdout, mystderr = WhateverIO(), WhateverIO()
    sys.stdout = sys.__stdout__ = mystdout
    sys.stderr = sys.__stderr__ = mystderr

    try:
        yield mystdout, mystderr
    finally:
        sys.stdout = prev_out
        sys.stderr = prev_err
        sys.__stdout__ = prev_rout
        sys.__stderr__ = prev_rerr


@contextmanager
def reset_modules(*modules):
    """Remove modules from :data:`sys.modules` by name,
    and reset back again when the test/context returns.
    Example::
        >>> with conftest.reset_modules('celery.result', 'celery.app.base'):
        ...     pass
    """
    prev = {
        k: sys.modules.pop(k) for k in modules if k in sys.modules
    }

    try:
        for k in modules:
            reload(import_module(k))
        yield
    finally:
        sys.modules.update(prev)


def get_logger_handlers(logger):
    return [
        h for h in logger.handlers
        if not isinstance(h, logging.NullHandler)
    ]


@contextmanager
def wrap_logger(logger, loglevel=logging.ERROR):
    """Wrap :class:`logging.Logger` with a StringIO() handler.
    yields a StringIO handle.
    Example::
        >>> with conftest.wrap_logger(logger, loglevel=logging.DEBUG) as sio:
        ...     ...
        ...     sio.getvalue()
    """
    old_handlers = get_logger_handlers(logger)
    sio = WhateverIO()
    siohandler = logging.StreamHandler(sio)
    logger.handlers = [siohandler]

    try:
        yield sio
    finally:
        logger.handlers = old_handlers


@contextmanager
def _mock_context(mock):
    context = mock.return_value = Mock()
    context.__enter__ = Mock()
    context.__exit__ = Mock()

    def on_exit(*x):
        if x[0]:
            raise x[0] from x[1]
    context.__exit__.side_effect = on_exit
    context.__enter__.return_value = context
    try:
        yield context
    finally:
        context.reset()


@contextmanager
def open(side_effect=None):
    """Patch builtins.open so that it returns StringIO object.
    :param side_effect: Additional side effect for when the open context
        is entered.
    Example::
        >>> with mock.open(io.BytesIO) as open_fh:
        ...     something_opening_and_writing_bytes_to_a_file()
        ...     self.assertIn(b'foo', open_fh.getvalue())
    """
    with patch('builtins.open') as open_:
        with _mock_context(open_) as context:
            if side_effect is not None:
                context.__enter__.side_effect = side_effect
            val = context.__enter__.return_value = WhateverIO()
            val.__exit__ = Mock()
            yield val


@contextmanager
def module_exists(*modules):
    """Patch one or more modules to ensure they exist.
    A module name with multiple paths (e.g. gevent.monkey) will
    ensure all parent modules are also patched (``gevent`` +
    ``gevent.monkey``).
    Example::
        >>> with conftest.module_exists('gevent.monkey'):
        ...     gevent.monkey.patch_all = Mock(name='patch_all')
        ...     ...
    """
    gen = []
    old_modules = []
    for module in modules:
        if isinstance(module, str):
            module = types.ModuleType(module)
        gen.append(module)
        if module.__name__ in sys.modules:
            old_modules.append(sys.modules[module.__name__])
        sys.modules[module.__name__] = module
        name = module.__name__
        if '.' in name:
            parent, _, attr = name.rpartition('.')
            setattr(sys.modules[parent], attr, module)
    try:
        yield
    finally:
        for module in gen:
            sys.modules.pop(module.__name__, None)
        for module in old_modules:
            sys.modules[module.__name__] = module


def _bind(f, o):
    @wraps(f)
    def bound_meth(*fargs, **fkwargs):
        return f(o, *fargs, **fkwargs)
    return bound_meth


class MockCallbacks:

    def __new__(cls, *args, **kwargs):
        r = Mock(name=cls.__name__)
        cls.__init__(r, *args, **kwargs)
        for key, value in vars(cls).items():
            if key not in ('__dict__', '__weakref__', '__new__', '__init__'):
                if inspect.ismethod(value) or inspect.isfunction(value):
                    r.__getattr__(key).side_effect = _bind(value, r)
                else:
                    r.__setattr__(key, value)
        return r