1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
|
from unittest import mock
from .mongo_client import MongoClient
import time
try:
import pymongo
from pymongo.uri_parser import parse_uri, split_hosts
_IMPORT_PYMONGO_ERROR = None
except ImportError as error:
from .helpers import parse_uri, split_hosts
_IMPORT_PYMONGO_ERROR = error
def _parse_any_host(host, default_port=27017):
if isinstance(host, tuple):
return _parse_any_host(host[0], host[1])
if '://' in host:
return parse_uri(host, warn=True)['nodelist']
return split_hosts(host, default_port=default_port)
def patch(servers='localhost', on_new='error'):
"""Patch pymongo.MongoClient.
This will patch the class MongoClient and use mongomock to mock MongoDB
servers. It keeps a consistant state of servers across multiple clients so
you can do:
```
client = pymongo.MongoClient(host='localhost', port=27017)
client.db.coll.insert_one({'name': 'Pascal'})
other_client = pymongo.MongoClient('mongodb://localhost:27017')
client.db.coll.find_one()
```
The data is persisted as long as the patch lives.
Args:
on_new: Behavior when accessing a new server (not in servers):
'create': mock a new empty server, accept any client connection.
'error': raise a ValueError immediately when trying to access.
'timeout': behave as pymongo when a server does not exist, raise an
error after a timeout.
'pymongo': use an actual pymongo client.
servers: a list of server that are avaiable.
"""
if _IMPORT_PYMONGO_ERROR:
PyMongoClient = None
else:
PyMongoClient = pymongo.MongoClient
persisted_clients = {}
parsed_servers = set()
for server in servers if isinstance(servers, (list, tuple)) else [servers]:
parsed_servers.update(_parse_any_host(server))
def _create_persistent_client(*args, **kwargs):
if _IMPORT_PYMONGO_ERROR:
raise _IMPORT_PYMONGO_ERROR # pylint: disable=raising-bad-type
client = MongoClient(*args, **kwargs)
try:
persisted_client = persisted_clients[client.address]
client._store = persisted_client._store
return client
except KeyError:
pass
if client.address in parsed_servers or on_new == 'create':
persisted_clients[client.address] = client
return client
if on_new == 'timeout':
# TODO(pcorpet): Only wait when trying to access the server's data.
time.sleep(kwargs.get('serverSelectionTimeoutMS', 30000))
raise pymongo.errors.ServerSelectionTimeoutError(
'%s:%d: [Errno 111] Connection refused' % client.address)
if on_new == 'pymongo':
return PyMongoClient(*args, **kwargs)
raise ValueError(
'MongoDB server %s:%d does not exist.\n' % client.address + '%s' % parsed_servers)
class _PersistentClient:
def __new__(cls, *args, **kwargs):
return _create_persistent_client(*args, **kwargs)
return mock.patch('pymongo.MongoClient', _PersistentClient)
|