File: streaming_failover_test.py

package info (click to toggle)
azure-cosmos-python 3.1.1-5
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 1,280 kB
  • sloc: python: 11,653; makefile: 155
file content (135 lines) | stat: -rw-r--r-- 7,753 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import unittest
import pytest
import azure.cosmos.cosmos_client as cosmos_client
import azure.cosmos.documents as documents
import azure.cosmos.errors as errors
from requests.exceptions import ConnectionError
from azure.cosmos.http_constants import HttpHeaders, StatusCodes, SubStatusCodes
import azure.cosmos.retry_utility as retry_utility
import azure.cosmos.endpoint_discovery_retry_policy as endpoint_discovery_retry_policy
from azure.cosmos.request_object import _RequestObject
import azure.cosmos.global_endpoint_manager as global_endpoint_manager
import azure.cosmos.http_constants as http_constants

@pytest.mark.usefixtures("teardown")
class TestStreamingFailover(unittest.TestCase):

    DEFAULT_ENDPOINT = "https://geotest.documents.azure.com:443/"
    MASTER_KEY = "SomeKeyValue"
    WRITE_ENDPOINT1 = "https://geotest-WestUS.documents.azure.com:443/"
    WRITE_ENDPOINT2 = "https://geotest-CentralUS.documents.azure.com:443/"
    READ_ENDPOINT1 = "https://geotest-SouthCentralUS.documents.azure.com:443/"
    READ_ENDPOINT2 = "https://geotest-EastUS.documents.azure.com:443/"
    WRITE_ENDPOINT_NAME1 = "West US"
    WRITE_ENDPOINT_NAME2 = "Central US"
    READ_ENDPOINT_NAME1 = "South Central US"
    READ_ENDPOINT_NAME2 = "East US"
    preferred_regional_endpoints = [READ_ENDPOINT_NAME1, READ_ENDPOINT_NAME2]
    counter = 0
    endpoint_sequence = []

    def test_streaming_failover(self):
        self.OriginalExecuteFunction = retry_utility._ExecuteFunction
        retry_utility._ExecuteFunction = self._MockExecuteFunctionEndpointDiscover
        connection_policy = documents.ConnectionPolicy()
        connection_policy.PreferredLocations = self.preferred_regional_endpoints
        connection_policy.DisableSSLVerification = True
        self.original_get_database_account = cosmos_client.CosmosClient.GetDatabaseAccount
        cosmos_client.CosmosClient.GetDatabaseAccount = self.mock_get_database_account

        client = cosmos_client.CosmosClient(self.DEFAULT_ENDPOINT, {'masterKey': self.MASTER_KEY}, connection_policy, documents.ConsistencyLevel.Eventual)

        document_definition = { 'id': 'doc',
                                'name': 'sample document',
                                'key': 'value'} 

        created_document = {}
        created_document = client.CreateItem("dbs/mydb/colls/mycoll", document_definition)
        
        self.assertDictEqual(created_document, {})
        self.assertDictEqual(client.last_response_headers, {})

        self.assertEqual(self.counter, 10)
        # First request is an initial read collection.
        # Next 8 requests hit forbidden write exceptions and the endpoint retry policy keeps 
        # flipping the resolved endpoint between the 2 write endpoints.
        # The 10th request returns the actual read document.
        for i in range(0,8):
            if i % 2 == 0:
                self.assertEqual(self.endpoint_sequence[i], self.WRITE_ENDPOINT1)
            else:
                self.assertEqual(self.endpoint_sequence[i], self.WRITE_ENDPOINT2)

        cosmos_client.CosmosClient.GetDatabaseAccount = self.original_get_database_account
        retry_utility._ExecuteFunction = self.OriginalExecuteFunction

    def mock_get_database_account(self, url_connection = None):
        database_account = documents.DatabaseAccount()
        database_account._EnableMultipleWritableLocations = True
        database_account._WritableLocations = [
                    {'name': self.WRITE_ENDPOINT_NAME1, 'databaseAccountEndpoint': self.WRITE_ENDPOINT1},
                    {'name': self.WRITE_ENDPOINT_NAME2, 'databaseAccountEndpoint': self.WRITE_ENDPOINT2}
                    ]
        database_account._ReadableLocations = [
                    {'name': self.READ_ENDPOINT_NAME1, 'databaseAccountEndpoint': self.READ_ENDPOINT1},
                    {'name': self.READ_ENDPOINT_NAME2, 'databaseAccountEndpoint': self.READ_ENDPOINT2}
                    ]
        return database_account

    def _MockExecuteFunctionEndpointDiscover(self, function, *args, **kwargs):
        self.counter += 1
        if self.counter >= 10 or ( len(args) > 0 and args[1].operation_type == documents._OperationType.Read):
            return ({}, {})
        else:
            self.endpoint_sequence.append(args[1].location_endpoint_to_route)
            raise errors.HTTPFailure(StatusCodes.FORBIDDEN, "Request is not permitted in this region", {HttpHeaders.SubStatus: SubStatusCodes.WRITE_FORBIDDEN})

    def test_retry_policy_does_not_mark_null_locations_unavailable(self):
        self.original_get_database_account = cosmos_client.CosmosClient.GetDatabaseAccount
        cosmos_client.CosmosClient.GetDatabaseAccount = self.mock_get_database_account

        client = cosmos_client.CosmosClient(self.DEFAULT_ENDPOINT, {'masterKey': self.MASTER_KEY}, None, documents.ConsistencyLevel.Eventual)
        endpoint_manager = global_endpoint_manager._GlobalEndpointManager(client)

        self.original_mark_endpoint_unavailable_for_read_function = endpoint_manager.mark_endpoint_unavailable_for_read
        endpoint_manager.mark_endpoint_unavailable_for_read = self._mock_mark_endpoint_unavailable_for_read
        self.original_mark_endpoint_unavailable_for_write_function = endpoint_manager.mark_endpoint_unavailable_for_write
        endpoint_manager.mark_endpoint_unavailable_for_write = self._mock_mark_endpoint_unavailable_for_write
        self.original_resolve_service_endpoint = endpoint_manager.resolve_service_endpoint
        endpoint_manager.resolve_service_endpoint = self._mock_resolve_service_endpoint

        # Read and write counters count the number of times the endpoint manager's
        # mark_endpoint_unavailable_for_read() and mark_endpoint_unavailable_for_read() 
        # functions were called. When a 'None' location is returned by resolve_service_endpoint(),
        # these functions  should not be called
        self._read_counter = 0
        self._write_counter = 0
        request = _RequestObject(http_constants.ResourceType.Document, documents._OperationType.Read)
        endpointDiscovery_retry_policy = endpoint_discovery_retry_policy._EndpointDiscoveryRetryPolicy(documents.ConnectionPolicy(), endpoint_manager, request)
        endpointDiscovery_retry_policy.ShouldRetry(errors.HTTPFailure(http_constants.StatusCodes.FORBIDDEN))
        self.assertEqual(self._read_counter, 0)
        self.assertEqual(self._write_counter, 0)

        self._read_counter = 0
        self._write_counter = 0
        request = _RequestObject(http_constants.ResourceType.Document, documents._OperationType.Create)
        endpointDiscovery_retry_policy = endpoint_discovery_retry_policy._EndpointDiscoveryRetryPolicy(documents.ConnectionPolicy(), endpoint_manager, request)
        endpointDiscovery_retry_policy.ShouldRetry(errors.HTTPFailure(http_constants.StatusCodes.FORBIDDEN))
        self.assertEqual(self._read_counter, 0)
        self.assertEqual(self._write_counter, 0)

        endpoint_manager.mark_endpoint_unavailable_for_read = self.original_mark_endpoint_unavailable_for_read_function
        endpoint_manager.mark_endpoint_unavailable_for_write = self.original_mark_endpoint_unavailable_for_write_function
        cosmos_client.CosmosClient.GetDatabaseAccount = self.original_get_database_account

    def _mock_mark_endpoint_unavailable_for_read(self, endpoint):
        self._read_counter += 1
        self.original_mark_endpoint_unavailable_for_read_function(endpoint)

    def _mock_mark_endpoint_unavailable_for_write(self, endpoint):
        self._write_counter += 1
        self.original_mark_endpoint_unavailable_for_write_function(endpoint)

    def _mock_resolve_service_endpoint(self, request):
        return None