File: test_ocsp_cache.py

package info (click to toggle)
pymongo 4.15.5-1
  • links: PTS, VCS
  • area: main
  • in suites: forky
  • size: 23,376 kB
  • sloc: python: 107,945; ansic: 4,601; javascript: 137; makefile: 38; sh: 10
file content (138 lines) | stat: -rw-r--r-- 5,145 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# Copyright 2020-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Test the pymongo ocsp_support module."""
from __future__ import annotations

import random
import sys
from collections import namedtuple
from datetime import datetime, timedelta, timezone
from os import urandom
from time import sleep
from typing import Any

sys.path[0:0] = [""]

from test import unittest

from pymongo.ocsp_cache import _OCSPCache


class TestOcspCache(unittest.TestCase):
    MockHashAlgorithm: Any
    MockOcspRequest: Any
    MockOcspResponse: Any

    @classmethod
    def setUpClass(cls):
        cls.MockHashAlgorithm = namedtuple("MockHashAlgorithm", ["name"])  # type: ignore
        cls.MockOcspRequest = namedtuple(  # type: ignore
            "MockOcspRequest",
            ["hash_algorithm", "issuer_name_hash", "issuer_key_hash", "serial_number"],
        )
        cls.MockOcspResponse = namedtuple(  # type: ignore
            "MockOcspResponse", ["this_update", "next_update"]
        )

    def setUp(self):
        self.cache = _OCSPCache()

    def _create_mock_request(self):
        hash_algorithm = self.MockHashAlgorithm(random.choice(["sha1", "md5", "sha256"]))
        issuer_name_hash = urandom(8)
        issuer_key_hash = urandom(8)
        serial_number = random.randint(0, 10**10)
        return self.MockOcspRequest(
            hash_algorithm=hash_algorithm,
            issuer_name_hash=issuer_name_hash,
            issuer_key_hash=issuer_key_hash,
            serial_number=serial_number,
        )

    def _create_mock_response(self, this_update_delta_seconds, next_update_delta_seconds):
        now = datetime.now(tz=timezone.utc).replace(tzinfo=None)
        this_update = now + timedelta(seconds=this_update_delta_seconds)
        if next_update_delta_seconds is not None:
            next_update = now + timedelta(seconds=next_update_delta_seconds)
        else:
            next_update = None
        return self.MockOcspResponse(this_update=this_update, next_update=next_update)

    def _add_mock_cache_entry(self, mock_request, mock_response):
        key = self.cache._get_cache_key(mock_request)
        self.cache._data[key] = mock_response

    def test_simple(self):
        # Start with 1 valid entry in the cache.
        request = self._create_mock_request()
        response = self._create_mock_response(-10, +3600)
        self._add_mock_cache_entry(request, response)

        # Ensure entry can be retrieved.
        self.assertEqual(self.cache[request], response)

        # Valid entries with an earlier next_update have no effect.
        response_1 = self._create_mock_response(-20, +1800)
        self.cache[request] = response_1
        self.assertEqual(self.cache[request], response)

        # Invalid entries with a later this_update have no effect.
        response_2 = self._create_mock_response(+20, +1800)
        self.cache[request] = response_2
        self.assertEqual(self.cache[request], response)

        # Invalid entries with passed next_update have no effect.
        response_3 = self._create_mock_response(-10, -5)
        self.cache[request] = response_3
        self.assertEqual(self.cache[request], response)

        # Valid entries with a later next_update update the cache.
        response_new = self._create_mock_response(-5, +7200)
        self.cache[request] = response_new
        self.assertEqual(self.cache[request], response_new)

        # Entries with an unset next_update purge the cache.
        response_notset = self._create_mock_response(-5, None)
        self.cache[request] = response_notset
        with self.assertRaises(KeyError):
            _ = self.cache[request]

    def test_invalidate(self):
        # Start with 1 valid entry in the cache.
        request = self._create_mock_request()
        response = self._create_mock_response(-10, +0.25)
        self._add_mock_cache_entry(request, response)

        # Ensure entry can be retrieved.
        self.assertEqual(self.cache[request], response)

        # Wait for entry to become invalid and ensure KeyError is raised.
        sleep(0.5)
        with self.assertRaises(KeyError):
            _ = self.cache[request]

    def test_non_existent(self):
        # Start with 1 valid entry in the cache.
        request = self._create_mock_request()
        response = self._create_mock_response(-10, +10)
        self._add_mock_cache_entry(request, response)

        # Attempt to retrieve non-existent entry must raise KeyError.
        with self.assertRaises(KeyError):
            _ = self.cache[self._create_mock_request()]


if __name__ == "__main__":
    unittest.main()