File: test_get_token_mixin_async.py

package info (click to toggle)
python-azure 20250603%2Bgit-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 851,724 kB
  • sloc: python: 7,362,925; ansic: 804; javascript: 287; makefile: 195; sh: 145; xml: 109
file content (135 lines) | stat: -rw-r--r-- 5,800 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import time
from unittest import mock

from azure.core.credentials import AccessTokenInfo
import pytest

from azure.identity._constants import DEFAULT_REFRESH_OFFSET
from azure.identity.aio._internal.get_token_mixin import GetTokenMixin

from helpers import GET_TOKEN_METHODS

pytestmark = pytest.mark.asyncio


class MockCredential(GetTokenMixin):
    NEW_TOKEN = AccessTokenInfo("new token", 42)

    def __init__(self, cached_token=None):
        super(MockCredential, self).__init__()
        self.token = cached_token
        self.request_token = mock.Mock(return_value=MockCredential.NEW_TOKEN)
        self.acquire_token_silently = mock.Mock(return_value=cached_token)

    async def _acquire_token_silently(self, *scopes, **kwargs):
        return self.acquire_token_silently(*scopes, **kwargs)

    async def _request_token(self, *scopes, **kwargs):
        return self.request_token(*scopes, **kwargs)

    async def get_token(self, *_, **__):
        return await super().get_token(*_, **__)

    async def get_token_info(self, *_, **__):
        return await super().get_token_info(*_, **__)


CACHED_TOKEN = "cached token"
SCOPE = "scope"


@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
async def test_no_cached_token(get_token_method):
    """When it has no token cached, a credential should request one every time get_token is called"""

    credential = MockCredential()
    token = await getattr(credential, get_token_method)(SCOPE)

    credential.acquire_token_silently.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None)
    credential.request_token.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None)
    assert token.token == MockCredential.NEW_TOKEN.token


@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
async def test_tenant_id(get_token_method):
    credential = MockCredential()
    kwargs = {"tenant_id": "tenant_id"}
    if get_token_method == "get_token_info":
        kwargs = {"options": kwargs}
    token = await getattr(credential, get_token_method)(SCOPE, **kwargs)

    assert token.token == MockCredential.NEW_TOKEN.token


@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
async def test_token_acquisition_failure(get_token_method):
    """When the credential has no token cached, every get_token call should prompt a token request"""

    credential = MockCredential()
    credential.request_token = mock.Mock(side_effect=Exception("whoops"))
    for i in range(4):
        with pytest.raises(Exception):
            await getattr(credential, get_token_method)(SCOPE)
        assert credential.request_token.call_count == i + 1
        credential.request_token.assert_called_with(SCOPE, claims=None, enable_cae=False, tenant_id=None)


@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
async def test_expired_token(get_token_method):
    """A credential should request a token when it has an expired token cached"""

    now = int(time.time())
    credential = MockCredential(cached_token=AccessTokenInfo(CACHED_TOKEN, now - 1))
    token = await getattr(credential, get_token_method)(SCOPE)

    credential.acquire_token_silently.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None)
    credential.request_token.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None)
    assert token.token == MockCredential.NEW_TOKEN.token


@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
async def test_cached_token_outside_refresh_window(get_token_method):
    """A credential shouldn't request a new token when it has a cached one with sufficient validity remaining"""

    credential = MockCredential(
        cached_token=AccessTokenInfo(CACHED_TOKEN, int(time.time() + DEFAULT_REFRESH_OFFSET + 1))
    )
    token = await getattr(credential, get_token_method)(SCOPE)

    credential.acquire_token_silently.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None)
    assert credential.request_token.call_count == 0
    assert token.token == CACHED_TOKEN


@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
async def test_cached_token_within_refresh_window(get_token_method):
    """A credential should request a new token when its cached one is within the refresh window"""

    credential = MockCredential(
        cached_token=AccessTokenInfo(CACHED_TOKEN, int(time.time() + DEFAULT_REFRESH_OFFSET - 1))
    )
    token = await getattr(credential, get_token_method)(SCOPE)

    credential.acquire_token_silently.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None)
    credential.request_token.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None)
    assert token.token == MockCredential.NEW_TOKEN.token


@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
async def test_retry_delay(get_token_method):
    """A credential should wait between requests when trying to refresh a token"""

    now = time.time()
    credential = MockCredential(cached_token=AccessTokenInfo(CACHED_TOKEN, int(now + DEFAULT_REFRESH_OFFSET - 1)))

    # the credential should swallow exceptions during proactive refresh attempts
    credential.request_token = mock.Mock(side_effect=Exception("whoops"))
    for i in range(4):
        token = await getattr(credential, get_token_method)(SCOPE)
        assert token.token == CACHED_TOKEN
        credential.acquire_token_silently.assert_called_with(SCOPE, claims=None, enable_cae=False, tenant_id=None)
        credential.request_token.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None)