File: treeadapter.py

package info (click to toggle)
orange3 3.40.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 15,908 kB
  • sloc: python: 162,745; ansic: 622; makefile: 322; sh: 93; cpp: 77
file content (351 lines) | stat: -rw-r--r-- 7,551 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
"""Base tree adapter class with common methods needed for visualisations."""
from abc import ABCMeta, abstractmethod
from functools import reduce
from operator import add
import random


class BaseTreeAdapter(metaclass=ABCMeta):
    """Base class for tree representation.

    Any subclass should implement the methods listed in this base class. Note
    that some simple methods do not need to reimplemented e.g. is_leaf since
    it that is the opposite of has_children.

    """

    ROOT_PARENT = None
    NO_CHILD = -1
    FEATURE_UNDEFINED = -2

    def __init__(self, model):
        self.model = model
        self.domain = model.domain
        if model.instances is None:
            self.instances = self.instances_transformed = None
        else:
            self.instances = model.instances
            self.instances_transformed = self.instances.transform(self.domain)

    @abstractmethod
    def weight(self, node):
        """Get the weight of the given node.

        The weights of the children always sum up to 1.

        Parameters
        ----------
        node : object
            The label of the node.

        Returns
        -------
        float
            The weight of the node relative to its siblings.

        """

    @abstractmethod
    def num_samples(self, node):
        """Get the number of samples that a given node contains.

        Parameters
        ----------
        node : object
            A unique identifier of a node.

        Returns
        -------
        int

        """

    @abstractmethod
    def parent(self, node):
        """Get the parent of a given node or ROOT_PARENT if the node is the root.

        Parameters
        ----------
        node : object

        Returns
        -------
        object

        """

    @abstractmethod
    def has_children(self, node):
        """Check if the given node has any children.

        Parameters
        ----------
        node : object

        Returns
        -------
        bool

        """

    def is_leaf(self, node):
        """Check if the given node is a leaf node.

        Parameters
        ----------
        node : object

        Returns
        -------
        object

        """
        return not self.has_children(node)

    @abstractmethod
    def children(self, node):
        """Get all the children of a given node.

        Parameters
        ----------
        node : object

        Returns
        -------
        Iterable[object]
            A iterable object containing the labels of the child nodes.

        """

    def reverse_children(self, node):
        """Reverse children of a given node.

        Parameters
        ----------
        node : object
        """

    def shuffle_children(self):
        """Randomly shuffle node's children in the entire tree.
        """

    @abstractmethod
    def get_distribution(self, node):
        """Get the distribution of types for a given node.

        This may be the number of nodes that belong to each different classe in
        a node.

        Parameters
        ----------
        node : object

        Returns
        -------
        Iterable[int, ...]
            The return type is an iterable with as many fields as there are
            different classes in the given node. The values of the fields are
            the number of nodes that belong to a given class inside the node.

        """

    @abstractmethod
    def get_impurity(self, node):
        """Get the impurity of a given node.

        Parameters
        ----------
        node : object

        Returns
        -------
        object

        """

    @abstractmethod
    def rules(self, node):
        """Get a list of rules that define the given node.

        Parameters
        ----------
        node : object

        Returns
        -------
        Iterable[Rule]
            A list of Rule objects, can be of any type.

        """

    @abstractmethod
    def short_rule(self, node):
        pass

    @abstractmethod
    def attribute(self, node):
        """Get the attribute that splits the given tree.

        Parameters
        ----------
        node

        Returns
        -------

        """

    def is_root(self, node):
        """Check if a given node is the root node.

        Parameters
        ----------
        node

        Returns
        -------

        """
        return node == self.root

    @abstractmethod
    def leaves(self, node):
        """Get all the leavse that belong to the subtree of a given node.

        Parameters
        ----------
        node

        Returns
        -------

        """

    @abstractmethod
    def get_instances_in_nodes(self, dataset, nodes):
        """Get all the instances belonging to a set of nodes for a given
        dataset.

        Parameters
        ----------
        dataset : Table
            A Orange Table dataset.
        nodes : iterable[node]
            A list of tree nodes for which we want the instances.

        Returns
        -------

        """

    @abstractmethod
    def get_indices(self, nodes):
        pass

    @property
    @abstractmethod
    def max_depth(self):
        """Get the maximum depth that the tree reaches.

        Returns
        -------
        int

        """

    @property
    @abstractmethod
    def num_nodes(self):
        """Get the total number of nodes that the tree contains.

        This does not mean the number of samples inside the entire tree, just
        the number of nodes.

        Returns
        -------
        int

        """

    @property
    @abstractmethod
    def root(self):
        """Get the label of the root node.

        Returns
        -------
        object

        """


class TreeAdapter(BaseTreeAdapter):
    def weight(self, node):
        return len(node.subset) / len(node.parent.subset)

    def num_samples(self, node):
        return len(node.subset)

    def parent(self, node):
        return node.parent

    def has_children(self, node):
        return any(node.children)

    def is_leaf(self, node):
        return not any(node.children)

    def children(self, node):
        return [child for child in node.children if child is not None]

    def reverse_children(self, node):
        node.children = node.children[::-1]

    def shuffle_children(self):
        def _shuffle_children(node):
            if node and node.children:
                random.shuffle(node.children)
                for c in node.children:
                    _shuffle_children(c)
        _shuffle_children(self.root)

    def get_distribution(self, node):
        return [node.value]

    def get_impurity(self, node):
        raise NotImplementedError

    def rules(self, node):
        return self.model.rule(node)

    def short_rule(self, node):
        return node.description

    def attribute(self, node):
        return node.attr

    def leaves(self, node):
        def _leaves(node):
            return reduce(add, map(_leaves, self.children(node)), []) or [node]
        return _leaves(node)

    def get_instances_in_nodes(self, nodes):
        from Orange import tree
        if isinstance(nodes, tree.Node):
            nodes = [nodes]
        return self.model.get_instances(nodes)

    def get_indices(self, nodes):
        return self.model.get_indices(nodes)

    @property
    def max_depth(self):
        return self.model.depth()

    @property
    def num_nodes(self):
        return self.model.node_count()

    @property
    def root(self):
        return self.model.root