File: _checks.py

package info (click to toggle)
python-oslo.policy 5.0.0-3
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 1,032 kB
  • sloc: python: 6,880; makefile: 23; sh: 20
file content (436 lines) | stat: -rw-r--r-- 12,836 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
#
# Copyright (c) 2015 OpenStack Foundation.
# All Rights Reserved.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.

import abc
import ast
from collections.abc import Callable, Mapping, MutableMapping, Sequence
import inspect
from typing import Any, TypeAlias, TypeVar, TYPE_CHECKING, overload

from oslo_context import context
import stevedore
from typing_extensions import Self

if TYPE_CHECKING:
    from typing_extensions import Self

    from oslo_policy.policy import Enforcer

CredsT: TypeAlias = context.RequestContext | MutableMapping[str, Any]
TargetT: TypeAlias = Mapping[str, Any]

registered_checks: dict[str | None, 'type[Check]'] = {}
extension_checks: dict[str, 'Callable[..., Check]'] | None = None


def get_extensions() -> dict[str, 'Callable[..., Check]']:
    global extension_checks
    if extension_checks is None:
        em: stevedore.ExtensionManager[Check]
        em = stevedore.ExtensionManager(
            'oslo.policy.rule_checks', invoke_on_load=False
        )
        extension_checks = {
            extension.name: extension.plugin for extension in em
        }
    return extension_checks


def _check(
    rule: 'BaseCheck',
    target: TargetT,
    creds: MutableMapping[str, Any],
    enforcer: 'Enforcer',
    current_rule: str | None,
) -> bool:
    """Evaluate the rule.

    This private method is meant to be used by the enforcer to call
    the rule. It can also be used by built-in checks that have nested
    rules.

    We use a private function because it makes it easier to change the
    API without having an impact on subclasses not defined within the
    oslo.policy library.

    We don't put this logic in Enforcer.enforce() and invoke that
    method recursively because that changes the BaseCheck API to
    require that the enforcer argument to __call__() be a valid
    Enforcer instance (as evidenced by all of the breaking unit
    tests).

    We don't put this in a private method of BaseCheck because that
    propagates the problem of extending the list of arguments to
    __call__() if subclasses change the implementation of the
    function.

    :param rule: A check object.
    :param target: Attributes of the object of the operation.
    :param creds: Attributes of the user performing the operation.
    :param enforcer: The Enforcer being used.
    :param current_rule: The name of the policy being checked.
    """
    # Evaluate the rule so we can check if the rule argument must be included
    # or not
    argspec = inspect.getfullargspec(rule.__call__)
    if len(argspec.args) > 4:
        return rule(target, creds, enforcer, current_rule)
    else:  # legacy code path
        return rule(target, creds, enforcer)


class BaseCheck(metaclass=abc.ABCMeta):
    """Abstract base class for Check classes."""

    scope_types: list[str] | None = None

    def __eq__(self, other: Any) -> bool:
        """Compare objects."""
        return (
            type(self) is type(other) and self.scope_types == other.scope_types
        )

    @abc.abstractmethod
    def __str__(self) -> str:
        """String representation of the Check tree rooted at this node."""

    @abc.abstractmethod
    def __call__(
        self,
        target: TargetT,
        creds: MutableMapping[str, Any],
        enforcer: 'Enforcer',
        current_rule: str | None = None,
    ) -> bool:
        """Triggers if instance of the class is called.

        Performs the check. Returns False to reject the access or a
        true value (not necessary True) to accept the access.
        """


class FalseCheck(BaseCheck):
    """A policy check that always returns ``False`` (disallow)."""

    def __str__(self) -> str:
        """Return a string representation of this check."""
        return '!'

    def __call__(
        self,
        target: TargetT,
        creds: MutableMapping[str, Any],
        enforcer: 'Enforcer',
        current_rule: str | None = None,
    ) -> bool:
        """Check the policy."""
        return False


class TrueCheck(BaseCheck):
    """A policy check that always returns ``True`` (allow)."""

    def __str__(self) -> str:
        """Return a string representation of this check."""
        return '@'

    def __call__(
        self,
        target: TargetT,
        creds: MutableMapping[str, Any],
        enforcer: 'Enforcer',
        current_rule: str | None = None,
    ) -> bool:
        """Check the policy."""
        return True


class Check(BaseCheck):
    def __init__(self, kind: str, match: str) -> None:
        self.kind = kind
        self.match = match

    def __eq__(self, other: Any) -> bool:
        return (
            type(self) is type(other)
            and self.kind == other.kind
            and self.match == other.match
        )

    def __str__(self) -> str:
        """Return a string representation of this check."""
        return f'{self.kind}:{self.match}'


class NotCheck(BaseCheck):
    def __init__(self, rule: TrueCheck | FalseCheck | Check) -> None:
        self.rule = rule

    def __eq__(self, other: Any) -> bool:
        return type(self) is type(other) and self.rule == other.rule

    def __str__(self) -> str:
        """Return a string representation of this check."""
        return f'not {self.rule}'

    def __call__(
        self,
        target: TargetT,
        creds: MutableMapping[str, Any],
        enforcer: 'Enforcer',
        current_rule: str | None = None,
    ) -> bool:
        """Check the policy.

        Returns the logical inverse of the wrapped check.
        """
        return not _check(self.rule, target, creds, enforcer, current_rule)


class AndCheck(BaseCheck):
    def __init__(self, rules: Sequence[BaseCheck]) -> None:
        self.rules = list(rules)

    def __eq__(self, other: Any) -> bool:
        return type(self) is type(other) and self.rules == other.rules

    def __str__(self) -> str:
        """Return a string representation of this check."""
        return '({})'.format(' and '.join(str(r) for r in self.rules))

    def __call__(
        self,
        target: TargetT,
        creds: MutableMapping[str, Any],
        enforcer: 'Enforcer',
        current_rule: str | None = None,
    ) -> bool:
        """Check the policy.

        Requires that all rules accept in order to return True.
        """
        for rule in self.rules:
            if not _check(rule, target, creds, enforcer, current_rule):
                return False

        return True

    def add_check(self, rule: BaseCheck) -> 'Self':
        """Adds rule to be tested.

        Allows addition of another rule to the list of rules that will
        be tested.

        :returns: self
        :rtype: :class:`.AndCheck`
        """
        self.rules.append(rule)
        return self


class OrCheck(BaseCheck):
    def __init__(self, rules: Sequence[BaseCheck]) -> None:
        self.rules = list(rules)

    def __eq__(self, other: Any) -> bool:
        return type(self) is type(other) and self.rules == other.rules

    def __str__(self) -> str:
        """Return a string representation of this check."""
        return '({})'.format(' or '.join(str(r) for r in self.rules))

    def __call__(
        self,
        target: TargetT,
        creds: MutableMapping[str, Any],
        enforcer: 'Enforcer',
        current_rule: str | None = None,
    ) -> bool:
        """Check the policy.

        Requires that at least one rule accept in order to return True.
        """
        for rule in self.rules:
            if _check(rule, target, creds, enforcer, current_rule):
                return True
        return False

    def add_check(self, rule: BaseCheck) -> Self:
        """Adds rule to be tested.

        Allows addition of another rule to the list of rules that will
        be tested.  Returns the OrCheck object for convenience.
        """
        self.rules.append(rule)
        return self

    def pop_check(self) -> tuple[Self, BaseCheck]:
        """Pops the last check from the list and returns them

        :returns: self, the popped check
        :rtype: :class:`.OrCheck`, class:`.Check`
        """
        check = self.rules.pop()
        return self, check


F = TypeVar('F', bound=type[Check])


@overload
def register(name: str | None, func: None = None) -> Callable[[F], F]: ...


@overload
def register(name: str | None, func: F) -> F: ...


def register(name: str | None, func: F | None = None) -> F | Callable[[F], F]:
    """Register a check class with the given name.

    Can be used as:
    - Direct call: register('spam', TestCheck)
    - Decorator: @register('spam')
    """

    def decorator(func: F) -> F:
        registered_checks[name] = func
        return func

    if func is not None:
        # Direct call pattern: register('spam', TestCheck)
        return decorator(func)

    # Decorator pattern: @register('spam')
    return decorator


@register('rule')
class RuleCheck(Check):
    def __call__(
        self,
        target: TargetT,
        creds: MutableMapping[str, Any],
        enforcer: 'Enforcer',
        current_rule: str | None = None,
    ) -> bool:
        try:
            return _check(
                rule=enforcer.rules[self.match],
                target=target,
                creds=creds,
                enforcer=enforcer,
                current_rule=current_rule,
            )
        except KeyError:
            # We don't have any matching rule; fail closed
            return False


@register('role')
class RoleCheck(Check):
    """Check that there is a matching role in the ``creds`` dict."""

    def __call__(
        self,
        target: TargetT,
        creds: MutableMapping[str, Any],
        enforcer: 'Enforcer',
        current_rule: str | None = None,
    ) -> bool:
        try:
            match = self.match % target
        except KeyError:
            # While doing RoleCheck if key not
            # present in Target return false
            return False

        if 'roles' in creds:
            return match.lower() in [x.lower() for x in creds['roles']]

        return False


@register(None)
class GenericCheck(Check):
    """Check an individual match.

    Matches look like:

        - tenant:%(tenant_id)s
        - role:compute:admin
        - True:%(user.enabled)s
        - 'Member':%(role.name)s
    """

    @classmethod
    def _find_in_dict(
        cls,
        test_value: MutableMapping[str, Any],
        path_segments: list[str],
        match: str,
    ) -> bool:
        """Searches for a match in the dictionary.

        test_value is a reference inside the dictionary. Since the process is
        recursive, each call to _find_in_dict will be one level deeper.

        path_segments is the segments of the path to search.  The recursion
        ends when there are no more segments of path.

        When specifying a value inside a list, each element of the list is
        checked for a match. If the value is found within any of the sub lists
        the check succeeds; The check only fails if the entry is not in any of
        the sublists.
        """
        if len(path_segments) == 0:
            return match == str(test_value)
        key, path_segments = path_segments[0], path_segments[1:]
        try:
            test_value = test_value[key]
        except KeyError:
            return False
        if isinstance(test_value, list):
            for val in test_value:
                if cls._find_in_dict(val, path_segments, match):
                    return True
            return False
        else:
            return cls._find_in_dict(test_value, path_segments, match)

    def __call__(
        self,
        target: TargetT,
        creds: MutableMapping[str, Any],
        enforcer: 'Enforcer',
        current_rule: str | None = None,
    ) -> bool:
        try:
            match = self.match % target
        except KeyError:
            # While doing GenericCheck if key not
            # present in Target return false
            return False
        try:
            # Try to interpret self.kind as a literal
            test_value = ast.literal_eval(self.kind)
            return match == str(test_value)

        except ValueError:
            pass

        path_segments = self.kind.split('.')
        return self._find_in_dict(creds, path_segments, match)