1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
|
from typing import Any, Mapping
import pytest
import logging
from azure.core.credentials import AzureSasCredential, TokenCredential, AccessToken
from azure.ai.evaluation._azure._clients import LiteMLClient
@pytest.mark.usefixtures("model_config", "project_scope", "recording_injection", "recorded_test")
class TestLiteAzureManagementClient(object):
"""End to end tests for the lite Azure management client."""
@pytest.mark.azuretest
def test_get_credential(self, project_scope, azure_cred):
client = LiteMLClient(
subscription_id=project_scope["subscription_id"],
resource_group=project_scope["resource_group_name"],
credential=azure_cred,
logger=logging.getLogger(__name__),
)
credential = client.get_credential()
assert isinstance(credential, TokenCredential)
@pytest.mark.azuretest
def test_get_token(self, project_scope, azure_cred):
client = LiteMLClient(
subscription_id=project_scope["subscription_id"],
resource_group=project_scope["resource_group_name"],
credential=azure_cred,
logger=logging.getLogger(__name__),
)
token: AccessToken = client.get_token()
assert isinstance(token, AccessToken) and len(token.token) > 0
@pytest.mark.azuretest
@pytest.mark.parametrize("include_credentials", [False, True])
@pytest.mark.parametrize("config_name", ["sas", "none"])
def test_workspace_get_default_store(
self, azure_cred, datastore_project_scopes, config_name: str, include_credentials: bool
):
project_scope = datastore_project_scopes[config_name]
client = LiteMLClient(
subscription_id=project_scope["subscription_id"],
resource_group=project_scope["resource_group_name"],
credential=azure_cred,
logger=logging.getLogger(__name__),
)
store = client.workspace_get_default_datastore(
workspace_name=project_scope["project_name"], include_credentials=include_credentials
)
assert store
assert store.name
assert store.account_name
assert store.endpoint
assert store.container_name
if include_credentials:
assert (
(config_name == "account_key" and isinstance(store.credential, str))
or (config_name == "sas" and isinstance(store.credential, AzureSasCredential))
or (config_name == "none" and isinstance(store.credential, TokenCredential))
)
else:
assert store.credential == None
@pytest.mark.azuretest
@pytest.mark.parametrize("config_name", ["sas", "none", "private"])
def test_workspace_get_info(
self, datastore_project_scopes: Mapping[str, Any], azure_cred: TokenCredential, config_name: str
):
project_scope = datastore_project_scopes[config_name]
client = LiteMLClient(
subscription_id=project_scope["subscription_id"],
resource_group=project_scope["resource_group_name"],
credential=azure_cred,
logger=logging.getLogger(__name__),
)
workspace = client.workspace_get_info(project_scope["project_name"])
assert workspace
assert workspace.name
assert workspace.ml_flow_tracking_uri
|