File: capture.py

package info (click to toggle)
python-beniget 0.4.1-3
  • links: PTS, VCS
  • area: main
  • in suites: bookworm, forky, sid, trixie
  • size: 188 kB
  • sloc: python: 1,590; sh: 7; makefile: 6
file content (45 lines) | stat: -rw-r--r-- 1,344 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from unittest import TestCase
from textwrap import dedent
import gast as ast
import beniget


class Capture(ast.NodeVisitor):
    def __init__(self, module_node):
        self.chains = beniget.DefUseChains()
        self.chains.visit(module_node)
        self.users = set()
        self.captured = set()

    def visit_FunctionDef(self, node):
        for def_ in self.chains.locals[node]:
            self.users.update(use.node for use in def_.users())
        self.generic_visit(node)

    def visit_Name(self, node):
        if isinstance(node.ctx, ast.Load):
            if node not in self.users:
                # FIXME: IRL, should be the definition of this use
                self.captured.add(node.id)


class TestCapture(TestCase):
    def checkCapture(self, code, extract, ref):
        module = ast.parse(dedent(code))
        c = Capture(module)
        c.visit(extract(module))
        self.assertEqual(c.captured, ref)

    def test_simple_capture(self):
        code = """
            def foo(x):
                def bar():
                    return x"""
        self.checkCapture(code, lambda n: n.body[0].body[0], {"x"})

    def test_no_capture(self):
        code = """
            def foo(x):
                def bar(x):
                    return x"""
        self.checkCapture(code, lambda n: n.body[0].body[0], set())