1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
|
from unittest import TestCase
from textwrap import dedent
import gast as ast
import beniget
class Capture(ast.NodeVisitor):
def __init__(self, module_node):
self.chains = beniget.DefUseChains()
self.chains.visit(module_node)
self.users = set()
self.captured = set()
def visit_FunctionDef(self, node):
for def_ in self.chains.locals[node]:
self.users.update(use.node for use in def_.users())
self.generic_visit(node)
def visit_Name(self, node):
if isinstance(node.ctx, ast.Load):
if node not in self.users:
# FIXME: IRL, should be the definition of this use
self.captured.add(node.id)
class TestCapture(TestCase):
def checkCapture(self, code, extract, ref):
module = ast.parse(dedent(code))
c = Capture(module)
c.visit(extract(module))
self.assertEqual(c.captured, ref)
def test_simple_capture(self):
code = """
def foo(x):
def bar():
return x"""
self.checkCapture(code, lambda n: n.body[0].body[0], {"x"})
def test_no_capture(self):
code = """
def foo(x):
def bar(x):
return x"""
self.checkCapture(code, lambda n: n.body[0].body[0], set())
|