1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
|
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
try:
from unittest.mock import Mock
except ImportError: # python < 3.3
from mock import Mock # type: ignore
from azure.core.credentials import AccessToken
from azure.core.exceptions import ClientAuthenticationError
from azure.identity import (
ChainedTokenCredential,
ClientSecretCredential,
CredentialUnavailableError,
)
import pytest
def test_error_message():
first_error = "first_error"
first_credential = Mock(
spec=ClientSecretCredential, get_token=Mock(side_effect=CredentialUnavailableError(first_error))
)
second_error = "second_error"
second_credential = Mock(
name="second_credential", get_token=Mock(side_effect=ClientAuthenticationError(second_error))
)
with pytest.raises(ClientAuthenticationError) as ex:
ChainedTokenCredential(first_credential, second_credential).get_token("scope")
assert "ClientSecretCredential" in ex.value.message
assert first_error in ex.value.message
assert second_error in ex.value.message
def test_attempts_all_credentials():
expected_token = AccessToken("expected_token", 0)
credentials = [
Mock(get_token=Mock(side_effect=CredentialUnavailableError(message=""))),
Mock(get_token=Mock(side_effect=CredentialUnavailableError(message=""))),
Mock(get_token=Mock(return_value=expected_token)),
]
token = ChainedTokenCredential(*credentials).get_token("scope")
assert token is expected_token
for credential in credentials:
assert credential.get_token.call_count == 1
def test_raises_for_unexpected_error():
"""the chain should not continue after an unexpected error (i.e. anything but CredentialUnavailableError)"""
expected_message = "it can't be done"
credentials = [
Mock(get_token=Mock(side_effect=CredentialUnavailableError(message=""))),
Mock(get_token=Mock(side_effect=ValueError(expected_message))),
Mock(get_token=Mock(return_value=AccessToken("**", 42))),
]
with pytest.raises(ClientAuthenticationError) as ex:
ChainedTokenCredential(*credentials).get_token("scope")
assert expected_message in ex.value.message
assert credentials[-1].get_token.call_count == 0
def test_returns_first_token():
expected_token = Mock()
first_credential = Mock(get_token=lambda _: expected_token)
second_credential = Mock(get_token=Mock())
aggregate = ChainedTokenCredential(first_credential, second_credential)
credential = aggregate.get_token("scope")
assert credential is expected_token
assert second_credential.get_token.call_count == 0
|