1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448
|
import os
import pickle
import sqlite3
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from datetime import timedelta
from os.path import join
from sys import version_info
from tempfile import NamedTemporaryFile, gettempdir
from threading import Thread
from unittest.mock import patch
import pytest
from platformdirs import user_cache_dir
from requests_cache.backends import BaseCache, SQLiteCache, SQLiteDict
from requests_cache.backends.sqlite import MEMORY_URI
from requests_cache.models import CachedResponse
from requests_cache.policy import utcnow
from tests.conftest import N_ITERATIONS, skip_pypy
from tests.integration.base_cache_test import BaseCacheTest
from tests.integration.base_storage_test import CACHE_NAME, BaseStorageTest
class TestSQLiteDict(BaseStorageTest):
storage_class = SQLiteDict
init_kwargs = {'use_temp': True}
@classmethod
def teardown_class(cls):
try:
os.unlink(f'{CACHE_NAME}.sqlite')
except Exception:
pass
@patch('requests_cache.backends.sqlite.sqlite3')
def test_connection_kwargs(self, mock_sqlite):
"""A spot check to make sure optional connection kwargs gets passed to connection"""
cache = self.storage_class('test', use_temp=True, timeout=0.5, invalid_kwarg='???')
mock_sqlite.connect.assert_called_with(cache.db_path, timeout=0.5, check_same_thread=False)
def test_use_cache_dir(self):
relative_path = self.storage_class(CACHE_NAME).db_path
cache_dir_path = self.storage_class(CACHE_NAME, use_cache_dir=True).db_path
assert not str(relative_path).startswith(user_cache_dir())
assert str(cache_dir_path).startswith(user_cache_dir())
def test_use_temp(self):
relative_path = self.storage_class(CACHE_NAME).db_path
temp_path = self.storage_class(CACHE_NAME, use_temp=True).db_path
assert not str(relative_path).startswith(gettempdir())
assert str(temp_path).startswith(gettempdir())
def test_use_memory(self):
cache = self.init_cache(use_memory=True)
assert cache.db_path == MEMORY_URI
for i in range(20):
cache[f'key_{i}'] = f'value_{i}'
for i in range(5):
del cache[f'key_{i}']
assert len(cache) == 15
assert set(cache.keys()) == {f'key_{i}' for i in range(5, 20)}
assert set(cache.values()) == {f'value_{i}' for i in range(5, 20)}
cache.clear()
assert len(cache) == 0
def test_use_memory__uri(self):
assert self.init_cache(':memory:').db_path == ':memory:'
def test_non_dir_parent_exists(self):
"""Expect a custom error message if a parent path already exists but isn't a directory"""
with NamedTemporaryFile() as tmp:
with pytest.raises(FileExistsError) as exc_info:
self.storage_class(join(tmp.name, 'invalid_path'))
assert 'not a directory' in str(exc_info.value)
def test_bulk_commit(self):
cache = self.init_cache()
with cache.bulk_commit():
pass
n_items = 1000
with cache.bulk_commit():
for i in range(n_items):
cache[f'key_{i}'] = f'value_{i}'
assert set(cache.keys()) == {f'key_{i}' for i in range(n_items)}
assert set(cache.values()) == {f'value_{i}' for i in range(n_items)}
def test_bulk_delete__chunked(self):
"""When deleting more items than SQLite can handle in a single statement, it should be
chunked into multiple smaller statements
"""
# Populate the cache with more items than can fit in a single delete statement
cache = self.init_cache()
with cache.bulk_commit():
for i in range(2000):
cache[f'key_{i}'] = f'value_{i}'
keys = list(cache.keys())
# First pass to ensure that bulk_delete is split across three statements
with patch.object(cache, 'connection') as mock_connection:
con = mock_connection().__enter__.return_value
cache.bulk_delete(keys)
assert con.execute.call_count == 3
# Second pass to actually delete keys and make sure it doesn't explode
cache.bulk_delete(keys)
assert len(cache) == 0
def test_bulk_commit__noop(self):
def do_noop_bulk(cache):
with cache.bulk_commit():
pass
del cache
cache = self.init_cache()
thread = Thread(target=do_noop_bulk, args=(cache,))
thread.start()
thread.join()
# make sure connection is not closed by the thread
cache['key_1'] = 'value_1'
assert list(cache.keys()) == ['key_1']
def test_switch_commit(self):
cache = self.init_cache()
cache['key_1'] = 'value_1'
cache = self.init_cache(clear=False)
assert 'key_1' in cache
cache._can_commit = False
cache['key_2'] = 'value_2'
cache = self.init_cache(clear=False)
assert 2 not in cache
assert cache._can_commit is True
@skip_pypy
@pytest.mark.parametrize('kwargs', [{'busy_timeout': 5}, {'fast_save': True}, {'wal': True}])
def test_pragma(self, kwargs):
"""Test settings that make additional PRAGMA statements"""
cache_1 = self.init_cache('cache_1', **kwargs)
cache_2 = self.init_cache('cache_2', **kwargs)
n = 500
for i in range(n):
cache_1[f'key_{i}'] = f'value_{i}'
cache_2[f'key_{i*2}'] = f'value_{i}'
assert set(cache_1.keys()) == {f'key_{i}' for i in range(n)}
assert set(cache_2.values()) == {f'value_{i}' for i in range(n)}
def test_busy_timeout(self):
cache = self.init_cache(busy_timeout=5)
with cache.connection() as con:
r = con.execute('PRAGMA busy_timeout').fetchone()
assert r[0] == 5
def test_wal_sync_mode(self):
# Should default to 'NORMAL' (1)
cache = self.init_cache(wal=True)
with cache.connection() as con:
r = con.execute('PRAGMA synchronous').fetchone()
assert r[0] == 1
# Not recommended, but should still work
cache = self.init_cache(wal=True, fast_save=True)
with cache.connection() as con:
r = con.execute('PRAGMA synchronous').fetchone()
assert r[0] == 0
def test_write_retry(self):
cache = self.init_cache()
locked_error = sqlite3.OperationalError('database is locked')
with patch.object(cache, '_write', side_effect=[locked_error, 1]) as mock_write:
cache['key_1'] = 'value_1'
assert mock_write.call_count == 2
def test_write_retry__exceeded_retries(self):
cache = self.init_cache()
locked_error = sqlite3.OperationalError('database is locked')
with patch.object(cache, '_write', side_effect=locked_error) as mock_write:
cache['key_1'] = 'value_1'
assert mock_write.call_count == 3
assert 'key_1' not in cache
# Set a custom number of retries
cache = self.init_cache(retries=5)
with patch.object(cache, '_write', side_effect=locked_error) as mock_write:
cache['key_1'] = 'value_1'
assert mock_write.call_count == 5
# Set retries to 0 to disable retrying
cache = self.init_cache(retries=0)
with patch.object(cache, '_write', side_effect=locked_error) as mock_write:
with pytest.raises(sqlite3.OperationalError):
cache['key_1'] = 'value_1'
assert mock_write.call_count == 1
# Expect no change to behavior if retrying is disabled and there are no errors
cache['key_1'] = 'value_1'
assert 'key_1' in cache
def test_write_retry__other_errors(self):
"""Errors other than 'OperationalError: database is locked' should not be retried"""
cache = self.init_cache()
error_1 = sqlite3.OperationalError('no more rows available')
with patch.object(cache, '_write', side_effect=error_1):
with pytest.raises(sqlite3.OperationalError):
cache['key_1'] = 'value_1'
error_2 = sqlite3.DatabaseError('hard drive is on fire')
with patch.object(cache, '_write', side_effect=error_2):
with pytest.raises(sqlite3.DatabaseError):
cache['key_1'] = 'value_1'
@skip_pypy
@pytest.mark.parametrize('limit', [None, 50])
def test_sorted__by_size(self, limit):
cache = self.init_cache()
# Insert items with decreasing size
for i in range(100):
suffix = 'padding' * (100 - i)
cache[f'key_{i}'] = f'value_{i}_{suffix}'
# Sorted items should be in ascending order by size
items = list(cache.sorted(key='size'))
assert len(items) == limit or 100
prev_item = None
for item in items:
assert prev_item is None or len(prev_item) > len(item)
@skip_pypy
def test_sorted__reversed(self):
cache = self.init_cache()
for i in range(100):
cache[f'key_{i+1:03}'] = f'value_{i+1}'
items = list(cache.sorted(key='key', reversed=True))
assert len(items) == 100
for i, item in enumerate(items):
assert item == f'value_{100-i}'
@skip_pypy
def test_sorted__invalid_sort_key(self):
cache = self.init_cache()
cache['key_1'] = 'value_1'
with pytest.raises(ValueError):
list(cache.sorted(key='invalid_key'))
@skip_pypy
@pytest.mark.parametrize('limit', [None, 50])
def test_sorted__by_expires(self, limit):
cache = self.init_cache()
now = utcnow()
# Insert items with decreasing expiration time
for i in range(100):
response = CachedResponse(expires=now + timedelta(seconds=101 - i))
cache[f'key_{i}'] = response
# Sorted items should be in ascending order by expiration time
items = list(cache.sorted(key='expires'))
assert len(items) == limit or 100
prev_item = None
for item in items:
assert prev_item is None or prev_item.expires < item.expires
@skip_pypy
def test_sorted__exclude_expired(self):
cache = self.init_cache()
now = utcnow()
# Make only odd numbered items expired
for i in range(100):
delta = 101 - i
if i % 2 == 1:
delta -= 101
response = CachedResponse(status_code=i, expires=now + timedelta(seconds=delta))
cache[f'key_{i}'] = response
# Items should only include unexpired (even numbered) items, and still be in sorted order
items = list(cache.sorted(key='expires', expired=False))
assert len(items) == 50
prev_item = None
for item in items:
assert prev_item is None or prev_item.expires < item.expires
assert item.status_code % 2 == 0
@skip_pypy
def test_sorted__error(self):
"""sorted() should handle deserialization errors and not return invalid responses"""
class BadSerializer:
def loads(self, value):
response = pickle.loads(value)
if response.cache_key == 'key_42':
raise pickle.PickleError()
return response
def dumps(self, value):
return pickle.dumps(value)
cache = self.init_cache(serializer=BadSerializer())
for i in range(100):
response = CachedResponse(status_code=i)
response.cache_key = f'key_{i}'
cache[f'key_{i}'] = response
# Items should only include unexpired (even numbered) items, and still be in sorted order
items = list(cache.sorted())
assert len(items) == 99
@pytest.mark.parametrize(
'db_path, use_temp',
[
('filesize_test', True),
(':memory:', False),
],
)
def test_size(self, db_path, use_temp):
"""Test approximate expected size of a database, for both file-based and in-memory databases"""
cache = self.init_cache(db_path, use_temp=use_temp)
for i in range(100):
cache[f'key_{i}'] = f'value_{i}'
assert 10000 < cache.size() < 200000
class TestSQLiteCache(BaseCacheTest):
backend_class = SQLiteCache
init_kwargs = {'use_temp': True}
@classmethod
def teardown_class(cls):
try:
os.unlink(CACHE_NAME)
except Exception:
pass
@patch.object(BaseCache, 'clear', side_effect=IOError)
@patch('requests_cache.backends.sqlite.unlink', side_effect=os.unlink)
def test_clear__failure(self, mock_unlink, mock_clear):
"""When a corrupted cache prevents a normal DROP TABLE, clear() should still succeed"""
session = self.init_session(clear=False)
session.cache.responses['key_1'] = 'value_1'
session.cache.clear()
assert len(session.cache.responses) == 0
assert mock_unlink.call_count == 1
@patch.object(BaseCache, 'clear', side_effect=IOError)
def test_clear__file_already_deleted(self, mock_clear):
session = self.init_session(clear=False)
session.cache.responses['key_1'] = 'value_1'
os.unlink(session.cache.responses.db_path)
session.cache.clear()
assert len(session.cache.responses) == 0
def test_db_path(self):
"""This is just provided as an alias, since both requests and redirects share the same db
file
"""
session = self.init_session()
assert session.cache.db_path == session.cache.responses.db_path
def test_count(self):
"""count() should work the same as len(), but with the option to exclude expired responses"""
session = self.init_session()
now = utcnow()
session.cache.responses['key_1'] = CachedResponse(expires=now + timedelta(1))
session.cache.responses['key_2'] = CachedResponse(expires=now - timedelta(1))
assert session.cache.count() == 2
assert session.cache.count(expired=False) == 1
def test_delete__single_key(self):
"""Vacuum should not be used after delete if there is only a single key"""
session = self.init_session()
session.cache.responses['key_1'] = 'value_1'
with patch.object(SQLiteDict, 'vacuum') as mock_vacuum:
session.cache.delete('key_1')
mock_vacuum.assert_not_called()
def test_delete__skip_vacuum(self):
"""Vacuum should not be used after delete if disabled"""
session = self.init_session()
session.cache.responses['key_1'] = 'value_1'
session.cache.responses['key_2'] = 'value_2'
with patch.object(SQLiteDict, 'vacuum') as mock_vacuum:
session.cache.delete('key_1', 'key_2', vacuum=False)
mock_vacuum.assert_not_called()
@patch.object(SQLiteDict, 'sorted')
def test_filter__expired(self, mock_sorted):
"""Filtering by expired should use a more efficient SQL query"""
session = self.init_session()
session.cache.filter()
mock_sorted.assert_called_with(expired=True)
session.cache.filter(expired=False)
mock_sorted.assert_called_with(expired=False)
def test_sorted(self):
"""Test wrapper method for SQLiteDict.sorted(), with all arguments combined"""
session = self.init_session(clear=False)
now = utcnow()
# Insert items with decreasing expiration time
for i in range(500):
delta = 1000 - i
if i > 400:
delta -= 2000
response = CachedResponse(status_code=i, expires=now + timedelta(seconds=delta))
session.cache.responses[f'key_{i}'] = response
# Sorted items should be in ascending order by expiration time
items = list(session.cache.sorted(key='expires', expired=False, reversed=True, limit=100))
assert len(items) == 100
prev_item = None
for item in items:
assert prev_item is None or prev_item.expires < item.expires
assert item.cache_key
assert not item.is_expired
# TODO: Remove after fixing issue with SQLite multiprocessing on python 3.12
@pytest.mark.parametrize('executor_class', [ThreadPoolExecutor, ProcessPoolExecutor])
@pytest.mark.parametrize('iteration', range(N_ITERATIONS))
def test_concurrency(self, iteration, executor_class):
if version_info >= (3, 12):
pytest.xfail('Concurrent usage of SQLite backend is not yet supported on python 3.12')
super().test_concurrency(iteration, executor_class)
|