File: dataclasses_extra.py

package info (click to toggle)
linux 6.16.5-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 1,724,468 kB
  • sloc: ansic: 26,560,391; asm: 271,356; sh: 143,999; python: 72,469; makefile: 57,129; perl: 36,821; xml: 19,553; cpp: 5,820; yacc: 4,915; lex: 2,955; awk: 1,667; sed: 28; ruby: 25
file content (109 lines) | stat: -rw-r--r-- 2,590 bytes parent folder | download | duplicates (5)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from __future__ import annotations

from dataclasses import (
    fields,
    is_dataclass,
    replace,
)
from typing import (
    Protocol,
    TYPE_CHECKING,
)

if TYPE_CHECKING:
    from _typeshed import DataclassInstance as _DataclassInstance

    class _HasName(Protocol, _DataclassInstance):
        name: str


def default[T: _DataclassInstance](
    cls: type[T],
    /,
) -> T:
    f = {}

    for field in fields(cls):
        if 'default' in field.metadata:
            f[field.name] = field.metadata['default']

    return cls(**f)


def merge[T: _DataclassInstance](
    self: T,
    other: T | None, /,
) -> T:
    if other is None:
        return self

    f = {}

    for field in fields(self):
        if not field.init:
            continue

        field_default_type = object
        if isinstance(field.default_factory, type):
            field_default_type = field.default_factory

        self_field = getattr(self, field.name)
        other_field = getattr(other, field.name)

        if field.name == 'name':
            assert self_field == other_field
        elif field.type == 'bool':
            f[field.name] = other_field
        elif field.metadata.get('merge') == 'assoclist':
            f[field.name] = _merge_assoclist(self_field, other_field)
        elif is_dataclass(field_default_type):
            f[field.name] = merge(self_field, other_field)
        elif issubclass(field_default_type, list):
            f[field.name] = self_field + other_field
        elif issubclass(field_default_type, dict):
            f[field.name] = self_field | other_field
        elif field.default is None:
            if other_field is not None:
                f[field.name] = other_field
        else:
            raise RuntimeError(f'Unable to merge for type {field.type}')

    return replace(self, **f)


def merge_default[T: _DataclassInstance](
    cls: type[T],
    /,
    *others: T,
) -> T:
    ret: T = default(cls)
    for o in others:
        ret = merge(ret, o)
    return ret


def _merge_assoclist[T: _HasName](
    self_list: list[T],
    other_list: list[T],
    /,
) -> list[T]:
    '''
    Merge lists where each item got a "name" attribute
    '''
    if not self_list:
        return other_list
    if not other_list:
        return self_list

    ret: list[T] = []
    other_dict = {
        i.name: i
        for i in other_list
    }
    for i in self_list:
        if i.name in other_dict:
            ret.append(merge(i, other_dict.pop(i.name)))
        else:
            ret.append(i)
    ret.extend(other_dict.values())
    return ret