File: ast_utils.py

package info (click to toggle)
opencv 4.10.0%2Bdfsg-5
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 282,092 kB
  • sloc: cpp: 1,178,079; xml: 682,621; python: 49,092; lisp: 31,150; java: 25,469; ansic: 11,039; javascript: 6,085; sh: 1,214; cs: 601; perl: 494; objc: 210; makefile: 173
file content (436 lines) | stat: -rw-r--r-- 17,346 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
from typing import (NamedTuple, Sequence, Tuple, Union, List,
                    Dict, Callable, Optional, Generator, cast)
import keyword

from .nodes import (ASTNode, NamespaceNode, ClassNode, FunctionNode,
                    EnumerationNode, ClassProperty, OptionalTypeNode,
                    TupleTypeNode)

from .types_conversion import create_type_node


class ScopeNotFoundError(Exception):
    pass


class SymbolNotFoundError(Exception):
    pass


class SymbolName(NamedTuple):
    namespaces: Tuple[str, ...]
    classes: Tuple[str, ...]
    name: str

    def __str__(self) -> str:
        return '(namespace="{}", classes="{}", name="{}")'.format(
            '::'.join(self.namespaces),
            '::'.join(self.classes),
            self.name
        )

    def __repr__(self) -> str:
        return str(self)

    @classmethod
    def parse(cls, full_symbol_name: str,
              known_namespaces: Sequence[str],
              symbol_parts_delimiter: str = '.') -> "SymbolName":
        """Performs contextual symbol name parsing into namespaces, classes
        and "bare" symbol name.

        Args:
            full_symbol_name (str): Input string to parse symbol name from.
            known_namespaces (Sequence[str]): Collection of namespace that was
                met during C++ headers parsing.
            symbol_parts_delimiter (str, optional): Delimiter string used to
                split `full_symbol_name` string into chunks. Defaults to '.'.

        Returns:
            SymbolName: Parsed symbol name structure.

        >>> SymbolName.parse('cv.ns.Feature', ('cv', 'cv.ns'))
        (namespace="cv::ns", classes="", name="Feature")

        >>> SymbolName.parse('cv.ns.Feature', ())
        (namespace="", classes="cv::ns", name="Feature")

        >>> SymbolName.parse('cv.ns.Feature.Params', ('cv', 'cv.ns'))
        (namespace="cv::ns", classes="Feature", name="Params")

        >>> SymbolName.parse('cv::ns::Feature::Params::serialize',
        ...                  known_namespaces=('cv', 'cv.ns'),
        ...                  symbol_parts_delimiter='::')
        (namespace="cv::ns", classes="Feature::Params", name="serialize")
        """

        chunks = full_symbol_name.split(symbol_parts_delimiter)
        namespaces, name = chunks[:-1], chunks[-1]
        classes: List[str] = []
        while len(namespaces) > 0 and '.'.join(namespaces) not in known_namespaces:
            classes.insert(0, namespaces.pop())
        return SymbolName(tuple(namespaces), tuple(classes), name)


def find_scope(root: NamespaceNode, symbol_name: SymbolName,
               create_missing_namespaces: bool = True) -> Union[NamespaceNode, ClassNode]:
    """Traverses down nodes hierarchy to the direct parent of the node referred
    by `symbol_name`.

    Args:
        root (NamespaceNode): Root node of the hierarchy.
        symbol_name (SymbolName): Full symbol name to find scope for.
        create_missing_namespaces (bool, optional): Set to True to create missing
            namespaces while traversing the hierarchy. Defaults to True.

    Raises:
        ScopeNotFoundError: If direct parent for the node referred by `symbol_name`
            can't be found e.g. one of classes doesn't exist.

    Returns:
        Union[NamespaceNode, ClassNode]: Direct parent for the node referred by
            `symbol_name`.

    >>> root = NamespaceNode('cv')
    >>> algorithm_node = root.add_class('Algorithm')
    >>> find_scope(root, SymbolName(('cv', ), ('Algorithm',), 'Params')) == algorithm_node
    True

    >>> root = NamespaceNode('cv')
    >>> scope = find_scope(root, SymbolName(('cv', 'gapi', 'detail'), (), 'function'))
    >>> scope.full_export_name
    'cv.gapi.detail'

    >>> root = NamespaceNode('cv')
    >>> scope = find_scope(root, SymbolName(('cv', 'gapi'), ('GOpaque',), 'function'))
    Traceback (most recent call last):
    ...
    ast_utils.ScopeNotFoundError: Can't find a scope for 'function', with \
'(namespace="cv::gapi", classes="GOpaque", name="function")', \
because 'GOpaque' class is not registered yet
    """
    assert isinstance(root, NamespaceNode), \
        'Wrong hierarchy root type: {}'.format(type(root))

    assert symbol_name.namespaces[0] == root.name, \
        "Trying to find scope for '{}' with root namespace different from: '{}'".format(
            symbol_name, root.name
    )

    scope: Union[NamespaceNode, ClassNode] = root
    for namespace in symbol_name.namespaces[1:]:
        if namespace not in scope.namespaces:  # type: ignore
            if not create_missing_namespaces:
                raise ScopeNotFoundError(
                    "Can't find a scope for '{}', with '{}', because namespace"
                    " '{}' is not created yet and `create_missing_namespaces`"
                    " flag is set to False".format(
                        symbol_name.name, symbol_name, namespace
                    )
                )
            scope = scope.add_namespace(namespace)  # type: ignore
        else:
            scope = scope.namespaces[namespace]  # type: ignore
    for class_name in symbol_name.classes:
        if class_name not in scope.classes:
            raise ScopeNotFoundError(
                "Can't find a scope for '{}', with '{}', because '{}' "
                "class is not registered yet".format(
                    symbol_name.name, symbol_name, class_name
                )
            )
        scope = scope.classes[class_name]
    return scope


def find_class_node(root: NamespaceNode, class_symbol: SymbolName,
                    create_missing_namespaces: bool = False) -> ClassNode:
    scope = find_scope(root, class_symbol, create_missing_namespaces)
    if class_symbol.name not in scope.classes:
        raise SymbolNotFoundError(
            "Can't find {} in its scope".format(class_symbol)
        )
    return scope.classes[class_symbol.name]


def find_function_node(root: NamespaceNode, function_symbol: SymbolName,
                       create_missing_namespaces: bool = False) -> FunctionNode:
    scope = find_scope(root, function_symbol, create_missing_namespaces)
    if function_symbol.name not in scope.functions:
        raise SymbolNotFoundError(
            "Can't find {} in its scope".format(function_symbol)
        )
    return scope.functions[function_symbol.name]


def create_function_node_in_scope(scope: Union[NamespaceNode, ClassNode],
                                  func_info) -> FunctionNode:
    def prepare_overload_arguments_and_return_type(variant):
        arguments = []  # type: list[FunctionNode.Arg]
        # Enumerate is requried, because `argno` in `variant.py_arglist`
        # refers to position of argument in C++ function interface,
        # but `variant.py_noptargs` refers to position in `py_arglist`
        for i, (_, argno) in enumerate(variant.py_arglist):
            arg_info = variant.args[argno]
            type_node = create_type_node(arg_info.tp)
            default_value = None
            if len(arg_info.defval):
                default_value = arg_info.defval
            # If argument is optional and can be None - make its type optional
            if variant.is_arg_optional(i):
                # NOTE: should UMat be always mandatory for better type hints?
                # otherwise overload won't be selected e.g. VideoCapture.read()
                if arg_info.py_outputarg:
                    type_node = OptionalTypeNode(type_node)
                    default_value = "None"
                elif arg_info.isbig() and "None" not in type_node.typename:
                    # but avoid duplication of the optioness
                    type_node = OptionalTypeNode(type_node)
            arguments.append(
                FunctionNode.Arg(arg_info.export_name, type_node=type_node,
                                 default_value=default_value)
            )
        if func_info.isconstructor:
            return arguments, None

        # Function has more than 1 output argument, so its return type is a tuple
        if len(variant.py_outlist) > 1:
            ret_types = []
            # Actual returned value of the function goes first
            if variant.py_outlist[0][1] == -1:
                ret_types.append(create_type_node(variant.rettype))
                outlist = variant.py_outlist[1:]
            else:
                outlist = variant.py_outlist
            for _, argno in outlist:
                assert argno >= 0, \
                    f"Logic Error! Outlist contains function return type: {outlist}"

                ret_types.append(create_type_node(variant.args[argno].tp))

            return arguments, FunctionNode.RetType(
                TupleTypeNode("return_type", ret_types)
            )
        # Function with 1 output argument in Python
        if len(variant.py_outlist) == 1:
            # Can be represented as a function with a non-void return type in C++
            if variant.rettype:
                return arguments, FunctionNode.RetType(
                    create_type_node(variant.rettype)
                )
            # or a function with void return type and output argument type
            # such non-const reference
            ret_type = variant.args[variant.py_outlist[0][1]].tp
            return arguments, FunctionNode.RetType(
                create_type_node(ret_type)
            )
        # Function without output types returns None in Python
        return arguments, None

    function_node = FunctionNode(func_info.name)
    function_node.parent = scope
    if func_info.isconstructor:
        function_node.export_name = "__init__"
    for variant in func_info.variants:
        arguments, ret_type = prepare_overload_arguments_and_return_type(variant)
        if isinstance(scope, ClassNode):
            if func_info.is_static:
                if ret_type is not None and ret_type.typename.endswith(scope.name):
                    function_node.is_classmethod = True
                    arguments.insert(0, FunctionNode.Arg("cls"))
                else:
                    function_node.is_static = True
            else:
                arguments.insert(0, FunctionNode.Arg("self"))
        function_node.add_overload(arguments, ret_type)
    return function_node


def create_function_node(root: NamespaceNode, func_info) -> FunctionNode:
    func_symbol_name = SymbolName(
        func_info.namespace.split(".") if len(func_info.namespace) else (),
        func_info.classname.split(".") if len(func_info.classname) else (),
        func_info.name
    )
    return create_function_node_in_scope(find_scope(root, func_symbol_name),
                                         func_info)


def create_class_node_in_scope(scope: Union[NamespaceNode, ClassNode],
                               symbol_name: SymbolName,
                               class_info) -> ClassNode:
    properties = []
    for property in class_info.props:
        export_property_name = property.name
        if keyword.iskeyword(export_property_name):
            export_property_name += "_"
        properties.append(
            ClassProperty(
                name=export_property_name,
                type_node=create_type_node(property.tp),
                is_readonly=property.readonly
            )
        )
    class_node = scope.add_class(symbol_name.name,
                                 properties=properties)
    class_node.export_name = class_info.export_name
    if class_info.constructor is not None:
        create_function_node_in_scope(class_node, class_info.constructor)
    for method in class_info.methods.values():
        create_function_node_in_scope(class_node, method)
    return class_node


def create_class_node(root: NamespaceNode, class_info,
                      namespaces: Sequence[str]) -> ClassNode:
    symbol_name = SymbolName.parse(class_info.full_original_name, namespaces)
    scope = find_scope(root, symbol_name)
    return create_class_node_in_scope(scope, symbol_name, class_info)


def resolve_enum_scopes(root: NamespaceNode,
                        enums: Dict[SymbolName, EnumerationNode]):
    """Attaches all enumeration nodes to the appropriate classes and modules

    If classes containing enumeration can't be found in the AST - they will
    be created and marked as not exportable. This behavior is required to cover
    cases, when enumeration is defined in base class, but only its derivatives
    are used. Example:
        ```cpp
        class CV_EXPORTS TermCriteria {
        public:
        enum Type { /* ... */ };
        // ...
        };
        ```

    Args:
        root (NamespaceNode): root of the reconstructed AST
        enums (Dict[SymbolName, EnumerationNode]): Mapping between enumerations
            symbol names and corresponding nodes without parents.
    """

    for symbol_name, enum_node in enums.items():
        if symbol_name.classes:
            try:
                scope = find_scope(root, symbol_name)
            except ScopeNotFoundError:
                # Scope can't be found if enumeration is a part of class
                # that is not exported.
                # Create class node, but mark it as not exported
                for i, class_name in enumerate(symbol_name.classes):
                    scope = find_scope(root,
                                       SymbolName(symbol_name.namespaces,
                                                  classes=symbol_name.classes[:i],
                                                  name=class_name))
                    if class_name in scope.classes:
                        continue
                    class_node = scope.add_class(class_name)
                    class_node.is_exported = False
                scope = find_scope(root, symbol_name)
        else:
            scope = find_scope(root, symbol_name)
        enum_node.parent = scope


def get_enclosing_namespace(
    node: ASTNode,
    class_node_callback: Optional[Callable[[ClassNode], None]] = None
) -> NamespaceNode:
    """Traverses up nodes hierarchy to find closest enclosing namespace of the
    passed node

    Args:
        node (ASTNode): Node to find a namespace for.
        class_node_callback (Optional[Callable[[ClassNode], None]]): Optional
            callable object invoked for each traversed class node in bottom-up
            order. Defaults: None.

    Returns:
        NamespaceNode: Closest enclosing namespace of the provided node.

    Raises:
        AssertionError: if nodes hierarchy missing a namespace node.

    >>> root = NamespaceNode('cv')
    >>> feature_class = root.add_class("Feature")
    >>> get_enclosing_namespace(feature_class) == root
    True

    >>> root = NamespaceNode('cv')
    >>> feature_class = root.add_class("Feature")
    >>> feature_params_class = feature_class.add_class("Params")
    >>> serialize_params_func = feature_params_class.add_function("serialize")
    >>> get_enclosing_namespace(serialize_params_func) == root
    True

    >>> root = NamespaceNode('cv')
    >>> detail_ns = root.add_namespace('detail')
    >>> flags_enum = detail_ns.add_enumeration('Flags')
    >>> get_enclosing_namespace(flags_enum) == detail_ns
    True
    """
    parent_node = node.parent
    while not isinstance(parent_node, NamespaceNode):
        assert parent_node is not None, \
            "Can't find enclosing namespace for '{}' known as: '{}'".format(
                node.full_export_name, node.native_name
            )
        if class_node_callback:
            class_node_callback(cast(ClassNode, parent_node))
        parent_node = parent_node.parent
    return parent_node


def get_enum_module_and_export_name(enum_node: EnumerationNode) -> Tuple[str, str]:
    """Get export name of the enum node with its module name.

    Note: Enumeration export names are prefixed with enclosing class names.

    Args:
        enum_node (EnumerationNode): Enumeration node to construct name for.

    Returns:
        Tuple[str, str]: a pair of enum export name and its full module name.
    """
    enum_export_name = enum_node.export_name

    def update_full_export_name(class_node: ClassNode) -> None:
        nonlocal enum_export_name
        enum_export_name = class_node.export_name + "_" + enum_export_name

    namespace_node = get_enclosing_namespace(enum_node,
                                             update_full_export_name)
    return enum_export_name, namespace_node.full_export_name


def for_each_class(
    node: Union[NamespaceNode, ClassNode]
) -> Generator[ClassNode, None, None]:
    for cls in node.classes.values():
        yield cls
        if len(cls.classes):
            yield from for_each_class(cls)


def for_each_function(
    node: Union[NamespaceNode, ClassNode],
    traverse_class_nodes: bool = True
) -> Generator[FunctionNode, None, None]:
    yield from node.functions.values()
    if traverse_class_nodes:
        for cls in for_each_class(node):
            yield from for_each_function(cls)


def for_each_function_overload(
    node: Union[NamespaceNode, ClassNode],
    traverse_class_nodes: bool = True
) -> Generator[FunctionNode.Overload, None, None]:
    for func in for_each_function(node, traverse_class_nodes):
        yield from func.overloads


if __name__ == '__main__':
    import doctest
    doctest.testmod()