1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
|
import collections
import datetime
import functools
import mongomock
from mongomock.thread import RWLock
class ServerStore:
"""Object holding the data for a whole server (many databases)."""
def __init__(self):
self._databases = {}
def __getitem__(self, db_name):
try:
return self._databases[db_name]
except KeyError:
db = self._databases[db_name] = DatabaseStore()
return db
def __contains__(self, db_name):
return self[db_name].is_created
def list_created_database_names(self):
return [name for name, db in self._databases.items() if db.is_created]
class DatabaseStore:
"""Object holding the data for a database (many collections)."""
def __init__(self):
self._collections = {}
def __getitem__(self, col_name):
try:
return self._collections[col_name]
except KeyError:
col = self._collections[col_name] = CollectionStore(col_name)
return col
def __contains__(self, col_name):
return self[col_name].is_created
def list_created_collection_names(self):
return [name for name, col in self._collections.items() if col.is_created]
def create_collection(self, name):
col = self[name]
col.create()
return col
def rename(self, name, new_name):
col = self._collections.pop(name, CollectionStore(new_name))
col.name = new_name
self._collections[new_name] = col
@property
def is_created(self):
return any(col.is_created for col in self._collections.values())
class CollectionStore:
"""Object holding the data for a collection."""
def __init__(self, name):
self._documents = collections.OrderedDict()
self.indexes = {}
self._is_force_created = False
self.name = name
self._ttl_indexes = {}
# 694 - Lock for safely iterating and mutating OrderedDicts
self._rwlock = RWLock()
def create(self):
self._is_force_created = True
@property
def is_created(self):
return self._documents or self.indexes or self._is_force_created
def drop(self):
self._documents = collections.OrderedDict()
self.indexes = {}
self._ttl_indexes = {}
self._is_force_created = False
def create_index(self, index_name, index_dict):
self.indexes[index_name] = index_dict
if index_dict.get('expireAfterSeconds') is not None:
self._ttl_indexes[index_name] = index_dict
def drop_index(self, index_name):
self._remove_expired_documents()
# The main index object should raise a KeyError, but the
# TTL indexes have no meaning to the outside.
del self.indexes[index_name]
self._ttl_indexes.pop(index_name, None)
@property
def is_empty(self):
self._remove_expired_documents()
return not self._documents
def __contains__(self, key):
self._remove_expired_documents()
with self._rwlock.reader():
return key in self._documents
def __getitem__(self, key):
self._remove_expired_documents()
with self._rwlock.reader():
return self._documents[key]
def __setitem__(self, key, val):
with self._rwlock.writer():
self._documents[key] = val
def __delitem__(self, key):
with self._rwlock.writer():
del self._documents[key]
def __len__(self):
self._remove_expired_documents()
with self._rwlock.reader():
return len(self._documents)
@property
def documents(self):
self._remove_expired_documents()
with self._rwlock.reader():
yield from self._documents.values()
def _remove_expired_documents(self):
for index in self._ttl_indexes.values():
self._expire_documents(index)
def _expire_documents(self, index):
# TODO(juannyg): use a caching mechanism to avoid re-expiring the documents if
# we just did and no document was added / updated
# Ignore non-integer values
try:
expiry = int(index['expireAfterSeconds'])
except ValueError:
return
# Ignore commpound keys
if len(index['key']) > 1:
return
# "key" structure = list of (field name, direction) tuples
ttl_field_name = next(iter(index['key']))[0]
ttl_now = mongomock.utcnow()
with self._rwlock.reader():
expired_ids = [
doc['_id'] for doc in self._documents.values()
if self._value_meets_expiry(doc.get(ttl_field_name), expiry, ttl_now)
]
for exp_id in expired_ids:
del self[exp_id]
def _value_meets_expiry(self, val, expiry, ttl_now):
val_to_compare = _get_min_datetime_from_value(val)
try:
return (ttl_now - val_to_compare).total_seconds() >= expiry
except TypeError:
return False
def _get_min_datetime_from_value(val):
if not val:
return datetime.datetime.max
if isinstance(val, list):
return functools.reduce(_min_dt, [datetime.datetime.max] + val)
return val
def _min_dt(dt1, dt2):
try:
return dt1 if dt1 < dt2 else dt2
except TypeError:
return dt1
|