1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
|
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import time
from unittest import mock
from azure.core.credentials import AccessTokenInfo
import pytest
from azure.identity._constants import DEFAULT_REFRESH_OFFSET
from azure.identity.aio._internal.get_token_mixin import GetTokenMixin
from helpers import GET_TOKEN_METHODS
pytestmark = pytest.mark.asyncio
class MockCredential(GetTokenMixin):
NEW_TOKEN = AccessTokenInfo("new token", 42)
def __init__(self, cached_token=None):
super(MockCredential, self).__init__()
self.token = cached_token
self.request_token = mock.Mock(return_value=MockCredential.NEW_TOKEN)
self.acquire_token_silently = mock.Mock(return_value=cached_token)
async def _acquire_token_silently(self, *scopes, **kwargs):
return self.acquire_token_silently(*scopes, **kwargs)
async def _request_token(self, *scopes, **kwargs):
return self.request_token(*scopes, **kwargs)
async def get_token(self, *_, **__):
return await super().get_token(*_, **__)
async def get_token_info(self, *_, **__):
return await super().get_token_info(*_, **__)
CACHED_TOKEN = "cached token"
SCOPE = "scope"
@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
async def test_no_cached_token(get_token_method):
"""When it has no token cached, a credential should request one every time get_token is called"""
credential = MockCredential()
token = await getattr(credential, get_token_method)(SCOPE)
credential.acquire_token_silently.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None)
credential.request_token.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None)
assert token.token == MockCredential.NEW_TOKEN.token
@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
async def test_tenant_id(get_token_method):
credential = MockCredential()
kwargs = {"tenant_id": "tenant_id"}
if get_token_method == "get_token_info":
kwargs = {"options": kwargs}
token = await getattr(credential, get_token_method)(SCOPE, **kwargs)
assert token.token == MockCredential.NEW_TOKEN.token
@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
async def test_token_acquisition_failure(get_token_method):
"""When the credential has no token cached, every get_token call should prompt a token request"""
credential = MockCredential()
credential.request_token = mock.Mock(side_effect=Exception("whoops"))
for i in range(4):
with pytest.raises(Exception):
await getattr(credential, get_token_method)(SCOPE)
assert credential.request_token.call_count == i + 1
credential.request_token.assert_called_with(SCOPE, claims=None, enable_cae=False, tenant_id=None)
@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
async def test_expired_token(get_token_method):
"""A credential should request a token when it has an expired token cached"""
now = int(time.time())
credential = MockCredential(cached_token=AccessTokenInfo(CACHED_TOKEN, now - 1))
token = await getattr(credential, get_token_method)(SCOPE)
credential.acquire_token_silently.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None)
credential.request_token.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None)
assert token.token == MockCredential.NEW_TOKEN.token
@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
async def test_cached_token_outside_refresh_window(get_token_method):
"""A credential shouldn't request a new token when it has a cached one with sufficient validity remaining"""
credential = MockCredential(
cached_token=AccessTokenInfo(CACHED_TOKEN, int(time.time() + DEFAULT_REFRESH_OFFSET + 1))
)
token = await getattr(credential, get_token_method)(SCOPE)
credential.acquire_token_silently.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None)
assert credential.request_token.call_count == 0
assert token.token == CACHED_TOKEN
@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
async def test_cached_token_within_refresh_window(get_token_method):
"""A credential should request a new token when its cached one is within the refresh window"""
credential = MockCredential(
cached_token=AccessTokenInfo(CACHED_TOKEN, int(time.time() + DEFAULT_REFRESH_OFFSET - 1))
)
token = await getattr(credential, get_token_method)(SCOPE)
credential.acquire_token_silently.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None)
credential.request_token.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None)
assert token.token == MockCredential.NEW_TOKEN.token
@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
async def test_retry_delay(get_token_method):
"""A credential should wait between requests when trying to refresh a token"""
now = time.time()
credential = MockCredential(cached_token=AccessTokenInfo(CACHED_TOKEN, int(now + DEFAULT_REFRESH_OFFSET - 1)))
# the credential should swallow exceptions during proactive refresh attempts
credential.request_token = mock.Mock(side_effect=Exception("whoops"))
for i in range(4):
token = await getattr(credential, get_token_method)(SCOPE)
assert token.token == CACHED_TOKEN
credential.acquire_token_silently.assert_called_with(SCOPE, claims=None, enable_cae=False, tenant_id=None)
credential.request_token.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None)
|