File: credentials.py

package info (click to toggle)
libmongocrypt 1.17.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,572 kB
  • sloc: ansic: 70,067; python: 4,547; cpp: 615; sh: 460; makefile: 44; awk: 8
file content (156 lines) | stat: -rw-r--r-- 5,521 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from collections import namedtuple
from datetime import datetime, timedelta, timezone

try:
    from pymongo_auth_aws.auth import aws_temp_credentials

    _HAVE_AUTH_AWS = True
except ImportError:
    _HAVE_AUTH_AWS = False

import httpx

from pymongocrypt.errors import MongoCryptError

_azure_creds = namedtuple("_azure_creds", ["access_token", "expires_utc"])
_azure_creds_cache = None


async def _get_gcp_credentials():
    """Get on-demand GCP credentials"""
    metadata_host = os.getenv("GCE_METADATA_HOST") or "metadata.google.internal"
    url = (
        "http://%s/computeMetadata/v1/instance/service-accounts/default/token"
        % metadata_host
    )

    headers = {"Metadata-Flavor": "Google"}
    client = httpx.AsyncClient()
    try:
        response = await client.get(url, headers=headers)
    except Exception as e:
        msg = "unable to retrieve GCP credentials: %s" % e
        raise MongoCryptError(msg) from e
    finally:
        await client.aclose()

    if response.status_code != 200:
        msg = f"Unable to retrieve GCP credentials: expected StatusCode 200, got StatusCode: {response.status_code}. Response body:\n{response.content}"
        raise MongoCryptError(msg)
    try:
        data = response.json()
    except Exception as e:
        raise MongoCryptError(
            f"unable to retrieve GCP credentials: error reading response body\n{response.content}"
        ) from e

    if not data.get("access_token"):
        msg = (
            "unable to retrieve GCP credentials: got unexpected empty accessToken from GCP Metadata Server. Response body: %s"
            % response.content
        )
        raise MongoCryptError(msg)

    return {"accessToken": data["access_token"]}


async def _get_azure_credentials():
    """Get on-demand Azure credentials"""
    global _azure_creds_cache
    # Credentials are considered expired when: Expiration - now < 1 mins.
    creds = _azure_creds_cache
    if creds:
        if creds.expires_utc - datetime.now(tz=timezone.utc) < timedelta(seconds=60):
            _azure_creds_cache = None
        else:
            return {"accessToken": creds.access_token}

    url = "http://169.254.169.254/metadata/identity/oauth2/token"
    url += "?api-version=2018-02-01"
    url += "&resource=https://vault.azure.net"
    headers = {"Metadata": "true", "Accept": "application/json"}
    client = httpx.AsyncClient()
    try:
        response = await client.get(url, headers=headers)
    except Exception as e:
        msg = "Failed to acquire IMDS access token: %s" % e
        raise MongoCryptError(msg) from e
    finally:
        await client.aclose()

    if response.status_code != 200:
        msg = "Failed to acquire IMDS access token."
        raise MongoCryptError(msg)
    try:
        data = response.json()
    except Exception as e:
        raise MongoCryptError("Azure IMDS response must be in JSON format.") from e

    for key in ["access_token", "expires_in"]:
        if not data.get(key):
            msg = "Azure IMDS response must contain %s, but was %s."
            msg = msg % (key, response.content)
            raise MongoCryptError(msg)

    try:
        expires_in = int(data["expires_in"])
    except ValueError as e:
        raise MongoCryptError(
            'Azure IMDS response must contain "expires_in" integer, but was %s.'
            % response.content
        ) from e

    expires_utc = datetime.now(tz=timezone.utc) + timedelta(seconds=expires_in)
    _azure_creds_cache = _azure_creds(data["access_token"], expires_utc)
    return {"accessToken": data["access_token"]}


async def _ask_for_kms_credentials(kms_providers):
    """Get on-demand kms credentials.

    This is a separate function so it can be overridden in unit tests."""
    global _azure_creds_cache
    on_demand_aws = "aws" in kms_providers and not len(kms_providers["aws"])
    on_demand_gcp = "gcp" in kms_providers and not len(kms_providers["gcp"])
    on_demand_azure = "azure" in kms_providers and not len(kms_providers["azure"])

    if not any([on_demand_aws, on_demand_gcp, on_demand_azure]):
        return {}
    creds = {}
    if on_demand_aws:
        if not _HAVE_AUTH_AWS:
            raise RuntimeError(
                "On-demand AWS credentials require pymongo-auth-aws: "
                "install with: python -m pip install 'pymongo[aws]'"
            )
        aws_creds = aws_temp_credentials()
        creds_dict = {
            "accessKeyId": aws_creds.username,
            "secretAccessKey": aws_creds.password,
        }
        if aws_creds.token:
            creds_dict["sessionToken"] = aws_creds.token
        creds["aws"] = creds_dict
    if on_demand_gcp:
        creds["gcp"] = await _get_gcp_credentials()
    if on_demand_azure:
        try:
            creds["azure"] = await _get_azure_credentials()
        except Exception:
            _azure_creds_cache = None
            raise
    return creds