File: functions.py

package info (click to toggle)
brian 2.9.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 6,872 kB
  • sloc: python: 51,820; cpp: 2,033; makefile: 108; sh: 72
file content (282 lines) | stat: -rw-r--r-- 8,894 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
import ast
import inspect

from brian2.utils.stringtools import deindent, get_identifiers, indent

from .rendering import NodeRenderer

__all__ = [
    "AbstractCodeFunction",
    "abstract_code_from_function",
    "extract_abstract_code_functions",
    "substitute_abstract_code_functions",
]


class AbstractCodeFunction:
    """
    The information defining an abstract code function

    Has attributes corresponding to initialisation parameters

    Parameters
    ----------

    name : str
        The function name.
    args : list of str
        The arguments to the function.
    code : str
        The abstract code string consisting of the body of the function less
        the return statement.
    return_expr : str or None
        The expression returned, or None if there is nothing returned.
    """

    def __init__(self, name, args, code, return_expr):
        self.name = name
        self.args = args
        self.code = code
        self.return_expr = return_expr

    def __str__(self):
        s = (
            f"def {self.name}({', '.join(self.args)}):\n{indent(self.code)}\n    return"
            f" {self.return_expr}\n"
        )
        return s

    __repr__ = __str__


def abstract_code_from_function(func):
    """
    Converts the body of the function to abstract code

    Parameters
    ----------
    func : function, str or ast.FunctionDef
        The function object to convert. Note that the arguments to the
        function are ignored.

    Returns
    -------
    func : AbstractCodeFunction
        The corresponding abstract code function

    Raises
    ------
    SyntaxError
        If unsupported features are used such as if statements or indexing.
    """
    if callable(func):
        code = deindent(inspect.getsource(func))
        funcnode = ast.parse(code, mode="exec").body[0]
    elif isinstance(func, str):
        funcnode = ast.parse(func, mode="exec").body[0]
    elif func.__class__ is ast.FunctionDef:
        funcnode = func
    else:
        raise TypeError("Unsupported function type")

    if funcnode.args.vararg is not None:
        raise SyntaxError("No support for variable number of arguments")
    if funcnode.args.kwarg is not None:
        raise SyntaxError("No support for arbitrary keyword arguments")
    if len(funcnode.args.defaults):
        raise SyntaxError("No support for default values in functions")

    nodes = funcnode.body
    nr = NodeRenderer()
    lines = []
    return_expr = None
    for node in nodes:
        if node.__class__ is ast.Return:
            return_expr = nr.render_node(node.value)
            break
        else:
            lines.append(nr.render_node(node))
    abstract_code = "\n".join(lines)
    args = [arg.arg for arg in funcnode.args.args]
    name = funcnode.name
    return AbstractCodeFunction(name, args, abstract_code, return_expr)


def extract_abstract_code_functions(code):
    """
    Returns a set of abstract code functions from function definitions.

    Returns all functions defined at the top level and ignores any other
    code in the string.

    Parameters
    ----------
    code : str
        The code string defining some functions.

    Returns
    -------
    funcs : dict
        A mapping ``(name, func)`` for ``func`` an `AbstractCodeFunction`.
    """
    code = deindent(code)
    nodes = ast.parse(code, mode="exec").body
    funcs = {}
    for node in nodes:
        if node.__class__ is ast.FunctionDef:
            func = abstract_code_from_function(node)
            funcs[func.name] = func
    return funcs


class VarRewriter(ast.NodeTransformer):
    """
    Rewrites all variable names in names by prepending pre
    """

    def __init__(self, pre):
        self.pre = pre

    def visit_Name(self, node):
        return ast.Name(id=self.pre + node.id, ctx=node.ctx)

    def visit_Call(self, node):
        args = [self.visit(arg) for arg in node.args]
        return ast.Call(
            func=ast.Name(id=node.func.id, ctx=ast.Load()),
            args=args,
            keywords=[],
        )


class FunctionRewriter(ast.NodeTransformer):
    """
    Inlines a function call using temporary variables

    numcalls is the number of times the function rewriter has been called so
    far, this is used to make sure that when recursively inlining there is no
    name aliasing. The substitute_abstract_code_functions ensures that this is
    kept up to date between recursive runs.

    The pre attribute is the set of lines to be inserted above the currently
    being processed line, i.e. the inline code.

    The visit method returns the current line processed so that the function
    call is replaced with the output of the inlining.
    """

    def __init__(self, func, numcalls=0):
        self.func = func
        self.numcalls = numcalls
        self.pre = []
        self.suspend = False

    def visit_Call(self, node):
        # we suspend operations during an inlining operation, then resume
        # afterwards, see below, so we only ever try to expand one inline
        # function call at a time, i.e. no f(f(x)). This case is handled
        # by the recursion.
        if self.suspend:
            return node
        # We only work with the function we're provided
        if node.func.id != self.func.name:
            return node
        # Suspend while processing arguments (no recursion)
        self.suspend = True
        args = [self.visit(arg) for arg in node.args]
        self.suspend = False
        # The basename is used for function-local variables
        basename = f"_inline_{self.func.name}_{str(self.numcalls)}"
        # Assign all the function-local variables
        for argname, arg in zip(self.func.args, args):
            newpre = ast.Assign(
                targets=[ast.Name(id=f"{basename}_{argname}", ctx=ast.Store())],
                value=arg,
            )
            self.pre.append(newpre)
        # Rewrite the lines of code of the function using the names defined
        # above
        vr = VarRewriter(f"{basename}_")
        for funcline in ast.parse(self.func.code).body:
            self.pre.append(vr.visit(funcline))
        # And rewrite the return expression
        return_expr = vr.visit(ast.parse(self.func.return_expr, mode="eval").body)
        self.pre.append(
            ast.Assign(
                targets=[ast.Name(id=basename, ctx=ast.Store())], value=return_expr
            )
        )
        # Finally we replace the function call with the output of the inlining
        newnode = ast.Name(id=basename)
        self.numcalls += 1
        return newnode


def substitute_abstract_code_functions(code, funcs):
    """
    Performs inline substitution of all the functions in the code

    Parameters
    ----------
    code : str
        The abstract code to make inline substitutions into.
    funcs : list, dict or set of AbstractCodeFunction
        The function substitutions to use, note in the case of a dict, the
        keys are ignored and the function name is used.

    Returns
    -------
    code : str
        The code with inline substitutions performed.
    """
    if isinstance(funcs, (list, set)):
        newfuncs = dict()
        for f in funcs:
            newfuncs[f.name] = f
        funcs = newfuncs

    code = deindent(code)
    lines = ast.parse(code, mode="exec").body

    # This is a slightly nasty hack, but basically we just check by looking at
    # the existing identifiers how many inline operations have already been
    # performed by previous calls to this function
    ids = get_identifiers(code)
    funcstarts = {}
    for func in funcs.values():
        subids = {id for id in ids if id.startswith(f"_inline_{func.name}_")}
        subids = {id.replace(f"_inline_{func.name}_", "") for id in subids}
        alli = []
        for subid in subids:
            p = subid.find("_")
            if p > 0:
                subid = subid[:p]
            i = int(subid)
            alli.append(i)
        if len(alli) == 0:
            i = 0
        else:
            i = max(alli) + 1
        funcstarts[func.name] = i

    # Now we rewrite all the lines, replacing each line with a sequence of
    # lines performing the inlining
    newlines = []
    for line in lines:
        for func in funcs.values():
            rw = FunctionRewriter(func, funcstarts[func.name])
            line = rw.visit(line)
            newlines.extend(rw.pre)
            funcstarts[func.name] = rw.numcalls
        newlines.append(line)

    # Now we render to a code string
    nr = NodeRenderer()
    newcode = "\n".join(nr.render_node(line) for line in newlines)

    # We recurse until no changes in the code to ensure that all functions
    # are expanded if one function refers to another, etc.
    if newcode == code:
        return newcode
    else:
        return substitute_abstract_code_functions(newcode, funcs)