1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
|
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import os
from unittest.mock import Mock, patch
from azure.core.credentials import AccessToken
from azure.identity import CredentialUnavailableError
from azure.identity.aio._credentials.application import AzureApplicationCredential
from azure.identity._constants import EnvironmentVariables
import pytest
from six.moves.urllib_parse import urlparse
from helpers import build_aad_response, mock_response
from helpers_async import get_completed_future
@pytest.mark.asyncio
async def test_iterates_only_once():
"""When a credential succeeds, AzureApplicationCredential should use that credential thereafter"""
expected_token = AccessToken("***", 42)
unavailable_credential = Mock(get_token=Mock(side_effect=CredentialUnavailableError(message="...")))
successful_credential = Mock(get_token=Mock(return_value=get_completed_future(expected_token)))
credential = AzureApplicationCredential()
credential.credentials = [
unavailable_credential,
successful_credential,
Mock(get_token=Mock(side_effect=Exception("iteration didn't stop after a credential provided a token"))),
]
for n in range(3):
token = await credential.get_token("scope")
assert token.token == expected_token.token
assert unavailable_credential.get_token.call_count == 1
assert successful_credential.get_token.call_count == n + 1
@pytest.mark.parametrize("authority", ("localhost", "https://localhost"))
def test_authority(authority):
"""the credential should accept authority configuration by keyword argument or environment"""
parsed_authority = urlparse(authority)
expected_netloc = parsed_authority.netloc or authority # "localhost" parses to netloc "", path "localhost"
def test_initialization(mock_credential, expect_argument):
AzureApplicationCredential(authority=authority)
assert mock_credential.call_count == 1
# N.B. if os.environ has been patched somewhere in the stack, that patch is in place here
environment = dict(os.environ, **{EnvironmentVariables.AZURE_AUTHORITY_HOST: authority})
with patch.dict(AzureApplicationCredential.__module__ + ".os.environ", environment, clear=True):
AzureApplicationCredential()
assert mock_credential.call_count == 2
for _, kwargs in mock_credential.call_args_list:
if expect_argument:
actual = urlparse(kwargs["authority"])
assert actual.scheme == "https"
assert actual.netloc == expected_netloc
else:
assert "authority" not in kwargs
# authority should be passed to EnvironmentCredential as a keyword argument
environment = {var: "foo" for var in EnvironmentVariables.CLIENT_SECRET_VARS}
with patch(AzureApplicationCredential.__module__ + ".EnvironmentCredential") as mock_credential:
with patch.dict("os.environ", environment, clear=True):
test_initialization(mock_credential, expect_argument=True)
# authority should not be passed to ManagedIdentityCredential
with patch(AzureApplicationCredential.__module__ + ".ManagedIdentityCredential") as mock_credential:
with patch.dict("os.environ", {EnvironmentVariables.MSI_ENDPOINT: "localhost"}, clear=True):
test_initialization(mock_credential, expect_argument=False)
@pytest.mark.asyncio
async def test_get_token():
expected_token = "***"
async def send(request, **_):
return mock_response(json_payload=build_aad_response(access_token=expected_token))
with patch.dict("os.environ", {var: "..." for var in EnvironmentVariables.CLIENT_SECRET_VARS}, clear=True):
credential = AzureApplicationCredential(transport=Mock(send=send))
token = await credential.get_token("scope")
assert token.token == expected_token
def test_managed_identity_client_id():
"""the credential should accept a user-assigned managed identity's client ID by kwarg or environment variable"""
expected_args = {"client_id": "the-client"}
with patch(AzureApplicationCredential.__module__ + ".ManagedIdentityCredential") as mock_credential:
AzureApplicationCredential(managed_identity_client_id=expected_args["client_id"])
mock_credential.assert_called_once_with(**expected_args)
# client id can also be specified in $AZURE_CLIENT_ID
with patch.dict(os.environ, {EnvironmentVariables.AZURE_CLIENT_ID: expected_args["client_id"]}, clear=True):
with patch(AzureApplicationCredential.__module__ + ".ManagedIdentityCredential") as mock_credential:
AzureApplicationCredential()
mock_credential.assert_called_once_with(**expected_args)
# keyword argument should override environment variable
with patch.dict(
os.environ, {EnvironmentVariables.AZURE_CLIENT_ID: "not-" + expected_args["client_id"]}, clear=True
):
with patch(AzureApplicationCredential.__module__ + ".ManagedIdentityCredential") as mock_credential:
AzureApplicationCredential(managed_identity_client_id=expected_args["client_id"])
mock_credential.assert_called_once_with(**expected_args)
|