File: policy.py

package info (click to toggle)
python-moto 5.1.18-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 116,520 kB
  • sloc: python: 636,725; javascript: 181; makefile: 39; sh: 3
file content (186 lines) | stat: -rw-r--r-- 6,699 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import json
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Optional,
    TypeVar,
    Union,
)

from moto.awslambda.exceptions import (
    GenericResourcNotFound,
    PreconditionFailedException,
    UnknownPolicyException,
)
from moto.moto_api._internal import mock_random

if TYPE_CHECKING:
    from .models import LambdaFunction, LayerVersion

TYPE_IDENTITY = TypeVar("TYPE_IDENTITY")


class Policy:
    def __init__(self, parent: Union["LambdaFunction", "LayerVersion"]):
        self.revision = str(mock_random.uuid4())
        self.statements: list[dict[str, Any]] = []
        self.parent = parent

    def wire_format(self) -> str:
        p = self.get_policy()
        p["Policy"] = json.dumps(p["Policy"])
        return json.dumps(p)

    def get_policy(self) -> dict[str, Any]:
        if not self.statements:
            raise GenericResourcNotFound()
        return {
            "Policy": {
                "Version": "2012-10-17",
                "Id": "default",
                "Statement": self.statements,
            },
            "RevisionId": self.revision,
        }

    # adds the raw JSON statement to the policy
    def add_statement(
        self, raw: str, qualifier: Optional[str] = None
    ) -> tuple[Any, str]:
        policy = json.loads(raw, object_hook=self.decode_policy)
        if len(policy.revision) > 0 and self.revision != policy.revision:
            raise PreconditionFailedException(
                "The RevisionId provided does not match the latest RevisionId"
                " for the Lambda function or alias. Call the GetFunction or the GetAlias API to retrieve"
                " the latest RevisionId for your resource."
            )
        # Remove #LATEST from the Resource (Lambda ARN)
        if policy.statements[0].get("Resource", "").endswith("$LATEST"):
            policy.statements[0]["Resource"] = policy.statements[0]["Resource"][0:-8]
        if qualifier:
            policy.statements[0]["Resource"] = (
                policy.statements[0]["Resource"] + ":" + qualifier
            )
        self.statements.append(policy.statements[0])
        self.revision = str(mock_random.uuid4())
        return policy.statements[0], self.revision

    # removes the statement that matches 'sid' from the policy
    def del_statement(self, sid: str, revision: str = "") -> None:
        if len(revision) > 0 and self.revision != revision:
            raise PreconditionFailedException(
                "The RevisionId provided does not match the latest RevisionId"
                " for the Lambda function or alias. Call the GetFunction or the GetAlias API to retrieve"
                " the latest RevisionId for your resource."
            )
        for statement in self.statements:
            if "Sid" in statement and statement["Sid"] == sid:
                self.statements.remove(statement)
                break
        else:
            raise UnknownPolicyException()

    # converts AddPermission request to PolicyStatement
    # https://docs.aws.amazon.com/lambda/latest/dg/API_AddPermission.html
    def decode_policy(self, obj: dict[str, Any]) -> "Policy":
        # Circumvent circular cimport
        from moto.awslambda.models import LayerVersion

        policy = Policy(self.parent)
        policy.revision = obj.get("RevisionId", "")
        # get function_arn or arn from parent
        if isinstance(self.parent, LayerVersion):
            resource_arn = self.parent.arn
        else:
            resource_arn = self.parent.function_arn

        # set some default values if these keys are not set
        self.ensure_set(obj, "Effect", "Allow")
        self.ensure_set(obj, "Resource", resource_arn + ":$LATEST")
        self.ensure_set(obj, "StatementId", str(mock_random.uuid4()))

        # transform field names and values
        self.transform_property(obj, "StatementId", "Sid", self.nop_formatter)
        self.transform_property(obj, "Principal", "Principal", self.principal_formatter)

        self.transform_property(
            obj, "SourceArn", "SourceArn", self.source_arn_formatter
        )
        self.transform_property(
            obj, "SourceAccount", "SourceAccount", self.source_account_formatter
        )
        self.transform_property(
            obj, "PrincipalOrgID", "Condition", self.principal_org_id_formatter
        )

        # remove RevisionId and EventSourceToken if they are set
        self.remove_if_set(obj, ["RevisionId", "EventSourceToken"])

        # merge conditional statements into a single map under the Condition key
        self.condition_merge(obj)

        # append resulting statement to policy.statements
        policy.statements.append(obj)

        return policy

    def nop_formatter(self, obj: TYPE_IDENTITY) -> TYPE_IDENTITY:
        return obj

    def ensure_set(self, obj: dict[str, Any], key: str, value: Any) -> None:
        if key not in obj:
            obj[key] = value

    def principal_formatter(self, obj: dict[str, Any]) -> dict[str, Any]:
        if isinstance(obj, str):
            if obj.endswith(".amazonaws.com"):
                return {"Service": obj}
            if obj.endswith(":root"):
                return {"AWS": obj}
        return obj

    def source_account_formatter(
        self, obj: TYPE_IDENTITY
    ) -> dict[str, dict[str, TYPE_IDENTITY]]:
        return {"StringEquals": {"AWS:SourceAccount": obj}}

    def source_arn_formatter(
        self, obj: TYPE_IDENTITY
    ) -> dict[str, dict[str, TYPE_IDENTITY]]:
        return {"ArnLike": {"AWS:SourceArn": obj}}

    def principal_org_id_formatter(
        self, obj: TYPE_IDENTITY
    ) -> dict[str, dict[str, TYPE_IDENTITY]]:
        return {"StringEquals": {"aws:PrincipalOrgID": obj}}

    def transform_property(
        self,
        obj: dict[str, Any],
        old_name: str,
        new_name: str,
        formatter: Callable[..., Any],
    ) -> None:
        if old_name in obj:
            obj[new_name] = formatter(obj[old_name])
            if new_name != old_name:
                del obj[old_name]

    def remove_if_set(self, obj: dict[str, Any], keys: list[str]) -> None:
        for key in keys:
            if key in obj:
                del obj[key]

    def condition_merge(self, obj: dict[str, Any]) -> None:
        if "SourceArn" in obj:
            if "Condition" not in obj:
                obj["Condition"] = {}
            obj["Condition"].update(obj["SourceArn"])
            del obj["SourceArn"]

        if "SourceAccount" in obj:
            if "Condition" not in obj:
                obj["Condition"] = {}
            obj["Condition"].update(obj["SourceAccount"])
            del obj["SourceAccount"]