File: test_lite_management_client.py

package info (click to toggle)
python-azure 20250603%2Bgit-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 851,724 kB
  • sloc: python: 7,362,925; ansic: 804; javascript: 287; makefile: 195; sh: 145; xml: 109
file content (87 lines) | stat: -rw-r--r-- 3,379 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from typing import Any, Mapping
import pytest
import logging
from azure.core.credentials import AzureSasCredential, TokenCredential, AccessToken
from azure.ai.evaluation._azure._clients import LiteMLClient


@pytest.mark.usefixtures("model_config", "project_scope", "recording_injection", "recorded_test")
class TestLiteAzureManagementClient(object):
    """End to end tests for the lite Azure management client."""

    @pytest.mark.azuretest
    def test_get_credential(self, project_scope, azure_cred):
        client = LiteMLClient(
            subscription_id=project_scope["subscription_id"],
            resource_group=project_scope["resource_group_name"],
            credential=azure_cred,
            logger=logging.getLogger(__name__),
        )

        credential = client.get_credential()
        assert isinstance(credential, TokenCredential)

    @pytest.mark.azuretest
    def test_get_token(self, project_scope, azure_cred):
        client = LiteMLClient(
            subscription_id=project_scope["subscription_id"],
            resource_group=project_scope["resource_group_name"],
            credential=azure_cred,
            logger=logging.getLogger(__name__),
        )

        token: AccessToken = client.get_token()
        assert isinstance(token, AccessToken) and len(token.token) > 0

    @pytest.mark.azuretest
    @pytest.mark.parametrize("include_credentials", [False, True])
    @pytest.mark.parametrize("config_name", ["sas", "none"])
    def test_workspace_get_default_store(
        self, azure_cred, datastore_project_scopes, config_name: str, include_credentials: bool
    ):
        project_scope = datastore_project_scopes[config_name]

        client = LiteMLClient(
            subscription_id=project_scope["subscription_id"],
            resource_group=project_scope["resource_group_name"],
            credential=azure_cred,
            logger=logging.getLogger(__name__),
        )

        store = client.workspace_get_default_datastore(
            workspace_name=project_scope["project_name"], include_credentials=include_credentials
        )

        assert store
        assert store.name
        assert store.account_name
        assert store.endpoint
        assert store.container_name
        if include_credentials:
            assert (
                (config_name == "account_key" and isinstance(store.credential, str))
                or (config_name == "sas" and isinstance(store.credential, AzureSasCredential))
                or (config_name == "none" and isinstance(store.credential, TokenCredential))
            )
        else:
            assert store.credential == None

    @pytest.mark.azuretest
    @pytest.mark.parametrize("config_name", ["sas", "none", "private"])
    def test_workspace_get_info(
        self, datastore_project_scopes: Mapping[str, Any], azure_cred: TokenCredential, config_name: str
    ):
        project_scope = datastore_project_scopes[config_name]

        client = LiteMLClient(
            subscription_id=project_scope["subscription_id"],
            resource_group=project_scope["resource_group_name"],
            credential=azure_cred,
            logger=logging.getLogger(__name__),
        )

        workspace = client.workspace_get_info(project_scope["project_name"])

        assert workspace
        assert workspace.name
        assert workspace.ml_flow_tracking_uri