1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
|
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import time
from unittest import mock
from azure.core.credentials import AccessTokenInfo
import pytest
from azure.identity._constants import DEFAULT_REFRESH_OFFSET
from azure.identity._internal.get_token_mixin import GetTokenMixin
from helpers import GET_TOKEN_METHODS
class MockCredential(GetTokenMixin):
NEW_TOKEN = AccessTokenInfo("new token", 42)
def __init__(self, cached_token=None):
super(MockCredential, self).__init__()
self.request_token = mock.Mock(return_value=MockCredential.NEW_TOKEN)
self.acquire_token_silently = mock.Mock(return_value=cached_token)
def _acquire_token_silently(self, *scopes, **kwargs):
return self.acquire_token_silently(*scopes, **kwargs)
def _request_token(self, *scopes, **kwargs):
return self.request_token(*scopes, **kwargs)
def get_token(self, *_, **__):
return super(MockCredential, self).get_token(*_, **__)
def get_token_info(self, *_, **__):
return super(MockCredential, self).get_token_info(*_, **__)
CACHED_TOKEN = "cached token"
SCOPE = "scope"
@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
def test_no_cached_token(get_token_method):
"""When it has no token cached, a credential should request one every time get_token is called"""
credential = MockCredential()
token = getattr(credential, get_token_method)(SCOPE)
credential.acquire_token_silently.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None)
credential.request_token.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None)
assert token.token == MockCredential.NEW_TOKEN.token
@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
def test_tenant_id(get_token_method):
credential = MockCredential()
kwargs = {"tenant_id": "tenant_id"}
if get_token_method == "get_token_info":
kwargs = {"options": kwargs}
token = getattr(credential, get_token_method)(SCOPE, **kwargs)
assert token.token == MockCredential.NEW_TOKEN.token
@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
def test_token_acquisition_failure(get_token_method):
"""When the credential has no token cached, every get_token call should prompt a token request"""
credential = MockCredential()
credential.request_token = mock.Mock(side_effect=Exception("whoops"))
for i in range(4):
with pytest.raises(Exception):
getattr(credential, get_token_method)(SCOPE)
assert credential.request_token.call_count == i + 1
credential.request_token.assert_called_with(SCOPE, claims=None, enable_cae=False, tenant_id=None)
@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
def test_expired_token(get_token_method):
"""A credential should request a token when it has an expired token cached"""
now = int(time.time())
credential = MockCredential(cached_token=AccessTokenInfo(CACHED_TOKEN, now - 1))
token = getattr(credential, get_token_method)(SCOPE)
credential.acquire_token_silently.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None)
credential.request_token.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None)
assert token.token == MockCredential.NEW_TOKEN.token
@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
def test_cached_token_outside_refresh_window(get_token_method):
"""A credential shouldn't request a new token when it has a cached one with sufficient validity remaining"""
credential = MockCredential(
cached_token=AccessTokenInfo(CACHED_TOKEN, int(time.time() + DEFAULT_REFRESH_OFFSET + 1))
)
token = getattr(credential, get_token_method)(SCOPE)
credential.acquire_token_silently.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None)
assert credential.request_token.call_count == 0
assert token.token == CACHED_TOKEN
@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
def test_cached_token_within_refresh_window(get_token_method):
"""A credential should request a new token when its cached one is within the refresh window"""
credential = MockCredential(
cached_token=AccessTokenInfo(CACHED_TOKEN, int(time.time() + DEFAULT_REFRESH_OFFSET - 1))
)
token = getattr(credential, get_token_method)(SCOPE)
credential.acquire_token_silently.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None)
credential.request_token.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None)
assert token.token == MockCredential.NEW_TOKEN.token
@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
def test_retry_delay(get_token_method):
"""A credential should wait between requests when trying to refresh a token"""
now = time.time()
credential = MockCredential(cached_token=AccessTokenInfo(CACHED_TOKEN, int(now + DEFAULT_REFRESH_OFFSET - 1)))
# the credential should swallow exceptions during proactive refresh attempts
credential.request_token = mock.Mock(side_effect=Exception("whoops"))
for i in range(4):
token = getattr(credential, get_token_method)(SCOPE)
assert token.token == CACHED_TOKEN
credential.acquire_token_silently.assert_called_with(SCOPE, claims=None, enable_cae=False, tenant_id=None)
credential.request_token.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None)
|