File: test_client_assertion_credential_async.py

package info (click to toggle)
python-azure 20250603%2Bgit-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 851,724 kB
  • sloc: python: 7,362,925; ansic: 804; javascript: 287; makefile: 195; sh: 145; xml: 109
file content (92 lines) | stat: -rw-r--r-- 3,445 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
from typing import Callable
from unittest.mock import MagicMock, Mock, patch

import pytest
from azure.identity._internal.aad_client_base import JWT_BEARER_ASSERTION
from azure.identity import TokenCachePersistenceOptions
from azure.identity.aio import ClientAssertionCredential

from helpers import build_aad_response, mock_response, GET_TOKEN_METHODS


def test_init_with_kwargs():
    tenant_id: str = "TENANT_ID"
    client_id: str = "CLIENT_ID"
    func: Callable[[], str] = lambda: "TOKEN"

    credential: ClientAssertionCredential = ClientAssertionCredential(
        tenant_id=tenant_id, client_id=client_id, func=func, authority="a"
    )

    # Test arbitrary keyword argument
    credential = ClientAssertionCredential(tenant_id=tenant_id, client_id=client_id, func=func, foo="a", bar="b")


@pytest.mark.asyncio
async def test_context_manager():
    tenant_id: str = "TENANT_ID"
    client_id: str = "CLIENT_ID"
    func: Callable[[], str] = lambda: "TOKEN"

    transport = MagicMock()
    credential: ClientAssertionCredential = ClientAssertionCredential(
        tenant_id=tenant_id, client_id=client_id, func=func, transport=transport
    )

    async with credential:
        assert transport.__aenter__.called
        assert not transport.__aexit__.called

    assert transport.__aexit__.called


@pytest.mark.asyncio
@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
async def test_token_cache_persistence(get_token_method):
    """The credential should use a persistent cache if cache_persistence_options are configured."""

    access_token = "foo"
    tenant_id: str = "TENANT_ID"
    client_id: str = "CLIENT_ID"
    scope = "scope"
    assertion = "ASSERTION_TOKEN"
    func: Callable[[], str] = lambda: assertion

    async def send(request, **kwargs):
        assert request.data["client_assertion"] == assertion
        assert request.data["client_assertion_type"] == JWT_BEARER_ASSERTION
        assert request.data["client_id"] == client_id
        assert request.data["grant_type"] == "client_credentials"
        assert request.data["scope"] == scope

        return mock_response(json_payload=build_aad_response(access_token=access_token))

    with patch("azure.identity._internal.aad_client_base._load_persistent_cache") as load_persistent_cache:
        credential = ClientAssertionCredential(
            tenant_id=tenant_id,
            client_id=client_id,
            func=func,
            cache_persistence_options=TokenCachePersistenceOptions(),
            transport=Mock(send=send),
        )

        assert load_persistent_cache.call_count == 0
        assert credential._client._cache is None
        assert credential._client._cae_cache is None

        token = await getattr(credential, get_token_method)(scope)
        assert token.token == access_token
        assert load_persistent_cache.call_count == 1
        assert credential._client._cache is not None
        assert credential._client._cae_cache is None

        kwargs = {"enable_cae": True}
        if get_token_method == "get_token_info":
            kwargs = {"options": kwargs}
        token = await getattr(credential, get_token_method)(scope, **kwargs)
        assert load_persistent_cache.call_count == 2
        assert credential._client._cae_cache is not None