1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
|
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
"""Demonstrates custom credential implementations using existing access tokens and an MSAL client"""
import time
from typing import TYPE_CHECKING
from azure.core.credentials import AccessToken
from azure.identity import AuthenticationRequiredError, AzureAuthorityHosts
import msal
if TYPE_CHECKING:
from typing import Any, Union
class StaticTokenCredential(object):
"""Authenticates with a previously acquired access token
Note that an access token is valid only for certain resources and eventually expires. This credential is therefore
quite limited. An application using it must ensure the token is valid and contains all claims required by any
service client given an instance of this credential.
"""
def __init__(self, access_token):
# type: (Union[str, AccessToken]) -> None
if isinstance(access_token, AccessToken):
self._token = access_token
else:
# setting expires_on in the past causes Azure SDK clients to call get_token every time they need a token
self._token = AccessToken(token=access_token, expires_on=0)
def get_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
"""get_token is the only method a credential must implement"""
return self._token
class MsalTokenCredential(object):
"""Uses an MSAL client directly to obtain access tokens with an interactive flow."""
def __init__(self, tenant_id, client_id):
# type: (str, str) -> None
self._app = msal.PublicClientApplication(
client_id=client_id, authority="https://{}/{}".format(AzureAuthorityHosts.AZURE_PUBLIC_CLOUD, tenant_id)
)
def get_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
"""get_token is the only method a credential must implement"""
now = int(time.time())
result = self._app.acquire_token_interactive(list(scopes), **kwargs)
try:
return AccessToken(result["access_token"], now + int(result["expires_in"]))
except:
print("\nFailed to get a valid access token")
raise AuthenticationRequiredError(scopes)
|