1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
|
# Copyright 2020-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test the pymongo ocsp_support module."""
from __future__ import annotations
import random
import sys
from collections import namedtuple
from datetime import datetime, timedelta, timezone
from os import urandom
from time import sleep
from typing import Any
sys.path[0:0] = [""]
from test import unittest
from pymongo.ocsp_cache import _OCSPCache
class TestOcspCache(unittest.TestCase):
MockHashAlgorithm: Any
MockOcspRequest: Any
MockOcspResponse: Any
@classmethod
def setUpClass(cls):
cls.MockHashAlgorithm = namedtuple("MockHashAlgorithm", ["name"]) # type: ignore
cls.MockOcspRequest = namedtuple( # type: ignore
"MockOcspRequest",
["hash_algorithm", "issuer_name_hash", "issuer_key_hash", "serial_number"],
)
cls.MockOcspResponse = namedtuple( # type: ignore
"MockOcspResponse", ["this_update", "next_update"]
)
def setUp(self):
self.cache = _OCSPCache()
def _create_mock_request(self):
hash_algorithm = self.MockHashAlgorithm(random.choice(["sha1", "md5", "sha256"]))
issuer_name_hash = urandom(8)
issuer_key_hash = urandom(8)
serial_number = random.randint(0, 10**10)
return self.MockOcspRequest(
hash_algorithm=hash_algorithm,
issuer_name_hash=issuer_name_hash,
issuer_key_hash=issuer_key_hash,
serial_number=serial_number,
)
def _create_mock_response(self, this_update_delta_seconds, next_update_delta_seconds):
now = datetime.now(tz=timezone.utc).replace(tzinfo=None)
this_update = now + timedelta(seconds=this_update_delta_seconds)
if next_update_delta_seconds is not None:
next_update = now + timedelta(seconds=next_update_delta_seconds)
else:
next_update = None
return self.MockOcspResponse(this_update=this_update, next_update=next_update)
def _add_mock_cache_entry(self, mock_request, mock_response):
key = self.cache._get_cache_key(mock_request)
self.cache._data[key] = mock_response
def test_simple(self):
# Start with 1 valid entry in the cache.
request = self._create_mock_request()
response = self._create_mock_response(-10, +3600)
self._add_mock_cache_entry(request, response)
# Ensure entry can be retrieved.
self.assertEqual(self.cache[request], response)
# Valid entries with an earlier next_update have no effect.
response_1 = self._create_mock_response(-20, +1800)
self.cache[request] = response_1
self.assertEqual(self.cache[request], response)
# Invalid entries with a later this_update have no effect.
response_2 = self._create_mock_response(+20, +1800)
self.cache[request] = response_2
self.assertEqual(self.cache[request], response)
# Invalid entries with passed next_update have no effect.
response_3 = self._create_mock_response(-10, -5)
self.cache[request] = response_3
self.assertEqual(self.cache[request], response)
# Valid entries with a later next_update update the cache.
response_new = self._create_mock_response(-5, +7200)
self.cache[request] = response_new
self.assertEqual(self.cache[request], response_new)
# Entries with an unset next_update purge the cache.
response_notset = self._create_mock_response(-5, None)
self.cache[request] = response_notset
with self.assertRaises(KeyError):
_ = self.cache[request]
def test_invalidate(self):
# Start with 1 valid entry in the cache.
request = self._create_mock_request()
response = self._create_mock_response(-10, +0.25)
self._add_mock_cache_entry(request, response)
# Ensure entry can be retrieved.
self.assertEqual(self.cache[request], response)
# Wait for entry to become invalid and ensure KeyError is raised.
sleep(0.5)
with self.assertRaises(KeyError):
_ = self.cache[request]
def test_non_existent(self):
# Start with 1 valid entry in the cache.
request = self._create_mock_request()
response = self._create_mock_response(-10, +10)
self._add_mock_cache_entry(request, response)
# Attempt to retrieve non-existent entry must raise KeyError.
with self.assertRaises(KeyError):
_ = self.cache[self._create_mock_request()]
if __name__ == "__main__":
unittest.main()
|