File: test_streaming_failover.py

package info (click to toggle)
python-azure 20201208%2Bgit-6
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 1,437,920 kB
  • sloc: python: 4,287,452; javascript: 269; makefile: 198; sh: 187; xml: 106
file content (143 lines) | stat: -rw-r--r-- 8,157 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import unittest
import azure.cosmos._cosmos_client_connection as cosmos_client_connection
import pytest
import azure.cosmos.documents as documents
import azure.cosmos.exceptions as exceptions
import test_config
from azure.cosmos.http_constants import HttpHeaders, StatusCodes, SubStatusCodes
from azure.cosmos import _retry_utility
from azure.cosmos import _endpoint_discovery_retry_policy
from azure.cosmos._request_object import RequestObject
import azure.cosmos._global_endpoint_manager as global_endpoint_manager
import azure.cosmos.http_constants as http_constants

pytestmark = pytest.mark.cosmosEmulator

@pytest.mark.usefixtures("teardown")
class TestStreamingFailover(unittest.TestCase):

    DEFAULT_ENDPOINT = "https://geotest.documents.azure.com:443/"
    MASTER_KEY = "SomeKeyValue"
    WRITE_ENDPOINT1 = "https://geotest-WestUS.documents.azure.com:443/"
    WRITE_ENDPOINT2 = "https://geotest-CentralUS.documents.azure.com:443/"
    READ_ENDPOINT1 = "https://geotest-SouthCentralUS.documents.azure.com:443/"
    READ_ENDPOINT2 = "https://geotest-EastUS.documents.azure.com:443/"
    WRITE_ENDPOINT_NAME1 = "West US"
    WRITE_ENDPOINT_NAME2 = "Central US"
    READ_ENDPOINT_NAME1 = "South Central US"
    READ_ENDPOINT_NAME2 = "East US"
    preferred_regional_endpoints = [READ_ENDPOINT_NAME1, READ_ENDPOINT_NAME2]
    counter = 0
    endpoint_sequence = []

    def test_streaming_failover(self):
        self.OriginalExecuteFunction = _retry_utility.ExecuteFunction
        _retry_utility.ExecuteFunction = self._MockExecuteFunctionEndpointDiscover
        connection_policy = documents.ConnectionPolicy()
        connection_policy.PreferredLocations = self.preferred_regional_endpoints
        connection_policy.DisableSSLVerification = True
        self.original_get_database_account = cosmos_client_connection.CosmosClientConnection.GetDatabaseAccount
        cosmos_client_connection.CosmosClientConnection.GetDatabaseAccount = self.mock_get_database_account

        client = cosmos_client_connection.CosmosClientConnection(self.DEFAULT_ENDPOINT, {'masterKey': self.MASTER_KEY}, connection_policy, documents.ConsistencyLevel.Eventual)

        document_definition = { 'id': 'doc',
                                'name': 'sample document',
                                'key': 'value'} 

        created_document = {}
        created_document = client.CreateItem("dbs/mydb/colls/mycoll", document_definition)
        
        self.assertDictEqual(created_document, {})
        self.assertDictEqual(client.last_response_headers, {})

        self.assertEqual(self.counter, 10)
        # First request is an initial read collection.
        # Next 8 requests hit forbidden write exceptions and the endpoint retry policy keeps 
        # flipping the resolved endpoint between the 2 write endpoints.
        # The 10th request returns the actual read document.
        for i in range(0,8):
            if i % 2 == 0:
                self.assertEqual(self.endpoint_sequence[i], self.WRITE_ENDPOINT1)
            else:
                self.assertEqual(self.endpoint_sequence[i], self.WRITE_ENDPOINT2)

        cosmos_client_connection.CosmosClientConnection.GetDatabaseAccount = self.original_get_database_account
        _retry_utility.ExecuteFunction = self.OriginalExecuteFunction

    def mock_get_database_account(self, url_connection = None):
        database_account = documents.DatabaseAccount()
        database_account._EnableMultipleWritableLocations = True
        database_account._WritableLocations = [
                    {'name': self.WRITE_ENDPOINT_NAME1, 'databaseAccountEndpoint': self.WRITE_ENDPOINT1},
                    {'name': self.WRITE_ENDPOINT_NAME2, 'databaseAccountEndpoint': self.WRITE_ENDPOINT2}
                    ]
        database_account._ReadableLocations = [
                    {'name': self.READ_ENDPOINT_NAME1, 'databaseAccountEndpoint': self.READ_ENDPOINT1},
                    {'name': self.READ_ENDPOINT_NAME2, 'databaseAccountEndpoint': self.READ_ENDPOINT2}
                    ]
        return database_account

    def _MockExecuteFunctionEndpointDiscover(self, function, *args, **kwargs):
        self.counter += 1
        if self.counter >= 10 or ( len(args) > 0 and args[1].operation_type == documents._OperationType.Read):
            return ({}, {})
        else:
            self.endpoint_sequence.append(args[1].location_endpoint_to_route)
            response = test_config.FakeResponse({HttpHeaders.SubStatus: SubStatusCodes.WRITE_FORBIDDEN})
            raise exceptions.CosmosHttpResponseError(
                status_code=StatusCodes.FORBIDDEN,
                message="Request is not permitted in this region",
                response=response)

    def test_retry_policy_does_not_mark_null_locations_unavailable(self):
        self.original_get_database_account = cosmos_client_connection.CosmosClientConnection.GetDatabaseAccount
        cosmos_client_connection.CosmosClientConnection.GetDatabaseAccount = self.mock_get_database_account

        client = cosmos_client_connection.CosmosClientConnection(self.DEFAULT_ENDPOINT, {'masterKey': self.MASTER_KEY}, None, documents.ConsistencyLevel.Eventual)
        endpoint_manager = global_endpoint_manager._GlobalEndpointManager(client)

        self.original_mark_endpoint_unavailable_for_read_function = endpoint_manager.mark_endpoint_unavailable_for_read
        endpoint_manager.mark_endpoint_unavailable_for_read = self._mock_mark_endpoint_unavailable_for_read
        self.original_mark_endpoint_unavailable_for_write_function = endpoint_manager.mark_endpoint_unavailable_for_write
        endpoint_manager.mark_endpoint_unavailable_for_write = self._mock_mark_endpoint_unavailable_for_write
        self.original_resolve_service_endpoint = endpoint_manager.resolve_service_endpoint
        endpoint_manager.resolve_service_endpoint = self._mock_resolve_service_endpoint

        # Read and write counters count the number of times the endpoint manager's
        # mark_endpoint_unavailable_for_read() and mark_endpoint_unavailable_for_read() 
        # functions were called. When a 'None' location is returned by resolve_service_endpoint(),
        # these functions  should not be called
        self._read_counter = 0
        self._write_counter = 0
        request = RequestObject(http_constants.ResourceType.Document, documents._OperationType.Read)
        endpointDiscovery_retry_policy = _endpoint_discovery_retry_policy.EndpointDiscoveryRetryPolicy(documents.ConnectionPolicy(), endpoint_manager, request)
        endpointDiscovery_retry_policy.ShouldRetry(exceptions.CosmosHttpResponseError(
            status_code=http_constants.StatusCodes.FORBIDDEN))
        self.assertEqual(self._read_counter, 0)
        self.assertEqual(self._write_counter, 0)

        self._read_counter = 0
        self._write_counter = 0
        request = RequestObject(http_constants.ResourceType.Document, documents._OperationType.Create)
        endpointDiscovery_retry_policy = _endpoint_discovery_retry_policy.EndpointDiscoveryRetryPolicy(documents.ConnectionPolicy(), endpoint_manager, request)
        endpointDiscovery_retry_policy.ShouldRetry(exceptions.CosmosHttpResponseError(
            status_code=http_constants.StatusCodes.FORBIDDEN))
        self.assertEqual(self._read_counter, 0)
        self.assertEqual(self._write_counter, 0)

        endpoint_manager.mark_endpoint_unavailable_for_read = self.original_mark_endpoint_unavailable_for_read_function
        endpoint_manager.mark_endpoint_unavailable_for_write = self.original_mark_endpoint_unavailable_for_write_function
        cosmos_client_connection.CosmosClientConnection.GetDatabaseAccount = self.original_get_database_account

    def _mock_mark_endpoint_unavailable_for_read(self, endpoint):
        self._read_counter += 1
        self.original_mark_endpoint_unavailable_for_read_function(endpoint)

    def _mock_mark_endpoint_unavailable_for_write(self, endpoint):
        self._write_counter += 1
        self.original_mark_endpoint_unavailable_for_write_function(endpoint)

    def _mock_resolve_service_endpoint(self, request):
        return None