File: test_srv_polling.py

package info (click to toggle)
pymongo 4.15.3-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 23,692 kB
  • sloc: python: 107,407; ansic: 4,601; javascript: 137; makefile: 30; sh: 10
file content (92 lines) | stat: -rw-r--r-- 3,096 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# Copyright 2019-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Run the SRV support tests."""
from __future__ import annotations

import asyncio
import sys
import time
from test.asynchronous.utils import flaky
from test.utils_shared import FunctionCallRecorder
from typing import Any

sys.path[0:0] = [""]

from test.asynchronous import AsyncPyMongoTestCase, client_knobs, unittest
from test.asynchronous.utils import async_wait_until

import pymongo
from pymongo import common
from pymongo.asynchronous.srv_resolver import _have_dnspython
from pymongo.errors import ConfigurationError

_IS_SYNC = False

WAIT_TIME = 0.1


class SrvPollingKnobs:
    def __init__(
        self,
        ttl_time=None,
        min_srv_rescan_interval=None,
        nodelist_callback=None,
        count_resolver_calls=False,
    ):
        self.ttl_time = ttl_time
        self.min_srv_rescan_interval = min_srv_rescan_interval
        self.nodelist_callback = nodelist_callback
        self.count_resolver_calls = count_resolver_calls

        self.old_min_srv_rescan_interval = None
        self.old_dns_resolver_response = None

    def enable(self):
        self.old_min_srv_rescan_interval = common.MIN_SRV_RESCAN_INTERVAL
        self.old_dns_resolver_response = (
            pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl
        )

        if self.min_srv_rescan_interval is not None:
            common.MIN_SRV_RESCAN_INTERVAL = self.min_srv_rescan_interval

        async def mock_get_hosts_and_min_ttl(resolver, *args):
            assert self.old_dns_resolver_response is not None
            nodes, ttl = await self.old_dns_resolver_response(resolver)
            if self.nodelist_callback is not None:
                nodes = self.nodelist_callback()
            if self.ttl_time is not None:
                ttl = self.ttl_time
            return nodes, ttl

        patch_func: Any
        if self.count_resolver_calls:
            patch_func = FunctionCallRecorder(mock_get_hosts_and_min_ttl)
        else:
            patch_func = mock_get_hosts_and_min_ttl

        pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl = patch_func  # type: ignore

    def __enter__(self):
        self.enable()

    def disable(self):
        common.MIN_SRV_RESCAN_INTERVAL = self.old_min_srv_rescan_interval  # type: ignore
        pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl = (  # type: ignore
            self.old_dns_resolver_response
        )

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.disable()