File: tree.py

package info (click to toggle)
tryton-server 7.0.43-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 7,760 kB
  • sloc: python: 53,744; xml: 5,194; sh: 803; sql: 217; makefile: 28
file content (177 lines) | stat: -rw-r--r-- 7,318 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
# This file is part of Tryton.  The COPYRIGHT file at the top level of
# this repository contains the full copyright notices and license terms.
from itertools import chain

from trytond.i18n import gettext
from trytond.tools import escape_wildcard

from .modelstorage import ValidationError


class RecursionError(ValidationError):
    pass


def tree(parent='parent', name='name', separator=None):
    from . import fields

    class TreeMixin(object):
        __slots__ = ()

        if separator:
            @classmethod
            def __setup__(cls):
                super(TreeMixin, cls).__setup__()
                field = getattr(cls, name)
                clause = ['OR',
                    (name, 'not like', '%' + escape_wildcard(separator) + '%'),
                    (name, '=', None),
                    ]
                # If TreeMixin is after the class where name is defined in
                # __mro__, it modifies the base field copied so it must ensure
                # to add only once the domain
                if clause not in field.domain:
                    domain = [clause]
                    if field.domain:
                        domain.append(field.domain)
                    field.domain = domain

            def get_rec_name(self, _):
                record, names = self, []
                while record:
                    names.append(getattr(record, name))
                    record = getattr(record, parent)
                return separator.join(reversed(names))

            @fields.depends(parent, '_parent_%s.rec_name' % parent, name)
            def on_change_with_rec_name(self):
                names = []
                if self.parent and self.parent.rec_name:
                    names.append(self.parent.rec_name)
                names.append(getattr(self, name) or '')
                return separator.join(names)

            @classmethod
            def search_rec_name(cls, _, clause):
                domain = []
                if isinstance(clause[2], str):
                    field = name
                    values = list(reversed(clause[2].split(separator)))
                    for value in values:
                        domain.append((field, clause[1], value.strip()))
                        field = parent + '.' + field
                    if ((
                                clause[1].endswith('like')
                                and not clause[2].replace(
                                    '%%', '__').startswith('%'))
                            or not clause[1].endswith('like')):
                        if clause[1].startswith('not') or clause[1] == '!=':
                            operator = '!='
                            domain.insert(0, 'OR')
                        else:
                            operator = '='
                        top_parent = '.'.join((parent,) * len(values))
                        domain.append((top_parent, operator, None))
                    if (clause[1].endswith('like')
                            and clause[2].replace('%%', '__').endswith('%')):
                        ids = list(map(int, cls.search(domain, order=[])))
                        domain = [(parent, 'child_of', ids)]
                elif clause[2] is None:
                    domain.append((name, clause[1], clause[2]))
                else:
                    if clause[1].startswith('not'):
                        operator = '!='
                        domain.append('AND')
                    else:
                        operator = '='
                        domain.append('OR')
                    for value in clause[2]:
                        domain.append(cls.search_rec_name(
                                name, (clause[0], operator, value)))
                return domain

        @classmethod
        def validate_fields(cls, records, field_names):
            super().validate_fields(records, field_names)
            cls.check_recursion(records, field_names)

        @classmethod
        def check_recursion(cls, records, field_names=None):
            '''
            Function that checks if there is no recursion in the tree
            composed with parent as parent field name.
            '''
            if hasattr(super(TreeMixin, cls), 'check_recursion'):
                super(TreeMixin, cls).check_recursion(records, field_names)

            if field_names and parent not in field_names:
                return

            parent_type = cls._fields[parent]._type

            if parent_type not in ('many2one', 'many2many', 'one2one'):
                raise ValueError(
                    'Unsupported field type "%s" for field "%s" on "%s"'
                    % (parent_type, parent, cls.__name__))

            visited = set()

            for record in records:
                walked = set()
                walker = getattr(record, parent)
                while walker:
                    if parent_type == 'many2many':
                        for walk in walker:
                            walked.add(walk.id)
                            if walk.id == record.id:
                                parent_name = ', '.join(getattr(r, name)
                                    for r in getattr(record, parent))
                                raise RecursionError(
                                    gettext('ir.msg_recursion_error',
                                        rec_name=getattr(record, name),
                                        parent_rec_name=parent_name))
                        walker = list(chain(*(
                                    getattr(walk, parent)
                                    for walk in walker
                                    if walk.id not in visited)))
                    else:
                        walked.add(walker.id)
                        if walker.id == record.id:
                            parent_name = getattr(
                                getattr(record, parent), name)
                            raise RecursionError(
                                gettext('ir.msg_recursion_error',
                                    rec_name=getattr(record, name),
                                    parent_rec_name=parent_name))
                        walker = (getattr(walker, parent) not in visited
                            and getattr(walker, parent))
                visited.update(walked)

    return TreeMixin


def sum_tree(records, values, parent='parent'):
    "Sum up values following records tree"
    result = values.copy()
    parents = {
        int(r): int(getattr(r, parent)) for r in records if getattr(r, parent)}
    records = set(map(int, records))
    leafs = records - set(parents.values())
    while leafs:
        for leaf in leafs:
            records.remove(leaf)
            parent = parents.get(leaf)
            if parent:
                try:
                    result[parent] += result[leaf]
                except KeyError:
                    result[parent] = result[leaf]
        next_leafs = set(records)
        for record in records:
            parent = parents.get(record)
            if not parent:
                continue
            if parent in next_leafs and parent in records:
                next_leafs.remove(parent)
        leafs = next_leafs
    return result