| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 
 | # The MIT License (MIT)
# Copyright (c) Microsoft Corporation. All rights reserved.
import uuid
from typing import List
import pytest
from azure.core.exceptions import ServiceRequestError
import test_config
from azure.cosmos import DatabaseAccount, _location_cache, CosmosClient, _global_endpoint_manager, \
    _cosmos_client_connection
from azure.cosmos._location_cache import RegionalRoutingContext
from _fault_injection_transport import FaultInjectionTransport
from azure.cosmos.exceptions import CosmosHttpResponseError
COLLECTION = "created_collection"
REGION_1 = test_config.TestConfig.WRITE_LOCATION
REGION_2 = test_config.TestConfig.READ_LOCATION
REGION_3 = "West US 2"
ACCOUNT_REGIONS = [REGION_1, REGION_2, REGION_3]
@pytest.fixture()
def setup():
    if (TestPreferredLocations.master_key == '[YOUR_KEY_HERE]' or
            TestPreferredLocations.host == '[YOUR_ENDPOINT_HERE]'):
        raise Exception(
            "You must specify your Azure Cosmos account values for "
            "'masterKey' and 'host' at the top of this class to run the "
            "tests.")
    client = CosmosClient(TestPreferredLocations.host, TestPreferredLocations.master_key, consistency_level="Session")
    created_database = client.get_database_client(TestPreferredLocations.TEST_DATABASE_ID)
    created_collection = created_database.get_container_client(TestPreferredLocations.TEST_CONTAINER_SINGLE_PARTITION_ID)
    yield {
        COLLECTION: created_collection
    }
def preferred_locations():
    host = test_config.TestConfig.host
    locational_endpoint = _location_cache.LocationCache.GetLocationalEndpoint(host, REGION_2)
    return [
        ([], host),
        ([REGION_1, REGION_2], host),
        ([REGION_1], host),
        ([REGION_2, REGION_3], host),
        ([REGION_1, REGION_2, REGION_3], host),
        ([], locational_endpoint),
        ([REGION_2], locational_endpoint),
        ([REGION_3, REGION_1], locational_endpoint),
        ([REGION_1, REGION_3], locational_endpoint),
        ([REGION_1, REGION_2, REGION_3], locational_endpoint)
    ]
def construct_item():
    return {
        "id": "test_item_no_preferred_locations" + str(uuid.uuid4()),
        test_config.TestConfig.TEST_CONTAINER_PARTITION_KEY: str(uuid.uuid4())
    }
def error():
    status_codes = [503, 408, 404]
    sub_status = [0, 0, 1002]
    errors = []
    for i, status_code in enumerate(status_codes):
        errors.append(CosmosHttpResponseError(
            status_code=status_code,
            message=f"Error with status code {status_code} and substatus {sub_status[i]}",
            sub_status=sub_status[i]
        ))
    return errors
@pytest.mark.unittest
@pytest.mark.usefixtures("setup")
class TestPreferredLocations:
    host = test_config.TestConfig.host
    master_key = test_config.TestConfig.masterKey
    TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID
    TEST_CONTAINER_SINGLE_PARTITION_ID = test_config.TestConfig.TEST_SINGLE_PARTITION_CONTAINER_ID
    partition_key = test_config.TestConfig.TEST_CONTAINER_PARTITION_KEY
    def setup_method_with_custom_transport(self, custom_transport, error_lambda, default_endpoint=host, **kwargs):
        uri_down = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_1)
        predicate = lambda r: (FaultInjectionTransport.predicate_is_document_operation(r) and
                               (FaultInjectionTransport.predicate_targets_region(r, uri_down) or
                                FaultInjectionTransport.predicate_targets_region(r, default_endpoint)) and
                               not FaultInjectionTransport.predicate_is_operation_type(r, "ReadFeed")
                               )
        custom_transport.add_fault(predicate,
                                   error_lambda)
        client = CosmosClient(default_endpoint,
                              self.master_key,
                              multiple_write_locations=True,
                              transport=custom_transport, consistency_level="Session", **kwargs)
        db = client.get_database_client(self.TEST_DATABASE_ID)
        container = db.get_container_client(self.TEST_CONTAINER_SINGLE_PARTITION_ID)
        return {"client": client, "db": db, "col": container}
    @pytest.mark.cosmosEmulator
    @pytest.mark.parametrize("preferred_location, default_endpoint", preferred_locations())
    def test_effective_preferred_regions(self, setup, preferred_location, default_endpoint):
        self.original_getDatabaseAccountStub = _global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub
        self.original_getDatabaseAccountCheck = _cosmos_client_connection.CosmosClientConnection._GetDatabaseAccountCheck
        _global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub = self.MockGetDatabaseAccount(ACCOUNT_REGIONS)
        _cosmos_client_connection.CosmosClientConnection._GetDatabaseAccountCheck = self.MockGetDatabaseAccount(ACCOUNT_REGIONS)
        try:
            client = CosmosClient(default_endpoint, self.master_key, preferred_locations=preferred_location)
            # this will setup the location cache
            client.client_connection._global_endpoint_manager.force_refresh_on_startup(None)
        finally:
            _global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub = self.original_getDatabaseAccountStub
            _cosmos_client_connection.CosmosClientConnection._GetDatabaseAccountCheck = self.original_getDatabaseAccountCheck
        expected_endpoints = []
        # if preferred location set should use that
        if preferred_location:
            expected_locations = preferred_location
        # if client created with regional endpoint preferred locations, only use hub region
        elif default_endpoint != self.host:
            expected_locations = ACCOUNT_REGIONS[:1]
        # if client created with global endpoint and no preferred locations, use all regions
        else:
            expected_locations = ACCOUNT_REGIONS
        for location in expected_locations:
            locational_endpoint = _location_cache.LocationCache.GetLocationalEndpoint(self.host, location)
            expected_endpoints.append(RegionalRoutingContext(locational_endpoint))
        read_endpoints = client.client_connection._global_endpoint_manager.location_cache.read_regional_routing_contexts
        assert read_endpoints == expected_endpoints
    @pytest.mark.cosmosMultiRegion
    @pytest.mark.parametrize("error", error())
    def test_read_no_preferred_locations_with_errors(self, setup, error):
        container = setup[COLLECTION]
        item_to_read = construct_item()
        container.create_item(item_to_read)
        # setup fault injection so that first account region fails
        custom_transport = FaultInjectionTransport()
        error_lambda = lambda r: FaultInjectionTransport.error_after_delay(
            0,
            error
        )
        expected = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2)
        fault_setup = self.setup_method_with_custom_transport(custom_transport=custom_transport, error_lambda=error_lambda)
        fault_container = fault_setup["col"]
        response = fault_container.read_item(item=item_to_read["id"], partition_key=item_to_read[self.partition_key])
        request = response.get_response_headers()["_request"]
        # Validate the response comes from another region meaning that the account locations were used
        assert request.url.startswith(expected)
        # should fail if using excluded locations because no where to failover to
        with pytest.raises(CosmosHttpResponseError):
            fault_container.read_item(item=item_to_read["id"], partition_key=item_to_read[self.partition_key], excluded_locations=[REGION_2])
    @pytest.mark.cosmosMultiRegion
    def test_write_no_preferred_locations_with_errors(self, setup):
        # setup fault injection so that first account region fails
        custom_transport = FaultInjectionTransport()
        expected = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2)
        error_lambda = lambda r: FaultInjectionTransport.error_region_down()
        fault_setup = self.setup_method_with_custom_transport(custom_transport=custom_transport, error_lambda=error_lambda)
        fault_container = fault_setup["col"]
        response = fault_container.create_item(body=construct_item())
        request = response.get_response_headers()["_request"]
        # Validate the response comes from another region meaning that the account locations were used
        assert request.url.startswith(expected)
        # should fail if using excluded locations because no where to failover to
        with pytest.raises(ServiceRequestError):
            fault_container.create_item(body=construct_item(), excluded_locations=[REGION_2])
    class MockGetDatabaseAccount(object):
        def __init__(
                self,
                regions: List[str],
        ):
            self.regions = regions
        def __call__(self, endpoint):
            read_regions = self.regions
            read_locations = []
            counter = 0
            for loc in read_regions:
                locational_endpoint = _location_cache.LocationCache.GetLocationalEndpoint(TestPreferredLocations.host, loc)
                read_locations.append({'databaseAccountEndpoint': locational_endpoint, 'name': loc})
                counter += 1
            write_regions = [self.regions[0]]
            write_locations = []
            for loc in write_regions:
                locational_endpoint = _location_cache.LocationCache.GetLocationalEndpoint(TestPreferredLocations.host, loc)
                write_locations.append({'databaseAccountEndpoint': locational_endpoint, 'name': loc})
            multi_write = False
            db_acc = DatabaseAccount()
            db_acc.DatabasesLink = "/dbs/"
            db_acc.MediaLink = "/media/"
            db_acc._ReadableLocations = read_locations
            db_acc._WritableLocations = write_locations
            db_acc._EnableMultipleWritableLocations = multi_write
            db_acc.ConsistencyPolicy = {"defaultConsistencyLevel": "Session"}
            return db_acc
 |