File: test_streaming_failover.py

package info (click to toggle)
python-azure 20250603%2Bgit-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 851,724 kB
  • sloc: python: 7,362,925; ansic: 804; javascript: 287; makefile: 195; sh: 145; xml: 109
file content (193 lines) | stat: -rw-r--r-- 10,698 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
# The MIT License (MIT)
# Copyright (c) Microsoft Corporation. All rights reserved.

import unittest
import uuid

import pytest

import azure.cosmos._cosmos_client_connection as cosmos_client_connection
import azure.cosmos._global_endpoint_manager as global_endpoint_manager
import azure.cosmos.documents as documents
import azure.cosmos.exceptions as exceptions
import azure.cosmos.http_constants as http_constants
import test_config
from azure.cosmos import _endpoint_discovery_retry_policy
from azure.cosmos import _retry_utility
from azure.cosmos import cosmos_client, PartitionKey
from azure.cosmos._request_object import RequestObject
from azure.cosmos.http_constants import HttpHeaders, StatusCodes, SubStatusCodes


@pytest.mark.cosmosEmulator
@pytest.mark.skip
class TestStreamingFailOver(unittest.TestCase):
    DEFAULT_ENDPOINT = "https://geotest.documents.azure.com:443/"
    MASTER_KEY = "SomeKeyValue"
    WRITE_ENDPOINT1 = "https://geotest-WestUS.documents.azure.com:443/"
    WRITE_ENDPOINT2 = "https://geotest-CentralUS.documents.azure.com:443/"
    READ_ENDPOINT1 = "https://geotest-SouthCentralUS.documents.azure.com:443/"
    READ_ENDPOINT2 = "https://geotest-EastUS.documents.azure.com:443/"
    WRITE_ENDPOINT_NAME1 = "West US"
    WRITE_ENDPOINT_NAME2 = "Central US"
    READ_ENDPOINT_NAME1 = "South Central US"
    READ_ENDPOINT_NAME2 = "East US"
    preferred_regional_endpoints = [READ_ENDPOINT_NAME1, READ_ENDPOINT_NAME2]
    counter = 0
    endpoint_sequence = []

    def test_streaming_fail_over(self):
        self.OriginalExecuteFunction = _retry_utility.ExecuteFunction
        _retry_utility.ExecuteFunction = self._MockExecuteFunctionEndpointDiscover
        connection_policy = documents.ConnectionPolicy()
        connection_policy.PreferredLocations = self.preferred_regional_endpoints
        connection_policy.DisableSSLVerification = True

        client = cosmos_client.CosmosClient(self.DEFAULT_ENDPOINT, self.MASTER_KEY,
                                            consistency_level=documents.ConsistencyLevel.Eventual,
                                            connection_policy=connection_policy)
        self.original_get_database_account = client.client_connection.GetDatabaseAccount
        self.original_get_read_endpoints = (client.client_connection._global_endpoint_manager.location_cache
                                            .get_read_regional_routing_contexts())
        self.original_get_write_endpoints = (client.client_connection._global_endpoint_manager.location_cache
                                             .get_write_regional_routing_contexts())
        client.client_connection.GetDatabaseAccount = self.mock_get_database_account
        client.client_connection._global_endpoint_manager.location_cache.get_read_regional_routing_contexts = (
            self.mock_get_read_endpoints)
        client.client_connection._global_endpoint_manager.location_cache.get_write_regional_routing_contexts = (
            self.mock_get_write_endpoints)
        created_db = client.create_database_if_not_exists("streaming-db" + str(uuid.uuid4()))
        created_container = created_db.create_container("streaming-container" + str(uuid.uuid4()),
                                                        PartitionKey(path="/id"))

        document_definition = {'id': 'doc',
                               'name': 'sample document',
                               'key': 'value'}

        created_document = created_container.create_item(document_definition)

        self.assertDictEqual(created_document, {})
        self.assertDictEqual(created_document.get_response_headers(), {})

        self.assertEqual(self.counter, 10)
        # First request is an initial read collection.
        # Next 6 requests hit forbidden write exceptions and the endpoint retry policy keeps
        # flipping the resolved endpoint between the 2 write endpoints.
        # The 10th request returns the actual read document.
        for i in range(0, 6):
            if i % 2 == 0:
                self.assertEqual(self.endpoint_sequence[i], self.WRITE_ENDPOINT1)
            else:
                self.assertEqual(self.endpoint_sequence[i], self.WRITE_ENDPOINT2)

        cosmos_client_connection.CosmosClientConnection.GetDatabaseAccount = self.original_get_database_account
        _retry_utility.ExecuteFunction = self.OriginalExecuteFunction
        client.client_connection._global_endpoint_manager.location_cache.get_read_regional_routing_contexts = (
            self.original_get_read_endpoints)
        client.client_connection._global_endpoint_manager.location_cache.get_write_regional_routing_contexts = (
            self.original_get_write_endpoints)

    def mock_get_database_account(self, url_connection=None):
        database_account = documents.DatabaseAccount()
        database_account._EnableMultipleWritableLocations = True
        database_account._WritableLocations = [
            {'name': self.WRITE_ENDPOINT_NAME1, 'databaseAccountEndpoint': self.WRITE_ENDPOINT1},
            {'name': self.WRITE_ENDPOINT_NAME2, 'databaseAccountEndpoint': self.WRITE_ENDPOINT2}
        ]
        database_account._ReadableLocations = [
            {'name': self.READ_ENDPOINT_NAME1, 'databaseAccountEndpoint': self.READ_ENDPOINT1},
            {'name': self.READ_ENDPOINT_NAME2, 'databaseAccountEndpoint': self.READ_ENDPOINT2}
        ]
        return database_account

    def mock_get_read_endpoints(self):
        return [
            {'name': self.READ_ENDPOINT_NAME1, 'databaseAccountEndpoint': self.READ_ENDPOINT1},
            {'name': self.READ_ENDPOINT_NAME2, 'databaseAccountEndpoint': self.READ_ENDPOINT2}
        ]

    def mock_get_write_endpoints(self):
        return [
            {'name': self.WRITE_ENDPOINT_NAME1, 'databaseAccountEndpoint': self.WRITE_ENDPOINT1},
            {'name': self.WRITE_ENDPOINT_NAME2, 'databaseAccountEndpoint': self.WRITE_ENDPOINT2}
        ]

    def _MockExecuteFunctionEndpointDiscover(self, function, *args, **kwargs):
        self.counter += 1
        if self.counter >= 10 or (len(args) > 0 and args[1].operation_type == documents._OperationType.Read):
            return {}, {}
        else:
            self.endpoint_sequence.append(args[1].location_endpoint_to_route)
            response = test_config.FakeResponse({HttpHeaders.SubStatus: SubStatusCodes.WRITE_FORBIDDEN})
            raise exceptions.CosmosHttpResponseError(
                status_code=StatusCodes.FORBIDDEN,
                message="Request is not permitted in this region",
                response=response)

    def test_retry_policy_does_not_mark_null_locations_unavailable(self):
        self.OriginalExecuteFunction = _retry_utility.ExecuteFunction
        _retry_utility.ExecuteFunction = self._MockExecuteFunctionEndpointDiscover
        connection_policy = documents.ConnectionPolicy()
        connection_policy.PreferredLocations = self.preferred_regional_endpoints
        connection_policy.DisableSSLVerification = True

        client = cosmos_client.CosmosClient(self.DEFAULT_ENDPOINT, self.MASTER_KEY,
                                            consistency_level=documents.ConsistencyLevel.Eventual,
                                            connection_policy=connection_policy)
        self.original_get_database_account = client.client_connection.GetDatabaseAccount
        client.client_connection.GetDatabaseAccount = self.mock_get_database_account

        endpoint_manager = global_endpoint_manager._GlobalEndpointManager(client.client_connection)

        self.original_mark_endpoint_unavailable_for_read_function = endpoint_manager.mark_endpoint_unavailable_for_read
        endpoint_manager.mark_endpoint_unavailable_for_read = self._mock_mark_endpoint_unavailable_for_read
        self.original_mark_endpoint_unavailable_for_write_function = endpoint_manager.mark_endpoint_unavailable_for_write
        endpoint_manager.mark_endpoint_unavailable_for_write = self._mock_mark_endpoint_unavailable_for_write
        self.original_resolve_service_endpoint = endpoint_manager.resolve_service_endpoint_for_partition
        endpoint_manager.resolve_service_endpoint_for_partition = self._mock_resolve_service_endpoint

        # Read and write counters count the number of times the endpoint manager's
        # mark_endpoint_unavailable_for_read() and mark_endpoint_unavailable_for_read() 
        # functions were called. When a 'None' location is returned by resolve_service_endpoint(),
        # these functions  should not be called
        self._read_counter = 0
        self._write_counter = 0
        request = RequestObject(http_constants.ResourceType.Document, documents._OperationType.Read)
        endpoint_discovery_retry_policy = _endpoint_discovery_retry_policy.EndpointDiscoveryRetryPolicy(
            documents.ConnectionPolicy(), endpoint_manager, request)
        endpoint_discovery_retry_policy.ShouldRetry(exceptions.CosmosHttpResponseError(
            status_code=http_constants.StatusCodes.FORBIDDEN))
        self.assertEqual(self._read_counter, 0)
        self.assertEqual(self._write_counter, 0)

        self._read_counter = 0
        self._write_counter = 0
        request = RequestObject(http_constants.ResourceType.Document, documents._OperationType.Create)
        endpoint_discovery_retry_policy = _endpoint_discovery_retry_policy.EndpointDiscoveryRetryPolicy(
            documents.ConnectionPolicy(), endpoint_manager, request)
        endpoint_discovery_retry_policy.ShouldRetry(exceptions.CosmosHttpResponseError(
            status_code=http_constants.StatusCodes.FORBIDDEN))
        self.assertEqual(self._read_counter, 0)
        self.assertEqual(self._write_counter, 0)

        endpoint_manager.mark_endpoint_unavailable_for_read = (self
                                                               .original_mark_endpoint_unavailable_for_read_function)
        endpoint_manager.mark_endpoint_unavailable_for_write = (self.
                                                                original_mark_endpoint_unavailable_for_write_function)
        cosmos_client_connection.CosmosClientConnection.GetDatabaseAccount = self.original_get_database_account

    def _mock_mark_endpoint_unavailable_for_read(self, endpoint):
        self._read_counter += 1
        self.original_mark_endpoint_unavailable_for_read_function(endpoint)

    def _mock_mark_endpoint_unavailable_for_write(self, endpoint):
        self._write_counter += 1
        self.original_mark_endpoint_unavailable_for_write_function(endpoint)

    @staticmethod
    def _mock_resolve_service_endpoint(request):
        return None


if __name__ == '__main__':
    unittest.main()