1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282
|
import ast
import inspect
from brian2.utils.stringtools import deindent, get_identifiers, indent
from .rendering import NodeRenderer
__all__ = [
"AbstractCodeFunction",
"abstract_code_from_function",
"extract_abstract_code_functions",
"substitute_abstract_code_functions",
]
class AbstractCodeFunction:
"""
The information defining an abstract code function
Has attributes corresponding to initialisation parameters
Parameters
----------
name : str
The function name.
args : list of str
The arguments to the function.
code : str
The abstract code string consisting of the body of the function less
the return statement.
return_expr : str or None
The expression returned, or None if there is nothing returned.
"""
def __init__(self, name, args, code, return_expr):
self.name = name
self.args = args
self.code = code
self.return_expr = return_expr
def __str__(self):
s = (
f"def {self.name}({', '.join(self.args)}):\n{indent(self.code)}\n return"
f" {self.return_expr}\n"
)
return s
__repr__ = __str__
def abstract_code_from_function(func):
"""
Converts the body of the function to abstract code
Parameters
----------
func : function, str or ast.FunctionDef
The function object to convert. Note that the arguments to the
function are ignored.
Returns
-------
func : AbstractCodeFunction
The corresponding abstract code function
Raises
------
SyntaxError
If unsupported features are used such as if statements or indexing.
"""
if callable(func):
code = deindent(inspect.getsource(func))
funcnode = ast.parse(code, mode="exec").body[0]
elif isinstance(func, str):
funcnode = ast.parse(func, mode="exec").body[0]
elif func.__class__ is ast.FunctionDef:
funcnode = func
else:
raise TypeError("Unsupported function type")
if funcnode.args.vararg is not None:
raise SyntaxError("No support for variable number of arguments")
if funcnode.args.kwarg is not None:
raise SyntaxError("No support for arbitrary keyword arguments")
if len(funcnode.args.defaults):
raise SyntaxError("No support for default values in functions")
nodes = funcnode.body
nr = NodeRenderer()
lines = []
return_expr = None
for node in nodes:
if node.__class__ is ast.Return:
return_expr = nr.render_node(node.value)
break
else:
lines.append(nr.render_node(node))
abstract_code = "\n".join(lines)
args = [arg.arg for arg in funcnode.args.args]
name = funcnode.name
return AbstractCodeFunction(name, args, abstract_code, return_expr)
def extract_abstract_code_functions(code):
"""
Returns a set of abstract code functions from function definitions.
Returns all functions defined at the top level and ignores any other
code in the string.
Parameters
----------
code : str
The code string defining some functions.
Returns
-------
funcs : dict
A mapping ``(name, func)`` for ``func`` an `AbstractCodeFunction`.
"""
code = deindent(code)
nodes = ast.parse(code, mode="exec").body
funcs = {}
for node in nodes:
if node.__class__ is ast.FunctionDef:
func = abstract_code_from_function(node)
funcs[func.name] = func
return funcs
class VarRewriter(ast.NodeTransformer):
"""
Rewrites all variable names in names by prepending pre
"""
def __init__(self, pre):
self.pre = pre
def visit_Name(self, node):
return ast.Name(id=self.pre + node.id, ctx=node.ctx)
def visit_Call(self, node):
args = [self.visit(arg) for arg in node.args]
return ast.Call(
func=ast.Name(id=node.func.id, ctx=ast.Load()),
args=args,
keywords=[],
)
class FunctionRewriter(ast.NodeTransformer):
"""
Inlines a function call using temporary variables
numcalls is the number of times the function rewriter has been called so
far, this is used to make sure that when recursively inlining there is no
name aliasing. The substitute_abstract_code_functions ensures that this is
kept up to date between recursive runs.
The pre attribute is the set of lines to be inserted above the currently
being processed line, i.e. the inline code.
The visit method returns the current line processed so that the function
call is replaced with the output of the inlining.
"""
def __init__(self, func, numcalls=0):
self.func = func
self.numcalls = numcalls
self.pre = []
self.suspend = False
def visit_Call(self, node):
# we suspend operations during an inlining operation, then resume
# afterwards, see below, so we only ever try to expand one inline
# function call at a time, i.e. no f(f(x)). This case is handled
# by the recursion.
if self.suspend:
return node
# We only work with the function we're provided
if node.func.id != self.func.name:
return node
# Suspend while processing arguments (no recursion)
self.suspend = True
args = [self.visit(arg) for arg in node.args]
self.suspend = False
# The basename is used for function-local variables
basename = f"_inline_{self.func.name}_{str(self.numcalls)}"
# Assign all the function-local variables
for argname, arg in zip(self.func.args, args):
newpre = ast.Assign(
targets=[ast.Name(id=f"{basename}_{argname}", ctx=ast.Store())],
value=arg,
)
self.pre.append(newpre)
# Rewrite the lines of code of the function using the names defined
# above
vr = VarRewriter(f"{basename}_")
for funcline in ast.parse(self.func.code).body:
self.pre.append(vr.visit(funcline))
# And rewrite the return expression
return_expr = vr.visit(ast.parse(self.func.return_expr, mode="eval").body)
self.pre.append(
ast.Assign(
targets=[ast.Name(id=basename, ctx=ast.Store())], value=return_expr
)
)
# Finally we replace the function call with the output of the inlining
newnode = ast.Name(id=basename)
self.numcalls += 1
return newnode
def substitute_abstract_code_functions(code, funcs):
"""
Performs inline substitution of all the functions in the code
Parameters
----------
code : str
The abstract code to make inline substitutions into.
funcs : list, dict or set of AbstractCodeFunction
The function substitutions to use, note in the case of a dict, the
keys are ignored and the function name is used.
Returns
-------
code : str
The code with inline substitutions performed.
"""
if isinstance(funcs, (list, set)):
newfuncs = dict()
for f in funcs:
newfuncs[f.name] = f
funcs = newfuncs
code = deindent(code)
lines = ast.parse(code, mode="exec").body
# This is a slightly nasty hack, but basically we just check by looking at
# the existing identifiers how many inline operations have already been
# performed by previous calls to this function
ids = get_identifiers(code)
funcstarts = {}
for func in funcs.values():
subids = {id for id in ids if id.startswith(f"_inline_{func.name}_")}
subids = {id.replace(f"_inline_{func.name}_", "") for id in subids}
alli = []
for subid in subids:
p = subid.find("_")
if p > 0:
subid = subid[:p]
i = int(subid)
alli.append(i)
if len(alli) == 0:
i = 0
else:
i = max(alli) + 1
funcstarts[func.name] = i
# Now we rewrite all the lines, replacing each line with a sequence of
# lines performing the inlining
newlines = []
for line in lines:
for func in funcs.values():
rw = FunctionRewriter(func, funcstarts[func.name])
line = rw.visit(line)
newlines.extend(rw.pre)
funcstarts[func.name] = rw.numcalls
newlines.append(line)
# Now we render to a code string
nr = NodeRenderer()
newcode = "\n".join(nr.render_node(line) for line in newlines)
# We recurse until no changes in the code to ensure that all functions
# are expanded if one function refers to another, etc.
if newcode == code:
return newcode
else:
return substitute_abstract_code_functions(newcode, funcs)
|