File: patch.py

package info (click to toggle)
python-mongomock 4.3.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,028 kB
  • sloc: python: 16,412; makefile: 24
file content (92 lines) | stat: -rw-r--r-- 3,158 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from unittest import mock

from .mongo_client import MongoClient
import time

try:
    import pymongo
    from pymongo.uri_parser import parse_uri, split_hosts
    _IMPORT_PYMONGO_ERROR = None
except ImportError as error:
    from .helpers import parse_uri, split_hosts
    _IMPORT_PYMONGO_ERROR = error


def _parse_any_host(host, default_port=27017):
    if isinstance(host, tuple):
        return _parse_any_host(host[0], host[1])
    if '://' in host:
        return parse_uri(host, warn=True)['nodelist']
    return split_hosts(host, default_port=default_port)


def patch(servers='localhost', on_new='error'):
    """Patch pymongo.MongoClient.

    This will patch the class MongoClient and use mongomock to mock MongoDB
    servers. It keeps a consistant state of servers across multiple clients so
    you can do:

    ```
    client = pymongo.MongoClient(host='localhost', port=27017)
    client.db.coll.insert_one({'name': 'Pascal'})

    other_client = pymongo.MongoClient('mongodb://localhost:27017')
    client.db.coll.find_one()
    ```

    The data is persisted as long as the patch lives.

    Args:
        on_new: Behavior when accessing a new server (not in servers):
            'create': mock a new empty server, accept any client connection.
            'error': raise a ValueError immediately when trying to access.
            'timeout': behave as pymongo when a server does not exist, raise an
                error after a timeout.
            'pymongo': use an actual pymongo client.
        servers: a list of server that are avaiable.
    """
    if _IMPORT_PYMONGO_ERROR:
        PyMongoClient = None
    else:
        PyMongoClient = pymongo.MongoClient

    persisted_clients = {}
    parsed_servers = set()
    for server in servers if isinstance(servers, (list, tuple)) else [servers]:
        parsed_servers.update(_parse_any_host(server))

    def _create_persistent_client(*args, **kwargs):
        if _IMPORT_PYMONGO_ERROR:
            raise _IMPORT_PYMONGO_ERROR  # pylint: disable=raising-bad-type

        client = MongoClient(*args, **kwargs)

        try:
            persisted_client = persisted_clients[client.address]
            client._store = persisted_client._store
            return client
        except KeyError:
            pass

        if client.address in parsed_servers or on_new == 'create':
            persisted_clients[client.address] = client
            return client

        if on_new == 'timeout':
            # TODO(pcorpet): Only wait when trying to access the server's data.
            time.sleep(kwargs.get('serverSelectionTimeoutMS', 30000))
            raise pymongo.errors.ServerSelectionTimeoutError(
                '%s:%d: [Errno 111] Connection refused' % client.address)

        if on_new == 'pymongo':
            return PyMongoClient(*args, **kwargs)

        raise ValueError(
            'MongoDB server %s:%d does not exist.\n' % client.address + '%s' % parsed_servers)

    class _PersistentClient:
        def __new__(cls, *args, **kwargs):
            return _create_persistent_client(*args, **kwargs)

    return mock.patch('pymongo.MongoClient', _PersistentClient)