1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
|
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
from datetime import datetime
import time
from unittest import mock
from azure.core.credentials import AccessToken
from azure.core.exceptions import ResourceExistsError
from azure.identity.aio import DefaultAzureCredential
from azure.keyvault.keys.aio import KeyClient
from azure.keyvault.administration.aio import KeyVaultBackupClient
from devtools_testutils import ResourceGroupPreparer, StorageAccountPreparer
import pytest
from _shared.helpers_async import get_completed_future
from _shared.test_case_async import KeyVaultTestCase
from blob_container_preparer import BlobContainerPreparer
from test_backup_client import assert_in_progress_operation
from test_backup_client import assert_successful_operation
@pytest.mark.usefixtures("managed_hsm")
class BackupClientTests(KeyVaultTestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, match_body=False, **kwargs)
def setUp(self, *args, **kwargs):
if self.is_live:
self.scrubber.register_name_pair(self.managed_hsm["url"].lower(), self.managed_hsm["playback_url"])
super().setUp(*args, **kwargs)
@property
def credential(self):
if self.is_live:
return DefaultAzureCredential()
async def get_token(*_, **__):
return AccessToken("secret", time.time() + 3600)
return mock.Mock(get_token=get_token)
@ResourceGroupPreparer(random_name_enabled=True, use_cache=True)
@StorageAccountPreparer(random_name_enabled=True)
@BlobContainerPreparer()
async def test_full_backup_and_restore(self, container_uri, sas_token):
# backup the vault
backup_client = KeyVaultBackupClient(self.managed_hsm["url"], self.credential)
backup_poller = await backup_client.begin_full_backup(container_uri, sas_token)
# check backup status and result
job_id = backup_poller.polling_method().resource().id
backup_status = await backup_client.get_backup_status(job_id)
assert_in_progress_operation(backup_status)
backup_operation = await backup_poller.result()
assert_successful_operation(backup_operation)
backup_status = await backup_client.get_backup_status(job_id)
assert_successful_operation(backup_status)
# restore the backup
folder_name = backup_operation.azure_storage_blob_container_uri.split("/")[-1]
restore_poller = await backup_client.begin_full_restore(container_uri, sas_token, folder_name)
# check restore status and result
job_id = restore_poller.polling_method().resource().id
restore_status = await backup_client.get_restore_status(job_id)
assert_in_progress_operation(restore_status)
restore_operation = await restore_poller.result()
assert_successful_operation(restore_operation)
restore_status = await backup_client.get_restore_status(job_id)
assert_successful_operation(restore_status)
@ResourceGroupPreparer(random_name_enabled=True, use_cache=True)
@StorageAccountPreparer(random_name_enabled=True)
@BlobContainerPreparer()
async def test_selective_key_restore(self, container_uri, sas_token):
# create a key to selectively restore
key_client = KeyClient(self.managed_hsm["url"], self.credential)
key_name = self.get_resource_name("selective-restore-test-key")
await key_client.create_rsa_key(key_name)
# backup the vault
backup_client = KeyVaultBackupClient(self.managed_hsm["url"], self.credential)
backup_poller = await backup_client.begin_full_backup(container_uri, sas_token)
# check backup status and result
job_id = backup_poller.polling_method().resource().id
backup_status = await backup_client.get_backup_status(job_id)
assert_in_progress_operation(backup_status)
backup_operation = await backup_poller.result()
assert_successful_operation(backup_operation)
backup_status = await backup_client.get_backup_status(job_id)
assert_successful_operation(backup_status)
# restore the key
folder_name = backup_operation.azure_storage_blob_container_uri.split("/")[-1]
restore_poller = await backup_client.begin_selective_restore(container_uri, sas_token, folder_name, key_name)
# check restore status and result
job_id = restore_poller.polling_method().resource().id
restore_status = await backup_client.get_restore_status(job_id)
assert_in_progress_operation(restore_status)
restore_operation = await restore_poller.result()
assert_successful_operation(restore_operation)
restore_status = await backup_client.get_restore_status(job_id)
assert_successful_operation(restore_status)
# delete the key
await self._poll_until_no_exception(key_client.delete_key, key_name, expected_exception=ResourceExistsError)
await key_client.purge_deleted_key(key_name)
@pytest.mark.asyncio
async def test_continuation_token():
"""Methods returning pollers should accept continuation tokens"""
expected_token = "token"
mock_generated_client = mock.Mock()
mock_methods = [
getattr(mock_generated_client, method_name)
for method_name in (
"begin_full_backup",
"begin_full_restore_operation",
"begin_selective_key_restore_operation",
)
]
for method in mock_methods:
# the mock client's methods must return awaitables, and we don't have AsyncMock before 3.8
method.return_value = get_completed_future()
backup_client = KeyVaultBackupClient("vault-url", object())
backup_client._client = mock_generated_client
await backup_client.begin_full_restore("storage uri", "sas", "folder", continuation_token=expected_token)
await backup_client.begin_full_backup("storage uri", "sas", continuation_token=expected_token)
await backup_client.begin_selective_restore(
"storage uri", "sas", "folder", "key", continuation_token=expected_token
)
for method in mock_methods:
assert method.call_count == 1
_, kwargs = method.call_args
assert kwargs["continuation_token"] == expected_token
|