File: ordering.py

package info (click to toggle)
python-apischema 0.18.3-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,636 kB
  • sloc: python: 15,281; makefile: 3; sh: 2
file content (142 lines) | stat: -rw-r--r-- 3,845 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from collections import defaultdict
from dataclasses import dataclass
from typing import (
    Any,
    Callable,
    Collection,
    Dict,
    List,
    Mapping,
    MutableMapping,
    Optional,
    Sequence,
    TypeVar,
    overload,
)

from apischema.cache import CacheAwareDict
from apischema.metadata.keys import ORDERING_METADATA
from apischema.types import MetadataMixin
from apischema.utils import stop_signature_abuse

Cls = TypeVar("Cls", bound=type)


@dataclass(frozen=True)
class Ordering(MetadataMixin):
    key = ORDERING_METADATA
    order: Optional[int] = None
    after: Optional[Any] = None
    before: Optional[Any] = None

    def __post_init__(self):
        from apischema.objects.fields import check_field_or_name

        if self.after is not None:
            check_field_or_name(self.after, methods=True)
        if self.before is not None:
            check_field_or_name(self.before, methods=True)


_order_overriding: MutableMapping[type, Mapping[Any, Ordering]] = CacheAwareDict({})


@overload
def order(__value: int) -> Ordering:
    ...


@overload
def order(*, after: Any) -> Ordering:
    ...


@overload
def order(*, before: Any) -> Ordering:
    ...


@overload
def order(__fields: Sequence[Any]) -> Callable[[Cls], Cls]:
    ...


@overload
def order(__override: Mapping[Any, Ordering]) -> Callable[[Cls], Cls]:
    ...


def order(__arg=None, *, before=None, after=None):
    if len([arg for arg in (__arg, before, after) if arg is not None]) != 1:
        stop_signature_abuse()
    if isinstance(__arg, Sequence):
        __arg = {field: order(after=prev) for field, prev in zip(__arg[1:], __arg)}
    if isinstance(__arg, Mapping):
        if not all(isinstance(val, Ordering) for val in __arg.values()):
            stop_signature_abuse()

        def decorator(cls: Cls) -> Cls:
            _order_overriding[cls] = __arg
            return cls

        return decorator
    elif __arg is not None and not isinstance(__arg, int):
        stop_signature_abuse()
    else:
        return Ordering(__arg, after, before)


def get_order_overriding(cls: type) -> Mapping[str, Ordering]:
    from apischema.objects.fields import get_field_name

    return {
        get_field_name(field, methods=True): ordering
        for sub_cls in reversed(cls.__mro__)
        if sub_cls in _order_overriding
        for field, ordering in _order_overriding[sub_cls].items()
    }


T = TypeVar("T")


def sort_by_order(
    cls: type,
    elts: Collection[T],
    name: Callable[[T], str],
    order: Callable[[T], Optional[Ordering]],
) -> Sequence[T]:
    from apischema.objects.fields import get_field_name

    order_overriding = get_order_overriding(cls)
    groups: Dict[int, List[T]] = defaultdict(list)
    after: Dict[str, List[T]] = defaultdict(list)
    before: Dict[str, List[T]] = defaultdict(list)
    for elt in elts:
        ordering = order_overriding.get(name(elt), order(elt))
        if ordering is None:
            groups[0].append(elt)
        elif ordering.order is not None:
            groups[ordering.order].append(elt)
        elif ordering.after is not None:
            after[get_field_name(ordering.after, methods=True)].append(elt)
        elif ordering.before is not None:
            before[get_field_name(ordering.before, methods=True)].append(elt)
        else:
            raise NotImplementedError
    if not after and not before and len(groups) == 1:
        return next(iter(groups.values()))
    result = []

    def add_to_result(elt: T):
        elt_name = name(elt)
        for before_elt in before[elt_name]:
            add_to_result(before_elt)
        result.append(elt)
        for after_elt in after[elt_name]:
            add_to_result(after_elt)

    for value in sorted(groups):
        for elt in groups[value]:
            add_to_result(elt)
    return result