File: _auth.py

package info (click to toggle)
python-azure 20250603%2Bgit-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 851,724 kB
  • sloc: python: 7,362,925; ansic: 804; javascript: 287; makefile: 195; sh: 145; xml: 109
file content (77 lines) | stat: -rw-r--r-- 2,959 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# coding=utf-8
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------


from base64 import b64decode, b64encode
from hashlib import sha256
from hmac import HMAC
from time import time
from urllib.parse import quote_plus, urlencode

from azure.core.credentials import AzureSasCredential
from azure.core.pipeline import PipelineRequest
from azure.core.pipeline.policies import SansIOHTTPPolicy


def generate_sas_token(audience: str, policy: str, key: str, expiry: int = 3600) -> str:
    """
    Generate a sas token according to the given audience, policy, key and expiry
    :param str audience: The audience / endpoint to create the SAS token for
    :param str policy: The policy this token represents
    :param str key: The key used to sign this token
    :param int expiry: Token expiry time in milliseconds
    :returns: SAS token as a string literal
    :rtype: str
    """

    encoded_uri = quote_plus(audience)

    ttl = int(time() + expiry)
    sign_key = f"{encoded_uri}\n{ttl}"
    signature = b64encode(
        HMAC(b64decode(key), sign_key.encode("utf-8"), sha256).digest()
    )
    result = {"sr": audience, "sig": signature, "se": str(ttl)}
    if policy:
        result["skn"] = policy
    return "SharedAccessSignature " + urlencode(result)


class SharedKeyCredentialPolicy(SansIOHTTPPolicy):
    def __init__(self, endpoint: str, policy_name: str, key: str) -> None:
        self.endpoint = endpoint
        self.policy_name = policy_name
        self.key = key
        super(SharedKeyCredentialPolicy, self).__init__()

    def _add_authorization_header(self, request: PipelineRequest) -> None:
        try:
            auth_string = generate_sas_token(
                audience=self.endpoint, policy=self.policy_name, key=self.key
            )
            request.http_request.headers["Authorization"] = auth_string
        except Exception as ex:
            # TODO - Wrap error as a signing error?
            raise ex

    def on_request(self, request: PipelineRequest) -> None:
        self._add_authorization_header(request=request)


class SasCredentialPolicy(SansIOHTTPPolicy):
    """Adds an authorization header for the provided credential.
    :param credential: The credential used to authenticate requests.
    :type credential: ~azure.core.credentials.AzureSasCredential
    """

    def __init__(
        self, credential: AzureSasCredential, **kwargs
    ):  # pylint: disable=unused-argument
        super(SasCredentialPolicy, self).__init__()
        self._credential = credential

    def on_request(self, request: PipelineRequest) -> None:
        request.http_request.headers["Authorization"] = self._credential.signature