File: test_managed_identity_client.py

package info (click to toggle)
python-azure 20250603%2Bgit-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 851,724 kB
  • sloc: python: 7,362,925; ansic: 804; javascript: 287; makefile: 195; sh: 145; xml: 109
file content (124 lines) | stat: -rw-r--r-- 4,053 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import json
import time

from azure.core.exceptions import ClientAuthenticationError, ServiceRequestError
from azure.core.pipeline.transport import HttpRequest
from azure.identity._internal.managed_identity_client import ManagedIdentityClient
import pytest

from helpers import mock, mock_response, Request, validating_transport


def test_caching():
    scope = "scope"
    now = int(time.time())
    expected_expires_on = now + 3600
    expected_token = "*"
    transport = validating_transport(
        requests=[Request(url="http://localhost")],
        responses=[
            mock_response(
                json_payload={
                    "access_token": expected_token,
                    "expires_in": 3600,
                    "expires_on": expected_expires_on,
                    "resource": scope,
                    "token_type": "Bearer",
                }
            )
        ],
    )
    client = ManagedIdentityClient(
        request_factory=lambda _, __: HttpRequest("GET", "http://localhost"), transport=transport
    )

    token = client.get_cached_token(scope)
    assert not token

    with mock.patch(ManagedIdentityClient.__module__ + ".time.time", lambda: now):
        token = client.request_token(scope)
    assert token.expires_on == expected_expires_on
    assert token.token == expected_token

    token = client.get_cached_token(scope)
    assert token.expires_on == expected_expires_on
    assert token.token == expected_token


def test_deserializes_json_from_text():
    """The client should gracefully handle a response with a JSON body and content-type text/plain"""

    scope = "scope"
    now = int(time.time())
    expected_expires_on = now + 3600
    expected_token = "*"

    def send(request, **_):
        body = json.dumps(
            {
                "access_token": expected_token,
                "expires_in": 3600,
                "expires_on": expected_expires_on,
                "resource": scope,
                "token_type": "Bearer",
            }
        )
        return mock.Mock(
            status_code=200,
            headers={"Content-Type": "text/plain"},
            content_type="text/plain",
            text=lambda encoding=None: body,
        )

    client = ManagedIdentityClient(
        request_factory=lambda _, __: HttpRequest("GET", "http://localhost"), transport=mock.Mock(send=send)
    )

    token = client.request_token(scope)
    assert token.expires_on == expected_expires_on
    assert token.token == expected_token


def test_retry():
    """ManagedIdentityClient should retry token requests"""

    message = "can't connect"
    transport = mock.Mock(send=mock.Mock(side_effect=ServiceRequestError(message)))
    request_factory = mock.Mock()

    client = ManagedIdentityClient(request_factory, transport=transport)

    for method in ("GET", "POST"):
        request_factory.return_value = HttpRequest(method, "https://localhost")
        with pytest.raises(ServiceRequestError, match=message):
            client.request_token("scope")
        assert transport.send.call_count > 1
        transport.send.reset_mock()


@pytest.mark.parametrize("content_type", ("text/html", "application/json"))
def test_unexpected_content(content_type):
    content = "<html><body>not JSON</body></html>"

    def send(request, **_):
        return mock.Mock(
            status_code=200,
            headers={"Content-Type": content_type},
            content_type=content_type,
            text=lambda encoding=None: content,
        )

    client = ManagedIdentityClient(
        request_factory=lambda _, __: HttpRequest("GET", "http://localhost"), transport=mock.Mock(send=send)
    )

    with pytest.raises(ClientAuthenticationError) as ex:
        client.request_token("scope")
    assert ex.value.response.text() == content

    if "json" not in content_type:
        assert content_type in ex.value.message