File: test_managed_identity_client_async.py

package info (click to toggle)
python-azure 20250603%2Bgit-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 851,724 kB
  • sloc: python: 7,362,925; ansic: 804; javascript: 287; makefile: 195; sh: 145; xml: 109
file content (150 lines) | stat: -rw-r--r-- 4,919 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import json
import time
from unittest.mock import Mock, patch

from azure.core.exceptions import ClientAuthenticationError, ServiceRequestError
from azure.core.pipeline.transport import HttpRequest
from azure.identity.aio._internal.managed_identity_client import AsyncManagedIdentityClient
import pytest

from helpers import mock_response, Request
from helpers_async import async_validating_transport, AsyncMockTransport, get_completed_future

pytestmark = pytest.mark.asyncio


async def test_close():
    transport = AsyncMockTransport()
    client = AsyncManagedIdentityClient(lambda *_: None, transport=transport)

    await client.close()

    assert transport.__aexit__.call_count == 1


async def test_context_manager():
    transport = AsyncMockTransport()
    client = AsyncManagedIdentityClient(lambda *_: None, transport=transport)

    async with client:
        assert transport.__aenter__.call_count == 1
        assert transport.__aexit__.call_count == 0

    assert transport.__aenter__.call_count == 1
    assert transport.__aexit__.call_count == 1


async def test_caching():
    scope = "scope"
    now = int(time.time())
    expected_expires_on = now + 3600
    expected_token = "*"
    transport = async_validating_transport(
        requests=[Request(url="http://localhost")],
        responses=[
            mock_response(
                json_payload={
                    "access_token": expected_token,
                    "expires_in": 3600,
                    "expires_on": expected_expires_on,
                    "resource": scope,
                    "token_type": "Bearer",
                }
            )
        ],
    )
    client = AsyncManagedIdentityClient(
        request_factory=lambda _, __: HttpRequest("GET", "http://localhost"), transport=transport
    )

    token = client.get_cached_token(scope)
    assert not token

    with patch(AsyncManagedIdentityClient.__module__ + ".time.time", lambda: now):
        token = await client.request_token(scope)
    assert token.expires_on == expected_expires_on
    assert token.token == expected_token

    token = client.get_cached_token(scope)
    assert token.expires_on == expected_expires_on
    assert token.token == expected_token


async def test_deserializes_json_from_text():
    """The client should gracefully handle a response with a JSON body and content-type text/plain"""

    scope = "scope"
    now = int(time.time())
    expected_expires_on = now + 3600
    expected_token = "*"

    async def send(request, **_):
        body = json.dumps(
            {
                "access_token": expected_token,
                "expires_in": 3600,
                "expires_on": expected_expires_on,
                "resource": scope,
                "token_type": "Bearer",
            }
        )
        return Mock(
            status_code=200,
            headers={"Content-Type": "text/plain"},
            content_type="text/plain",
            text=lambda encoding=None: body,
        )

    client = AsyncManagedIdentityClient(
        request_factory=lambda _, __: HttpRequest("GET", "http://localhost"), transport=Mock(send=send)
    )

    token = await client.request_token(scope)
    assert token.expires_on == expected_expires_on
    assert token.token == expected_token


@pytest.mark.asyncio
async def test_managed_identity_client_retry():
    """AsyncManagedIdentityClient should retry token requests"""

    message = "can't connect"
    transport = Mock(send=Mock(side_effect=ServiceRequestError(message)), sleep=get_completed_future)
    request_factory = Mock()

    client = AsyncManagedIdentityClient(request_factory, transport=transport)

    for method in ("GET", "POST"):
        request_factory.return_value = HttpRequest(method, "https://localhost")
        with pytest.raises(ServiceRequestError, match=message):
            await client.request_token("scope")
        assert transport.send.call_count > 1
        transport.send.reset_mock()


@pytest.mark.parametrize("content_type", ("text/html", "application/json"))
async def test_unexpected_content(content_type):
    content = "<html><body>not JSON</body></html>"

    async def send(request, **_):
        return Mock(
            status_code=200,
            headers={"Content-Type": content_type},
            content_type=content_type,
            text=lambda encoding=None: content,
        )

    client = AsyncManagedIdentityClient(
        request_factory=lambda _, __: HttpRequest("GET", "http://localhost"), transport=Mock(send=send)
    )

    with pytest.raises(ClientAuthenticationError) as ex:
        await client.request_token("scope")
    assert ex.value.response.text() == content

    if "json" not in content_type:
        assert content_type in ex.value.message