File: test_context_manager.py

package info (click to toggle)
python-azure 20250603%2Bgit-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 851,724 kB
  • sloc: python: 7,362,925; ansic: 804; javascript: 287; makefile: 195; sh: 145; xml: 109
file content (122 lines) | stat: -rw-r--r-- 3,970 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
from unittest.mock import MagicMock, patch

from azure.identity import (
    AzureCliCredential,
    AzureDeveloperCliCredential,
    AzurePowerShellCredential,
    AuthorizationCodeCredential,
    CertificateCredential,
    ClientSecretCredential,
    DeviceCodeCredential,
    EnvironmentCredential,
    InteractiveBrowserCredential,
    OnBehalfOfCredential,
    SharedTokenCacheCredential,
    UsernamePasswordCredential,
    VisualStudioCodeCredential,
)
from azure.identity._constants import EnvironmentVariables

import pytest

from test_certificate_credential import PEM_CERT_PATH
from test_vscode_credential import GET_USER_SETTINGS


class CredentialFixture:
    def __init__(self, cls, default_kwargs=None, ctor_patch_factory=None):
        self.cls = cls
        self._default_kwargs = default_kwargs or {}
        self._ctor_patch_factory = ctor_patch_factory or MagicMock

    def get_credential(self, **kwargs):
        patch = self._ctor_patch_factory()
        with patch:
            return self.cls(**dict(self._default_kwargs, **kwargs))


FIXTURES = (
    CredentialFixture(
        AuthorizationCodeCredential,
        {kwarg: "..." for kwarg in ("tenant_id", "client_id", "authorization_code", "redirect_uri")},
    ),
    CredentialFixture(
        CertificateCredential, {"tenant_id": "...", "client_id": "...", "certificate_path": PEM_CERT_PATH}
    ),
    CredentialFixture(ClientSecretCredential, {kwarg: "..." for kwarg in ("tenant_id", "client_id", "client_secret")}),
    CredentialFixture(DeviceCodeCredential),
    CredentialFixture(
        EnvironmentCredential,
        ctor_patch_factory=lambda: patch.dict(
            EnvironmentCredential.__module__ + ".os.environ",
            {var: "..." for var in EnvironmentVariables.CLIENT_SECRET_VARS},
        ),
    ),
    CredentialFixture(InteractiveBrowserCredential),
    CredentialFixture(
        OnBehalfOfCredential,
        {kwarg: "..." for kwarg in ("tenant_id", "client_id", "client_secret", "user_assertion")},
    ),
    CredentialFixture(UsernamePasswordCredential, {"client_id": "...", "username": "...", "password": "..."}),
    CredentialFixture(VisualStudioCodeCredential, ctor_patch_factory=lambda: patch(GET_USER_SETTINGS, lambda: {})),
)

all_fixtures = pytest.mark.parametrize("fixture", FIXTURES, ids=lambda fixture: fixture.cls.__name__)


@all_fixtures
def test_close(fixture):
    transport = MagicMock()
    credential = fixture.get_credential(transport=transport)
    assert not transport.__enter__.called
    assert not transport.__exit__.called

    credential.close()
    assert not transport.__enter__.called
    assert transport.__exit__.call_count == 1


@all_fixtures
def test_context_manager(fixture):
    transport = MagicMock()
    credential = fixture.get_credential(transport=transport)

    with credential:
        assert transport.__enter__.call_count == 1
        assert not transport.__exit__.called

    assert transport.__enter__.call_count == 1
    assert transport.__exit__.call_count == 1


@all_fixtures
def test_exit_args(fixture):
    transport = MagicMock()
    credential = fixture.get_credential(transport=transport)
    expected_args = ("type", "value", "traceback")
    credential.__exit__(*expected_args)
    transport.__exit__.assert_called_once_with(*expected_args)


@pytest.mark.parametrize(
    "cls",
    (
        AzureCliCredential,
        AzureDeveloperCliCredential,
        AzurePowerShellCredential,
        EnvironmentCredential,
        SharedTokenCacheCredential,
    ),
)
def test_no_op(cls):
    """Credentials that don't allow custom transports, or require initialization or optional config, should have no-op methods"""
    with patch.dict("os.environ", {}, clear=True):
        credential = cls()

    with credential:
        pass
    credential.close()