File: fix_annotate.py

package info (click to toggle)
pyannotate 1.2.0-2
  • links: PTS, VCS
  • area: main
  • in suites: bookworm, bullseye
  • size: 384 kB
  • sloc: python: 5,001; makefile: 5
file content (481 lines) | stat: -rw-r--r-- 17,648 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
"""Fixer that inserts mypy annotations into all methods.

This transforms e.g.

  def foo(self, bar, baz=12):
      return bar + baz

into a type annoted version:

	  def foo(self, bar, baz=12):
	      # type: (Any, int) -> Any            # noqa: F821
	      return bar + baz

or (when setting options['annotation_style'] to 'py3'):

	  def foo(self, bar : Any, baz : int = 12) -> Any:
	      return bar + baz


It does not do type inference but it recognizes some basic default
argument values such as numbers and strings (and assumes their type
implies the argument type).

It also uses some basic heuristics to decide whether to ignore the
first argument:

  - always if it's named 'self'
  - if there's a @classmethod decorator

Finally, it knows that __init__() is supposed to return None.
"""

from __future__ import print_function

import os
import re

from lib2to3.fixer_base import BaseFix
from lib2to3.fixer_util import syms, touch_import, find_indentation
from lib2to3.patcomp import compile_pattern
from lib2to3.pgen2 import token
from lib2to3.pytree import Leaf, Node


class FixAnnotate(BaseFix):

    # This fixer is compatible with the bottom matcher.
    BM_compatible = True

    # This fixer shouldn't run by default.
    explicit = True

    # The pattern to match.
    PATTERN = """
              funcdef< 'def' name=any parameters=parameters< '(' [args=any] rpar=')' > ':' suite=any+ >
              """

    _maxfixes = os.getenv('MAXFIXES')
    counter = None if not _maxfixes else int(_maxfixes)

    def transform(self, node, results):
        if FixAnnotate.counter is not None:
            if FixAnnotate.counter <= 0:
                return

        # Check if there's already a long-form annotation for some argument.
        parameters = results.get('parameters')
        if parameters is not None:
            for ch in parameters.pre_order():
                if ch.prefix.lstrip().startswith('# type:'):
                    return
        args = results.get('args')
        if args is not None:
            for ch in args.pre_order():
                if ch.prefix.lstrip().startswith('# type:'):
                    return

        children = results['suite'][0].children

        # NOTE: I've reverse-engineered the structure of the parse tree.
        # It's always a list of nodes, the first of which contains the
        # entire suite.  Its children seem to be:
        #
        #   [0] NEWLINE
        #   [1] INDENT
        #   [2...n-2] statements (the first may be a docstring)
        #   [n-1] DEDENT
        #
        # Comments before the suite are part of the INDENT's prefix.
        #
        # "Compact" functions (e.g. "def foo(x, y): return max(x, y)")
        # have a different structure (no NEWLINE, INDENT, or DEDENT).

        # Check if there's already an annotation.
        for ch in children:
            if ch.prefix.lstrip().startswith('# type:'):
                return  # There's already a # type: comment here; don't change anything.

        # Python 3 style return annotation are already skipped by the pattern

        ### Python 3 style argument annotation structure
        #
        # Structure of the arguments tokens for one positional argument without default value :
        # + LPAR '('
        # + NAME_NODE_OR_LEAF arg1
        # + RPAR ')'
        #
        # NAME_NODE_OR_LEAF is either:
        # 1. Just a leaf with value NAME
        # 2. A node with children: NAME, ':", node expr or value leaf
        #
        # Structure of the arguments tokens for one args with default value or multiple
        # args, with or without default value, and/or with extra arguments :
        # + LPAR '('
        # + node
        #   [
        #     + NAME_NODE_OR_LEAF
        #      [
        #        + EQUAL '='
        #        + node expr or value leaf
        #      ]
        #    (
        #        + COMMA ','
        #        + NAME_NODE_OR_LEAF positional argn
        #      [
        #        + EQUAL '='
        #        + node expr or value leaf
        #      ]
        #    )*
        #   ]
        #   [
        #     + STAR '*'
        #     [
        #     + NAME_NODE_OR_LEAF positional star argument name
        #     ]
        #   ]
        #   [
        #     + COMMA ','
        #     + DOUBLESTAR '**'
        #     + NAME_NODE_OR_LEAF positional keyword argument name
        #   ]
        # + RPAR ')'

        # Let's skip Python 3 argument annotations
        it = iter(args.children) if args else iter([])
        for ch in it:
            if ch.type == token.STAR:
                # *arg part
                ch = next(it)
                if ch.type == token.COMMA:
                    continue
            elif ch.type == token.DOUBLESTAR:
                # *arg part
                ch = next(it)
            if ch.type > 256:
                # this is a node, therefore an annotation
                assert ch.children[0].type == token.NAME
                return
            try:
                ch = next(it)
                if ch.type == token.COLON:
                    # this is an annotation
                    return
                elif ch.type == token.EQUAL:
                    ch = next(it)
                    ch = next(it)
                assert ch.type == token.COMMA
                continue
            except StopIteration:
                break

        # Compute the annotation
        annot = self.make_annotation(node, results)
        if annot is None:
            return
        argtypes, restype = annot

        if self.options['annotation_style'] == 'py3':
            self.add_py3_annot(argtypes, restype, node, results)
        else:
            self.add_py2_annot(argtypes, restype, node, results)

        # Common to py2 and py3 style annotations:
        if FixAnnotate.counter is not None:
            FixAnnotate.counter -= 1

        # Also add 'from typing import Any' at the top if needed.
        self.patch_imports(argtypes + [restype], node)

    def add_py3_annot(self, argtypes, restype, node, results):
        args = results.get('args')

        argleaves = []
        if args is None:
            # function with 0 arguments
            it = iter([])
        elif len(args.children) == 0:
            # function with 1 argument
            it = iter([args])
        else:
            # function with multiple arguments or 1 arg with default value
            it = iter(args.children)

        for ch in it:
            argstyle = 'name'
            if ch.type == token.STAR:
                # *arg part
                argstyle = 'star'
                ch = next(it)
                if ch.type == token.COMMA:
                    continue
            elif ch.type == token.DOUBLESTAR:
                # *arg part
                argstyle = 'keyword'
                ch = next(it)
            assert ch.type == token.NAME
            argleaves.append((argstyle, ch))
            try:
                ch = next(it)
                if ch.type == token.EQUAL:
                    ch = next(it)
                    ch = next(it)
                assert ch.type == token.COMMA
                continue
            except StopIteration:
                break

        # when self or cls is not annotated, argleaves == argtypes+1
        argleaves = argleaves[len(argleaves) - len(argtypes):]

        for ch_withstyle, chtype in zip(argleaves, argtypes):
            style, ch = ch_withstyle
            if style == 'star':
                assert chtype[0] == '*'
                assert chtype[1] != '*'
                chtype = chtype[1:]
            elif style == 'keyword':
                assert chtype[0:2] == '**'
                assert chtype[2] != '*'
                chtype = chtype[2:]
            ch.value = '%s: %s' % (ch.value, chtype)

            # put spaces around the equal sign
            if ch.next_sibling and ch.next_sibling.type == token.EQUAL:
                nextch = ch.next_sibling
                if not nextch.prefix[:1].isspace():
                    nextch.prefix = ' ' + nextch.prefix
                nextch = nextch.next_sibling
                assert nextch != None
                if not nextch.prefix[:1].isspace():
                    nextch.prefix = ' ' + nextch.prefix

        # Add return annotation
        rpar = results['rpar']
        rpar.value = '%s -> %s' % (rpar.value, restype)

        rpar.changed()

    def add_py2_annot(self, argtypes, restype, node, results):
        children = results['suite'][0].children

        # Insert '# type: {annot}' comment.
        # For reference, see lib2to3/fixes/fix_tuple_params.py in stdlib.
        if len(children) >= 1 and children[0].type != token.NEWLINE:
            # one liner function
            if children[0].prefix.strip() == '':
                children[0].prefix = ''
                children.insert(0, Leaf(token.NEWLINE, '\n'))
                children.insert(
                    1, Leaf(token.INDENT, find_indentation(node) + '    '))
                children.append(Leaf(token.DEDENT, ''))
        if len(children) >= 2 and children[1].type == token.INDENT:
            degen_str = '(...) -> %s' % restype
            short_str = '(%s) -> %s' % (', '.join(argtypes), restype)
            if (len(short_str) > 64 or len(argtypes) > 5) and len(short_str) > len(degen_str):
                self.insert_long_form(node, results, argtypes)
                annot_str = degen_str
            else:
                annot_str = short_str
            children[1].prefix = '%s# type: %s\n%s' % (children[1].value, annot_str,
                                                       children[1].prefix)
            children[1].changed()
        else:
            self.log_message("%s:%d: cannot insert annotation for one-line function" %
                             (self.filename, node.get_lineno()))

    def insert_long_form(self, node, results, argtypes):
        argtypes = list(argtypes)  # We destroy it
        args = results['args']
        if isinstance(args, Node):
            children = args.children
        elif isinstance(args, Leaf):
            children = [args]
        else:
            children = []
        # Interpret children according to the following grammar:
        # (('*'|'**')? NAME ['=' expr] ','?)*
        flag = False  # Set when the next leaf should get a type prefix
        indent = ''  # Will be set by the first child

        def set_prefix(child):
            if argtypes:
                arg = argtypes.pop(0).lstrip('*')
            else:
                arg = 'Any'  # Somehow there aren't enough args
            if not arg:
                # Skip self (look for 'check_self' below)
                prefix = child.prefix.rstrip()
            else:
                prefix = '  # type: ' + arg
                old_prefix = child.prefix.strip()
                if old_prefix:
                    assert old_prefix.startswith('#')
                    prefix += '  ' + old_prefix
            child.prefix = prefix + '\n' + indent

        check_self = self.is_method(node)
        for child in children:
            if check_self and isinstance(child, Leaf) and child.type == token.NAME:
                check_self = False
                if child.value in ('self', 'cls'):
                    argtypes.insert(0, '')
            if not indent:
                indent = ' ' * child.column
            if isinstance(child, Leaf) and child.value == ',':
                flag = True
            elif isinstance(child, Leaf) and flag:
                set_prefix(child)
                flag = False
        need_comma = len(children) >= 1 and children[-1].type != token.COMMA
        if need_comma and len(children) >= 2:
            if (children[-1].type == token.NAME and
                    (children[-2].type in (token.STAR, token.DOUBLESTAR))):
                need_comma = False
        if need_comma:
            children.append(Leaf(token.COMMA, u","))
        # Find the ')' and insert a prefix before it too.
        parameters = args.parent
        close_paren = parameters.children[-1]
        assert close_paren.type == token.RPAR, close_paren
        set_prefix(close_paren)
        assert not argtypes, argtypes

    def patch_imports(self, types, node):
        for typ in types:
            if 'Any' in typ:
                touch_import('typing', 'Any', node)
                break

    def make_annotation(self, node, results):
        name = results['name']
        assert isinstance(name, Leaf), repr(name)
        assert name.type == token.NAME, repr(name)
        decorators = self.get_decorators(node)
        is_method = self.is_method(node)
        if name.value == '__init__' or not self.has_return_exprs(node):
            restype = 'None'
        else:
            restype = 'Any'
        args = results.get('args')
        argtypes = []
        if isinstance(args, Node):
            children = args.children
        elif isinstance(args, Leaf):
            children = [args]
        else:
            children = []
        # Interpret children according to the following grammar:
        # (('*'|'**')? NAME ['=' expr] ','?)*
        stars = inferred_type = ''
        in_default = False
        at_start = True
        for child in children:
            if isinstance(child, Leaf):
                if child.value in ('*', '**'):
                    stars += child.value
                elif child.type == token.NAME and not in_default:
                    if not is_method or not at_start or 'staticmethod' in decorators:
                        inferred_type = 'Any'
                    else:
                        # Always skip the first argument if it's named 'self'.
                        # Always skip the first argument of a class method.
                        if child.value == 'self' or 'classmethod' in decorators:
                            pass
                        else:
                            inferred_type = 'Any'
                elif child.value == '=':
                    in_default = True
                elif in_default and child.value != ',':
                    if child.type == token.NUMBER:
                        if re.match(r'\d+[lL]?$', child.value):
                            inferred_type = 'int'
                        else:
                            inferred_type = 'float'  # TODO: complex?
                    elif child.type == token.STRING:
                        if child.value.startswith(('u', 'U')):
                            inferred_type = 'unicode'
                        else:
                            inferred_type = 'str'
                    elif child.type == token.NAME and child.value in ('True', 'False'):
                        inferred_type = 'bool'
                elif child.value == ',':
                    if inferred_type:
                        argtypes.append(stars + inferred_type)
                    # Reset
                    stars = inferred_type = ''
                    in_default = False
                    at_start = False
        if inferred_type:
            argtypes.append(stars + inferred_type)
        return argtypes, restype

    # The parse tree has a different shape when there is a single
    # decorator vs. when there are multiple decorators.
    DECORATED = "decorated< (d=decorator | decorators< dd=decorator+ >) funcdef >"
    decorated = compile_pattern(DECORATED)

    def get_decorators(self, node):
        """Return a list of decorators found on a function definition.

        This is a list of strings; only simple decorators
        (e.g. @staticmethod) are returned.

        If the function is undecorated or only non-simple decorators
        are found, return [].
        """
        if node.parent is None:
            return []
        results = {}
        if not self.decorated.match(node.parent, results):
            return []
        decorators = results.get('dd') or [results['d']]
        decs = []
        for d in decorators:
            for child in d.children:
                if isinstance(child, Leaf) and child.type == token.NAME:
                    decs.append(child.value)
        return decs

    def is_method(self, node):
        """Return whether the node occurs (directly) inside a class."""
        node = node.parent
        while node is not None:
            if node.type == syms.classdef:
                return True
            if node.type == syms.funcdef:
                return False
            node = node.parent
        return False

    RETURN_EXPR = "return_stmt< 'return' any >"
    return_expr = compile_pattern(RETURN_EXPR)

    def has_return_exprs(self, node):
        """Traverse the tree below node looking for 'return expr'.

        Return True if at least 'return expr' is found, False if not.
        (If both 'return' and 'return expr' are found, return True.)
        """
        results = {}
        if self.return_expr.match(node, results):
            return True
        for child in node.children:
            if child.type not in (syms.funcdef, syms.classdef):
                if self.has_return_exprs(child):
                    return True
        return False

    YIELD_EXPR = "yield_expr< 'yield' [any] >"
    yield_expr = compile_pattern(YIELD_EXPR)

    def is_generator(self, node):
        """Traverse the tree below node looking for 'yield [expr]'."""
        results = {}
        if self.yield_expr.match(node, results):
            return True
        for child in node.children:
            if child.type not in (syms.funcdef, syms.classdef):
                if self.is_generator(child):
                    return True
        return False