File: models.py

package info (click to toggle)
python-moto 5.1.18-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 116,520 kB
  • sloc: python: 636,725; javascript: 181; makefile: 39; sh: 3
file content (198 lines) | stat: -rw-r--r-- 7,226 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import datetime
import re
from base64 import b64decode
from typing import Any, Optional

import xmltodict

from moto.core.base_backend import BackendDict, BaseBackend
from moto.core.common_models import BaseModel
from moto.core.utils import utcnow
from moto.iam.models import AccessKey, iam_backends
from moto.sts.utils import (
    DEFAULT_STS_SESSION_DURATION,
    random_assumed_role_id,
    random_session_token,
)
from moto.utilities.utils import ARN_PARTITION_REGEX, PARTITION_NAMES, get_partition


class Token(BaseModel):
    def __init__(self, duration: int, name: Optional[str] = None):
        now = utcnow()
        self.expiration = now + datetime.timedelta(seconds=duration)
        self.name = name
        self.policy = None


class AssumedRole(BaseModel):
    def __init__(
        self,
        account_id: str,
        region_name: str,
        access_key: AccessKey,
        role_session_name: str,
        role_arn: str,
        policy: str,
        duration: int,
        external_id: str,
    ):
        self.account_id = account_id
        self.region_name = region_name
        self.session_name = role_session_name
        self.role_arn = role_arn
        self.policy = policy
        now = utcnow()
        self.expiration = now + datetime.timedelta(seconds=duration)
        self.external_id = external_id
        self.access_key = access_key
        self.access_key_id = access_key.access_key_id
        self.secret_access_key = access_key.secret_access_key
        self.session_token = random_session_token()
        self.partition = get_partition(region_name)

    @property
    def user_id(self) -> str:
        iam_backend = iam_backends[self.account_id][self.partition]
        try:
            role_id = iam_backend.get_role_by_arn(arn=self.role_arn).id
        except Exception:
            role_id = "AROA" + random_assumed_role_id()
        return role_id + ":" + self.session_name

    @property
    def arn(self) -> str:
        partition = get_partition(self.region_name)
        return f"arn:{partition}:sts::{self.account_id}:assumed-role/{self.role_arn.split('/')[-1]}/{self.session_name}"


class STSBackend(BaseBackend):
    def __init__(self, region_name: str, account_id: str):
        super().__init__(region_name, account_id)
        self.assumed_roles: list[AssumedRole] = []

    def get_session_token(self, duration: int) -> Token:
        return Token(duration=duration)

    def get_federation_token(self, name: Optional[str], duration: int) -> Token:
        return Token(duration=duration, name=name)

    def assume_role(
        self,
        region_name: str,
        role_session_name: str,
        role_arn: str,
        policy: str,
        duration: int,
        external_id: str,
    ) -> AssumedRole:
        """
        Assume an IAM Role. Note that the role does not need to exist. The ARN can point to another account, providing an opportunity to switch accounts.
        """
        account_id, access_key = self._create_access_key(role=role_arn)
        role = AssumedRole(
            account_id=account_id,
            region_name=region_name,
            access_key=access_key,
            role_session_name=role_session_name,
            role_arn=role_arn,
            policy=policy,
            duration=duration,
            external_id=external_id,
        )
        access_key.role_arn = role_arn
        account_backend = sts_backends[account_id][get_partition(region_name)]
        account_backend.assumed_roles.append(role)
        return role

    def get_assumed_role_from_access_key(
        self, access_key_id: str
    ) -> Optional[AssumedRole]:
        for assumed_role in self.assumed_roles:
            if assumed_role.access_key_id == access_key_id:
                return assumed_role
        return None

    def assume_role_with_web_identity(self, **kwargs: Any) -> AssumedRole:
        return self.assume_role(**kwargs)

    def assume_role_with_saml(self, **kwargs: Any) -> AssumedRole:
        del kwargs["principal_arn"]
        saml_assertion_encoded = kwargs.pop("saml_assertion")
        saml_assertion_decoded = b64decode(saml_assertion_encoded)

        namespaces = {
            "urn:oasis:names:tc:SAML:2.0:protocol": "samlp",
            "urn:oasis:names:tc:SAML:2.0:assertion": "saml",
        }
        saml_assertion = xmltodict.parse(
            saml_assertion_decoded.decode("utf-8"),
            force_cdata=True,
            process_namespaces=True,
            namespaces=namespaces,
            namespace_separator="|",
        )

        target_role = None
        saml_assertion_attributes = saml_assertion["samlp|Response"]["saml|Assertion"][
            "saml|AttributeStatement"
        ]["saml|Attribute"]
        for attribute in saml_assertion_attributes:
            if (
                attribute["@Name"]
                == "https://aws.amazon.com/SAML/Attributes/RoleSessionName"
            ):
                kwargs["role_session_name"] = attribute["saml|AttributeValue"]["#text"]
            if (
                attribute["@Name"]
                == "https://aws.amazon.com/SAML/Attributes/SessionDuration"
            ):
                kwargs["duration"] = int(attribute["saml|AttributeValue"]["#text"])
            if attribute["@Name"] == "https://aws.amazon.com/SAML/Attributes/Role":
                target_role = attribute["saml|AttributeValue"]["#text"].split(",")[0]

        if "duration" not in kwargs:
            kwargs["duration"] = DEFAULT_STS_SESSION_DURATION

        account_id, access_key = self._create_access_key(role=target_role)  # type: ignore
        kwargs["account_id"] = account_id
        kwargs["region_name"] = self.region_name
        kwargs["access_key"] = access_key

        kwargs["external_id"] = None
        kwargs["policy"] = None
        role = AssumedRole(**kwargs)
        self.assumed_roles.append(role)
        return role

    def get_caller_identity(
        self, access_key_id: str, region: str
    ) -> tuple[str, str, str]:
        assumed_role = self.get_assumed_role_from_access_key(access_key_id)
        if assumed_role:
            return assumed_role.user_id, assumed_role.arn, assumed_role.account_id

        iam_backend = iam_backends[self.account_id][self.partition]
        user = iam_backend.get_user_from_access_key_id(access_key_id)
        if user:
            return user.id, user.arn, user.account_id

        # Default values in case the request does not use valid credentials generated by moto
        partition = get_partition(region)
        user_id = "AKIAIOSFODNN7EXAMPLE"
        arn = f"arn:{partition}:sts::{self.account_id}:user/moto"
        return user_id, arn, self.account_id

    def _create_access_key(self, role: str) -> tuple[str, AccessKey]:
        account_id_match = re.search(ARN_PARTITION_REGEX + r":iam::([0-9]+).+", role)
        if account_id_match:
            account_id = account_id_match.group(2)
        else:
            account_id = self.account_id
        iam_backend = iam_backends[account_id][self.partition]
        return account_id, iam_backend.create_temp_access_key()


sts_backends = BackendDict(
    STSBackend, "sts", use_boto3_regions=False, additional_regions=PARTITION_NAMES
)